1use std::sync::Arc;
2use std::{error::Error as StdError, fmt};
3
4pub use fastembed::EmbeddingModel as FastembedModel;
5use fastembed::{InitOptionsUserDefined, ModelInfo, TextEmbedding, UserDefinedEmbeddingModel};
6use rig::embeddings::{self, EmbeddingError};
7
8#[cfg(feature = "hf-hub")]
9use fastembed::InitOptions;
10#[cfg(feature = "hf-hub")]
11use rig::{Embed, embeddings::EmbeddingsBuilder};
12
13#[derive(Clone)]
17pub struct Client;
18
19#[derive(Debug, Clone)]
20pub enum FastembedError {
21 UnknownModel(FastembedModel),
22 Initialization(String),
23 UnsupportedMake,
24}
25
26impl fmt::Display for FastembedError {
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 match self {
29 FastembedError::UnknownModel(model) => {
30 write!(
31 f,
32 "Failed to resolve FastEmbed model metadata for {model:?}"
33 )
34 }
35 FastembedError::Initialization(message) => {
36 write!(f, "Failed to initialize FastEmbed model: {message}")
37 }
38 FastembedError::UnsupportedMake => write!(
39 f,
40 "`EmbeddingModel::make` is not supported for rig-fastembed; construct models via `Client::embedding_model` or `EmbeddingModel::new_from_user_defined`"
41 ),
42 }
43 }
44}
45
46impl StdError for FastembedError {}
47
48impl Default for Client {
49 fn default() -> Self {
50 Self::new()
51 }
52}
53
54impl Client {
55 pub fn new() -> Self {
57 Self
58 }
59
60 #[cfg(feature = "hf-hub")]
74 pub fn embedding_model(
75 &self,
76 model: &FastembedModel,
77 ) -> Result<EmbeddingModel, FastembedError> {
78 let ndims = TextEmbedding::get_model_info(model)
79 .map(|info| info.dim)
80 .map_err(|_| FastembedError::UnknownModel(model.clone()))?;
81
82 EmbeddingModel::new(model, ndims)
83 }
84
85 #[cfg(feature = "hf-hub")]
109 pub fn embeddings<D: Embed>(
110 &self,
111 model: &fastembed::EmbeddingModel,
112 ) -> Result<EmbeddingsBuilder<EmbeddingModel, D>, FastembedError> {
113 Ok(EmbeddingsBuilder::new(self.embedding_model(model)?))
114 }
115}
116
117#[derive(Clone)]
118pub struct EmbeddingModel {
119 embedder: Option<Arc<TextEmbedding>>,
120 init_error: Option<FastembedError>,
121 pub model: FastembedModel,
122 ndims: usize,
123}
124
125impl EmbeddingModel {
126 #[cfg(feature = "hf-hub")]
127 pub fn new(model: &fastembed::EmbeddingModel, ndims: usize) -> Result<Self, FastembedError> {
128 let embedder = Arc::new(
129 TextEmbedding::try_new(
130 InitOptions::new(model.to_owned()).with_show_download_progress(true),
131 )
132 .map_err(|err| FastembedError::Initialization(err.to_string()))?,
133 );
134
135 Ok(Self {
136 embedder: Some(embedder),
137 init_error: None,
138 model: model.to_owned(),
139 ndims,
140 })
141 }
142
143 pub fn new_from_user_defined(
144 user_defined_model: UserDefinedEmbeddingModel,
145 ndims: usize,
146 model_info: &ModelInfo<FastembedModel>,
147 ) -> Result<Self, FastembedError> {
148 let fastembed_embedding_model = TextEmbedding::try_new_from_user_defined(
149 user_defined_model,
150 InitOptionsUserDefined::default(),
151 )
152 .map_err(|err| FastembedError::Initialization(err.to_string()))?;
153
154 let embedder = Arc::new(fastembed_embedding_model);
155
156 Ok(Self {
157 embedder: Some(embedder),
158 init_error: None,
159 model: model_info.model.to_owned(),
160 ndims,
161 })
162 }
163}
164
165impl embeddings::EmbeddingModel for EmbeddingModel {
166 const MAX_DOCUMENTS: usize = 1024;
167
168 type Client = Client;
169
170 fn make(_: &Self::Client, _: impl Into<String>, _: Option<usize>) -> Self {
171 Self {
172 embedder: None,
173 init_error: Some(FastembedError::UnsupportedMake),
174 model: FastembedModel::AllMiniLML6V2Q,
175 ndims: 0,
176 }
177 }
178
179 fn ndims(&self) -> usize {
180 self.ndims
181 }
182
183 async fn embed_texts(
184 &self,
185 documents: impl IntoIterator<Item = String>,
186 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
187 let Some(embedder) = &self.embedder else {
188 let message = self
189 .init_error
190 .as_ref()
191 .map(ToString::to_string)
192 .unwrap_or_else(|| "FastEmbed model initialization failed".to_string());
193 return Err(EmbeddingError::ProviderError(message));
194 };
195
196 let documents_as_strings: Vec<String> = documents.into_iter().collect();
197
198 let documents_as_vec = embedder
199 .embed(documents_as_strings.clone(), None)
200 .map_err(|err| EmbeddingError::ProviderError(err.to_string()))?;
201
202 let docs = documents_as_strings
203 .into_iter()
204 .zip(documents_as_vec)
205 .map(|(document, embedding)| embeddings::Embedding {
206 document,
207 vec: embedding.into_iter().map(|f| f as f64).collect(),
208 })
209 .collect::<Vec<embeddings::Embedding>>();
210
211 Ok(docs)
212 }
213}