sqlite_graphrag/
embedder.rs1use crate::constants::{EMBEDDING_DIM, FASTEMBED_BATCH_SIZE, PASSAGE_PREFIX, QUERY_PREFIX};
2use crate::errors::AppError;
3use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
4use std::path::Path;
5use std::sync::{Mutex, OnceLock};
6
7static EMBEDDER: OnceLock<Mutex<TextEmbedding>> = OnceLock::new();
8
9pub fn get_embedder(models_dir: &Path) -> Result<&'static Mutex<TextEmbedding>, AppError> {
12 if let Some(m) = EMBEDDER.get() {
13 return Ok(m);
14 }
15 let model = TextEmbedding::try_new(
16 InitOptions::new(EmbeddingModel::MultilingualE5Small)
17 .with_show_download_progress(true)
18 .with_cache_dir(models_dir.to_path_buf()),
19 )
20 .map_err(|e| AppError::Embedding(e.to_string()))?;
21 let _ = EMBEDDER.set(Mutex::new(model));
23 Ok(EMBEDDER.get().expect("just set above"))
24}
25
26pub fn embed_passage(embedder: &Mutex<TextEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
27 let prefixed = format!("{PASSAGE_PREFIX}{text}");
28 let results = embedder
29 .lock()
30 .map_err(|e| AppError::Embedding(format!("lock poisoned: {e}")))?
31 .embed(vec![prefixed.as_str()], Some(1))
32 .map_err(|e| AppError::Embedding(e.to_string()))?;
33 let emb = results
34 .into_iter()
35 .next()
36 .ok_or_else(|| AppError::Embedding("empty embedding result".into()))?;
37 assert_eq!(emb.len(), EMBEDDING_DIM, "unexpected embedding dimension");
38 Ok(emb)
39}
40
41pub fn embed_query(embedder: &Mutex<TextEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
42 let prefixed = format!("{QUERY_PREFIX}{text}");
43 let results = embedder
44 .lock()
45 .map_err(|e| AppError::Embedding(format!("lock poisoned: {e}")))?
46 .embed(vec![prefixed.as_str()], Some(1))
47 .map_err(|e| AppError::Embedding(e.to_string()))?;
48 let emb = results
49 .into_iter()
50 .next()
51 .ok_or_else(|| AppError::Embedding("empty embedding result".into()))?;
52 Ok(emb)
53}
54
55pub fn embed_passages_batch(
56 embedder: &Mutex<TextEmbedding>,
57 texts: &[String],
58) -> Result<Vec<Vec<f32>>, AppError> {
59 let prefixed: Vec<String> = texts
60 .iter()
61 .map(|t| format!("{PASSAGE_PREFIX}{t}"))
62 .collect();
63 let strs: Vec<&str> = prefixed.iter().map(String::as_str).collect();
64 let results = embedder
65 .lock()
66 .map_err(|e| AppError::Embedding(format!("lock poisoned: {e}")))?
67 .embed(strs, Some(FASTEMBED_BATCH_SIZE))
68 .map_err(|e| AppError::Embedding(e.to_string()))?;
69 for emb in &results {
70 assert_eq!(emb.len(), EMBEDDING_DIM, "unexpected embedding dimension");
71 }
72 Ok(results)
73}
74
75pub fn f32_to_bytes(v: &[f32]) -> &[u8] {
79 unsafe { std::slice::from_raw_parts(v.as_ptr() as *const u8, std::mem::size_of_val(v)) }
80}
81
82#[cfg(test)]
83mod testes {
84 use super::*;
85 use crate::constants::{EMBEDDING_DIM, PASSAGE_PREFIX, QUERY_PREFIX};
86
87 #[test]
90 fn f32_to_bytes_slice_vazio_retorna_vazio() {
91 let v: Vec<f32> = vec![];
92 assert_eq!(f32_to_bytes(&v), &[] as &[u8]);
93 }
94
95 #[test]
96 fn f32_to_bytes_um_elemento_retorna_4_bytes() {
97 let v = vec![1.0_f32];
98 let bytes = f32_to_bytes(&v);
99 assert_eq!(bytes.len(), 4);
100 let recovered = f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
102 assert_eq!(recovered, 1.0_f32);
103 }
104
105 #[test]
106 fn f32_to_bytes_comprimento_e_4x_elementos() {
107 let v = vec![0.0_f32, 1.0, 2.0, 3.0];
108 assert_eq!(f32_to_bytes(&v).len(), v.len() * 4);
109 }
110
111 #[test]
112 fn f32_to_bytes_zero_codificado_como_4_zeros() {
113 let v = vec![0.0_f32];
114 assert_eq!(f32_to_bytes(&v), &[0u8, 0, 0, 0]);
115 }
116
117 #[test]
118 fn f32_to_bytes_roundtrip_vetor_embedding_dim() {
119 let v: Vec<f32> = (0..EMBEDDING_DIM).map(|i| i as f32 * 0.001).collect();
120 let bytes = f32_to_bytes(&v);
121 assert_eq!(bytes.len(), EMBEDDING_DIM * 4);
122 let first = f32::from_le_bytes(bytes[0..4].try_into().unwrap());
124 assert!((first - 0.0_f32).abs() < 1e-6);
125 let last_start = (EMBEDDING_DIM - 1) * 4;
126 let last = f32::from_le_bytes(bytes[last_start..last_start + 4].try_into().unwrap());
127 assert!((last - (EMBEDDING_DIM - 1) as f32 * 0.001).abs() < 1e-4);
128 }
129
130 #[test]
133 fn passage_prefix_nao_vazio() {
134 assert_eq!(PASSAGE_PREFIX, "passage: ");
135 }
136
137 #[test]
138 fn query_prefix_nao_vazio() {
139 assert_eq!(QUERY_PREFIX, "query: ");
140 }
141
142 #[test]
143 fn embedding_dim_e_384() {
144 assert_eq!(EMBEDDING_DIM, 384);
145 }
146
147 #[test]
150 #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
151 fn embed_passage_retorna_vetor_com_dimensao_correta() {
152 let dir = tempfile::tempdir().unwrap();
153 let embedder = get_embedder(dir.path()).unwrap();
154 let result = embed_passage(embedder, "texto de teste").unwrap();
155 assert_eq!(result.len(), EMBEDDING_DIM);
156 }
157
158 #[test]
159 #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
160 fn embed_query_retorna_vetor_com_dimensao_correta() {
161 let dir = tempfile::tempdir().unwrap();
162 let embedder = get_embedder(dir.path()).unwrap();
163 let result = embed_query(embedder, "consulta de teste").unwrap();
164 assert_eq!(result.len(), EMBEDDING_DIM);
165 }
166
167 #[test]
168 #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
169 fn embed_passages_batch_retorna_um_vetor_por_texto() {
170 let dir = tempfile::tempdir().unwrap();
171 let embedder = get_embedder(dir.path()).unwrap();
172 let textos = vec!["primeiro".to_string(), "segundo".to_string()];
173 let results = embed_passages_batch(embedder, &textos).unwrap();
174 assert_eq!(results.len(), 2);
175 for emb in &results {
176 assert_eq!(emb.len(), EMBEDDING_DIM);
177 }
178 }
179}