fastembed/text_embedding/
init.rs

1//! Initialization options for the text embedding models.
2//!
3
4use crate::{
5    common::TokenizerFiles,
6    init::{HasMaxLength, InitOptionsWithLength},
7    pooling::Pooling,
8    EmbeddingModel, OutputKey, QuantizationMode,
9};
10use ort::{execution_providers::ExecutionProviderDispatch, session::Session};
11use tokenizers::Tokenizer;
12
13use super::DEFAULT_MAX_LENGTH;
14
15impl HasMaxLength for EmbeddingModel {
16    const MAX_LENGTH: usize = DEFAULT_MAX_LENGTH;
17}
18
19/// Options for initializing the TextEmbedding model
20pub type TextInitOptions = InitOptionsWithLength<EmbeddingModel>;
21
22/// Options for initializing UserDefinedEmbeddingModel
23///
24/// Model files are held by the UserDefinedEmbeddingModel struct
25#[derive(Debug, Clone)]
26#[non_exhaustive]
27pub struct InitOptionsUserDefined {
28    pub execution_providers: Vec<ExecutionProviderDispatch>,
29    pub max_length: usize,
30}
31
32impl InitOptionsUserDefined {
33    pub fn new() -> Self {
34        Self {
35            ..Default::default()
36        }
37    }
38
39    pub fn with_execution_providers(
40        mut self,
41        execution_providers: Vec<ExecutionProviderDispatch>,
42    ) -> Self {
43        self.execution_providers = execution_providers;
44        self
45    }
46
47    pub fn with_max_length(mut self, max_length: usize) -> Self {
48        self.max_length = max_length;
49        self
50    }
51}
52
53impl Default for InitOptionsUserDefined {
54    fn default() -> Self {
55        Self {
56            execution_providers: Default::default(),
57            max_length: DEFAULT_MAX_LENGTH,
58        }
59    }
60}
61
62/// Convert InitOptions to InitOptionsUserDefined
63///
64/// This is useful for when the user wants to use the same options for both the default and user-defined models
65impl From<TextInitOptions> for InitOptionsUserDefined {
66    fn from(options: TextInitOptions) -> Self {
67        InitOptionsUserDefined {
68            execution_providers: options.execution_providers,
69            max_length: options.max_length,
70        }
71    }
72}
73
74/// Struct for "bring your own" embedding models
75///
76/// The onnx_file and tokenizer_files are expecting the files' bytes
77#[derive(Debug, Clone, PartialEq, Eq)]
78pub struct UserDefinedEmbeddingModel {
79    pub onnx_file: Vec<u8>,
80    pub tokenizer_files: TokenizerFiles,
81    pub pooling: Option<Pooling>,
82    pub quantization: QuantizationMode,
83    pub output_key: Option<OutputKey>,
84}
85
86impl UserDefinedEmbeddingModel {
87    pub fn new(onnx_file: Vec<u8>, tokenizer_files: TokenizerFiles) -> Self {
88        Self {
89            onnx_file,
90            tokenizer_files,
91            quantization: QuantizationMode::None,
92            pooling: None,
93            output_key: None,
94        }
95    }
96
97    pub fn with_quantization(mut self, quantization: QuantizationMode) -> Self {
98        self.quantization = quantization;
99        self
100    }
101
102    pub fn with_pooling(mut self, pooling: Pooling) -> Self {
103        self.pooling = Some(pooling);
104        self
105    }
106}
107
108/// Rust representation of the TextEmbedding model
109pub struct TextEmbedding {
110    pub tokenizer: Tokenizer,
111    pub(crate) pooling: Option<Pooling>,
112    pub(crate) session: Session,
113    pub(crate) need_token_type_ids: bool,
114    pub(crate) quantization: QuantizationMode,
115    pub(crate) output_key: Option<OutputKey>,
116}