1use crate::constants::{
8 EMBEDDING_DIM, EMBEDDING_MAX_TOKENS, FASTEMBED_BATCH_SIZE, PASSAGE_PREFIX, QUERY_PREFIX,
9 REMEMBER_MAX_CONTROLLED_BATCH_CHUNKS, REMEMBER_MAX_CONTROLLED_BATCH_PADDED_TOKENS,
10};
11use crate::errors::AppError;
12use fastembed::{EmbeddingModel, ExecutionProviderDispatch, TextEmbedding, TextInitOptions};
13use ort::ep::CPU;
14use parking_lot::Mutex;
15use std::path::Path;
16use std::sync::OnceLock;
17
18static EMBEDDER: OnceLock<Mutex<TextEmbedding>> = OnceLock::new();
28
29pub fn get_embedder(models_dir: &Path) -> Result<&'static Mutex<TextEmbedding>, AppError> {
37 if let Some(m) = EMBEDDER.get() {
38 return Ok(m);
39 }
40
41 maybe_init_dynamic_ort(models_dir)?;
42
43 let cpu_ep: ExecutionProviderDispatch = CPU::default().with_arena_allocator(false).build();
59
60 let model = TextEmbedding::try_new(
61 TextInitOptions::new(EmbeddingModel::MultilingualE5Small)
62 .with_execution_providers(vec![cpu_ep])
63 .with_max_length(EMBEDDING_MAX_TOKENS)
64 .with_show_download_progress(true)
65 .with_cache_dir(models_dir.to_path_buf()),
66 )
67 .map_err(|e| AppError::Embedding(e.to_string()))?;
68 let _ = EMBEDDER.set(Mutex::new(model));
70 EMBEDDER.get().ok_or_else(|| {
71 AppError::Embedding(
72 "embedder OnceLock unexpectedly empty after set() (likely a racing initializer aborted before completion)"
73 .into(),
74 )
75 })
76}
77
78#[cfg(all(target_arch = "aarch64", target_os = "linux", target_env = "gnu"))]
79fn maybe_init_dynamic_ort(models_dir: &Path) -> Result<(), AppError> {
80 let mut candidates = Vec::with_capacity(4);
81
82 if let Ok(path) = std::env::var("ORT_DYLIB_PATH") {
83 if !path.is_empty() {
84 candidates.push(std::path::PathBuf::from(path));
85 }
86 }
87
88 if let Ok(exe) = std::env::current_exe() {
89 if let Some(dir) = exe.parent() {
90 candidates.push(dir.join("libonnxruntime.so"));
91 candidates.push(dir.join("lib").join("libonnxruntime.so"));
92 }
93 }
94
95 candidates.push(models_dir.join("libonnxruntime.so"));
96
97 for path in candidates {
98 if !path.exists() {
99 continue;
100 }
101
102 std::env::set_var("ORT_DYLIB_PATH", &path);
103 let _ = ort::init_from(&path)
104 .map_err(|e| AppError::Embedding(e.to_string()))?
105 .commit();
106 return Ok(());
107 }
108
109 Ok(())
110}
111
112#[cfg(not(all(target_arch = "aarch64", target_os = "linux", target_env = "gnu")))]
113fn maybe_init_dynamic_ort(_models_dir: &Path) -> Result<(), AppError> {
114 Ok(())
115}
116
117#[tracing::instrument(skip(embedder, text), fields(text_len = text.len()))]
122pub fn embed_passage(embedder: &Mutex<TextEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
123 let prefixed = format!("{PASSAGE_PREFIX}{text}");
124 let results = embedder
125 .lock()
126 .embed(vec![prefixed.as_str()], Some(1))
127 .map_err(|e| AppError::Embedding(e.to_string()))?;
128 let emb = results
129 .into_iter()
130 .next()
131 .ok_or_else(|| AppError::Embedding("empty embedding result".into()))?;
132 assert_eq!(emb.len(), EMBEDDING_DIM, "unexpected embedding dimension");
133 Ok(emb)
134}
135
136#[tracing::instrument(skip(embedder, text), fields(text_len = text.len()))]
141pub fn embed_query(embedder: &Mutex<TextEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
142 let prefixed = format!("{QUERY_PREFIX}{text}");
143 let results = embedder
144 .lock()
145 .embed(vec![prefixed.as_str()], Some(1))
146 .map_err(|e| AppError::Embedding(e.to_string()))?;
147 let emb = results
148 .into_iter()
149 .next()
150 .ok_or_else(|| AppError::Embedding("empty embedding result".into()))?;
151 Ok(emb)
152}
153
154#[tracing::instrument(skip(embedder, texts), fields(batch_size = texts.len()))]
161pub fn embed_passages_batch(
162 embedder: &Mutex<TextEmbedding>,
163 texts: &[&str],
164 batch_size: usize,
165) -> Result<Vec<Vec<f32>>, AppError> {
166 let prefixed: Vec<String> = texts
167 .iter()
168 .map(|t| format!("{PASSAGE_PREFIX}{t}"))
169 .collect();
170 let strs: Vec<&str> = prefixed.iter().map(String::as_str).collect();
171 let results = embedder
172 .lock()
173 .embed(strs, Some(batch_size.min(FASTEMBED_BATCH_SIZE)))
174 .map_err(|e| AppError::Embedding(e.to_string()))?;
175 for emb in &results {
176 assert_eq!(emb.len(), EMBEDDING_DIM, "unexpected embedding dimension");
177 }
178 Ok(results)
179}
180
181pub fn controlled_batch_count(token_counts: &[usize]) -> usize {
184 plan_controlled_batches(token_counts).len()
185}
186
187pub fn embed_passages_controlled(
195 embedder: &Mutex<TextEmbedding>,
196 texts: &[&str],
197 token_counts: &[usize],
198) -> Result<Vec<Vec<f32>>, AppError> {
199 if texts.len() != token_counts.len() {
200 return Err(AppError::Internal(anyhow::anyhow!(
201 "texts/token_counts length mismatch in controlled embedding"
202 )));
203 }
204
205 let mut results = Vec::with_capacity(texts.len());
206 for (start, end) in plan_controlled_batches(token_counts) {
207 if end - start == 1 {
208 results.push(embed_passage(embedder, texts[start])?);
209 continue;
210 }
211
212 results.extend(embed_passages_batch(
213 embedder,
214 &texts[start..end],
215 end - start,
216 )?);
217 }
218
219 Ok(results)
220}
221
222pub fn embed_passages_serial<'a, I>(
234 embedder: &Mutex<TextEmbedding>,
235 texts: I,
236) -> Result<Vec<Vec<f32>>, AppError>
237where
238 I: IntoIterator<Item = &'a str>,
239{
240 let iter = texts.into_iter();
241 let (lower, _) = iter.size_hint();
242 let mut results = Vec::with_capacity(lower);
243 for text in iter {
244 results.push(embed_passage(embedder, text)?);
245 }
246 Ok(results)
247}
248
249fn plan_controlled_batches(token_counts: &[usize]) -> Vec<(usize, usize)> {
250 let mut batches =
251 Vec::with_capacity((token_counts.len() / REMEMBER_MAX_CONTROLLED_BATCH_CHUNKS).max(1));
252 let mut start = 0usize;
253
254 while start < token_counts.len() {
255 let mut end = start + 1;
256 let mut max_tokens = token_counts[start].max(1);
257
258 while end < token_counts.len() && end - start < REMEMBER_MAX_CONTROLLED_BATCH_CHUNKS {
259 let candidate_max = max_tokens.max(token_counts[end].max(1));
260 let candidate_len = end + 1 - start;
261 if candidate_max * candidate_len > REMEMBER_MAX_CONTROLLED_BATCH_PADDED_TOKENS {
262 break;
263 }
264 max_tokens = candidate_max;
265 end += 1;
266 }
267
268 batches.push((start, end));
269 start = end;
270 }
271
272 batches
273}
274
275#[cfg(target_endian = "big")]
287compile_error!(
288 "sqlite-graphrag requires little-endian f32 layout for sqlite-vec compatibility. \
289 Big-endian targets (PPC64, S390x) are not supported."
290);
291
292pub fn f32_to_bytes(v: &[f32]) -> &[u8] {
293 unsafe { std::slice::from_raw_parts(v.as_ptr() as *const u8, std::mem::size_of_val(v)) }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300 use crate::constants::{EMBEDDING_DIM, PASSAGE_PREFIX, QUERY_PREFIX};
301
302 #[test]
305 fn f32_to_bytes_empty_slice_returns_empty() {
306 let v: Vec<f32> = vec![];
307 assert_eq!(f32_to_bytes(&v), &[] as &[u8]);
308 }
309
310 #[test]
311 fn f32_to_bytes_one_element_returns_4_bytes() {
312 let v = vec![1.0_f32];
313 let bytes = f32_to_bytes(&v);
314 assert_eq!(bytes.len(), 4);
315 let recovered = f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
317 assert_eq!(recovered, 1.0_f32);
318 }
319
320 #[test]
321 fn f32_to_bytes_length_is_4x_elements() {
322 let v = vec![0.0_f32, 1.0, 2.0, 3.0];
323 assert_eq!(f32_to_bytes(&v).len(), v.len() * 4);
324 }
325
326 #[test]
327 fn f32_to_bytes_zero_encoded_as_4_zeros() {
328 let v = vec![0.0_f32];
329 assert_eq!(f32_to_bytes(&v), &[0u8, 0, 0, 0]);
330 }
331
332 #[test]
333 fn f32_to_bytes_roundtrip_vector_embedding_dim() {
334 let v: Vec<f32> = (0..EMBEDDING_DIM).map(|i| i as f32 * 0.001).collect();
335 let bytes = f32_to_bytes(&v);
336 assert_eq!(bytes.len(), EMBEDDING_DIM * 4);
337 let first = f32::from_le_bytes(bytes[0..4].try_into().unwrap());
339 assert!((first - 0.0_f32).abs() < 1e-6);
340 let last_start = (EMBEDDING_DIM - 1) * 4;
341 let last = f32::from_le_bytes(bytes[last_start..last_start + 4].try_into().unwrap());
342 assert!((last - (EMBEDDING_DIM - 1) as f32 * 0.001).abs() < 1e-4);
343 }
344
345 #[test]
348 fn passage_prefix_not_empty() {
349 assert_eq!(PASSAGE_PREFIX, "passage: ");
350 }
351
352 #[test]
353 fn query_prefix_not_empty() {
354 assert_eq!(QUERY_PREFIX, "query: ");
355 }
356
357 #[test]
358 fn embedding_dim_is_384() {
359 assert_eq!(EMBEDDING_DIM, 384);
360 }
361
362 #[test]
365 #[ignore = "requires ~600 MB model on disk; run with --include-ignored"]
366 fn embed_passage_returns_vector_with_correct_dimension() {
367 let dir = tempfile::tempdir().unwrap();
368 let embedder = get_embedder(dir.path()).unwrap();
369 let result = embed_passage(embedder, "test text").unwrap();
370 assert_eq!(result.len(), EMBEDDING_DIM);
371 }
372
373 #[test]
374 #[ignore = "requires ~600 MB model on disk; run with --include-ignored"]
375 fn embed_query_returns_vector_with_correct_dimension() {
376 let dir = tempfile::tempdir().unwrap();
377 let embedder = get_embedder(dir.path()).unwrap();
378 let result = embed_query(embedder, "test query").unwrap();
379 assert_eq!(result.len(), EMBEDDING_DIM);
380 }
381
382 #[test]
383 #[ignore = "requires ~600 MB model on disk; run with --include-ignored"]
384 fn embed_passages_batch_returns_one_vector_per_text() {
385 let dir = tempfile::tempdir().unwrap();
386 let embedder = get_embedder(dir.path()).unwrap();
387 let textos = ["primeiro", "segundo"];
388 let results = embed_passages_batch(embedder, &textos, 2).unwrap();
389 assert_eq!(results.len(), 2);
390 for emb in &results {
391 assert_eq!(emb.len(), EMBEDDING_DIM);
392 }
393 }
394
395 #[test]
396 fn controlled_batch_plan_respects_budget() {
397 assert_eq!(
398 plan_controlled_batches(&[100, 100, 100, 100, 300, 300]),
399 vec![(0, 4), (4, 5), (5, 6)]
400 );
401 }
402
403 #[test]
404 fn controlled_batch_count_returns_one_for_single_chunk() {
405 assert_eq!(controlled_batch_count(&[350]), 1);
406 }
407}