sqlite_graphrag/
embedder.rs1use crate::constants::{EMBEDDING_DIM, FASTEMBED_BATCH_SIZE, PASSAGE_PREFIX, QUERY_PREFIX};
2use crate::errors::AppError;
3use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
4use std::path::Path;
5use std::sync::{Mutex, OnceLock};
6
7static EMBEDDER: OnceLock<Mutex<TextEmbedding>> = OnceLock::new();
8
9pub fn get_embedder(models_dir: &Path) -> Result<&'static Mutex<TextEmbedding>, AppError> {
12 if let Some(m) = EMBEDDER.get() {
13 return Ok(m);
14 }
15 let model = TextEmbedding::try_new(
16 InitOptions::new(EmbeddingModel::MultilingualE5Small)
17 .with_show_download_progress(true)
18 .with_cache_dir(models_dir.to_path_buf()),
19 )
20 .map_err(|e| AppError::Embedding(e.to_string()))?;
21 let _ = EMBEDDER.set(Mutex::new(model));
23 Ok(EMBEDDER.get().expect("just set above"))
24}
25
26pub fn embed_passage(embedder: &Mutex<TextEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
27 let prefixed = format!("{PASSAGE_PREFIX}{text}");
28 let results = embedder
29 .lock()
30 .map_err(|e| AppError::Embedding(format!("lock poisoned: {e}")))?
31 .embed(vec![prefixed.as_str()], Some(1))
32 .map_err(|e| AppError::Embedding(e.to_string()))?;
33 let emb = results
34 .into_iter()
35 .next()
36 .ok_or_else(|| AppError::Embedding("empty embedding result".into()))?;
37 assert_eq!(emb.len(), EMBEDDING_DIM, "unexpected embedding dimension");
38 Ok(emb)
39}
40
41pub fn embed_query(embedder: &Mutex<TextEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
42 let prefixed = format!("{QUERY_PREFIX}{text}");
43 let results = embedder
44 .lock()
45 .map_err(|e| AppError::Embedding(format!("lock poisoned: {e}")))?
46 .embed(vec![prefixed.as_str()], Some(1))
47 .map_err(|e| AppError::Embedding(e.to_string()))?;
48 let emb = results
49 .into_iter()
50 .next()
51 .ok_or_else(|| AppError::Embedding("empty embedding result".into()))?;
52 Ok(emb)
53}
54
55pub fn embed_passages_batch(
56 embedder: &Mutex<TextEmbedding>,
57 texts: &[String],
58) -> Result<Vec<Vec<f32>>, AppError> {
59 let prefixed: Vec<String> = texts
60 .iter()
61 .map(|t| format!("{PASSAGE_PREFIX}{t}"))
62 .collect();
63 let strs: Vec<&str> = prefixed.iter().map(String::as_str).collect();
64 let results = embedder
65 .lock()
66 .map_err(|e| AppError::Embedding(format!("lock poisoned: {e}")))?
67 .embed(strs, Some(FASTEMBED_BATCH_SIZE))
68 .map_err(|e| AppError::Embedding(e.to_string()))?;
69 for emb in &results {
70 assert_eq!(emb.len(), EMBEDDING_DIM, "unexpected embedding dimension");
71 }
72 Ok(results)
73}
74
75pub fn embed_passages_serial<'a, I>(
80 embedder: &Mutex<TextEmbedding>,
81 texts: I,
82) -> Result<Vec<Vec<f32>>, AppError>
83where
84 I: IntoIterator<Item = &'a str>,
85{
86 let iter = texts.into_iter();
87 let (lower, _) = iter.size_hint();
88 let mut results = Vec::with_capacity(lower);
89 for text in iter {
90 results.push(embed_passage(embedder, text)?);
91 }
92 Ok(results)
93}
94
95pub fn f32_to_bytes(v: &[f32]) -> &[u8] {
99 unsafe { std::slice::from_raw_parts(v.as_ptr() as *const u8, std::mem::size_of_val(v)) }
100}
101
102#[cfg(test)]
103mod testes {
104 use super::*;
105 use crate::constants::{EMBEDDING_DIM, PASSAGE_PREFIX, QUERY_PREFIX};
106
107 #[test]
110 fn f32_to_bytes_slice_vazio_retorna_vazio() {
111 let v: Vec<f32> = vec![];
112 assert_eq!(f32_to_bytes(&v), &[] as &[u8]);
113 }
114
115 #[test]
116 fn f32_to_bytes_um_elemento_retorna_4_bytes() {
117 let v = vec![1.0_f32];
118 let bytes = f32_to_bytes(&v);
119 assert_eq!(bytes.len(), 4);
120 let recovered = f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
122 assert_eq!(recovered, 1.0_f32);
123 }
124
125 #[test]
126 fn f32_to_bytes_comprimento_e_4x_elementos() {
127 let v = vec![0.0_f32, 1.0, 2.0, 3.0];
128 assert_eq!(f32_to_bytes(&v).len(), v.len() * 4);
129 }
130
131 #[test]
132 fn f32_to_bytes_zero_codificado_como_4_zeros() {
133 let v = vec![0.0_f32];
134 assert_eq!(f32_to_bytes(&v), &[0u8, 0, 0, 0]);
135 }
136
137 #[test]
138 fn f32_to_bytes_roundtrip_vetor_embedding_dim() {
139 let v: Vec<f32> = (0..EMBEDDING_DIM).map(|i| i as f32 * 0.001).collect();
140 let bytes = f32_to_bytes(&v);
141 assert_eq!(bytes.len(), EMBEDDING_DIM * 4);
142 let first = f32::from_le_bytes(bytes[0..4].try_into().unwrap());
144 assert!((first - 0.0_f32).abs() < 1e-6);
145 let last_start = (EMBEDDING_DIM - 1) * 4;
146 let last = f32::from_le_bytes(bytes[last_start..last_start + 4].try_into().unwrap());
147 assert!((last - (EMBEDDING_DIM - 1) as f32 * 0.001).abs() < 1e-4);
148 }
149
150 #[test]
153 fn passage_prefix_nao_vazio() {
154 assert_eq!(PASSAGE_PREFIX, "passage: ");
155 }
156
157 #[test]
158 fn query_prefix_nao_vazio() {
159 assert_eq!(QUERY_PREFIX, "query: ");
160 }
161
162 #[test]
163 fn embedding_dim_e_384() {
164 assert_eq!(EMBEDDING_DIM, 384);
165 }
166
167 #[test]
170 #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
171 fn embed_passage_retorna_vetor_com_dimensao_correta() {
172 let dir = tempfile::tempdir().unwrap();
173 let embedder = get_embedder(dir.path()).unwrap();
174 let result = embed_passage(embedder, "texto de teste").unwrap();
175 assert_eq!(result.len(), EMBEDDING_DIM);
176 }
177
178 #[test]
179 #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
180 fn embed_query_retorna_vetor_com_dimensao_correta() {
181 let dir = tempfile::tempdir().unwrap();
182 let embedder = get_embedder(dir.path()).unwrap();
183 let result = embed_query(embedder, "consulta de teste").unwrap();
184 assert_eq!(result.len(), EMBEDDING_DIM);
185 }
186
187 #[test]
188 #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
189 fn embed_passages_batch_retorna_um_vetor_por_texto() {
190 let dir = tempfile::tempdir().unwrap();
191 let embedder = get_embedder(dir.path()).unwrap();
192 let textos = vec!["primeiro".to_string(), "segundo".to_string()];
193 let results = embed_passages_batch(embedder, &textos).unwrap();
194 assert_eq!(results.len(), 2);
195 for emb in &results {
196 assert_eq!(emb.len(), EMBEDDING_DIM);
197 }
198 }
199}