Skip to main content

redis_vl/vectorizers/
voyageai.rs

1//! VoyageAI embedding adapter.
2//!
3//! Enabled by the `voyageai` feature flag. VoyageAI has its own REST API shape
4//! at `https://api.voyageai.com/v1/embeddings`.
5
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8
9use super::{AsyncVectorizer, Vectorizer};
10use crate::error::Result;
11
12/// Configuration for the VoyageAI embedding provider.
13#[derive(Debug, Clone)]
14pub struct VoyageAIConfig {
15    /// API key for VoyageAI.
16    pub api_key: String,
17    /// Embedding model name (default: `voyage-3-large`).
18    pub model: String,
19    /// The VoyageAI `input_type` to use (e.g. `"document"`, `"query"`).
20    pub input_type: Option<String>,
21}
22
23impl VoyageAIConfig {
24    /// Creates a new VoyageAI config.
25    pub fn new(
26        api_key: impl Into<String>,
27        model: impl Into<String>,
28        input_type: Option<String>,
29    ) -> Self {
30        Self {
31            api_key: api_key.into(),
32            model: model.into(),
33            input_type,
34        }
35    }
36
37    /// Constructs from `VOYAGE_API_KEY` environment variable.
38    pub fn from_env(model: impl Into<String>, input_type: Option<String>) -> Result<Self> {
39        let api_key = std::env::var("VOYAGE_API_KEY")
40            .map_err(|_| crate::error::Error::InvalidInput("VOYAGE_API_KEY not set".into()))?;
41        Ok(Self::new(api_key, model, input_type))
42    }
43}
44
45const VOYAGEAI_EMBED_URL: &str = "https://api.voyageai.com/v1/embeddings";
46
47#[derive(Serialize)]
48struct VoyageAIEmbedRequest<'a> {
49    model: &'a str,
50    input: Vec<&'a str>,
51    #[serde(skip_serializing_if = "Option::is_none")]
52    input_type: Option<&'a str>,
53}
54
55#[derive(Deserialize)]
56struct VoyageAIEmbedResponse {
57    data: Vec<VoyageAIEmbedDatum>,
58}
59
60#[derive(Deserialize)]
61struct VoyageAIEmbedDatum {
62    embedding: Vec<f32>,
63}
64
65/// VoyageAI embedding adapter.
66///
67/// Uses the VoyageAI `/v1/embeddings` REST API.
68#[derive(Debug, Clone)]
69pub struct VoyageAITextVectorizer {
70    config: VoyageAIConfig,
71    client: reqwest::Client,
72    blocking_client: reqwest::blocking::Client,
73}
74
75impl VoyageAITextVectorizer {
76    /// Creates a new VoyageAI adapter.
77    pub fn new(config: VoyageAIConfig) -> Self {
78        Self {
79            config,
80            client: reqwest::Client::new(),
81            blocking_client: reqwest::blocking::Client::new(),
82        }
83    }
84
85    fn build_request<'a>(&'a self, texts: &[&'a str]) -> VoyageAIEmbedRequest<'a> {
86        VoyageAIEmbedRequest {
87            model: &self.config.model,
88            input: texts.to_vec(),
89            input_type: self.config.input_type.as_deref(),
90        }
91    }
92
93    async fn embed_many_inner(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
94        let resp: VoyageAIEmbedResponse = self
95            .client
96            .post(VOYAGEAI_EMBED_URL)
97            .bearer_auth(&self.config.api_key)
98            .json(&self.build_request(texts))
99            .send()
100            .await?
101            .error_for_status()?
102            .json()
103            .await?;
104        Ok(resp.data.into_iter().map(|d| d.embedding).collect())
105    }
106}
107
108impl Vectorizer for VoyageAITextVectorizer {
109    fn embed(&self, text: &str) -> Result<Vec<f32>> {
110        let resp: VoyageAIEmbedResponse = self
111            .blocking_client
112            .post(VOYAGEAI_EMBED_URL)
113            .bearer_auth(&self.config.api_key)
114            .json(&self.build_request(&[text]))
115            .send()?
116            .error_for_status()?
117            .json()?;
118        Ok(resp
119            .data
120            .into_iter()
121            .next()
122            .map_or_else(Vec::new, |d| d.embedding))
123    }
124}
125
126#[async_trait]
127impl AsyncVectorizer for VoyageAITextVectorizer {
128    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
129        let mut v = self.embed_many_inner(&[text]).await?;
130        Ok(v.pop().unwrap_or_default())
131    }
132
133    async fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
134        self.embed_many_inner(texts).await
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141
142    #[test]
143    fn voyageai_config_stores_fields() {
144        let cfg = VoyageAIConfig::new("key", "voyage-3-large", Some("document".into()));
145        assert_eq!(cfg.api_key, "key");
146        assert_eq!(cfg.model, "voyage-3-large");
147        assert_eq!(cfg.input_type.as_deref(), Some("document"));
148    }
149
150    #[test]
151    fn voyageai_request_serializes_with_input_type() {
152        let cfg = VoyageAIConfig::new("k", "voyage-3-large", Some("query".into()));
153        let v = VoyageAITextVectorizer::new(cfg);
154        let body = v.build_request(&["hello"]);
155        let json = serde_json::to_value(&body).unwrap();
156        assert_eq!(json["model"], "voyage-3-large");
157        assert_eq!(json["input"], serde_json::json!(["hello"]));
158        assert_eq!(json["input_type"], "query");
159    }
160
161    #[test]
162    fn voyageai_request_omits_none_input_type() {
163        let cfg = VoyageAIConfig::new("k", "voyage-3-large", None);
164        let v = VoyageAITextVectorizer::new(cfg);
165        let body = v.build_request(&["hello"]);
166        let json = serde_json::to_value(&body).unwrap();
167        assert!(json.get("input_type").is_none());
168    }
169
170    #[test]
171    fn voyageai_vectorizer_is_send_sync() {
172        fn assert_send_sync<T: Send + Sync>() {}
173        assert_send_sync::<VoyageAITextVectorizer>();
174    }
175}