Skip to main content

sqlite_graphrag/
tokenizer.rs

1use crate::constants::PASSAGE_PREFIX;
2use crate::errors::AppError;
3use fastembed::{EmbeddingModel, TextEmbedding};
4use huggingface_hub::api::sync::ApiBuilder;
5use std::path::{Path, PathBuf};
6use std::sync::OnceLock;
7use tokenizers::Tokenizer;
8
9struct TokenizerRuntime {
10    tokenizer: Tokenizer,
11    model_max_length: usize,
12}
13
14static TOKENIZER_RUNTIME: OnceLock<TokenizerRuntime> = OnceLock::new();
15
16pub fn get_tokenizer(models_dir: &Path) -> Result<&'static Tokenizer, AppError> {
17    Ok(&get_runtime(models_dir)?.tokenizer)
18}
19
20pub fn get_model_max_length(models_dir: &Path) -> Result<usize, AppError> {
21    Ok(get_runtime(models_dir)?.model_max_length)
22}
23
24pub fn count_passage_tokens(tokenizer: &Tokenizer, text: &str) -> Result<usize, AppError> {
25    let prefixed = format!("{PASSAGE_PREFIX}{text}");
26    count_tokens(tokenizer, &prefixed)
27}
28
29pub fn passage_token_offsets(
30    tokenizer: &Tokenizer,
31    text: &str,
32) -> Result<Vec<(usize, usize)>, AppError> {
33    let prefixed = format!("{PASSAGE_PREFIX}{text}");
34    let prefix_len = PASSAGE_PREFIX.len();
35    let encoding = tokenizer
36        .encode(prefixed, true)
37        .map_err(|e| AppError::Embedding(e.to_string()))?;
38
39    let mut offsets = Vec::new();
40    for &(start, end) in encoding.get_offsets() {
41        if end <= start || end <= prefix_len {
42            continue;
43        }
44
45        let adjusted_start = start.saturating_sub(prefix_len).min(text.len());
46        let adjusted_end = end.saturating_sub(prefix_len).min(text.len());
47
48        if adjusted_end > adjusted_start
49            && text.is_char_boundary(adjusted_start)
50            && text.is_char_boundary(adjusted_end)
51        {
52            offsets.push((adjusted_start, adjusted_end));
53        }
54    }
55
56    if offsets.is_empty() && !text.is_empty() {
57        offsets.push((0, text.len()));
58    }
59
60    Ok(offsets)
61}
62
63fn count_tokens(tokenizer: &Tokenizer, text: &str) -> Result<usize, AppError> {
64    let encoding = tokenizer
65        .encode(text, true)
66        .map_err(|e| AppError::Embedding(e.to_string()))?;
67    Ok(encoding.len())
68}
69
70fn get_runtime(models_dir: &Path) -> Result<&'static TokenizerRuntime, AppError> {
71    if let Some(runtime) = TOKENIZER_RUNTIME.get() {
72        return Ok(runtime);
73    }
74
75    let runtime = load_runtime(models_dir)?;
76    let _ = TOKENIZER_RUNTIME.set(runtime);
77    Ok(TOKENIZER_RUNTIME
78        .get()
79        .expect("tokenizer runtime just initialized"))
80}
81
82fn load_runtime(models_dir: &Path) -> Result<TokenizerRuntime, AppError> {
83    let model_info = TextEmbedding::get_model_info(&EmbeddingModel::MultilingualE5Small)
84        .map_err(|e| AppError::Embedding(e.to_string()))?;
85
86    let cache_dir = std::env::var("HF_HOME")
87        .map(PathBuf::from)
88        .unwrap_or_else(|_| models_dir.to_path_buf());
89    let endpoint =
90        std::env::var("HF_ENDPOINT").unwrap_or_else(|_| "https://huggingface.co".to_string());
91
92    let api = ApiBuilder::new()
93        .with_cache_dir(cache_dir)
94        .with_endpoint(endpoint)
95        .with_progress(false)
96        .build()
97        .map_err(|e| AppError::Embedding(e.to_string()))?;
98    let repo = api.model(model_info.model_code.clone());
99
100    let tokenizer_bytes =
101        std::fs::read(repo.get("tokenizer.json").map_err(map_hf_err)?).map_err(AppError::Io)?;
102    let tokenizer_config_bytes =
103        std::fs::read(repo.get("tokenizer_config.json").map_err(map_hf_err)?)
104            .map_err(AppError::Io)?;
105
106    let tokenizer =
107        Tokenizer::from_bytes(tokenizer_bytes).map_err(|e| AppError::Embedding(e.to_string()))?;
108    let tokenizer_config: serde_json::Value =
109        serde_json::from_slice(&tokenizer_config_bytes).map_err(AppError::Json)?;
110    let model_max_length = tokenizer_config["model_max_length"]
111        .as_u64()
112        .map(|n| n as usize)
113        .or_else(|| {
114            tokenizer_config["model_max_length"]
115                .as_f64()
116                .map(|n| n as usize)
117        })
118        .ok_or_else(|| AppError::Embedding("tokenizer_config.json sem model_max_length".into()))?;
119
120    Ok(TokenizerRuntime {
121        tokenizer,
122        model_max_length,
123    })
124}
125
126fn map_hf_err(err: huggingface_hub::api::sync::ApiError) -> AppError {
127    AppError::Embedding(err.to_string())
128}