Skip to main content

sqlite_graphrag/
embedder.rs

1use crate::constants::{EMBEDDING_DIM, FASTEMBED_BATCH_SIZE, PASSAGE_PREFIX, QUERY_PREFIX};
2use crate::errors::AppError;
3use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
4use std::path::Path;
5use std::sync::{Mutex, OnceLock};
6
7static EMBEDDER: OnceLock<Mutex<TextEmbedding>> = OnceLock::new();
8
9/// Returns the process-wide singleton embedder, initializing it on first call.
10/// Subsequent calls return the cached instance regardless of `models_dir`.
11pub fn get_embedder(models_dir: &Path) -> Result<&'static Mutex<TextEmbedding>, AppError> {
12    if let Some(m) = EMBEDDER.get() {
13        return Ok(m);
14    }
15    let model = TextEmbedding::try_new(
16        InitOptions::new(EmbeddingModel::MultilingualE5Small)
17            .with_show_download_progress(true)
18            .with_cache_dir(models_dir.to_path_buf()),
19    )
20    .map_err(|e| AppError::Embedding(e.to_string()))?;
21    // If another thread raced and won, discard our instance and return theirs.
22    let _ = EMBEDDER.set(Mutex::new(model));
23    Ok(EMBEDDER.get().expect("just set above"))
24}
25
26pub fn embed_passage(embedder: &Mutex<TextEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
27    let prefixed = format!("{PASSAGE_PREFIX}{text}");
28    let results = embedder
29        .lock()
30        .map_err(|e| AppError::Embedding(format!("lock poisoned: {e}")))?
31        .embed(vec![prefixed.as_str()], Some(1))
32        .map_err(|e| AppError::Embedding(e.to_string()))?;
33    let emb = results
34        .into_iter()
35        .next()
36        .ok_or_else(|| AppError::Embedding("empty embedding result".into()))?;
37    assert_eq!(emb.len(), EMBEDDING_DIM, "unexpected embedding dimension");
38    Ok(emb)
39}
40
41pub fn embed_query(embedder: &Mutex<TextEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
42    let prefixed = format!("{QUERY_PREFIX}{text}");
43    let results = embedder
44        .lock()
45        .map_err(|e| AppError::Embedding(format!("lock poisoned: {e}")))?
46        .embed(vec![prefixed.as_str()], Some(1))
47        .map_err(|e| AppError::Embedding(e.to_string()))?;
48    let emb = results
49        .into_iter()
50        .next()
51        .ok_or_else(|| AppError::Embedding("empty embedding result".into()))?;
52    Ok(emb)
53}
54
55pub fn embed_passages_batch(
56    embedder: &Mutex<TextEmbedding>,
57    texts: &[String],
58) -> Result<Vec<Vec<f32>>, AppError> {
59    let prefixed: Vec<String> = texts
60        .iter()
61        .map(|t| format!("{PASSAGE_PREFIX}{t}"))
62        .collect();
63    let strs: Vec<&str> = prefixed.iter().map(String::as_str).collect();
64    let results = embedder
65        .lock()
66        .map_err(|e| AppError::Embedding(format!("lock poisoned: {e}")))?
67        .embed(strs, Some(FASTEMBED_BATCH_SIZE))
68        .map_err(|e| AppError::Embedding(e.to_string()))?;
69    for emb in &results {
70        assert_eq!(emb.len(), EMBEDDING_DIM, "unexpected embedding dimension");
71    }
72    Ok(results)
73}
74
75/// Embed multiple passages serially.
76///
77/// This path intentionally avoids ONNX batch inference for robustness when
78/// real-world Markdown chunks trigger pathological runtime behavior.
79pub fn embed_passages_serial<'a, I>(
80    embedder: &Mutex<TextEmbedding>,
81    texts: I,
82) -> Result<Vec<Vec<f32>>, AppError>
83where
84    I: IntoIterator<Item = &'a str>,
85{
86    let iter = texts.into_iter();
87    let (lower, _) = iter.size_hint();
88    let mut results = Vec::with_capacity(lower);
89    for text in iter {
90        results.push(embed_passage(embedder, text)?);
91    }
92    Ok(results)
93}
94
95/// Convert &[f32] to &[u8] for sqlite-vec storage.
96/// # Safety
97/// Safe because f32 has no padding and is well-defined bit pattern.
98pub fn f32_to_bytes(v: &[f32]) -> &[u8] {
99    unsafe { std::slice::from_raw_parts(v.as_ptr() as *const u8, std::mem::size_of_val(v)) }
100}
101
102#[cfg(test)]
103mod testes {
104    use super::*;
105    use crate::constants::{EMBEDDING_DIM, PASSAGE_PREFIX, QUERY_PREFIX};
106
107    // --- testes de f32_to_bytes (função pura, sem modelo) ---
108
109    #[test]
110    fn f32_to_bytes_slice_vazio_retorna_vazio() {
111        let v: Vec<f32> = vec![];
112        assert_eq!(f32_to_bytes(&v), &[] as &[u8]);
113    }
114
115    #[test]
116    fn f32_to_bytes_um_elemento_retorna_4_bytes() {
117        let v = vec![1.0_f32];
118        let bytes = f32_to_bytes(&v);
119        assert_eq!(bytes.len(), 4);
120        // roundtrip: os 4 bytes devem reconstruir o f32 original
121        let recovered = f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
122        assert_eq!(recovered, 1.0_f32);
123    }
124
125    #[test]
126    fn f32_to_bytes_comprimento_e_4x_elementos() {
127        let v = vec![0.0_f32, 1.0, 2.0, 3.0];
128        assert_eq!(f32_to_bytes(&v).len(), v.len() * 4);
129    }
130
131    #[test]
132    fn f32_to_bytes_zero_codificado_como_4_zeros() {
133        let v = vec![0.0_f32];
134        assert_eq!(f32_to_bytes(&v), &[0u8, 0, 0, 0]);
135    }
136
137    #[test]
138    fn f32_to_bytes_roundtrip_vetor_embedding_dim() {
139        let v: Vec<f32> = (0..EMBEDDING_DIM).map(|i| i as f32 * 0.001).collect();
140        let bytes = f32_to_bytes(&v);
141        assert_eq!(bytes.len(), EMBEDDING_DIM * 4);
142        // reconstrói e compara primeiro e último elemento
143        let first = f32::from_le_bytes(bytes[0..4].try_into().unwrap());
144        assert!((first - 0.0_f32).abs() < 1e-6);
145        let last_start = (EMBEDDING_DIM - 1) * 4;
146        let last = f32::from_le_bytes(bytes[last_start..last_start + 4].try_into().unwrap());
147        assert!((last - (EMBEDDING_DIM - 1) as f32 * 0.001).abs() < 1e-4);
148    }
149
150    // --- verifica prefixos usados pelo embedder (sem modelo) ---
151
152    #[test]
153    fn passage_prefix_nao_vazio() {
154        assert_eq!(PASSAGE_PREFIX, "passage: ");
155    }
156
157    #[test]
158    fn query_prefix_nao_vazio() {
159        assert_eq!(QUERY_PREFIX, "query: ");
160    }
161
162    #[test]
163    fn embedding_dim_e_384() {
164        assert_eq!(EMBEDDING_DIM, 384);
165    }
166
167    // --- testes com modelo real (ignorados no CI normal) ---
168
169    #[test]
170    #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
171    fn embed_passage_retorna_vetor_com_dimensao_correta() {
172        let dir = tempfile::tempdir().unwrap();
173        let embedder = get_embedder(dir.path()).unwrap();
174        let result = embed_passage(embedder, "texto de teste").unwrap();
175        assert_eq!(result.len(), EMBEDDING_DIM);
176    }
177
178    #[test]
179    #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
180    fn embed_query_retorna_vetor_com_dimensao_correta() {
181        let dir = tempfile::tempdir().unwrap();
182        let embedder = get_embedder(dir.path()).unwrap();
183        let result = embed_query(embedder, "consulta de teste").unwrap();
184        assert_eq!(result.len(), EMBEDDING_DIM);
185    }
186
187    #[test]
188    #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
189    fn embed_passages_batch_retorna_um_vetor_por_texto() {
190        let dir = tempfile::tempdir().unwrap();
191        let embedder = get_embedder(dir.path()).unwrap();
192        let textos = vec!["primeiro".to_string(), "segundo".to_string()];
193        let results = embed_passages_batch(embedder, &textos).unwrap();
194        assert_eq!(results.len(), 2);
195        for emb in &results {
196            assert_eq!(emb.len(), EMBEDDING_DIM);
197        }
198    }
199}