1use std::sync::Arc;
2
3pub use fastembed::EmbeddingModel as FastembedModel;
4use fastembed::{
5 InitOptions, InitOptionsUserDefined, ModelInfo, TextEmbedding, UserDefinedEmbeddingModel,
6};
7use rig::{
8 Embed,
9 embeddings::{self, EmbeddingError, EmbeddingsBuilder},
10};
11
12#[derive(Clone)]
16pub struct Client;
17
18impl Default for Client {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24impl Client {
25 pub fn new() -> Self {
27 Self
28 }
29
30 pub fn embedding_model(&self, model: &FastembedModel) -> EmbeddingModel {
44 let ndims = TextEmbedding::get_model_info(model).unwrap().dim;
45
46 EmbeddingModel::new(model, ndims)
47 }
48
49 pub fn embeddings<D: Embed>(
66 &self,
67 model: &fastembed::EmbeddingModel,
68 ) -> EmbeddingsBuilder<EmbeddingModel, D> {
69 EmbeddingsBuilder::new(self.embedding_model(model))
70 }
71}
72
73#[derive(Clone)]
74pub struct EmbeddingModel {
75 embedder: Arc<TextEmbedding>,
76 pub model: FastembedModel,
77 ndims: usize,
78}
79
80impl EmbeddingModel {
81 pub fn new(model: &fastembed::EmbeddingModel, ndims: usize) -> Self {
82 let embedder = Arc::new(
83 TextEmbedding::try_new(
84 InitOptions::new(model.to_owned()).with_show_download_progress(true),
85 )
86 .unwrap(),
87 );
88
89 Self {
90 embedder,
91 model: model.to_owned(),
92 ndims,
93 }
94 }
95
96 pub fn new_from_user_defined(
97 user_defined_model: UserDefinedEmbeddingModel,
98 ndims: usize,
99 model_info: &ModelInfo<FastembedModel>,
100 ) -> Self {
101 let fastembed_embedding_model = TextEmbedding::try_new_from_user_defined(
102 user_defined_model,
103 InitOptionsUserDefined::default(),
104 )
105 .unwrap();
106
107 let embedder = Arc::new(fastembed_embedding_model);
108
109 Self {
110 embedder,
111 model: model_info.model.to_owned(),
112 ndims,
113 }
114 }
115}
116
117impl embeddings::EmbeddingModel for EmbeddingModel {
118 const MAX_DOCUMENTS: usize = 1024;
119
120 fn ndims(&self) -> usize {
121 self.ndims
122 }
123
124 async fn embed_texts(
125 &self,
126 documents: impl IntoIterator<Item = String>,
127 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
128 let documents_as_strings: Vec<String> = documents.into_iter().collect();
129
130 let documents_as_vec = self
131 .embedder
132 .embed(documents_as_strings.clone(), None)
133 .map_err(|err| EmbeddingError::ProviderError(err.to_string()))?;
134
135 let docs = documents_as_strings
136 .into_iter()
137 .zip(documents_as_vec)
138 .map(|(document, embedding)| embeddings::Embedding {
139 document,
140 vec: embedding.into_iter().map(|f| f as f64).collect(),
141 })
142 .collect::<Vec<embeddings::Embedding>>();
143
144 Ok(docs)
145 }
146}