1use std::sync::Arc;
8
9use sha2::{Digest, Sha256};
10
11use crate::frame::{MapCache, MemoryMapCache};
12use crate::map::{parse_hash, TokenizerMap, TokenizerMapError};
13
14#[derive(Clone, Default)]
16pub struct LoadOptions {
17 pub url: String,
19 pub hash: Option<String>,
22 pub cache: Option<Arc<dyn MapCache>>,
24 pub cache_key: Option<String>,
26}
27
28#[derive(Debug, thiserror::Error)]
30#[error("TokenizerMap hash mismatch.\n expected: {expected}\n actual: {actual}")]
31pub struct TokenizerMapHashMismatchError {
32 pub expected: String,
33 pub actual: String,
34}
35
36#[derive(Debug, thiserror::Error)]
38pub enum LoadError {
39 #[error("http error: {0}")]
40 Http(#[from] reqwest::Error),
41 #[error(transparent)]
42 HashMismatch(#[from] TokenizerMapHashMismatchError),
43 #[error(transparent)]
44 Map(#[from] TokenizerMapError),
45}
46
47pub struct MapLoader;
49
50fn default_cache() -> Arc<dyn MapCache> {
51 use std::sync::OnceLock;
52 static CACHE: OnceLock<Arc<dyn MapCache>> = OnceLock::new();
53 CACHE.get_or_init(|| Arc::new(MemoryMapCache::new())).clone()
54}
55
56fn build_blocking_client() -> Result<reqwest::blocking::Client, reqwest::Error> {
57 reqwest::blocking::Client::builder()
58 .user_agent("codec-rs/0.1")
59 .gzip(true)
60 .brotli(true)
61 .build()
62}
63
64fn build_async_client() -> Result<reqwest::Client, reqwest::Error> {
65 reqwest::Client::builder()
66 .user_agent("codec-rs/0.1")
67 .gzip(true)
68 .brotli(true)
69 .build()
70}
71
72impl MapLoader {
73 pub fn load_blocking(opts: LoadOptions) -> Result<Arc<TokenizerMap>, LoadError> {
75 let cache = opts.cache.unwrap_or_else(default_cache);
76 let cache_key = opts
77 .cache_key
78 .unwrap_or_else(|| format!("{}#{}", opts.url, opts.hash.as_deref().unwrap_or("")));
79
80 if let Some(hit) = cache.get(&cache_key) {
81 return Ok(hit);
82 }
83
84 let client = build_blocking_client()?;
85 let bytes = client.get(&opts.url).send()?.error_for_status()?.bytes()?;
86
87 if let Some(expected) = &opts.hash {
88 let want = parse_hash(expected);
89 let actual = sha256_hex(&bytes);
90 if !actual.eq_ignore_ascii_case(&want) {
91 return Err(LoadError::HashMismatch(TokenizerMapHashMismatchError {
92 expected: want,
93 actual,
94 }));
95 }
96 }
97
98 let map = TokenizerMap::from_json(&bytes)?;
99 let arc = Arc::new(map);
100 cache.set(&cache_key, Arc::clone(&arc));
101 Ok(arc)
102 }
103
104 pub async fn load(opts: LoadOptions) -> Result<Arc<TokenizerMap>, LoadError> {
106 let cache = opts.cache.unwrap_or_else(default_cache);
107 let cache_key = opts
108 .cache_key
109 .unwrap_or_else(|| format!("{}#{}", opts.url, opts.hash.as_deref().unwrap_or("")));
110
111 if let Some(hit) = cache.get(&cache_key) {
112 return Ok(hit);
113 }
114
115 let client = build_async_client()?;
116 let bytes = client
117 .get(&opts.url)
118 .send()
119 .await?
120 .error_for_status()?
121 .bytes()
122 .await?;
123
124 if let Some(expected) = &opts.hash {
125 let want = parse_hash(expected);
126 let actual = sha256_hex(&bytes);
127 if !actual.eq_ignore_ascii_case(&want) {
128 return Err(LoadError::HashMismatch(TokenizerMapHashMismatchError {
129 expected: want,
130 actual,
131 }));
132 }
133 }
134
135 let map = TokenizerMap::from_json(&bytes)?;
136 let arc = Arc::new(map);
137 cache.set(&cache_key, Arc::clone(&arc));
138 Ok(arc)
139 }
140
141 pub fn verify_and_parse(
145 bytes: &[u8],
146 expected_hash: Option<&str>,
147 ) -> Result<TokenizerMap, LoadError> {
148 if let Some(expected) = expected_hash {
149 let want = parse_hash(expected);
150 let actual = sha256_hex(bytes);
151 if !actual.eq_ignore_ascii_case(&want) {
152 return Err(LoadError::HashMismatch(TokenizerMapHashMismatchError {
153 expected: want,
154 actual,
155 }));
156 }
157 }
158 Ok(TokenizerMap::from_json(bytes)?)
159 }
160}
161
162fn sha256_hex(bytes: &[u8]) -> String {
163 let mut hasher = Sha256::new();
164 hasher.update(bytes);
165 hex::encode(hasher.finalize())
166}