sqlite_graphrag/
tokenizer.rs1use crate::constants::PASSAGE_PREFIX;
7use crate::errors::AppError;
8use fastembed::{EmbeddingModel, TextEmbedding};
9use huggingface_hub::api::sync::ApiBuilder;
10use std::path::{Path, PathBuf};
11use std::sync::OnceLock;
12use tokenizers::Tokenizer;
13
14struct TokenizerRuntime {
15 tokenizer: Tokenizer,
16 model_max_length: usize,
17}
18
19static TOKENIZER_RUNTIME: OnceLock<TokenizerRuntime> = OnceLock::new();
20
21pub fn get_tokenizer(models_dir: &Path) -> Result<&'static Tokenizer, AppError> {
26 Ok(&get_runtime(models_dir)?.tokenizer)
27}
28
29pub fn get_model_max_length(models_dir: &Path) -> Result<usize, AppError> {
34 Ok(get_runtime(models_dir)?.model_max_length)
35}
36
37pub fn count_passage_tokens(tokenizer: &Tokenizer, text: &str) -> Result<usize, AppError> {
45 let prefixed = format!("{PASSAGE_PREFIX}{text}");
46 count_tokens(tokenizer, &prefixed)
47}
48
49pub fn passage_token_offsets(
57 tokenizer: &Tokenizer,
58 text: &str,
59) -> Result<Vec<(usize, usize)>, AppError> {
60 let prefixed = format!("{PASSAGE_PREFIX}{text}");
61 let prefix_len = PASSAGE_PREFIX.len();
62 let encoding = tokenizer
63 .encode(prefixed, true)
64 .map_err(|e| AppError::Embedding(e.to_string()))?;
65
66 let mut offsets = Vec::new();
67 for &(start, end) in encoding.get_offsets() {
68 if end <= start || end <= prefix_len {
69 continue;
70 }
71
72 let adjusted_start = start.saturating_sub(prefix_len).min(text.len());
73 let adjusted_end = end.saturating_sub(prefix_len).min(text.len());
74
75 if adjusted_end > adjusted_start
76 && text.is_char_boundary(adjusted_start)
77 && text.is_char_boundary(adjusted_end)
78 {
79 offsets.push((adjusted_start, adjusted_end));
80 }
81 }
82
83 if offsets.is_empty() && !text.is_empty() {
84 offsets.push((0, text.len()));
85 }
86
87 Ok(offsets)
88}
89
90fn count_tokens(tokenizer: &Tokenizer, text: &str) -> Result<usize, AppError> {
91 let encoding = tokenizer
92 .encode(text, true)
93 .map_err(|e| AppError::Embedding(e.to_string()))?;
94 Ok(encoding.len())
95}
96
97fn get_runtime(models_dir: &Path) -> Result<&'static TokenizerRuntime, AppError> {
98 if let Some(runtime) = TOKENIZER_RUNTIME.get() {
99 return Ok(runtime);
100 }
101
102 let runtime = load_runtime(models_dir)?;
103 let _ = TOKENIZER_RUNTIME.set(runtime);
104 Ok(TOKENIZER_RUNTIME
105 .get()
106 .expect("tokenizer runtime just initialized"))
107}
108
109fn load_runtime(models_dir: &Path) -> Result<TokenizerRuntime, AppError> {
110 let model_info = TextEmbedding::get_model_info(&EmbeddingModel::MultilingualE5Small)
111 .map_err(|e| AppError::Embedding(e.to_string()))?;
112
113 let cache_dir = std::env::var("HF_HOME")
114 .map(PathBuf::from)
115 .unwrap_or_else(|_| models_dir.to_path_buf());
116 let endpoint =
117 std::env::var("HF_ENDPOINT").unwrap_or_else(|_| "https://huggingface.co".to_string());
118
119 let api = ApiBuilder::new()
120 .with_cache_dir(cache_dir)
121 .with_endpoint(endpoint)
122 .with_progress(false)
123 .build()
124 .map_err(|e| AppError::Embedding(e.to_string()))?;
125 let repo = api.model(model_info.model_code.clone());
126
127 let tokenizer_bytes =
128 std::fs::read(repo.get("tokenizer.json").map_err(map_hf_err)?).map_err(AppError::Io)?;
129 let tokenizer_config_bytes =
130 std::fs::read(repo.get("tokenizer_config.json").map_err(map_hf_err)?)
131 .map_err(AppError::Io)?;
132
133 let tokenizer =
134 Tokenizer::from_bytes(tokenizer_bytes).map_err(|e| AppError::Embedding(e.to_string()))?;
135 let tokenizer_config: serde_json::Value =
136 serde_json::from_slice(&tokenizer_config_bytes).map_err(AppError::Json)?;
137 let model_max_length = tokenizer_config["model_max_length"]
138 .as_u64()
139 .map(|n| n as usize)
140 .or_else(|| {
141 tokenizer_config["model_max_length"]
142 .as_f64()
143 .map(|n| n as usize)
144 })
145 .ok_or_else(|| {
146 AppError::Embedding("tokenizer_config.json missing model_max_length field".into())
147 })?;
148
149 Ok(TokenizerRuntime {
150 tokenizer,
151 model_max_length,
152 })
153}
154
155fn map_hf_err(err: huggingface_hub::api::sync::ApiError) -> AppError {
156 AppError::Embedding(err.to_string())
157}