Skip to main content

codec_rs/
map_loader.rs

1// SPDX-License-Identifier: MIT
2//! Fetch, verify, and cache tokenizer maps.
3//!
4//! Gated behind the `http` feature (default-on). Pulls in `reqwest`
5//! with both blocking and async APIs.
6
7use std::sync::Arc;
8
9use sha2::{Digest, Sha256};
10
11use crate::frame::{MapCache, MemoryMapCache};
12use crate::map::{parse_hash, TokenizerMap, TokenizerMapError};
13
14/// Options for [`MapLoader::load_blocking`] / [`MapLoader::load`].
15#[derive(Clone, Default)]
16pub struct LoadOptions {
17    /// URL to fetch the map from.
18    pub url: String,
19    /// Optional sha256 hex digest to verify the fetched map against.
20    /// Accepts `sha256:<hex>` or bare `<hex>`. If omitted, no verification.
21    pub hash: Option<String>,
22    /// Pluggable cache. Defaults to a process-wide in-memory cache.
23    pub cache: Option<Arc<dyn MapCache>>,
24    /// Cache key. Defaults to `{url}#{hash}`.
25    pub cache_key: Option<String>,
26}
27
28/// Thrown when a fetched map doesn't match the expected hash.
29#[derive(Debug, thiserror::Error)]
30#[error("TokenizerMap hash mismatch.\n  expected: {expected}\n  actual:   {actual}")]
31pub struct TokenizerMapHashMismatchError {
32    pub expected: String,
33    pub actual: String,
34}
35
36/// Errors raised by [`MapLoader`].
37#[derive(Debug, thiserror::Error)]
38pub enum LoadError {
39    #[error("http error: {0}")]
40    Http(#[from] reqwest::Error),
41    #[error(transparent)]
42    HashMismatch(#[from] TokenizerMapHashMismatchError),
43    #[error(transparent)]
44    Map(#[from] TokenizerMapError),
45}
46
47/// Fetch, verify, and cache tokenizer maps.
48pub struct MapLoader;
49
50fn default_cache() -> Arc<dyn MapCache> {
51    use std::sync::OnceLock;
52    static CACHE: OnceLock<Arc<dyn MapCache>> = OnceLock::new();
53    CACHE.get_or_init(|| Arc::new(MemoryMapCache::new())).clone()
54}
55
56fn build_blocking_client() -> Result<reqwest::blocking::Client, reqwest::Error> {
57    reqwest::blocking::Client::builder()
58        .user_agent("codec-rs/0.1")
59        .gzip(true)
60        .brotli(true)
61        .build()
62}
63
64fn build_async_client() -> Result<reqwest::Client, reqwest::Error> {
65    reqwest::Client::builder()
66        .user_agent("codec-rs/0.1")
67        .gzip(true)
68        .brotli(true)
69        .build()
70}
71
72impl MapLoader {
73    /// Synchronous fetch + verify + cache.
74    pub fn load_blocking(opts: LoadOptions) -> Result<Arc<TokenizerMap>, LoadError> {
75        let cache = opts.cache.unwrap_or_else(default_cache);
76        let cache_key = opts
77            .cache_key
78            .unwrap_or_else(|| format!("{}#{}", opts.url, opts.hash.as_deref().unwrap_or("")));
79
80        if let Some(hit) = cache.get(&cache_key) {
81            return Ok(hit);
82        }
83
84        let client = build_blocking_client()?;
85        let bytes = client.get(&opts.url).send()?.error_for_status()?.bytes()?;
86
87        if let Some(expected) = &opts.hash {
88            let want = parse_hash(expected);
89            let actual = sha256_hex(&bytes);
90            if !actual.eq_ignore_ascii_case(&want) {
91                return Err(LoadError::HashMismatch(TokenizerMapHashMismatchError {
92                    expected: want,
93                    actual,
94                }));
95            }
96        }
97
98        let map = TokenizerMap::from_json(&bytes)?;
99        let arc = Arc::new(map);
100        cache.set(&cache_key, Arc::clone(&arc));
101        Ok(arc)
102    }
103
104    /// Async fetch + verify + cache. Requires a Tokio runtime.
105    pub async fn load(opts: LoadOptions) -> Result<Arc<TokenizerMap>, LoadError> {
106        let cache = opts.cache.unwrap_or_else(default_cache);
107        let cache_key = opts
108            .cache_key
109            .unwrap_or_else(|| format!("{}#{}", opts.url, opts.hash.as_deref().unwrap_or("")));
110
111        if let Some(hit) = cache.get(&cache_key) {
112            return Ok(hit);
113        }
114
115        let client = build_async_client()?;
116        let bytes = client
117            .get(&opts.url)
118            .send()
119            .await?
120            .error_for_status()?
121            .bytes()
122            .await?;
123
124        if let Some(expected) = &opts.hash {
125            let want = parse_hash(expected);
126            let actual = sha256_hex(&bytes);
127            if !actual.eq_ignore_ascii_case(&want) {
128                return Err(LoadError::HashMismatch(TokenizerMapHashMismatchError {
129                    expected: want,
130                    actual,
131                }));
132            }
133        }
134
135        let map = TokenizerMap::from_json(&bytes)?;
136        let arc = Arc::new(map);
137        cache.set(&cache_key, Arc::clone(&arc));
138        Ok(arc)
139    }
140
141    /// Verify-only helper exposed for tests / callers that fetched bytes
142    /// out-of-band (e.g. local file). Returns the map on success or a
143    /// hash-mismatch error.
144    pub fn verify_and_parse(
145        bytes: &[u8],
146        expected_hash: Option<&str>,
147    ) -> Result<TokenizerMap, LoadError> {
148        if let Some(expected) = expected_hash {
149            let want = parse_hash(expected);
150            let actual = sha256_hex(bytes);
151            if !actual.eq_ignore_ascii_case(&want) {
152                return Err(LoadError::HashMismatch(TokenizerMapHashMismatchError {
153                    expected: want,
154                    actual,
155                }));
156            }
157        }
158        Ok(TokenizerMap::from_json(bytes)?)
159    }
160}
161
162fn sha256_hex(bytes: &[u8]) -> String {
163    let mut hasher = Sha256::new();
164    hasher.update(bytes);
165    hex::encode(hasher.finalize())
166}