Skip to main content

retrieval_kit/
embeddings.rs

1use std::fs;
2use std::path::{Path, PathBuf};
3use std::sync::{Arc, Mutex};
4
5use hf_hub::api::sync::ApiBuilder;
6use hf_hub::{Repo, RepoType};
7use ndarray::{Array2, ArrayView2, ArrayView3, Axis, Ix2, Ix3};
8use ort::session::Session;
9use ort::value::Tensor;
10use serde::Deserialize;
11use tokenizers::Tokenizer;
12use tokenizers::tokenizer::{PaddingParams, PaddingStrategy, TruncationParams};
13
14const DEFAULT_MODEL_REPO: &str = "sentence-transformers/all-MiniLM-L12-v2";
15const DEFAULT_MODEL_REVISION: &str = "main";
16const DEFAULT_MODEL_FILE: &str = "onnx/model.onnx";
17const DEFAULT_TOKENIZER_FILE: &str = "tokenizer.json";
18const DEFAULT_POOLING_CONFIG_FILE: &str = "1_Pooling/config.json";
19const DEFAULT_TRANSFORMER_CONFIG_FILE: &str = "config.json";
20const DEFAULT_MAX_LENGTH: usize = 128;
21type EncodedInputs = (Array2<i64>, Array2<i64>, Option<Array2<i64>>);
22
23#[allow(async_fn_in_trait)]
24/// Provider interface for generating embeddings from text batches.
25pub trait EmbeddingsProvider {
26    async fn embed_batch(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError>;
27
28    async fn embed(&mut self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
29        let mut embeddings = self.embed_batch(&[text.to_owned()]).await?;
30        embeddings.pop().ok_or(EmbeddingError::MissingOutput(
31            "no embeddings returned".to_string(),
32        ))
33    }
34}
35
36#[derive(Clone, Debug, Eq, PartialEq)]
37/// Configuration for the built-in ONNX Runtime embedder.
38pub struct EmbeddingsConfig {
39    /// Hugging Face model repository used when local assets are not supplied.
40    pub model_repo: String,
41    /// Hugging Face model revision used when local assets are not supplied.
42    pub model_revision: String,
43    /// Model file path inside the Hugging Face repository.
44    pub model_file: String,
45    /// Tokenizer file path inside the Hugging Face repository.
46    pub tokenizer_file: String,
47    /// Pooling config file path inside the Hugging Face repository.
48    pub pooling_config_file: String,
49    /// Transformer config file path inside the Hugging Face repository.
50    pub transformer_config_file: String,
51    /// Maximum tokenizer sequence length, capped by model config when known.
52    pub max_length: usize,
53    /// Whether output embeddings should be L2-normalized.
54    pub normalize: bool,
55    /// Optional ONNX Runtime intra-op thread count.
56    pub intra_threads: Option<usize>,
57    /// Optional Hugging Face cache directory.
58    pub cache_dir: Option<PathBuf>,
59    /// Local ONNX model path. If unset, the model is resolved from Hugging Face.
60    pub local_model_path: Option<PathBuf>,
61    /// Local tokenizer path. If unset, the tokenizer is resolved from Hugging Face.
62    pub local_tokenizer_path: Option<PathBuf>,
63    /// Local sentence-transformers pooling config path.
64    pub local_pooling_config_path: Option<PathBuf>,
65    /// Local transformer config path.
66    pub local_transformer_config_path: Option<PathBuf>,
67    /// Override for models that use a non-standard input IDs name.
68    pub input_ids_name: Option<String>,
69    /// Override for models that use a non-standard attention mask name.
70    pub attention_mask_name: Option<String>,
71    /// Override for models that use a non-standard token type IDs name.
72    pub token_type_ids_name: Option<String>,
73    /// Optional output tensor name. Defaults to the first model output.
74    pub output_name: Option<String>,
75}
76
77impl Default for EmbeddingsConfig {
78    fn default() -> Self {
79        Self {
80            model_repo: DEFAULT_MODEL_REPO.to_string(),
81            model_revision: DEFAULT_MODEL_REVISION.to_string(),
82            model_file: DEFAULT_MODEL_FILE.to_string(),
83            tokenizer_file: DEFAULT_TOKENIZER_FILE.to_string(),
84            pooling_config_file: DEFAULT_POOLING_CONFIG_FILE.to_string(),
85            transformer_config_file: DEFAULT_TRANSFORMER_CONFIG_FILE.to_string(),
86            max_length: DEFAULT_MAX_LENGTH,
87            normalize: true,
88            intra_threads: None,
89            cache_dir: None,
90            local_model_path: None,
91            local_tokenizer_path: None,
92            local_pooling_config_path: None,
93            local_transformer_config_path: None,
94            input_ids_name: None,
95            attention_mask_name: None,
96            token_type_ids_name: None,
97            output_name: None,
98        }
99    }
100}
101
102#[derive(Debug)]
103/// ONNX Runtime sentence embedding provider.
104pub struct OrtEmbedder {
105    inner: Arc<Mutex<OrtEmbedderInner>>,
106    max_length: usize,
107}
108
109impl OrtEmbedder {
110    pub fn new(config: EmbeddingsConfig) -> Result<Self, EmbeddingError> {
111        let assets = resolve_model_assets(&config)?;
112        let pooling_config = read_json::<PoolingConfig>(&assets.pooling_config_path)?;
113        validate_pooling_config(&pooling_config)?;
114
115        let transformer_config = read_json::<TransformerConfig>(&assets.transformer_config_path)?;
116        let expected_embedding_size = pooling_config
117            .word_embedding_dimension
118            .or(transformer_config.hidden_size);
119        let max_length = transformer_config
120            .max_position_embeddings
121            .map(|value| value.min(config.max_length))
122            .unwrap_or(config.max_length);
123
124        let tokenizer = load_tokenizer(&assets.tokenizer_path, max_length)?;
125        let session = load_session(&assets.model_path, config.intra_threads)?;
126        let input_names = SessionInputNames::from_session(
127            &session,
128            config.input_ids_name.as_deref(),
129            config.attention_mask_name.as_deref(),
130            config.token_type_ids_name.as_deref(),
131        )?;
132        let output_name = select_output_name(&session, config.output_name.as_deref())?;
133
134        Ok(Self {
135            inner: Arc::new(Mutex::new(OrtEmbedderInner {
136                tokenizer,
137                session,
138                input_names,
139                output_name,
140                normalize: config.normalize,
141                expected_embedding_size,
142            })),
143            max_length,
144        })
145    }
146
147    pub fn max_length(&self) -> usize {
148        self.max_length
149    }
150
151    pub fn expected_embedding_size(&self) -> Option<usize> {
152        self.inner
153            .lock()
154            .ok()
155            .and_then(|inner| inner.expected_embedding_size)
156    }
157
158    pub fn chunk_text(
159        &self,
160        text: &str,
161        overlap_tokens: usize,
162    ) -> Result<Vec<String>, EmbeddingError> {
163        if text.trim().is_empty() {
164            return Ok(Vec::new());
165        }
166
167        let inner = self
168            .inner
169            .lock()
170            .map_err(|error| EmbeddingError::State(format!("embedder state poisoned: {error}")))?;
171        inner.chunk_text(text, self.max_length, overlap_tokens)
172    }
173}
174
175#[derive(Debug)]
176struct OrtEmbedderInner {
177    tokenizer: Tokenizer,
178    session: Session,
179    input_names: SessionInputNames,
180    output_name: Option<String>,
181    normalize: bool,
182    expected_embedding_size: Option<usize>,
183}
184
185impl OrtEmbedderInner {
186    fn chunk_text(
187        &self,
188        text: &str,
189        max_length: usize,
190        overlap_tokens: usize,
191    ) -> Result<Vec<String>, EmbeddingError> {
192        chunk_text_with_tokenizer(&self.tokenizer, text, max_length, overlap_tokens)
193    }
194
195    fn encode_inputs(&self, texts: &[String]) -> Result<EncodedInputs, EmbeddingError> {
196        let encodings = self
197            .tokenizer
198            .encode_batch(texts.iter().map(String::as_str).collect(), true)
199            .map_err(EmbeddingError::Tokenizer)?;
200
201        let batch_size = encodings.len();
202        let sequence_length = encodings
203            .first()
204            .map(|encoding| encoding.get_ids().len())
205            .unwrap_or(0);
206
207        let mut input_ids = Array2::<i64>::zeros((batch_size, sequence_length));
208        let mut attention_mask = Array2::<i64>::zeros((batch_size, sequence_length));
209        let mut token_type_ids = self
210            .input_names
211            .token_type_ids
212            .as_ref()
213            .map(|_| Array2::<i64>::zeros((batch_size, sequence_length)));
214
215        for (row_index, encoding) in encodings.iter().enumerate() {
216            for (column_index, token_id) in encoding.get_ids().iter().enumerate() {
217                input_ids[(row_index, column_index)] = i64::from(*token_id);
218            }
219
220            for (column_index, mask) in encoding.get_attention_mask().iter().enumerate() {
221                attention_mask[(row_index, column_index)] = i64::from(*mask);
222            }
223
224            if let Some(token_type_ids) = token_type_ids.as_mut() {
225                for (column_index, token_type_id) in encoding.get_type_ids().iter().enumerate() {
226                    token_type_ids[(row_index, column_index)] = i64::from(*token_type_id);
227                }
228            }
229        }
230
231        Ok((input_ids, attention_mask, token_type_ids))
232    }
233
234    fn run_inference(
235        &mut self,
236        input_ids: Array2<i64>,
237        attention_mask: Array2<i64>,
238        token_type_ids: Option<Array2<i64>>,
239    ) -> Result<Vec<Vec<f32>>, EmbeddingError> {
240        let mut inputs = vec![
241            (
242                self.input_names.input_ids.clone(),
243                Tensor::from_array(input_ids)
244                    .map_err(|error| EmbeddingError::Ort(error.to_string()))?,
245            ),
246            (
247                self.input_names.attention_mask.clone(),
248                Tensor::from_array(attention_mask.clone())
249                    .map_err(|error| EmbeddingError::Ort(error.to_string()))?,
250            ),
251        ];
252
253        if let (Some(input_name), Some(token_type_ids)) =
254            (self.input_names.token_type_ids.as_ref(), token_type_ids)
255        {
256            inputs.push((
257                input_name.clone(),
258                Tensor::from_array(token_type_ids)
259                    .map_err(|error| EmbeddingError::Ort(error.to_string()))?,
260            ));
261        }
262
263        let outputs = self
264            .session
265            .run(inputs)
266            .map_err(|error| EmbeddingError::Ort(error.to_string()))?;
267        let output_value = match self.output_name.as_deref() {
268            Some(output_name) => &outputs[output_name],
269            None => {
270                if outputs.len() == 0 {
271                    return Err(EmbeddingError::MissingOutput(
272                        "model returned no outputs".to_string(),
273                    ));
274                }
275
276                &outputs[0]
277            }
278        };
279
280        let output_array = match output_value.try_extract_array::<f32>() {
281            Ok(array) => array,
282            Err(error) => return Err(EmbeddingError::Ort(error.to_string())),
283        };
284
285        let embeddings = match output_array.ndim() {
286            2 => collect_sentence_embeddings(
287                output_array
288                    .into_dimensionality::<Ix2>()
289                    .map_err(|_| EmbeddingError::InvalidOutputShape(vec![]))?,
290                self.normalize,
291            ),
292            3 => mean_pool_embeddings(
293                output_array
294                    .into_dimensionality::<Ix3>()
295                    .map_err(|_| EmbeddingError::InvalidOutputShape(vec![]))?,
296                attention_mask.view(),
297                self.normalize,
298            )?,
299            _ => {
300                return Err(EmbeddingError::InvalidOutputShape(
301                    output_array.shape().to_vec(),
302                ));
303            }
304        };
305
306        if let Some(expected_embedding_size) = self.expected_embedding_size {
307            for embedding in &embeddings {
308                if embedding.len() != expected_embedding_size {
309                    return Err(EmbeddingError::EmbeddingDimensionMismatch {
310                        expected: expected_embedding_size,
311                        actual: embedding.len(),
312                    });
313                }
314            }
315        }
316
317        Ok(embeddings)
318    }
319}
320
321fn chunk_text_with_tokenizer(
322    tokenizer: &Tokenizer,
323    text: &str,
324    max_length: usize,
325    overlap_tokens: usize,
326) -> Result<Vec<String>, EmbeddingError> {
327    if text.trim().is_empty() {
328        return Ok(Vec::new());
329    }
330
331    let max_content_tokens = max_length.saturating_sub(2).max(1);
332    let overlap_tokens = overlap_tokens.min(max_content_tokens.saturating_sub(1));
333    let encoding = tokenizer
334        .encode(text, false)
335        .map_err(EmbeddingError::Tokenizer)?;
336    let offsets = encoding
337        .get_offsets()
338        .iter()
339        .copied()
340        .filter(|(start, end)| start < end)
341        .collect::<Vec<_>>();
342
343    if offsets.is_empty() {
344        return Ok(Vec::new());
345    }
346
347    let mut chunks = Vec::new();
348    let mut start_token = 0;
349
350    while start_token < offsets.len() {
351        let end_token = (start_token + max_content_tokens).min(offsets.len());
352        let start_byte = offsets[start_token].0;
353        let end_byte = offsets[end_token - 1].1;
354        let chunk = text[start_byte..end_byte].trim();
355
356        if !chunk.is_empty() {
357            chunks.push(chunk.to_string());
358        }
359
360        if end_token >= offsets.len() {
361            break;
362        }
363
364        let next_start = end_token.saturating_sub(overlap_tokens);
365        start_token = if next_start <= start_token {
366            end_token
367        } else {
368            next_start
369        };
370    }
371
372    Ok(chunks)
373}
374
375impl EmbeddingsProvider for OrtEmbedder {
376    async fn embed_batch(&mut self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
377        self.embed_batch_shared(texts).await
378    }
379}
380
381impl OrtEmbedder {
382    pub async fn embed_batch_shared(
383        &self,
384        texts: &[String],
385    ) -> Result<Vec<Vec<f32>>, EmbeddingError> {
386        if texts.is_empty() {
387            return Ok(Vec::new());
388        }
389
390        let inner = Arc::clone(&self.inner);
391        let texts = texts.to_vec();
392        tokio::task::spawn_blocking(move || {
393            let mut inner = inner.lock().map_err(|error| {
394                EmbeddingError::State(format!("embedder state poisoned: {error}"))
395            })?;
396            let (input_ids, attention_mask, token_type_ids) = inner.encode_inputs(&texts)?;
397            inner.run_inference(input_ids, attention_mask, token_type_ids)
398        })
399        .await
400        .map_err(|error| EmbeddingError::BlockingTask(error.to_string()))?
401    }
402}
403
404#[derive(Debug)]
405pub enum EmbeddingError {
406    InvalidConfig(&'static str),
407    MissingAsset { asset: &'static str, path: PathBuf },
408    MissingModelInput(&'static str),
409    MissingOutput(String),
410    UnsupportedPooling(String),
411    InvalidOutputShape(Vec<usize>),
412    EmbeddingDimensionMismatch { expected: usize, actual: usize },
413    Hub(hf_hub::api::sync::ApiError),
414    Io(std::io::Error),
415    Json(serde_json::Error),
416    Ort(String),
417    State(String),
418    BlockingTask(String),
419    Tokenizer(tokenizers::Error),
420}
421
422impl std::fmt::Display for EmbeddingError {
423    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424        match self {
425            Self::InvalidConfig(message) => write!(f, "{message}"),
426            Self::MissingAsset { asset, path } => {
427                write!(f, "missing {asset} asset at {}", path.display())
428            }
429            Self::MissingModelInput(input_name) => {
430                write!(f, "model is missing required input `{input_name}`")
431            }
432            Self::MissingOutput(output_name) => write!(f, "model output not found: {output_name}"),
433            Self::UnsupportedPooling(message) => write!(f, "{message}"),
434            Self::InvalidOutputShape(shape) => {
435                write!(f, "unexpected model output shape: {shape:?}")
436            }
437            Self::EmbeddingDimensionMismatch { expected, actual } => write!(
438                f,
439                "embedding dimension mismatch: expected {expected}, got {actual}"
440            ),
441            Self::Hub(error) => write!(f, "{error}"),
442            Self::Io(error) => write!(f, "{error}"),
443            Self::Json(error) => write!(f, "{error}"),
444            Self::Ort(error) => write!(f, "{error}"),
445            Self::State(error) => write!(f, "{error}"),
446            Self::BlockingTask(error) => write!(f, "embedding task failed: {error}"),
447            Self::Tokenizer(error) => write!(f, "{error}"),
448        }
449    }
450}
451
452impl std::error::Error for EmbeddingError {}
453
454impl From<hf_hub::api::sync::ApiError> for EmbeddingError {
455    fn from(value: hf_hub::api::sync::ApiError) -> Self {
456        Self::Hub(value)
457    }
458}
459
460impl From<std::io::Error> for EmbeddingError {
461    fn from(value: std::io::Error) -> Self {
462        Self::Io(value)
463    }
464}
465
466impl From<serde_json::Error> for EmbeddingError {
467    fn from(value: serde_json::Error) -> Self {
468        Self::Json(value)
469    }
470}
471
472#[derive(Debug)]
473struct SessionInputNames {
474    input_ids: String,
475    attention_mask: String,
476    token_type_ids: Option<String>,
477}
478
479impl SessionInputNames {
480    fn from_session(
481        session: &Session,
482        input_ids_name: Option<&str>,
483        attention_mask_name: Option<&str>,
484        token_type_ids_name: Option<&str>,
485    ) -> Result<Self, EmbeddingError> {
486        let inputs = session.inputs();
487        let input_ids = resolve_required_name(inputs, input_ids_name, "input_ids")?;
488        let attention_mask = resolve_required_name(inputs, attention_mask_name, "attention_mask")?;
489        let token_type_ids = resolve_optional_name(inputs, token_type_ids_name, "token_type_ids")?;
490
491        Ok(Self {
492            input_ids,
493            attention_mask,
494            token_type_ids,
495        })
496    }
497}
498
499#[derive(Debug)]
500struct ModelAssets {
501    model_path: PathBuf,
502    tokenizer_path: PathBuf,
503    pooling_config_path: PathBuf,
504    transformer_config_path: PathBuf,
505}
506
507#[derive(Debug, Deserialize)]
508struct PoolingConfig {
509    #[serde(default)]
510    pooling_mode_cls_token: bool,
511    #[serde(default)]
512    pooling_mode_mean_tokens: bool,
513    #[serde(default)]
514    pooling_mode_max_tokens: bool,
515    #[serde(default)]
516    pooling_mode_mean_sqrt_len_tokens: bool,
517    #[serde(default)]
518    word_embedding_dimension: Option<usize>,
519}
520
521#[derive(Debug, Default, Deserialize)]
522struct TransformerConfig {
523    #[serde(default)]
524    hidden_size: Option<usize>,
525    #[serde(default)]
526    max_position_embeddings: Option<usize>,
527}
528
529fn resolve_model_assets(config: &EmbeddingsConfig) -> Result<ModelAssets, EmbeddingError> {
530    let use_hub = config.local_model_path.is_none()
531        || config.local_tokenizer_path.is_none()
532        || config.local_pooling_config_path.is_none()
533        || config.local_transformer_config_path.is_none();
534
535    let api = if use_hub {
536        let builder = match config.cache_dir.clone() {
537            Some(cache_dir) => ApiBuilder::new().with_cache_dir(cache_dir),
538            None => ApiBuilder::from_env(),
539        };
540        Some(builder.with_progress(false).build()?)
541    } else {
542        None
543    };
544
545    let repo = api.as_ref().map(|api| {
546        api.repo(Repo::with_revision(
547            config.model_repo.clone(),
548            RepoType::Model,
549            config.model_revision.clone(),
550        ))
551    });
552
553    Ok(ModelAssets {
554        model_path: resolve_asset_path(
555            config.local_model_path.as_deref(),
556            repo.as_ref(),
557            &config.model_file,
558            "model",
559        )?,
560        tokenizer_path: resolve_asset_path(
561            config.local_tokenizer_path.as_deref(),
562            repo.as_ref(),
563            &config.tokenizer_file,
564            "tokenizer",
565        )?,
566        pooling_config_path: resolve_asset_path(
567            config.local_pooling_config_path.as_deref(),
568            repo.as_ref(),
569            &config.pooling_config_file,
570            "pooling config",
571        )?,
572        transformer_config_path: resolve_asset_path(
573            config.local_transformer_config_path.as_deref(),
574            repo.as_ref(),
575            &config.transformer_config_file,
576            "transformer config",
577        )?,
578    })
579}
580
581fn resolve_asset_path(
582    local_path: Option<&Path>,
583    repo: Option<&hf_hub::api::sync::ApiRepo>,
584    remote_path: &str,
585    asset_name: &'static str,
586) -> Result<PathBuf, EmbeddingError> {
587    if let Some(local_path) = local_path {
588        return ensure_existing_path(local_path.to_path_buf(), asset_name);
589    }
590
591    let repo = repo.ok_or(EmbeddingError::InvalidConfig(
592        "remote model resolution requires a Hugging Face repository",
593    ))?;
594
595    let path = repo.get(remote_path)?;
596    ensure_existing_path(path, asset_name)
597}
598
599fn ensure_existing_path(
600    path: PathBuf,
601    asset_name: &'static str,
602) -> Result<PathBuf, EmbeddingError> {
603    if path.exists() {
604        Ok(path)
605    } else {
606        Err(EmbeddingError::MissingAsset {
607            asset: asset_name,
608            path,
609        })
610    }
611}
612
613fn read_json<T>(path: &Path) -> Result<T, EmbeddingError>
614where
615    T: for<'de> Deserialize<'de>,
616{
617    let contents = fs::read_to_string(path)?;
618    Ok(serde_json::from_str(&contents)?)
619}
620
621fn validate_pooling_config(pooling_config: &PoolingConfig) -> Result<(), EmbeddingError> {
622    if pooling_config.pooling_mode_mean_tokens
623        && !pooling_config.pooling_mode_cls_token
624        && !pooling_config.pooling_mode_max_tokens
625        && !pooling_config.pooling_mode_mean_sqrt_len_tokens
626    {
627        return Ok(());
628    }
629
630    Err(EmbeddingError::UnsupportedPooling(
631        "only mean-token pooling is currently supported".to_string(),
632    ))
633}
634
635fn load_tokenizer(path: &Path, max_length: usize) -> Result<Tokenizer, EmbeddingError> {
636    let mut tokenizer = Tokenizer::from_file(path).map_err(EmbeddingError::Tokenizer)?;
637    tokenizer
638        .with_truncation(Some(TruncationParams {
639            max_length,
640            ..Default::default()
641        }))
642        .map_err(EmbeddingError::Tokenizer)?;
643
644    let mut padding = tokenizer.get_padding().cloned().unwrap_or_default();
645    padding.strategy = PaddingStrategy::BatchLongest;
646    tokenizer.with_padding(Some(PaddingParams { ..padding }));
647
648    Ok(tokenizer)
649}
650
651fn load_session(path: &Path, intra_threads: Option<usize>) -> Result<Session, EmbeddingError> {
652    let builder = Session::builder().map_err(|error| EmbeddingError::Ort(error.to_string()))?;
653    let mut builder = if let Some(intra_threads) = intra_threads {
654        builder
655            .with_intra_threads(intra_threads)
656            .map_err(|error| EmbeddingError::Ort(error.to_string()))?
657    } else {
658        builder
659    };
660
661    builder
662        .commit_from_file(path)
663        .map_err(|error| EmbeddingError::Ort(error.to_string()))
664}
665
666fn resolve_required_name(
667    inputs: &[ort::value::Outlet],
668    configured_name: Option<&str>,
669    default_name: &'static str,
670) -> Result<String, EmbeddingError> {
671    if let Some(configured_name) = configured_name {
672        return inputs
673            .iter()
674            .find(|input| input.name() == configured_name)
675            .map(|input| input.name().to_string())
676            .ok_or(EmbeddingError::MissingModelInput(default_name));
677    }
678
679    inputs
680        .iter()
681        .find(|input| input.name() == default_name)
682        .map(|input| input.name().to_string())
683        .ok_or(EmbeddingError::MissingModelInput(default_name))
684}
685
686fn resolve_optional_name(
687    inputs: &[ort::value::Outlet],
688    configured_name: Option<&str>,
689    default_name: &'static str,
690) -> Result<Option<String>, EmbeddingError> {
691    if let Some(configured_name) = configured_name {
692        return inputs
693            .iter()
694            .find(|input| input.name() == configured_name)
695            .map(|input| Some(input.name().to_string()))
696            .ok_or(EmbeddingError::MissingModelInput(default_name));
697    }
698
699    Ok(inputs
700        .iter()
701        .find(|input| input.name() == default_name)
702        .map(|input| input.name().to_string()))
703}
704
705fn select_output_name(
706    session: &Session,
707    configured_name: Option<&str>,
708) -> Result<Option<String>, EmbeddingError> {
709    if let Some(configured_name) = configured_name {
710        return session
711            .outputs()
712            .iter()
713            .find(|output| output.name() == configured_name)
714            .map(|output| Some(output.name().to_string()))
715            .ok_or_else(|| EmbeddingError::MissingOutput(configured_name.to_string()));
716    }
717
718    Ok(session
719        .outputs()
720        .first()
721        .map(|output| output.name().to_string()))
722}
723
724fn mean_pool_embeddings(
725    token_embeddings: ArrayView3<'_, f32>,
726    attention_mask: ArrayView2<'_, i64>,
727    normalize: bool,
728) -> Result<Vec<Vec<f32>>, EmbeddingError> {
729    let batch_size = token_embeddings.len_of(Axis(0));
730    let sequence_length = token_embeddings.len_of(Axis(1));
731    let embedding_size = token_embeddings.len_of(Axis(2));
732
733    if attention_mask.shape() != [batch_size, sequence_length] {
734        return Err(EmbeddingError::InvalidOutputShape(vec![
735            batch_size,
736            sequence_length,
737            embedding_size,
738        ]));
739    }
740
741    let mut sentence_embeddings = Vec::with_capacity(batch_size);
742    for batch_index in 0..batch_size {
743        let mut pooled = vec![0.0_f32; embedding_size];
744        let mut token_count = 0.0_f32;
745
746        for token_index in 0..sequence_length {
747            let mask = attention_mask[(batch_index, token_index)] as f32;
748            if mask <= 0.0 {
749                continue;
750            }
751
752            token_count += mask;
753            for embedding_index in 0..embedding_size {
754                pooled[embedding_index] +=
755                    token_embeddings[(batch_index, token_index, embedding_index)] * mask;
756            }
757        }
758
759        if token_count > 0.0 {
760            for value in &mut pooled {
761                *value /= token_count;
762            }
763        }
764
765        if normalize {
766            l2_normalize(&mut pooled);
767        }
768
769        sentence_embeddings.push(pooled);
770    }
771
772    Ok(sentence_embeddings)
773}
774
775fn collect_sentence_embeddings(embeddings: ArrayView2<'_, f32>, normalize: bool) -> Vec<Vec<f32>> {
776    embeddings
777        .axis_iter(Axis(0))
778        .map(|row| {
779            let mut embedding = row.to_vec();
780            if normalize {
781                l2_normalize(&mut embedding);
782            }
783            embedding
784        })
785        .collect()
786}
787
788fn l2_normalize(values: &mut [f32]) {
789    let norm = values.iter().map(|value| value * value).sum::<f32>().sqrt();
790    if norm > 0.0 {
791        for value in values {
792            *value /= norm;
793        }
794    }
795}
796
797#[cfg(test)]
798mod tests {
799    use super::{
800        DEFAULT_MAX_LENGTH, DEFAULT_MODEL_FILE, DEFAULT_MODEL_REPO, DEFAULT_MODEL_REVISION,
801        DEFAULT_POOLING_CONFIG_FILE, DEFAULT_TOKENIZER_FILE, DEFAULT_TRANSFORMER_CONFIG_FILE,
802        EmbeddingError, EmbeddingsConfig, TransformerConfig, chunk_text_with_tokenizer,
803        collect_sentence_embeddings, ensure_existing_path, mean_pool_embeddings, read_json,
804        resolve_asset_path,
805    };
806    use ahash::AHashMap;
807    use ndarray::{Array2, Array3, array};
808    use std::fs;
809    use std::path::PathBuf;
810    use tempfile::tempdir;
811    use tokenizers::Tokenizer;
812    use tokenizers::models::wordlevel::WordLevel;
813    use tokenizers::pre_tokenizers::whitespace::Whitespace;
814    use tokenizers::processors::bert::BertProcessing;
815
816    #[test]
817    fn uses_expected_default_embedding_config() {
818        let config = EmbeddingsConfig::default();
819
820        assert_eq!(config.model_repo, DEFAULT_MODEL_REPO);
821        assert_eq!(config.model_revision, DEFAULT_MODEL_REVISION);
822        assert_eq!(config.model_file, DEFAULT_MODEL_FILE);
823        assert_eq!(config.tokenizer_file, DEFAULT_TOKENIZER_FILE);
824        assert_eq!(config.pooling_config_file, DEFAULT_POOLING_CONFIG_FILE);
825        assert_eq!(
826            config.transformer_config_file,
827            DEFAULT_TRANSFORMER_CONFIG_FILE
828        );
829        assert_eq!(config.max_length, DEFAULT_MAX_LENGTH);
830        assert!(config.normalize);
831        assert!(config.cache_dir.is_none());
832    }
833
834    #[test]
835    fn prefers_local_asset_override_when_present() {
836        let temp_dir = tempdir().unwrap();
837        let model_path = temp_dir.path().join("model.onnx");
838        fs::write(&model_path, b"model").unwrap();
839
840        let resolved = resolve_asset_path(Some(&model_path), None, "ignored", "model").unwrap();
841
842        assert_eq!(resolved, model_path);
843    }
844
845    #[test]
846    fn rejects_missing_local_asset_override() {
847        let missing_path = PathBuf::from("/tmp/retrieval-kit-missing-model.onnx");
848
849        let error = ensure_existing_path(missing_path.clone(), "model").unwrap_err();
850
851        match error {
852            EmbeddingError::MissingAsset { asset, path } => {
853                assert_eq!(asset, "model");
854                assert_eq!(path, missing_path);
855            }
856            other => panic!("unexpected error: {other}"),
857        }
858    }
859
860    #[test]
861    fn mean_pooling_respects_attention_mask() {
862        let token_embeddings =
863            Array3::from_shape_vec((1, 3, 2), vec![1.0, 0.0, 3.0, 4.0, 100.0, 100.0]).unwrap();
864        let attention_mask = array![[1_i64, 1, 0]];
865
866        let embeddings =
867            mean_pool_embeddings(token_embeddings.view(), attention_mask.view(), false).unwrap();
868
869        assert_eq!(embeddings, vec![vec![2.0, 2.0]]);
870    }
871
872    #[test]
873    fn sentence_embeddings_are_normalized_when_requested() {
874        let embeddings = Array2::from_shape_vec((1, 2), vec![3.0_f32, 4.0]).unwrap();
875
876        let normalized = collect_sentence_embeddings(embeddings.view(), true);
877
878        assert!((normalized[0][0] - 0.6).abs() < 1e-6);
879        assert!((normalized[0][1] - 0.8).abs() < 1e-6);
880    }
881
882    #[test]
883    fn reads_transformer_config_from_local_file() {
884        let temp_dir = tempdir().unwrap();
885        let config_path = temp_dir.path().join("config.json");
886        fs::write(
887            &config_path,
888            r#"{"hidden_size":384,"max_position_embeddings":256}"#,
889        )
890        .unwrap();
891
892        let config: TransformerConfig = read_json(&config_path).unwrap();
893
894        assert_eq!(config.hidden_size, Some(384));
895        assert_eq!(config.max_position_embeddings, Some(256));
896    }
897
898    #[test]
899    fn tokenizer_fixture_saves_to_local_json() {
900        let temp_dir = tempdir().unwrap();
901        let tokenizer_path = temp_dir.path().join("tokenizer.json");
902
903        build_test_tokenizer().save(&tokenizer_path, false).unwrap();
904
905        assert!(tokenizer_path.exists());
906    }
907
908    #[test]
909    fn token_chunking_respects_model_length_with_overlap() {
910        let tokenizer = build_test_tokenizer();
911        let chunks =
912            chunk_text_with_tokenizer(&tokenizer, "hello world hello world hello world", 5, 1)
913                .unwrap();
914
915        assert_eq!(
916            chunks,
917            vec!["hello world hello", "hello world hello", "hello world"]
918        );
919        for chunk in chunks {
920            let encoding = tokenizer.encode(chunk.as_str(), true).unwrap();
921            assert!(encoding.len() <= 5);
922        }
923    }
924
925    fn build_test_tokenizer() -> Tokenizer {
926        let vocab = AHashMap::from_iter([
927            ("[UNK]".to_string(), 0),
928            ("[PAD]".to_string(), 1),
929            ("[CLS]".to_string(), 2),
930            ("[SEP]".to_string(), 3),
931            ("hello".to_string(), 4),
932            ("world".to_string(), 5),
933        ]);
934
935        let model = WordLevel::builder()
936            .vocab(vocab)
937            .unk_token("[UNK]".to_string())
938            .build()
939            .unwrap();
940        let mut tokenizer = Tokenizer::new(model);
941        tokenizer.with_pre_tokenizer(Some(Whitespace));
942        tokenizer.with_post_processor(Some(BertProcessing::new(
943            ("[SEP]".to_string(), 3),
944            ("[CLS]".to_string(), 2),
945        )));
946        tokenizer
947    }
948}