redis_vl/vectorizers/
mod.rs1use std::sync::Arc;
4
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7
8use crate::error::Result;
9
10#[cfg(feature = "anthropic")]
11mod anthropic;
12#[cfg(feature = "anthropic")]
13pub use self::anthropic::{AnthropicConfig, AnthropicTextVectorizer};
14
15#[cfg(feature = "azure-openai")]
16mod azure_openai;
17#[cfg(feature = "azure-openai")]
18pub use azure_openai::{AzureOpenAIConfig, AzureOpenAITextVectorizer};
19
20#[cfg(feature = "bedrock")]
21mod bedrock;
22#[cfg(feature = "bedrock")]
23pub use self::bedrock::{BedrockConfig, BedrockTextVectorizer};
24
25#[cfg(feature = "cohere")]
26mod cohere;
27#[cfg(feature = "cohere")]
28pub use self::cohere::{CohereConfig, CohereTextVectorizer};
29
30#[cfg(feature = "hf-local")]
31mod hf_local;
32#[cfg(feature = "hf-local")]
33pub use self::hf_local::{HuggingFaceConfig, HuggingFaceTextVectorizer};
34
35#[cfg(feature = "mistral")]
36mod mistral;
37#[cfg(feature = "mistral")]
38pub use self::mistral::{MistralAITextVectorizer, MistralConfig};
39
40#[cfg(feature = "voyageai")]
41mod voyageai;
42#[cfg(feature = "voyageai")]
43pub use self::voyageai::{VoyageAIConfig, VoyageAITextVectorizer};
44
45#[cfg(feature = "vertex-ai")]
46mod vertex_ai;
47#[cfg(feature = "vertex-ai")]
48pub use self::vertex_ai::{VertexAIConfig, VertexAITextVectorizer};
49
50#[derive(Debug, Clone, Serialize)]
52pub struct EmbeddingRequest<'a> {
53 pub model: &'a str,
55 pub input: Vec<&'a str>,
57}
58
59pub trait Vectorizer: Send + Sync {
61 fn embed(&self, text: &str) -> Result<Vec<f32>>;
63
64 fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
66 texts.iter().map(|text| self.embed(text)).collect()
67 }
68}
69
70#[async_trait]
72pub trait AsyncVectorizer: Send + Sync {
73 async fn embed(&self, text: &str) -> Result<Vec<f32>>;
75
76 async fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
78 let mut embeddings = Vec::with_capacity(texts.len());
79 for text in texts {
80 embeddings.push(self.embed(text).await?);
81 }
82 Ok(embeddings)
83 }
84}
85
86#[derive(Debug, Clone)]
88pub struct OpenAICompatibleConfig {
89 pub base_url: url::Url,
91 pub api_key: String,
93 pub model: String,
95}
96
97impl OpenAICompatibleConfig {
98 pub fn new(
100 base_url: impl AsRef<str>,
101 api_key: impl Into<String>,
102 model: impl Into<String>,
103 ) -> Result<Self> {
104 Ok(Self {
105 base_url: url::Url::parse(base_url.as_ref())?,
106 api_key: api_key.into(),
107 model: model.into(),
108 })
109 }
110
111 fn embeddings_url(&self) -> Result<url::Url> {
112 Ok(self.base_url.join("embeddings")?)
113 }
114}
115
116#[derive(Debug, Clone)]
118pub struct OpenAITextVectorizer {
119 config: OpenAICompatibleConfig,
120 client: reqwest::Client,
121 blocking_client: reqwest::blocking::Client,
122}
123
124impl OpenAITextVectorizer {
125 pub fn new(config: OpenAICompatibleConfig) -> Self {
127 Self {
128 config,
129 client: reqwest::Client::new(),
130 blocking_client: reqwest::blocking::Client::new(),
131 }
132 }
133
134 async fn embed_many_inner(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
135 let response: EmbeddingResponse = self
136 .client
137 .post(self.config.embeddings_url()?)
138 .bearer_auth(&self.config.api_key)
139 .json(&EmbeddingRequest {
140 model: &self.config.model,
141 input: texts.to_vec(),
142 })
143 .send()
144 .await?
145 .error_for_status()?
146 .json()
147 .await?;
148 Ok(response
149 .data
150 .into_iter()
151 .map(|item| item.embedding)
152 .collect())
153 }
154}
155
156impl Vectorizer for OpenAITextVectorizer {
157 fn embed(&self, text: &str) -> Result<Vec<f32>> {
158 let response: EmbeddingResponse = self
159 .blocking_client
160 .post(self.config.embeddings_url()?)
161 .bearer_auth(&self.config.api_key)
162 .json(&EmbeddingRequest {
163 model: &self.config.model,
164 input: vec![text],
165 })
166 .send()?
167 .error_for_status()?
168 .json()?;
169 Ok(response
170 .data
171 .into_iter()
172 .next()
173 .map_or_else(Vec::new, |item| item.embedding))
174 }
175
176 fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
177 let response: EmbeddingResponse = self
178 .blocking_client
179 .post(self.config.embeddings_url()?)
180 .bearer_auth(&self.config.api_key)
181 .json(&EmbeddingRequest {
182 model: &self.config.model,
183 input: texts.to_vec(),
184 })
185 .send()?
186 .error_for_status()?
187 .json()?;
188 Ok(response
189 .data
190 .into_iter()
191 .map(|item| item.embedding)
192 .collect())
193 }
194}
195
196#[async_trait]
197impl AsyncVectorizer for OpenAITextVectorizer {
198 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
199 let mut embeddings = self.embed_many_inner(&[text]).await?;
200 Ok(embeddings.pop().unwrap_or_default())
201 }
202
203 async fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
204 self.embed_many_inner(texts).await
205 }
206}
207
208#[derive(Debug, Clone)]
210pub struct LiteLLMTextVectorizer {
211 inner: OpenAITextVectorizer,
212}
213
214impl LiteLLMTextVectorizer {
215 pub fn new(config: OpenAICompatibleConfig) -> Self {
217 Self {
218 inner: OpenAITextVectorizer::new(config),
219 }
220 }
221}
222
223impl Vectorizer for LiteLLMTextVectorizer {
224 fn embed(&self, text: &str) -> Result<Vec<f32>> {
225 Vectorizer::embed(&self.inner, text)
226 }
227
228 fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
229 Vectorizer::embed_many(&self.inner, texts)
230 }
231}
232
233#[async_trait]
234impl AsyncVectorizer for LiteLLMTextVectorizer {
235 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
236 AsyncVectorizer::embed(&self.inner, text).await
237 }
238
239 async fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
240 AsyncVectorizer::embed_many(&self.inner, texts).await
241 }
242}
243
244pub struct CustomTextVectorizer<F>
246where
247 F: Fn(&str) -> Result<Vec<f32>> + Send + Sync + 'static,
248{
249 embedder: Arc<F>,
250}
251
252impl<F> CustomTextVectorizer<F>
253where
254 F: Fn(&str) -> Result<Vec<f32>> + Send + Sync + 'static,
255{
256 pub fn new(embedder: F) -> Self {
258 Self {
259 embedder: Arc::new(embedder),
260 }
261 }
262}
263
264impl<F> Vectorizer for CustomTextVectorizer<F>
265where
266 F: Fn(&str) -> Result<Vec<f32>> + Send + Sync + 'static,
267{
268 fn embed(&self, text: &str) -> Result<Vec<f32>> {
269 (self.embedder)(text)
270 }
271}
272
273#[derive(Debug, Deserialize)]
274pub(crate) struct EmbeddingResponse {
275 pub(crate) data: Vec<EmbeddingDatum>,
276}
277
278#[derive(Debug, Deserialize)]
279pub(crate) struct EmbeddingDatum {
280 pub(crate) embedding: Vec<f32>,
281}