1use std::sync::Arc;
12use std::{error::Error as StdError, fmt};
13
14pub use fastembed::EmbeddingModel as FastembedModel;
15use fastembed::{InitOptionsUserDefined, ModelInfo, TextEmbedding, UserDefinedEmbeddingModel};
16use rig_core::embeddings::{self, EmbeddingError};
17
18#[cfg(feature = "hf-hub")]
19use fastembed::InitOptions;
20#[cfg(feature = "hf-hub")]
21use rig_core::{Embed, embeddings::EmbeddingsBuilder};
22
23#[derive(Clone)]
27pub struct Client;
28
29#[derive(Debug, Clone)]
30pub enum FastembedError {
31 UnknownModel(FastembedModel),
32 Initialization(String),
33 UnsupportedMake,
34}
35
36impl fmt::Display for FastembedError {
37 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38 match self {
39 FastembedError::UnknownModel(model) => {
40 write!(
41 f,
42 "Failed to resolve FastEmbed model metadata for {model:?}"
43 )
44 }
45 FastembedError::Initialization(message) => {
46 write!(f, "Failed to initialize FastEmbed model: {message}")
47 }
48 FastembedError::UnsupportedMake => write!(
49 f,
50 "`EmbeddingModel::make` is not supported for rig-fastembed; construct models via `Client::embedding_model` or `EmbeddingModel::new_from_user_defined`"
51 ),
52 }
53 }
54}
55
56impl StdError for FastembedError {}
57
58impl Default for Client {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64impl Client {
65 pub fn new() -> Self {
67 Self
68 }
69
70 #[cfg(feature = "hf-hub")]
84 pub fn embedding_model(
85 &self,
86 model: &FastembedModel,
87 ) -> Result<EmbeddingModel, FastembedError> {
88 let ndims = TextEmbedding::get_model_info(model)
89 .map(|info| info.dim)
90 .map_err(|_| FastembedError::UnknownModel(model.clone()))?;
91
92 EmbeddingModel::new(model, ndims)
93 }
94
95 #[cfg(feature = "hf-hub")]
119 pub fn embeddings<D: Embed>(
120 &self,
121 model: &fastembed::EmbeddingModel,
122 ) -> Result<EmbeddingsBuilder<EmbeddingModel, D>, FastembedError> {
123 Ok(EmbeddingsBuilder::new(self.embedding_model(model)?))
124 }
125}
126
127#[derive(Clone)]
128pub struct EmbeddingModel {
129 embedder: Option<Arc<TextEmbedding>>,
130 init_error: Option<FastembedError>,
131 pub model: FastembedModel,
132 ndims: usize,
133}
134
135impl EmbeddingModel {
136 #[cfg(feature = "hf-hub")]
137 pub fn new(model: &fastembed::EmbeddingModel, ndims: usize) -> Result<Self, FastembedError> {
138 let embedder = Arc::new(
139 TextEmbedding::try_new(
140 InitOptions::new(model.to_owned()).with_show_download_progress(true),
141 )
142 .map_err(|err| FastembedError::Initialization(err.to_string()))?,
143 );
144
145 Ok(Self {
146 embedder: Some(embedder),
147 init_error: None,
148 model: model.to_owned(),
149 ndims,
150 })
151 }
152
153 pub fn new_from_user_defined(
154 user_defined_model: UserDefinedEmbeddingModel,
155 ndims: usize,
156 model_info: &ModelInfo<FastembedModel>,
157 ) -> Result<Self, FastembedError> {
158 let fastembed_embedding_model = TextEmbedding::try_new_from_user_defined(
159 user_defined_model,
160 InitOptionsUserDefined::default(),
161 )
162 .map_err(|err| FastembedError::Initialization(err.to_string()))?;
163
164 let embedder = Arc::new(fastembed_embedding_model);
165
166 Ok(Self {
167 embedder: Some(embedder),
168 init_error: None,
169 model: model_info.model.to_owned(),
170 ndims,
171 })
172 }
173}
174
175impl embeddings::EmbeddingModel for EmbeddingModel {
176 const MAX_DOCUMENTS: usize = 1024;
177
178 type Client = Client;
179
180 fn make(_: &Self::Client, _: impl Into<String>, _: Option<usize>) -> Self {
181 Self {
182 embedder: None,
183 init_error: Some(FastembedError::UnsupportedMake),
184 model: FastembedModel::AllMiniLML6V2Q,
185 ndims: 0,
186 }
187 }
188
189 fn ndims(&self) -> usize {
190 self.ndims
191 }
192
193 async fn embed_texts(
194 &self,
195 documents: impl IntoIterator<Item = String>,
196 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
197 let Some(embedder) = &self.embedder else {
198 let message = self
199 .init_error
200 .as_ref()
201 .map(ToString::to_string)
202 .unwrap_or_else(|| "FastEmbed model initialization failed".to_string());
203 return Err(EmbeddingError::ProviderError(message));
204 };
205
206 let documents_as_strings: Vec<String> = documents.into_iter().collect();
207
208 let documents_as_vec = embedder
209 .embed(documents_as_strings.clone(), None)
210 .map_err(|err| EmbeddingError::ProviderError(err.to_string()))?;
211
212 let docs = documents_as_strings
213 .into_iter()
214 .zip(documents_as_vec)
215 .map(|(document, embedding)| embeddings::Embedding {
216 document,
217 vec: embedding.into_iter().map(|f| f as f64).collect(),
218 })
219 .collect::<Vec<embeddings::Embedding>>();
220
221 Ok(docs)
222 }
223}