Skip to main content

redis_vl/vectorizers/
mod.rs

1//! Embedding provider abstractions and OpenAI-compatible adapters.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7
8use crate::error::Result;
9
10#[cfg(feature = "anthropic")]
11mod anthropic;
12#[cfg(feature = "anthropic")]
13pub use self::anthropic::{AnthropicConfig, AnthropicTextVectorizer};
14
15#[cfg(feature = "azure-openai")]
16mod azure_openai;
17#[cfg(feature = "azure-openai")]
18pub use azure_openai::{AzureOpenAIConfig, AzureOpenAITextVectorizer};
19
20#[cfg(feature = "bedrock")]
21mod bedrock;
22#[cfg(feature = "bedrock")]
23pub use self::bedrock::{BedrockConfig, BedrockTextVectorizer};
24
25#[cfg(feature = "cohere")]
26mod cohere;
27#[cfg(feature = "cohere")]
28pub use self::cohere::{CohereConfig, CohereTextVectorizer};
29
30#[cfg(feature = "hf-local")]
31mod hf_local;
32#[cfg(feature = "hf-local")]
33pub use self::hf_local::{HuggingFaceConfig, HuggingFaceTextVectorizer};
34
35#[cfg(feature = "mistral")]
36mod mistral;
37#[cfg(feature = "mistral")]
38pub use self::mistral::{MistralAITextVectorizer, MistralConfig};
39
40#[cfg(feature = "voyageai")]
41mod voyageai;
42#[cfg(feature = "voyageai")]
43pub use self::voyageai::{VoyageAIConfig, VoyageAITextVectorizer};
44
45#[cfg(feature = "vertex-ai")]
46mod vertex_ai;
47#[cfg(feature = "vertex-ai")]
48pub use self::vertex_ai::{VertexAIConfig, VertexAITextVectorizer};
49
50/// Shared embedding request payload.
51#[derive(Debug, Clone, Serialize)]
52pub struct EmbeddingRequest<'a> {
53    /// Model name.
54    pub model: &'a str,
55    /// Input texts.
56    pub input: Vec<&'a str>,
57}
58
59/// Synchronous vectorizer abstraction.
60pub trait Vectorizer: Send + Sync {
61    /// Embeds a single string.
62    fn embed(&self, text: &str) -> Result<Vec<f32>>;
63
64    /// Embeds many strings.
65    fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
66        texts.iter().map(|text| self.embed(text)).collect()
67    }
68}
69
70/// Asynchronous vectorizer abstraction.
71#[async_trait]
72pub trait AsyncVectorizer: Send + Sync {
73    /// Embeds a single string asynchronously.
74    async fn embed(&self, text: &str) -> Result<Vec<f32>>;
75
76    /// Embeds many strings asynchronously.
77    async fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
78        let mut embeddings = Vec::with_capacity(texts.len());
79        for text in texts {
80            embeddings.push(self.embed(text).await?);
81        }
82        Ok(embeddings)
83    }
84}
85
86/// Shared configuration for OpenAI-compatible embedding providers.
87#[derive(Debug, Clone)]
88pub struct OpenAICompatibleConfig {
89    /// Base URL for the provider.
90    pub base_url: url::Url,
91    /// API key used for authentication.
92    pub api_key: String,
93    /// Embedding model name.
94    pub model: String,
95}
96
97impl OpenAICompatibleConfig {
98    /// Creates a new OpenAI-compatible config.
99    pub fn new(
100        base_url: impl AsRef<str>,
101        api_key: impl Into<String>,
102        model: impl Into<String>,
103    ) -> Result<Self> {
104        Ok(Self {
105            base_url: url::Url::parse(base_url.as_ref())?,
106            api_key: api_key.into(),
107            model: model.into(),
108        })
109    }
110
111    fn embeddings_url(&self) -> Result<url::Url> {
112        Ok(self.base_url.join("embeddings")?)
113    }
114}
115
116/// OpenAI embedding adapter.
117#[derive(Debug, Clone)]
118pub struct OpenAITextVectorizer {
119    config: OpenAICompatibleConfig,
120    client: reqwest::Client,
121    blocking_client: reqwest::blocking::Client,
122}
123
124impl OpenAITextVectorizer {
125    /// Creates a new OpenAI adapter.
126    pub fn new(config: OpenAICompatibleConfig) -> Self {
127        Self {
128            config,
129            client: reqwest::Client::new(),
130            blocking_client: reqwest::blocking::Client::new(),
131        }
132    }
133
134    async fn embed_many_inner(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
135        let response: EmbeddingResponse = self
136            .client
137            .post(self.config.embeddings_url()?)
138            .bearer_auth(&self.config.api_key)
139            .json(&EmbeddingRequest {
140                model: &self.config.model,
141                input: texts.to_vec(),
142            })
143            .send()
144            .await?
145            .error_for_status()?
146            .json()
147            .await?;
148        Ok(response
149            .data
150            .into_iter()
151            .map(|item| item.embedding)
152            .collect())
153    }
154}
155
156impl Vectorizer for OpenAITextVectorizer {
157    fn embed(&self, text: &str) -> Result<Vec<f32>> {
158        let response: EmbeddingResponse = self
159            .blocking_client
160            .post(self.config.embeddings_url()?)
161            .bearer_auth(&self.config.api_key)
162            .json(&EmbeddingRequest {
163                model: &self.config.model,
164                input: vec![text],
165            })
166            .send()?
167            .error_for_status()?
168            .json()?;
169        Ok(response
170            .data
171            .into_iter()
172            .next()
173            .map_or_else(Vec::new, |item| item.embedding))
174    }
175
176    fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
177        let response: EmbeddingResponse = self
178            .blocking_client
179            .post(self.config.embeddings_url()?)
180            .bearer_auth(&self.config.api_key)
181            .json(&EmbeddingRequest {
182                model: &self.config.model,
183                input: texts.to_vec(),
184            })
185            .send()?
186            .error_for_status()?
187            .json()?;
188        Ok(response
189            .data
190            .into_iter()
191            .map(|item| item.embedding)
192            .collect())
193    }
194}
195
196#[async_trait]
197impl AsyncVectorizer for OpenAITextVectorizer {
198    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
199        let mut embeddings = self.embed_many_inner(&[text]).await?;
200        Ok(embeddings.pop().unwrap_or_default())
201    }
202
203    async fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
204        self.embed_many_inner(texts).await
205    }
206}
207
208/// LiteLLM embedding adapter built on the same OpenAI-compatible transport.
209#[derive(Debug, Clone)]
210pub struct LiteLLMTextVectorizer {
211    inner: OpenAITextVectorizer,
212}
213
214impl LiteLLMTextVectorizer {
215    /// Creates a new LiteLLM adapter.
216    pub fn new(config: OpenAICompatibleConfig) -> Self {
217        Self {
218            inner: OpenAITextVectorizer::new(config),
219        }
220    }
221}
222
223impl Vectorizer for LiteLLMTextVectorizer {
224    fn embed(&self, text: &str) -> Result<Vec<f32>> {
225        Vectorizer::embed(&self.inner, text)
226    }
227
228    fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
229        Vectorizer::embed_many(&self.inner, texts)
230    }
231}
232
233#[async_trait]
234impl AsyncVectorizer for LiteLLMTextVectorizer {
235    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
236        AsyncVectorizer::embed(&self.inner, text).await
237    }
238
239    async fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
240        AsyncVectorizer::embed_many(&self.inner, texts).await
241    }
242}
243
244/// Custom synchronous vectorizer backed by a user callback.
245pub struct CustomTextVectorizer<F>
246where
247    F: Fn(&str) -> Result<Vec<f32>> + Send + Sync + 'static,
248{
249    embedder: Arc<F>,
250}
251
252impl<F> CustomTextVectorizer<F>
253where
254    F: Fn(&str) -> Result<Vec<f32>> + Send + Sync + 'static,
255{
256    /// Creates a custom synchronous vectorizer.
257    pub fn new(embedder: F) -> Self {
258        Self {
259            embedder: Arc::new(embedder),
260        }
261    }
262}
263
264impl<F> Vectorizer for CustomTextVectorizer<F>
265where
266    F: Fn(&str) -> Result<Vec<f32>> + Send + Sync + 'static,
267{
268    fn embed(&self, text: &str) -> Result<Vec<f32>> {
269        (self.embedder)(text)
270    }
271}
272
273#[derive(Debug, Deserialize)]
274pub(crate) struct EmbeddingResponse {
275    pub(crate) data: Vec<EmbeddingDatum>,
276}
277
278#[derive(Debug, Deserialize)]
279pub(crate) struct EmbeddingDatum {
280    pub(crate) embedding: Vec<f32>,
281}