Skip to main content

sqlite_graphrag/
embedder.rs

1use crate::constants::{
2    EMBEDDING_DIM, EMBEDDING_MAX_TOKENS, FASTEMBED_BATCH_SIZE, PASSAGE_PREFIX, QUERY_PREFIX,
3    REMEMBER_MAX_CONTROLLED_BATCH_CHUNKS, REMEMBER_MAX_CONTROLLED_BATCH_PADDED_TOKENS,
4};
5use crate::errors::AppError;
6use fastembed::{EmbeddingModel, ExecutionProviderDispatch, TextEmbedding, TextInitOptions};
7use ort::execution_providers::CPU;
8use std::path::Path;
9use std::sync::{Mutex, OnceLock};
10
11static EMBEDDER: OnceLock<Mutex<TextEmbedding>> = OnceLock::new();
12
13/// Returns the process-wide singleton embedder, initializing it on first call.
14/// Subsequent calls return the cached instance regardless of `models_dir`.
15pub fn get_embedder(models_dir: &Path) -> Result<&'static Mutex<TextEmbedding>, AppError> {
16    if let Some(m) = EMBEDDER.get() {
17        return Ok(m);
18    }
19
20    maybe_init_dynamic_ort(models_dir)?;
21
22    // Desabilita arena allocator da EP CPU para reduzir retenção agressiva de memória
23    // entre inferências repetidas com shapes variáveis. O fastembed já desliga
24    // memory pattern em alguns cenários, mas não desliga a CPU arena por padrão.
25    let cpu_ep: ExecutionProviderDispatch = CPU::default().with_arena_allocator(false).build();
26
27    let model = TextEmbedding::try_new(
28        TextInitOptions::new(EmbeddingModel::MultilingualE5Small)
29            .with_execution_providers(vec![cpu_ep])
30            .with_max_length(EMBEDDING_MAX_TOKENS)
31            .with_show_download_progress(true)
32            .with_cache_dir(models_dir.to_path_buf()),
33    )
34    .map_err(|e| AppError::Embedding(e.to_string()))?;
35    // If another thread raced and won, discard our instance and return theirs.
36    let _ = EMBEDDER.set(Mutex::new(model));
37    Ok(EMBEDDER.get().expect("just set above"))
38}
39
40#[cfg(all(target_arch = "aarch64", target_os = "linux", target_env = "gnu"))]
41fn maybe_init_dynamic_ort(models_dir: &Path) -> Result<(), AppError> {
42    let mut candidates = Vec::new();
43
44    if let Ok(path) = std::env::var("ORT_DYLIB_PATH") {
45        if !path.is_empty() {
46            candidates.push(std::path::PathBuf::from(path));
47        }
48    }
49
50    if let Ok(exe) = std::env::current_exe() {
51        if let Some(dir) = exe.parent() {
52            candidates.push(dir.join("libonnxruntime.so"));
53            candidates.push(dir.join("lib").join("libonnxruntime.so"));
54        }
55    }
56
57    candidates.push(models_dir.join("libonnxruntime.so"));
58
59    for path in candidates {
60        if !path.exists() {
61            continue;
62        }
63
64        std::env::set_var("ORT_DYLIB_PATH", &path);
65        let _ = ort::init_from(&path)
66            .map_err(|e| AppError::Embedding(e.to_string()))?
67            .commit();
68        return Ok(());
69    }
70
71    Ok(())
72}
73
74#[cfg(not(all(target_arch = "aarch64", target_os = "linux", target_env = "gnu")))]
75fn maybe_init_dynamic_ort(_models_dir: &Path) -> Result<(), AppError> {
76    Ok(())
77}
78
79pub fn embed_passage(embedder: &Mutex<TextEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
80    let prefixed = format!("{PASSAGE_PREFIX}{text}");
81    let results = embedder
82        .lock()
83        .map_err(|e| AppError::Embedding(format!("lock poisoned: {e}")))?
84        .embed(vec![prefixed.as_str()], Some(1))
85        .map_err(|e| AppError::Embedding(e.to_string()))?;
86    let emb = results
87        .into_iter()
88        .next()
89        .ok_or_else(|| AppError::Embedding("empty embedding result".into()))?;
90    assert_eq!(emb.len(), EMBEDDING_DIM, "unexpected embedding dimension");
91    Ok(emb)
92}
93
94pub fn embed_query(embedder: &Mutex<TextEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
95    let prefixed = format!("{QUERY_PREFIX}{text}");
96    let results = embedder
97        .lock()
98        .map_err(|e| AppError::Embedding(format!("lock poisoned: {e}")))?
99        .embed(vec![prefixed.as_str()], Some(1))
100        .map_err(|e| AppError::Embedding(e.to_string()))?;
101    let emb = results
102        .into_iter()
103        .next()
104        .ok_or_else(|| AppError::Embedding("empty embedding result".into()))?;
105    Ok(emb)
106}
107
108pub fn embed_passages_batch(
109    embedder: &Mutex<TextEmbedding>,
110    texts: &[&str],
111    batch_size: usize,
112) -> Result<Vec<Vec<f32>>, AppError> {
113    let prefixed: Vec<String> = texts
114        .iter()
115        .map(|t| format!("{PASSAGE_PREFIX}{t}"))
116        .collect();
117    let strs: Vec<&str> = prefixed.iter().map(String::as_str).collect();
118    let results = embedder
119        .lock()
120        .map_err(|e| AppError::Embedding(format!("lock poisoned: {e}")))?
121        .embed(strs, Some(batch_size.min(FASTEMBED_BATCH_SIZE)))
122        .map_err(|e| AppError::Embedding(e.to_string()))?;
123    for emb in &results {
124        assert_eq!(emb.len(), EMBEDDING_DIM, "unexpected embedding dimension");
125    }
126    Ok(results)
127}
128
129pub fn controlled_batch_count(token_counts: &[usize]) -> usize {
130    plan_controlled_batches(token_counts).len()
131}
132
133pub fn embed_passages_controlled(
134    embedder: &Mutex<TextEmbedding>,
135    texts: &[&str],
136    token_counts: &[usize],
137) -> Result<Vec<Vec<f32>>, AppError> {
138    if texts.len() != token_counts.len() {
139        return Err(AppError::Internal(anyhow::anyhow!(
140            "texts/token_counts length mismatch in controlled embedding"
141        )));
142    }
143
144    let mut results = Vec::with_capacity(texts.len());
145    for (start, end) in plan_controlled_batches(token_counts) {
146        if end - start == 1 {
147            results.push(embed_passage(embedder, texts[start])?);
148            continue;
149        }
150
151        results.extend(embed_passages_batch(
152            embedder,
153            &texts[start..end],
154            end - start,
155        )?);
156    }
157
158    Ok(results)
159}
160
161/// Embed multiple passages serially.
162///
163/// This path intentionally avoids ONNX batch inference for robustness when
164/// real-world Markdown chunks trigger pathological runtime behavior.
165pub fn embed_passages_serial<'a, I>(
166    embedder: &Mutex<TextEmbedding>,
167    texts: I,
168) -> Result<Vec<Vec<f32>>, AppError>
169where
170    I: IntoIterator<Item = &'a str>,
171{
172    let iter = texts.into_iter();
173    let (lower, _) = iter.size_hint();
174    let mut results = Vec::with_capacity(lower);
175    for text in iter {
176        results.push(embed_passage(embedder, text)?);
177    }
178    Ok(results)
179}
180
181fn plan_controlled_batches(token_counts: &[usize]) -> Vec<(usize, usize)> {
182    let mut batches = Vec::new();
183    let mut start = 0usize;
184
185    while start < token_counts.len() {
186        let mut end = start + 1;
187        let mut max_tokens = token_counts[start].max(1);
188
189        while end < token_counts.len() && end - start < REMEMBER_MAX_CONTROLLED_BATCH_CHUNKS {
190            let candidate_max = max_tokens.max(token_counts[end].max(1));
191            let candidate_len = end + 1 - start;
192            if candidate_max * candidate_len > REMEMBER_MAX_CONTROLLED_BATCH_PADDED_TOKENS {
193                break;
194            }
195            max_tokens = candidate_max;
196            end += 1;
197        }
198
199        batches.push((start, end));
200        start = end;
201    }
202
203    batches
204}
205
206/// Convert &[f32] to &[u8] for sqlite-vec storage.
207/// # Safety
208/// Safe because f32 has no padding and is well-defined bit pattern.
209pub fn f32_to_bytes(v: &[f32]) -> &[u8] {
210    unsafe { std::slice::from_raw_parts(v.as_ptr() as *const u8, std::mem::size_of_val(v)) }
211}
212
213#[cfg(test)]
214mod testes {
215    use super::*;
216    use crate::constants::{EMBEDDING_DIM, PASSAGE_PREFIX, QUERY_PREFIX};
217
218    // --- testes de f32_to_bytes (função pura, sem modelo) ---
219
220    #[test]
221    fn f32_to_bytes_slice_vazio_retorna_vazio() {
222        let v: Vec<f32> = vec![];
223        assert_eq!(f32_to_bytes(&v), &[] as &[u8]);
224    }
225
226    #[test]
227    fn f32_to_bytes_um_elemento_retorna_4_bytes() {
228        let v = vec![1.0_f32];
229        let bytes = f32_to_bytes(&v);
230        assert_eq!(bytes.len(), 4);
231        // roundtrip: os 4 bytes devem reconstruir o f32 original
232        let recovered = f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
233        assert_eq!(recovered, 1.0_f32);
234    }
235
236    #[test]
237    fn f32_to_bytes_comprimento_e_4x_elementos() {
238        let v = vec![0.0_f32, 1.0, 2.0, 3.0];
239        assert_eq!(f32_to_bytes(&v).len(), v.len() * 4);
240    }
241
242    #[test]
243    fn f32_to_bytes_zero_codificado_como_4_zeros() {
244        let v = vec![0.0_f32];
245        assert_eq!(f32_to_bytes(&v), &[0u8, 0, 0, 0]);
246    }
247
248    #[test]
249    fn f32_to_bytes_roundtrip_vetor_embedding_dim() {
250        let v: Vec<f32> = (0..EMBEDDING_DIM).map(|i| i as f32 * 0.001).collect();
251        let bytes = f32_to_bytes(&v);
252        assert_eq!(bytes.len(), EMBEDDING_DIM * 4);
253        // reconstrói e compara primeiro e último elemento
254        let first = f32::from_le_bytes(bytes[0..4].try_into().unwrap());
255        assert!((first - 0.0_f32).abs() < 1e-6);
256        let last_start = (EMBEDDING_DIM - 1) * 4;
257        let last = f32::from_le_bytes(bytes[last_start..last_start + 4].try_into().unwrap());
258        assert!((last - (EMBEDDING_DIM - 1) as f32 * 0.001).abs() < 1e-4);
259    }
260
261    // --- verifica prefixos usados pelo embedder (sem modelo) ---
262
263    #[test]
264    fn passage_prefix_nao_vazio() {
265        assert_eq!(PASSAGE_PREFIX, "passage: ");
266    }
267
268    #[test]
269    fn query_prefix_nao_vazio() {
270        assert_eq!(QUERY_PREFIX, "query: ");
271    }
272
273    #[test]
274    fn embedding_dim_e_384() {
275        assert_eq!(EMBEDDING_DIM, 384);
276    }
277
278    // --- testes com modelo real (ignorados no CI normal) ---
279
280    #[test]
281    #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
282    fn embed_passage_retorna_vetor_com_dimensao_correta() {
283        let dir = tempfile::tempdir().unwrap();
284        let embedder = get_embedder(dir.path()).unwrap();
285        let result = embed_passage(embedder, "texto de teste").unwrap();
286        assert_eq!(result.len(), EMBEDDING_DIM);
287    }
288
289    #[test]
290    #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
291    fn embed_query_retorna_vetor_com_dimensao_correta() {
292        let dir = tempfile::tempdir().unwrap();
293        let embedder = get_embedder(dir.path()).unwrap();
294        let result = embed_query(embedder, "consulta de teste").unwrap();
295        assert_eq!(result.len(), EMBEDDING_DIM);
296    }
297
298    #[test]
299    #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
300    fn embed_passages_batch_retorna_um_vetor_por_texto() {
301        let dir = tempfile::tempdir().unwrap();
302        let embedder = get_embedder(dir.path()).unwrap();
303        let textos = ["primeiro", "segundo"];
304        let results = embed_passages_batch(embedder, &textos, 2).unwrap();
305        assert_eq!(results.len(), 2);
306        for emb in &results {
307            assert_eq!(emb.len(), EMBEDDING_DIM);
308        }
309    }
310
311    #[test]
312    fn controlled_batch_plan_respeita_orcamento() {
313        assert_eq!(
314            plan_controlled_batches(&[100, 100, 100, 100, 300, 300]),
315            vec![(0, 4), (4, 5), (5, 6)]
316        );
317    }
318
319    #[test]
320    fn controlled_batch_count_retorna_um_para_chunk_unico() {
321        assert_eq!(controlled_batch_count(&[350]), 1);
322    }
323}