1use std::fs;
2use std::path::{Path, PathBuf};
3use std::sync::{Arc, Mutex};
4
5use hf_hub::api::sync::ApiBuilder;
6use hf_hub::{Repo, RepoType};
7use ndarray::{Array2, ArrayView2, ArrayView3, Axis, Ix2, Ix3};
8use ort::session::Session;
9use ort::value::Tensor;
10use serde::Deserialize;
11use tokenizers::Tokenizer;
12use tokenizers::tokenizer::{PaddingParams, PaddingStrategy, TruncationParams};
13
14const DEFAULT_MODEL_REPO: &str = "sentence-transformers/all-MiniLM-L12-v2";
15const DEFAULT_MODEL_REVISION: &str = "main";
16const DEFAULT_MODEL_FILE: &str = "onnx/model.onnx";
17const DEFAULT_TOKENIZER_FILE: &str = "tokenizer.json";
18const DEFAULT_POOLING_CONFIG_FILE: &str = "1_Pooling/config.json";
19const DEFAULT_TRANSFORMER_CONFIG_FILE: &str = "config.json";
20const DEFAULT_MAX_LENGTH: usize = 128;
21type EncodedInputs = (Array2<i64>, Array2<i64>, Option<Array2<i64>>);
22
23#[allow(async_fn_in_trait)]
24pub trait EmbeddingsProvider {
26 async fn embed_batch(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError>;
27
28 async fn embed(&mut self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
29 let mut embeddings = self.embed_batch(&[text.to_owned()]).await?;
30 embeddings.pop().ok_or(EmbeddingError::MissingOutput(
31 "no embeddings returned".to_string(),
32 ))
33 }
34}
35
36#[derive(Clone, Debug, Eq, PartialEq)]
37pub struct EmbeddingsConfig {
39 pub model_repo: String,
41 pub model_revision: String,
43 pub model_file: String,
45 pub tokenizer_file: String,
47 pub pooling_config_file: String,
49 pub transformer_config_file: String,
51 pub max_length: usize,
53 pub normalize: bool,
55 pub intra_threads: Option<usize>,
57 pub cache_dir: Option<PathBuf>,
59 pub local_model_path: Option<PathBuf>,
61 pub local_tokenizer_path: Option<PathBuf>,
63 pub local_pooling_config_path: Option<PathBuf>,
65 pub local_transformer_config_path: Option<PathBuf>,
67 pub input_ids_name: Option<String>,
69 pub attention_mask_name: Option<String>,
71 pub token_type_ids_name: Option<String>,
73 pub output_name: Option<String>,
75}
76
77impl Default for EmbeddingsConfig {
78 fn default() -> Self {
79 Self {
80 model_repo: DEFAULT_MODEL_REPO.to_string(),
81 model_revision: DEFAULT_MODEL_REVISION.to_string(),
82 model_file: DEFAULT_MODEL_FILE.to_string(),
83 tokenizer_file: DEFAULT_TOKENIZER_FILE.to_string(),
84 pooling_config_file: DEFAULT_POOLING_CONFIG_FILE.to_string(),
85 transformer_config_file: DEFAULT_TRANSFORMER_CONFIG_FILE.to_string(),
86 max_length: DEFAULT_MAX_LENGTH,
87 normalize: true,
88 intra_threads: None,
89 cache_dir: None,
90 local_model_path: None,
91 local_tokenizer_path: None,
92 local_pooling_config_path: None,
93 local_transformer_config_path: None,
94 input_ids_name: None,
95 attention_mask_name: None,
96 token_type_ids_name: None,
97 output_name: None,
98 }
99 }
100}
101
102#[derive(Debug)]
103pub struct OrtEmbedder {
105 inner: Arc<Mutex<OrtEmbedderInner>>,
106 max_length: usize,
107}
108
109impl OrtEmbedder {
110 pub fn new(config: EmbeddingsConfig) -> Result<Self, EmbeddingError> {
111 let assets = resolve_model_assets(&config)?;
112 let pooling_config = read_json::<PoolingConfig>(&assets.pooling_config_path)?;
113 validate_pooling_config(&pooling_config)?;
114
115 let transformer_config = read_json::<TransformerConfig>(&assets.transformer_config_path)?;
116 let expected_embedding_size = pooling_config
117 .word_embedding_dimension
118 .or(transformer_config.hidden_size);
119 let max_length = transformer_config
120 .max_position_embeddings
121 .map(|value| value.min(config.max_length))
122 .unwrap_or(config.max_length);
123
124 let tokenizer = load_tokenizer(&assets.tokenizer_path, max_length)?;
125 let session = load_session(&assets.model_path, config.intra_threads)?;
126 let input_names = SessionInputNames::from_session(
127 &session,
128 config.input_ids_name.as_deref(),
129 config.attention_mask_name.as_deref(),
130 config.token_type_ids_name.as_deref(),
131 )?;
132 let output_name = select_output_name(&session, config.output_name.as_deref())?;
133
134 Ok(Self {
135 inner: Arc::new(Mutex::new(OrtEmbedderInner {
136 tokenizer,
137 session,
138 input_names,
139 output_name,
140 normalize: config.normalize,
141 expected_embedding_size,
142 })),
143 max_length,
144 })
145 }
146
147 pub fn max_length(&self) -> usize {
148 self.max_length
149 }
150
151 pub fn expected_embedding_size(&self) -> Option<usize> {
152 self.inner
153 .lock()
154 .ok()
155 .and_then(|inner| inner.expected_embedding_size)
156 }
157
158 pub fn chunk_text(
159 &self,
160 text: &str,
161 overlap_tokens: usize,
162 ) -> Result<Vec<String>, EmbeddingError> {
163 if text.trim().is_empty() {
164 return Ok(Vec::new());
165 }
166
167 let inner = self
168 .inner
169 .lock()
170 .map_err(|error| EmbeddingError::State(format!("embedder state poisoned: {error}")))?;
171 inner.chunk_text(text, self.max_length, overlap_tokens)
172 }
173}
174
175#[derive(Debug)]
176struct OrtEmbedderInner {
177 tokenizer: Tokenizer,
178 session: Session,
179 input_names: SessionInputNames,
180 output_name: Option<String>,
181 normalize: bool,
182 expected_embedding_size: Option<usize>,
183}
184
185impl OrtEmbedderInner {
186 fn chunk_text(
187 &self,
188 text: &str,
189 max_length: usize,
190 overlap_tokens: usize,
191 ) -> Result<Vec<String>, EmbeddingError> {
192 chunk_text_with_tokenizer(&self.tokenizer, text, max_length, overlap_tokens)
193 }
194
195 fn encode_inputs(&self, texts: &[String]) -> Result<EncodedInputs, EmbeddingError> {
196 let encodings = self
197 .tokenizer
198 .encode_batch(texts.iter().map(String::as_str).collect(), true)
199 .map_err(EmbeddingError::Tokenizer)?;
200
201 let batch_size = encodings.len();
202 let sequence_length = encodings
203 .first()
204 .map(|encoding| encoding.get_ids().len())
205 .unwrap_or(0);
206
207 let mut input_ids = Array2::<i64>::zeros((batch_size, sequence_length));
208 let mut attention_mask = Array2::<i64>::zeros((batch_size, sequence_length));
209 let mut token_type_ids = self
210 .input_names
211 .token_type_ids
212 .as_ref()
213 .map(|_| Array2::<i64>::zeros((batch_size, sequence_length)));
214
215 for (row_index, encoding) in encodings.iter().enumerate() {
216 for (column_index, token_id) in encoding.get_ids().iter().enumerate() {
217 input_ids[(row_index, column_index)] = i64::from(*token_id);
218 }
219
220 for (column_index, mask) in encoding.get_attention_mask().iter().enumerate() {
221 attention_mask[(row_index, column_index)] = i64::from(*mask);
222 }
223
224 if let Some(token_type_ids) = token_type_ids.as_mut() {
225 for (column_index, token_type_id) in encoding.get_type_ids().iter().enumerate() {
226 token_type_ids[(row_index, column_index)] = i64::from(*token_type_id);
227 }
228 }
229 }
230
231 Ok((input_ids, attention_mask, token_type_ids))
232 }
233
234 fn run_inference(
235 &mut self,
236 input_ids: Array2<i64>,
237 attention_mask: Array2<i64>,
238 token_type_ids: Option<Array2<i64>>,
239 ) -> Result<Vec<Vec<f32>>, EmbeddingError> {
240 let mut inputs = vec![
241 (
242 self.input_names.input_ids.clone(),
243 Tensor::from_array(input_ids)
244 .map_err(|error| EmbeddingError::Ort(error.to_string()))?,
245 ),
246 (
247 self.input_names.attention_mask.clone(),
248 Tensor::from_array(attention_mask.clone())
249 .map_err(|error| EmbeddingError::Ort(error.to_string()))?,
250 ),
251 ];
252
253 if let (Some(input_name), Some(token_type_ids)) =
254 (self.input_names.token_type_ids.as_ref(), token_type_ids)
255 {
256 inputs.push((
257 input_name.clone(),
258 Tensor::from_array(token_type_ids)
259 .map_err(|error| EmbeddingError::Ort(error.to_string()))?,
260 ));
261 }
262
263 let outputs = self
264 .session
265 .run(inputs)
266 .map_err(|error| EmbeddingError::Ort(error.to_string()))?;
267 let output_value = match self.output_name.as_deref() {
268 Some(output_name) => &outputs[output_name],
269 None => {
270 if outputs.len() == 0 {
271 return Err(EmbeddingError::MissingOutput(
272 "model returned no outputs".to_string(),
273 ));
274 }
275
276 &outputs[0]
277 }
278 };
279
280 let output_array = match output_value.try_extract_array::<f32>() {
281 Ok(array) => array,
282 Err(error) => return Err(EmbeddingError::Ort(error.to_string())),
283 };
284
285 let embeddings = match output_array.ndim() {
286 2 => collect_sentence_embeddings(
287 output_array
288 .into_dimensionality::<Ix2>()
289 .map_err(|_| EmbeddingError::InvalidOutputShape(vec![]))?,
290 self.normalize,
291 ),
292 3 => mean_pool_embeddings(
293 output_array
294 .into_dimensionality::<Ix3>()
295 .map_err(|_| EmbeddingError::InvalidOutputShape(vec![]))?,
296 attention_mask.view(),
297 self.normalize,
298 )?,
299 _ => {
300 return Err(EmbeddingError::InvalidOutputShape(
301 output_array.shape().to_vec(),
302 ));
303 }
304 };
305
306 if let Some(expected_embedding_size) = self.expected_embedding_size {
307 for embedding in &embeddings {
308 if embedding.len() != expected_embedding_size {
309 return Err(EmbeddingError::EmbeddingDimensionMismatch {
310 expected: expected_embedding_size,
311 actual: embedding.len(),
312 });
313 }
314 }
315 }
316
317 Ok(embeddings)
318 }
319}
320
321fn chunk_text_with_tokenizer(
322 tokenizer: &Tokenizer,
323 text: &str,
324 max_length: usize,
325 overlap_tokens: usize,
326) -> Result<Vec<String>, EmbeddingError> {
327 if text.trim().is_empty() {
328 return Ok(Vec::new());
329 }
330
331 let max_content_tokens = max_length.saturating_sub(2).max(1);
332 let overlap_tokens = overlap_tokens.min(max_content_tokens.saturating_sub(1));
333 let encoding = tokenizer
334 .encode(text, false)
335 .map_err(EmbeddingError::Tokenizer)?;
336 let offsets = encoding
337 .get_offsets()
338 .iter()
339 .copied()
340 .filter(|(start, end)| start < end)
341 .collect::<Vec<_>>();
342
343 if offsets.is_empty() {
344 return Ok(Vec::new());
345 }
346
347 let mut chunks = Vec::new();
348 let mut start_token = 0;
349
350 while start_token < offsets.len() {
351 let end_token = (start_token + max_content_tokens).min(offsets.len());
352 let start_byte = offsets[start_token].0;
353 let end_byte = offsets[end_token - 1].1;
354 let chunk = text[start_byte..end_byte].trim();
355
356 if !chunk.is_empty() {
357 chunks.push(chunk.to_string());
358 }
359
360 if end_token >= offsets.len() {
361 break;
362 }
363
364 let next_start = end_token.saturating_sub(overlap_tokens);
365 start_token = if next_start <= start_token {
366 end_token
367 } else {
368 next_start
369 };
370 }
371
372 Ok(chunks)
373}
374
375impl EmbeddingsProvider for OrtEmbedder {
376 async fn embed_batch(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
377 self.embed_batch_shared(texts).await
378 }
379}
380
381impl OrtEmbedder {
382 pub async fn embed_batch_shared(
383 &self,
384 texts: &[String],
385 ) -> Result<Vec<Vec<f32>>, EmbeddingError> {
386 if texts.is_empty() {
387 return Ok(Vec::new());
388 }
389
390 let inner = Arc::clone(&self.inner);
391 let texts = texts.to_vec();
392 tokio::task::spawn_blocking(move || {
393 let mut inner = inner.lock().map_err(|error| {
394 EmbeddingError::State(format!("embedder state poisoned: {error}"))
395 })?;
396 let (input_ids, attention_mask, token_type_ids) = inner.encode_inputs(&texts)?;
397 inner.run_inference(input_ids, attention_mask, token_type_ids)
398 })
399 .await
400 .map_err(|error| EmbeddingError::BlockingTask(error.to_string()))?
401 }
402}
403
404#[derive(Debug)]
405pub enum EmbeddingError {
406 InvalidConfig(&'static str),
407 MissingAsset { asset: &'static str, path: PathBuf },
408 MissingModelInput(&'static str),
409 MissingOutput(String),
410 UnsupportedPooling(String),
411 InvalidOutputShape(Vec<usize>),
412 EmbeddingDimensionMismatch { expected: usize, actual: usize },
413 Hub(hf_hub::api::sync::ApiError),
414 Io(std::io::Error),
415 Json(serde_json::Error),
416 Ort(String),
417 State(String),
418 BlockingTask(String),
419 Tokenizer(tokenizers::Error),
420}
421
422impl std::fmt::Display for EmbeddingError {
423 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424 match self {
425 Self::InvalidConfig(message) => write!(f, "{message}"),
426 Self::MissingAsset { asset, path } => {
427 write!(f, "missing {asset} asset at {}", path.display())
428 }
429 Self::MissingModelInput(input_name) => {
430 write!(f, "model is missing required input `{input_name}`")
431 }
432 Self::MissingOutput(output_name) => write!(f, "model output not found: {output_name}"),
433 Self::UnsupportedPooling(message) => write!(f, "{message}"),
434 Self::InvalidOutputShape(shape) => {
435 write!(f, "unexpected model output shape: {shape:?}")
436 }
437 Self::EmbeddingDimensionMismatch { expected, actual } => write!(
438 f,
439 "embedding dimension mismatch: expected {expected}, got {actual}"
440 ),
441 Self::Hub(error) => write!(f, "{error}"),
442 Self::Io(error) => write!(f, "{error}"),
443 Self::Json(error) => write!(f, "{error}"),
444 Self::Ort(error) => write!(f, "{error}"),
445 Self::State(error) => write!(f, "{error}"),
446 Self::BlockingTask(error) => write!(f, "embedding task failed: {error}"),
447 Self::Tokenizer(error) => write!(f, "{error}"),
448 }
449 }
450}
451
452impl std::error::Error for EmbeddingError {}
453
454impl From<hf_hub::api::sync::ApiError> for EmbeddingError {
455 fn from(value: hf_hub::api::sync::ApiError) -> Self {
456 Self::Hub(value)
457 }
458}
459
460impl From<std::io::Error> for EmbeddingError {
461 fn from(value: std::io::Error) -> Self {
462 Self::Io(value)
463 }
464}
465
466impl From<serde_json::Error> for EmbeddingError {
467 fn from(value: serde_json::Error) -> Self {
468 Self::Json(value)
469 }
470}
471
472#[derive(Debug)]
473struct SessionInputNames {
474 input_ids: String,
475 attention_mask: String,
476 token_type_ids: Option<String>,
477}
478
479impl SessionInputNames {
480 fn from_session(
481 session: &Session,
482 input_ids_name: Option<&str>,
483 attention_mask_name: Option<&str>,
484 token_type_ids_name: Option<&str>,
485 ) -> Result<Self, EmbeddingError> {
486 let inputs = session.inputs();
487 let input_ids = resolve_required_name(inputs, input_ids_name, "input_ids")?;
488 let attention_mask = resolve_required_name(inputs, attention_mask_name, "attention_mask")?;
489 let token_type_ids = resolve_optional_name(inputs, token_type_ids_name, "token_type_ids")?;
490
491 Ok(Self {
492 input_ids,
493 attention_mask,
494 token_type_ids,
495 })
496 }
497}
498
499#[derive(Debug)]
500struct ModelAssets {
501 model_path: PathBuf,
502 tokenizer_path: PathBuf,
503 pooling_config_path: PathBuf,
504 transformer_config_path: PathBuf,
505}
506
507#[derive(Debug, Deserialize)]
508struct PoolingConfig {
509 #[serde(default)]
510 pooling_mode_cls_token: bool,
511 #[serde(default)]
512 pooling_mode_mean_tokens: bool,
513 #[serde(default)]
514 pooling_mode_max_tokens: bool,
515 #[serde(default)]
516 pooling_mode_mean_sqrt_len_tokens: bool,
517 #[serde(default)]
518 word_embedding_dimension: Option<usize>,
519}
520
521#[derive(Debug, Default, Deserialize)]
522struct TransformerConfig {
523 #[serde(default)]
524 hidden_size: Option<usize>,
525 #[serde(default)]
526 max_position_embeddings: Option<usize>,
527}
528
529fn resolve_model_assets(config: &EmbeddingsConfig) -> Result<ModelAssets, EmbeddingError> {
530 let use_hub = config.local_model_path.is_none()
531 || config.local_tokenizer_path.is_none()
532 || config.local_pooling_config_path.is_none()
533 || config.local_transformer_config_path.is_none();
534
535 let api = if use_hub {
536 let builder = match config.cache_dir.clone() {
537 Some(cache_dir) => ApiBuilder::new().with_cache_dir(cache_dir),
538 None => ApiBuilder::from_env(),
539 };
540 Some(builder.with_progress(false).build()?)
541 } else {
542 None
543 };
544
545 let repo = api.as_ref().map(|api| {
546 api.repo(Repo::with_revision(
547 config.model_repo.clone(),
548 RepoType::Model,
549 config.model_revision.clone(),
550 ))
551 });
552
553 Ok(ModelAssets {
554 model_path: resolve_asset_path(
555 config.local_model_path.as_deref(),
556 repo.as_ref(),
557 &config.model_file,
558 "model",
559 )?,
560 tokenizer_path: resolve_asset_path(
561 config.local_tokenizer_path.as_deref(),
562 repo.as_ref(),
563 &config.tokenizer_file,
564 "tokenizer",
565 )?,
566 pooling_config_path: resolve_asset_path(
567 config.local_pooling_config_path.as_deref(),
568 repo.as_ref(),
569 &config.pooling_config_file,
570 "pooling config",
571 )?,
572 transformer_config_path: resolve_asset_path(
573 config.local_transformer_config_path.as_deref(),
574 repo.as_ref(),
575 &config.transformer_config_file,
576 "transformer config",
577 )?,
578 })
579}
580
581fn resolve_asset_path(
582 local_path: Option<&Path>,
583 repo: Option<&hf_hub::api::sync::ApiRepo>,
584 remote_path: &str,
585 asset_name: &'static str,
586) -> Result<PathBuf, EmbeddingError> {
587 if let Some(local_path) = local_path {
588 return ensure_existing_path(local_path.to_path_buf(), asset_name);
589 }
590
591 let repo = repo.ok_or(EmbeddingError::InvalidConfig(
592 "remote model resolution requires a Hugging Face repository",
593 ))?;
594
595 let path = repo.get(remote_path)?;
596 ensure_existing_path(path, asset_name)
597}
598
599fn ensure_existing_path(
600 path: PathBuf,
601 asset_name: &'static str,
602) -> Result<PathBuf, EmbeddingError> {
603 if path.exists() {
604 Ok(path)
605 } else {
606 Err(EmbeddingError::MissingAsset {
607 asset: asset_name,
608 path,
609 })
610 }
611}
612
613fn read_json<T>(path: &Path) -> Result<T, EmbeddingError>
614where
615 T: for<'de> Deserialize<'de>,
616{
617 let contents = fs::read_to_string(path)?;
618 Ok(serde_json::from_str(&contents)?)
619}
620
621fn validate_pooling_config(pooling_config: &PoolingConfig) -> Result<(), EmbeddingError> {
622 if pooling_config.pooling_mode_mean_tokens
623 && !pooling_config.pooling_mode_cls_token
624 && !pooling_config.pooling_mode_max_tokens
625 && !pooling_config.pooling_mode_mean_sqrt_len_tokens
626 {
627 return Ok(());
628 }
629
630 Err(EmbeddingError::UnsupportedPooling(
631 "only mean-token pooling is currently supported".to_string(),
632 ))
633}
634
635fn load_tokenizer(path: &Path, max_length: usize) -> Result<Tokenizer, EmbeddingError> {
636 let mut tokenizer = Tokenizer::from_file(path).map_err(EmbeddingError::Tokenizer)?;
637 tokenizer
638 .with_truncation(Some(TruncationParams {
639 max_length,
640 ..Default::default()
641 }))
642 .map_err(EmbeddingError::Tokenizer)?;
643
644 let mut padding = tokenizer.get_padding().cloned().unwrap_or_default();
645 padding.strategy = PaddingStrategy::BatchLongest;
646 tokenizer.with_padding(Some(PaddingParams { ..padding }));
647
648 Ok(tokenizer)
649}
650
651fn load_session(path: &Path, intra_threads: Option<usize>) -> Result<Session, EmbeddingError> {
652 let builder = Session::builder().map_err(|error| EmbeddingError::Ort(error.to_string()))?;
653 let mut builder = if let Some(intra_threads) = intra_threads {
654 builder
655 .with_intra_threads(intra_threads)
656 .map_err(|error| EmbeddingError::Ort(error.to_string()))?
657 } else {
658 builder
659 };
660
661 builder
662 .commit_from_file(path)
663 .map_err(|error| EmbeddingError::Ort(error.to_string()))
664}
665
666fn resolve_required_name(
667 inputs: &[ort::value::Outlet],
668 configured_name: Option<&str>,
669 default_name: &'static str,
670) -> Result<String, EmbeddingError> {
671 if let Some(configured_name) = configured_name {
672 return inputs
673 .iter()
674 .find(|input| input.name() == configured_name)
675 .map(|input| input.name().to_string())
676 .ok_or(EmbeddingError::MissingModelInput(default_name));
677 }
678
679 inputs
680 .iter()
681 .find(|input| input.name() == default_name)
682 .map(|input| input.name().to_string())
683 .ok_or(EmbeddingError::MissingModelInput(default_name))
684}
685
686fn resolve_optional_name(
687 inputs: &[ort::value::Outlet],
688 configured_name: Option<&str>,
689 default_name: &'static str,
690) -> Result<Option<String>, EmbeddingError> {
691 if let Some(configured_name) = configured_name {
692 return inputs
693 .iter()
694 .find(|input| input.name() == configured_name)
695 .map(|input| Some(input.name().to_string()))
696 .ok_or(EmbeddingError::MissingModelInput(default_name));
697 }
698
699 Ok(inputs
700 .iter()
701 .find(|input| input.name() == default_name)
702 .map(|input| input.name().to_string()))
703}
704
705fn select_output_name(
706 session: &Session,
707 configured_name: Option<&str>,
708) -> Result<Option<String>, EmbeddingError> {
709 if let Some(configured_name) = configured_name {
710 return session
711 .outputs()
712 .iter()
713 .find(|output| output.name() == configured_name)
714 .map(|output| Some(output.name().to_string()))
715 .ok_or_else(|| EmbeddingError::MissingOutput(configured_name.to_string()));
716 }
717
718 Ok(session
719 .outputs()
720 .first()
721 .map(|output| output.name().to_string()))
722}
723
724fn mean_pool_embeddings(
725 token_embeddings: ArrayView3<'_, f32>,
726 attention_mask: ArrayView2<'_, i64>,
727 normalize: bool,
728) -> Result<Vec<Vec<f32>>, EmbeddingError> {
729 let batch_size = token_embeddings.len_of(Axis(0));
730 let sequence_length = token_embeddings.len_of(Axis(1));
731 let embedding_size = token_embeddings.len_of(Axis(2));
732
733 if attention_mask.shape() != [batch_size, sequence_length] {
734 return Err(EmbeddingError::InvalidOutputShape(vec![
735 batch_size,
736 sequence_length,
737 embedding_size,
738 ]));
739 }
740
741 let mut sentence_embeddings = Vec::with_capacity(batch_size);
742 for batch_index in 0..batch_size {
743 let mut pooled = vec![0.0_f32; embedding_size];
744 let mut token_count = 0.0_f32;
745
746 for token_index in 0..sequence_length {
747 let mask = attention_mask[(batch_index, token_index)] as f32;
748 if mask <= 0.0 {
749 continue;
750 }
751
752 token_count += mask;
753 for embedding_index in 0..embedding_size {
754 pooled[embedding_index] +=
755 token_embeddings[(batch_index, token_index, embedding_index)] * mask;
756 }
757 }
758
759 if token_count > 0.0 {
760 for value in &mut pooled {
761 *value /= token_count;
762 }
763 }
764
765 if normalize {
766 l2_normalize(&mut pooled);
767 }
768
769 sentence_embeddings.push(pooled);
770 }
771
772 Ok(sentence_embeddings)
773}
774
775fn collect_sentence_embeddings(embeddings: ArrayView2<'_, f32>, normalize: bool) -> Vec<Vec<f32>> {
776 embeddings
777 .axis_iter(Axis(0))
778 .map(|row| {
779 let mut embedding = row.to_vec();
780 if normalize {
781 l2_normalize(&mut embedding);
782 }
783 embedding
784 })
785 .collect()
786}
787
788fn l2_normalize(values: &mut [f32]) {
789 let norm = values.iter().map(|value| value * value).sum::<f32>().sqrt();
790 if norm > 0.0 {
791 for value in values {
792 *value /= norm;
793 }
794 }
795}
796
797#[cfg(test)]
798mod tests {
799 use super::{
800 DEFAULT_MAX_LENGTH, DEFAULT_MODEL_FILE, DEFAULT_MODEL_REPO, DEFAULT_MODEL_REVISION,
801 DEFAULT_POOLING_CONFIG_FILE, DEFAULT_TOKENIZER_FILE, DEFAULT_TRANSFORMER_CONFIG_FILE,
802 EmbeddingError, EmbeddingsConfig, TransformerConfig, chunk_text_with_tokenizer,
803 collect_sentence_embeddings, ensure_existing_path, mean_pool_embeddings, read_json,
804 resolve_asset_path,
805 };
806 use ahash::AHashMap;
807 use ndarray::{Array2, Array3, array};
808 use std::fs;
809 use std::path::PathBuf;
810 use tempfile::tempdir;
811 use tokenizers::Tokenizer;
812 use tokenizers::models::wordlevel::WordLevel;
813 use tokenizers::pre_tokenizers::whitespace::Whitespace;
814 use tokenizers::processors::bert::BertProcessing;
815
816 #[test]
817 fn uses_expected_default_embedding_config() {
818 let config = EmbeddingsConfig::default();
819
820 assert_eq!(config.model_repo, DEFAULT_MODEL_REPO);
821 assert_eq!(config.model_revision, DEFAULT_MODEL_REVISION);
822 assert_eq!(config.model_file, DEFAULT_MODEL_FILE);
823 assert_eq!(config.tokenizer_file, DEFAULT_TOKENIZER_FILE);
824 assert_eq!(config.pooling_config_file, DEFAULT_POOLING_CONFIG_FILE);
825 assert_eq!(
826 config.transformer_config_file,
827 DEFAULT_TRANSFORMER_CONFIG_FILE
828 );
829 assert_eq!(config.max_length, DEFAULT_MAX_LENGTH);
830 assert!(config.normalize);
831 assert!(config.cache_dir.is_none());
832 }
833
834 #[test]
835 fn prefers_local_asset_override_when_present() {
836 let temp_dir = tempdir().unwrap();
837 let model_path = temp_dir.path().join("model.onnx");
838 fs::write(&model_path, b"model").unwrap();
839
840 let resolved = resolve_asset_path(Some(&model_path), None, "ignored", "model").unwrap();
841
842 assert_eq!(resolved, model_path);
843 }
844
845 #[test]
846 fn rejects_missing_local_asset_override() {
847 let missing_path = PathBuf::from("/tmp/retrieval-kit-missing-model.onnx");
848
849 let error = ensure_existing_path(missing_path.clone(), "model").unwrap_err();
850
851 match error {
852 EmbeddingError::MissingAsset { asset, path } => {
853 assert_eq!(asset, "model");
854 assert_eq!(path, missing_path);
855 }
856 other => panic!("unexpected error: {other}"),
857 }
858 }
859
860 #[test]
861 fn mean_pooling_respects_attention_mask() {
862 let token_embeddings =
863 Array3::from_shape_vec((1, 3, 2), vec![1.0, 0.0, 3.0, 4.0, 100.0, 100.0]).unwrap();
864 let attention_mask = array![[1_i64, 1, 0]];
865
866 let embeddings =
867 mean_pool_embeddings(token_embeddings.view(), attention_mask.view(), false).unwrap();
868
869 assert_eq!(embeddings, vec![vec![2.0, 2.0]]);
870 }
871
872 #[test]
873 fn sentence_embeddings_are_normalized_when_requested() {
874 let embeddings = Array2::from_shape_vec((1, 2), vec![3.0_f32, 4.0]).unwrap();
875
876 let normalized = collect_sentence_embeddings(embeddings.view(), true);
877
878 assert!((normalized[0][0] - 0.6).abs() < 1e-6);
879 assert!((normalized[0][1] - 0.8).abs() < 1e-6);
880 }
881
882 #[test]
883 fn reads_transformer_config_from_local_file() {
884 let temp_dir = tempdir().unwrap();
885 let config_path = temp_dir.path().join("config.json");
886 fs::write(
887 &config_path,
888 r#"{"hidden_size":384,"max_position_embeddings":256}"#,
889 )
890 .unwrap();
891
892 let config: TransformerConfig = read_json(&config_path).unwrap();
893
894 assert_eq!(config.hidden_size, Some(384));
895 assert_eq!(config.max_position_embeddings, Some(256));
896 }
897
898 #[test]
899 fn tokenizer_fixture_saves_to_local_json() {
900 let temp_dir = tempdir().unwrap();
901 let tokenizer_path = temp_dir.path().join("tokenizer.json");
902
903 build_test_tokenizer().save(&tokenizer_path, false).unwrap();
904
905 assert!(tokenizer_path.exists());
906 }
907
908 #[test]
909 fn token_chunking_respects_model_length_with_overlap() {
910 let tokenizer = build_test_tokenizer();
911 let chunks =
912 chunk_text_with_tokenizer(&tokenizer, "hello world hello world hello world", 5, 1)
913 .unwrap();
914
915 assert_eq!(
916 chunks,
917 vec!["hello world hello", "hello world hello", "hello world"]
918 );
919 for chunk in chunks {
920 let encoding = tokenizer.encode(chunk.as_str(), true).unwrap();
921 assert!(encoding.len() <= 5);
922 }
923 }
924
925 fn build_test_tokenizer() -> Tokenizer {
926 let vocab = AHashMap::from_iter([
927 ("[UNK]".to_string(), 0),
928 ("[PAD]".to_string(), 1),
929 ("[CLS]".to_string(), 2),
930 ("[SEP]".to_string(), 3),
931 ("hello".to_string(), 4),
932 ("world".to_string(), 5),
933 ]);
934
935 let model = WordLevel::builder()
936 .vocab(vocab)
937 .unk_token("[UNK]".to_string())
938 .build()
939 .unwrap();
940 let mut tokenizer = Tokenizer::new(model);
941 tokenizer.with_pre_tokenizer(Some(Whitespace));
942 tokenizer.with_post_processor(Some(BertProcessing::new(
943 ("[SEP]".to_string(), 3),
944 ("[CLS]".to_string(), 2),
945 )));
946 tokenizer
947 }
948}