1use crate::constants::{
7 EMBEDDING_DIM, EMBEDDING_MAX_TOKENS, FASTEMBED_BATCH_SIZE, PASSAGE_PREFIX, QUERY_PREFIX,
8 REMEMBER_MAX_CONTROLLED_BATCH_CHUNKS, REMEMBER_MAX_CONTROLLED_BATCH_PADDED_TOKENS,
9};
10use crate::errors::AppError;
11use fastembed::{EmbeddingModel, ExecutionProviderDispatch, TextEmbedding, TextInitOptions};
12use ort::execution_providers::CPU;
13use std::path::Path;
14use std::sync::{Mutex, OnceLock};
15
16static EMBEDDER: OnceLock<Mutex<TextEmbedding>> = OnceLock::new();
17
18pub fn get_embedder(models_dir: &Path) -> Result<&'static Mutex<TextEmbedding>, AppError> {
21 if let Some(m) = EMBEDDER.get() {
22 return Ok(m);
23 }
24
25 maybe_init_dynamic_ort(models_dir)?;
26
27 let cpu_ep: ExecutionProviderDispatch = CPU::default().with_arena_allocator(false).build();
43
44 let model = TextEmbedding::try_new(
45 TextInitOptions::new(EmbeddingModel::MultilingualE5Small)
46 .with_execution_providers(vec![cpu_ep])
47 .with_max_length(EMBEDDING_MAX_TOKENS)
48 .with_show_download_progress(true)
49 .with_cache_dir(models_dir.to_path_buf()),
50 )
51 .map_err(|e| AppError::Embedding(e.to_string()))?;
52 let _ = EMBEDDER.set(Mutex::new(model));
54 EMBEDDER.get().ok_or_else(|| {
55 AppError::Embedding(
56 "embedder OnceLock unexpectedly empty after set() (likely a racing initializer aborted before completion)"
57 .into(),
58 )
59 })
60}
61
62#[cfg(all(target_arch = "aarch64", target_os = "linux", target_env = "gnu"))]
63fn maybe_init_dynamic_ort(models_dir: &Path) -> Result<(), AppError> {
64 let mut candidates = Vec::new();
65
66 if let Ok(path) = std::env::var("ORT_DYLIB_PATH") {
67 if !path.is_empty() {
68 candidates.push(std::path::PathBuf::from(path));
69 }
70 }
71
72 if let Ok(exe) = std::env::current_exe() {
73 if let Some(dir) = exe.parent() {
74 candidates.push(dir.join("libonnxruntime.so"));
75 candidates.push(dir.join("lib").join("libonnxruntime.so"));
76 }
77 }
78
79 candidates.push(models_dir.join("libonnxruntime.so"));
80
81 for path in candidates {
82 if !path.exists() {
83 continue;
84 }
85
86 std::env::set_var("ORT_DYLIB_PATH", &path);
87 let _ = ort::init_from(&path)
88 .map_err(|e| AppError::Embedding(e.to_string()))?
89 .commit();
90 return Ok(());
91 }
92
93 Ok(())
94}
95
96#[cfg(not(all(target_arch = "aarch64", target_os = "linux", target_env = "gnu")))]
97fn maybe_init_dynamic_ort(_models_dir: &Path) -> Result<(), AppError> {
98 Ok(())
99}
100
101pub fn embed_passage(embedder: &Mutex<TextEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
102 let prefixed = format!("{PASSAGE_PREFIX}{text}");
103 let results = embedder
104 .lock()
105 .map_err(|e| AppError::Embedding(format!("embedder mutex poisoned: {e}")))?
106 .embed(vec![prefixed.as_str()], Some(1))
107 .map_err(|e| AppError::Embedding(e.to_string()))?;
108 let emb = results
109 .into_iter()
110 .next()
111 .ok_or_else(|| AppError::Embedding("empty embedding result".into()))?;
112 assert_eq!(emb.len(), EMBEDDING_DIM, "unexpected embedding dimension");
113 Ok(emb)
114}
115
116pub fn embed_query(embedder: &Mutex<TextEmbedding>, text: &str) -> Result<Vec<f32>, AppError> {
117 let prefixed = format!("{QUERY_PREFIX}{text}");
118 let results = embedder
119 .lock()
120 .map_err(|e| AppError::Embedding(format!("embedder mutex poisoned: {e}")))?
121 .embed(vec![prefixed.as_str()], Some(1))
122 .map_err(|e| AppError::Embedding(e.to_string()))?;
123 let emb = results
124 .into_iter()
125 .next()
126 .ok_or_else(|| AppError::Embedding("empty embedding result".into()))?;
127 Ok(emb)
128}
129
130pub fn embed_passages_batch(
131 embedder: &Mutex<TextEmbedding>,
132 texts: &[&str],
133 batch_size: usize,
134) -> Result<Vec<Vec<f32>>, AppError> {
135 let prefixed: Vec<String> = texts
136 .iter()
137 .map(|t| format!("{PASSAGE_PREFIX}{t}"))
138 .collect();
139 let strs: Vec<&str> = prefixed.iter().map(String::as_str).collect();
140 let results = embedder
141 .lock()
142 .map_err(|e| AppError::Embedding(format!("embedder mutex poisoned: {e}")))?
143 .embed(strs, Some(batch_size.min(FASTEMBED_BATCH_SIZE)))
144 .map_err(|e| AppError::Embedding(e.to_string()))?;
145 for emb in &results {
146 assert_eq!(emb.len(), EMBEDDING_DIM, "unexpected embedding dimension");
147 }
148 Ok(results)
149}
150
151pub fn controlled_batch_count(token_counts: &[usize]) -> usize {
152 plan_controlled_batches(token_counts).len()
153}
154
155pub fn embed_passages_controlled(
156 embedder: &Mutex<TextEmbedding>,
157 texts: &[&str],
158 token_counts: &[usize],
159) -> Result<Vec<Vec<f32>>, AppError> {
160 if texts.len() != token_counts.len() {
161 return Err(AppError::Internal(anyhow::anyhow!(
162 "texts/token_counts length mismatch in controlled embedding"
163 )));
164 }
165
166 let mut results = Vec::with_capacity(texts.len());
167 for (start, end) in plan_controlled_batches(token_counts) {
168 if end - start == 1 {
169 results.push(embed_passage(embedder, texts[start])?);
170 continue;
171 }
172
173 results.extend(embed_passages_batch(
174 embedder,
175 &texts[start..end],
176 end - start,
177 )?);
178 }
179
180 Ok(results)
181}
182
183pub fn embed_passages_serial<'a, I>(
188 embedder: &Mutex<TextEmbedding>,
189 texts: I,
190) -> Result<Vec<Vec<f32>>, AppError>
191where
192 I: IntoIterator<Item = &'a str>,
193{
194 let iter = texts.into_iter();
195 let (lower, _) = iter.size_hint();
196 let mut results = Vec::with_capacity(lower);
197 for text in iter {
198 results.push(embed_passage(embedder, text)?);
199 }
200 Ok(results)
201}
202
203fn plan_controlled_batches(token_counts: &[usize]) -> Vec<(usize, usize)> {
204 let mut batches = Vec::new();
205 let mut start = 0usize;
206
207 while start < token_counts.len() {
208 let mut end = start + 1;
209 let mut max_tokens = token_counts[start].max(1);
210
211 while end < token_counts.len() && end - start < REMEMBER_MAX_CONTROLLED_BATCH_CHUNKS {
212 let candidate_max = max_tokens.max(token_counts[end].max(1));
213 let candidate_len = end + 1 - start;
214 if candidate_max * candidate_len > REMEMBER_MAX_CONTROLLED_BATCH_PADDED_TOKENS {
215 break;
216 }
217 max_tokens = candidate_max;
218 end += 1;
219 }
220
221 batches.push((start, end));
222 start = end;
223 }
224
225 batches
226}
227
228pub fn f32_to_bytes(v: &[f32]) -> &[u8] {
232 unsafe { std::slice::from_raw_parts(v.as_ptr() as *const u8, std::mem::size_of_val(v)) }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use crate::constants::{EMBEDDING_DIM, PASSAGE_PREFIX, QUERY_PREFIX};
239
240 #[test]
243 fn f32_to_bytes_empty_slice_returns_empty() {
244 let v: Vec<f32> = vec![];
245 assert_eq!(f32_to_bytes(&v), &[] as &[u8]);
246 }
247
248 #[test]
249 fn f32_to_bytes_one_element_returns_4_bytes() {
250 let v = vec![1.0_f32];
251 let bytes = f32_to_bytes(&v);
252 assert_eq!(bytes.len(), 4);
253 let recovered = f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
255 assert_eq!(recovered, 1.0_f32);
256 }
257
258 #[test]
259 fn f32_to_bytes_length_is_4x_elements() {
260 let v = vec![0.0_f32, 1.0, 2.0, 3.0];
261 assert_eq!(f32_to_bytes(&v).len(), v.len() * 4);
262 }
263
264 #[test]
265 fn f32_to_bytes_zero_encoded_as_4_zeros() {
266 let v = vec![0.0_f32];
267 assert_eq!(f32_to_bytes(&v), &[0u8, 0, 0, 0]);
268 }
269
270 #[test]
271 fn f32_to_bytes_roundtrip_vector_embedding_dim() {
272 let v: Vec<f32> = (0..EMBEDDING_DIM).map(|i| i as f32 * 0.001).collect();
273 let bytes = f32_to_bytes(&v);
274 assert_eq!(bytes.len(), EMBEDDING_DIM * 4);
275 let first = f32::from_le_bytes(bytes[0..4].try_into().unwrap());
277 assert!((first - 0.0_f32).abs() < 1e-6);
278 let last_start = (EMBEDDING_DIM - 1) * 4;
279 let last = f32::from_le_bytes(bytes[last_start..last_start + 4].try_into().unwrap());
280 assert!((last - (EMBEDDING_DIM - 1) as f32 * 0.001).abs() < 1e-4);
281 }
282
283 #[test]
286 fn passage_prefix_not_empty() {
287 assert_eq!(PASSAGE_PREFIX, "passage: ");
288 }
289
290 #[test]
291 fn query_prefix_not_empty() {
292 assert_eq!(QUERY_PREFIX, "query: ");
293 }
294
295 #[test]
296 fn embedding_dim_is_384() {
297 assert_eq!(EMBEDDING_DIM, 384);
298 }
299
300 #[test]
303 #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
304 fn embed_passage_returns_vector_with_correct_dimension() {
305 let dir = tempfile::tempdir().unwrap();
306 let embedder = get_embedder(dir.path()).unwrap();
307 let result = embed_passage(embedder, "texto de teste").unwrap();
308 assert_eq!(result.len(), EMBEDDING_DIM);
309 }
310
311 #[test]
312 #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
313 fn embed_query_returns_vector_with_correct_dimension() {
314 let dir = tempfile::tempdir().unwrap();
315 let embedder = get_embedder(dir.path()).unwrap();
316 let result = embed_query(embedder, "consulta de teste").unwrap();
317 assert_eq!(result.len(), EMBEDDING_DIM);
318 }
319
320 #[test]
321 #[ignore = "requer modelo ~600 MB em disco; executar com --include-ignored"]
322 fn embed_passages_batch_returns_one_vector_per_text() {
323 let dir = tempfile::tempdir().unwrap();
324 let embedder = get_embedder(dir.path()).unwrap();
325 let textos = ["primeiro", "segundo"];
326 let results = embed_passages_batch(embedder, &textos, 2).unwrap();
327 assert_eq!(results.len(), 2);
328 for emb in &results {
329 assert_eq!(emb.len(), EMBEDDING_DIM);
330 }
331 }
332
333 #[test]
334 fn controlled_batch_plan_respects_budget() {
335 assert_eq!(
336 plan_controlled_batches(&[100, 100, 100, 100, 300, 300]),
337 vec![(0, 4), (4, 5), (5, 6)]
338 );
339 }
340
341 #[test]
342 fn controlled_batch_count_returns_one_for_single_chunk() {
343 assert_eq!(controlled_batch_count(&[350]), 1);
344 }
345}