1use std::sync::Arc;
2
3pub use fastembed::EmbeddingModel as FastembedModel;
4use fastembed::{InitOptionsUserDefined, ModelInfo, TextEmbedding, UserDefinedEmbeddingModel};
5use rig::embeddings::{self, EmbeddingError};
6
7#[cfg(feature = "hf-hub")]
8use fastembed::InitOptions;
9#[cfg(feature = "hf-hub")]
10use rig::{Embed, embeddings::EmbeddingsBuilder};
11
12#[derive(Clone)]
16pub struct Client;
17
18impl Default for Client {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24impl Client {
25 pub fn new() -> Self {
27 Self
28 }
29
30 #[cfg(feature = "hf-hub")]
44 pub fn embedding_model(&self, model: &FastembedModel) -> EmbeddingModel {
45 let ndims = TextEmbedding::get_model_info(model).unwrap().dim;
46
47 EmbeddingModel::new(model, ndims)
48 }
49
50 #[cfg(feature = "hf-hub")]
67 pub fn embeddings<D: Embed>(
68 &self,
69 model: &fastembed::EmbeddingModel,
70 ) -> EmbeddingsBuilder<EmbeddingModel, D> {
71 EmbeddingsBuilder::new(self.embedding_model(model))
72 }
73}
74
75#[derive(Clone)]
76pub struct EmbeddingModel {
77 embedder: Arc<TextEmbedding>,
78 pub model: FastembedModel,
79 ndims: usize,
80}
81
82impl EmbeddingModel {
83 #[cfg(feature = "hf-hub")]
84 pub fn new(model: &fastembed::EmbeddingModel, ndims: usize) -> Self {
85 let embedder = Arc::new(
86 TextEmbedding::try_new(
87 InitOptions::new(model.to_owned()).with_show_download_progress(true),
88 )
89 .unwrap(),
90 );
91
92 Self {
93 embedder,
94 model: model.to_owned(),
95 ndims,
96 }
97 }
98
99 pub fn new_from_user_defined(
100 user_defined_model: UserDefinedEmbeddingModel,
101 ndims: usize,
102 model_info: &ModelInfo<FastembedModel>,
103 ) -> Self {
104 let fastembed_embedding_model = TextEmbedding::try_new_from_user_defined(
105 user_defined_model,
106 InitOptionsUserDefined::default(),
107 )
108 .unwrap();
109
110 let embedder = Arc::new(fastembed_embedding_model);
111
112 Self {
113 embedder,
114 model: model_info.model.to_owned(),
115 ndims,
116 }
117 }
118}
119
120impl embeddings::EmbeddingModel for EmbeddingModel {
121 const MAX_DOCUMENTS: usize = 1024;
122
123 type Client = Client;
124
125 fn make(_: &Self::Client, _: impl Into<String>, _: Option<usize>) -> Self {
127 panic!("Cannot create a fastembed model via `EmbeddingModel::make`")
128 }
129
130 fn ndims(&self) -> usize {
131 self.ndims
132 }
133
134 async fn embed_texts(
135 &self,
136 documents: impl IntoIterator<Item = String>,
137 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
138 let documents_as_strings: Vec<String> = documents.into_iter().collect();
139
140 let documents_as_vec = self
141 .embedder
142 .embed(documents_as_strings.clone(), None)
143 .map_err(|err| EmbeddingError::ProviderError(err.to_string()))?;
144
145 let docs = documents_as_strings
146 .into_iter()
147 .zip(documents_as_vec)
148 .map(|(document, embedding)| embeddings::Embedding {
149 document,
150 vec: embedding.into_iter().map(|f| f as f64).collect(),
151 })
152 .collect::<Vec<embeddings::Embedding>>();
153
154 Ok(docs)
155 }
156}