redis_vl/vectorizers/
voyageai.rs1use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8
9use super::{AsyncVectorizer, Vectorizer};
10use crate::error::Result;
11
12#[derive(Debug, Clone)]
14pub struct VoyageAIConfig {
15 pub api_key: String,
17 pub model: String,
19 pub input_type: Option<String>,
21}
22
23impl VoyageAIConfig {
24 pub fn new(
26 api_key: impl Into<String>,
27 model: impl Into<String>,
28 input_type: Option<String>,
29 ) -> Self {
30 Self {
31 api_key: api_key.into(),
32 model: model.into(),
33 input_type,
34 }
35 }
36
37 pub fn from_env(model: impl Into<String>, input_type: Option<String>) -> Result<Self> {
39 let api_key = std::env::var("VOYAGE_API_KEY")
40 .map_err(|_| crate::error::Error::InvalidInput("VOYAGE_API_KEY not set".into()))?;
41 Ok(Self::new(api_key, model, input_type))
42 }
43}
44
45const VOYAGEAI_EMBED_URL: &str = "https://api.voyageai.com/v1/embeddings";
46
47#[derive(Serialize)]
48struct VoyageAIEmbedRequest<'a> {
49 model: &'a str,
50 input: Vec<&'a str>,
51 #[serde(skip_serializing_if = "Option::is_none")]
52 input_type: Option<&'a str>,
53}
54
55#[derive(Deserialize)]
56struct VoyageAIEmbedResponse {
57 data: Vec<VoyageAIEmbedDatum>,
58}
59
60#[derive(Deserialize)]
61struct VoyageAIEmbedDatum {
62 embedding: Vec<f32>,
63}
64
65#[derive(Debug, Clone)]
69pub struct VoyageAITextVectorizer {
70 config: VoyageAIConfig,
71 client: reqwest::Client,
72 blocking_client: reqwest::blocking::Client,
73}
74
75impl VoyageAITextVectorizer {
76 pub fn new(config: VoyageAIConfig) -> Self {
78 Self {
79 config,
80 client: reqwest::Client::new(),
81 blocking_client: reqwest::blocking::Client::new(),
82 }
83 }
84
85 fn build_request<'a>(&'a self, texts: &[&'a str]) -> VoyageAIEmbedRequest<'a> {
86 VoyageAIEmbedRequest {
87 model: &self.config.model,
88 input: texts.to_vec(),
89 input_type: self.config.input_type.as_deref(),
90 }
91 }
92
93 async fn embed_many_inner(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
94 let resp: VoyageAIEmbedResponse = self
95 .client
96 .post(VOYAGEAI_EMBED_URL)
97 .bearer_auth(&self.config.api_key)
98 .json(&self.build_request(texts))
99 .send()
100 .await?
101 .error_for_status()?
102 .json()
103 .await?;
104 Ok(resp.data.into_iter().map(|d| d.embedding).collect())
105 }
106}
107
108impl Vectorizer for VoyageAITextVectorizer {
109 fn embed(&self, text: &str) -> Result<Vec<f32>> {
110 let resp: VoyageAIEmbedResponse = self
111 .blocking_client
112 .post(VOYAGEAI_EMBED_URL)
113 .bearer_auth(&self.config.api_key)
114 .json(&self.build_request(&[text]))
115 .send()?
116 .error_for_status()?
117 .json()?;
118 Ok(resp
119 .data
120 .into_iter()
121 .next()
122 .map_or_else(Vec::new, |d| d.embedding))
123 }
124}
125
126#[async_trait]
127impl AsyncVectorizer for VoyageAITextVectorizer {
128 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
129 let mut v = self.embed_many_inner(&[text]).await?;
130 Ok(v.pop().unwrap_or_default())
131 }
132
133 async fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
134 self.embed_many_inner(texts).await
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
143 fn voyageai_config_stores_fields() {
144 let cfg = VoyageAIConfig::new("key", "voyage-3-large", Some("document".into()));
145 assert_eq!(cfg.api_key, "key");
146 assert_eq!(cfg.model, "voyage-3-large");
147 assert_eq!(cfg.input_type.as_deref(), Some("document"));
148 }
149
150 #[test]
151 fn voyageai_request_serializes_with_input_type() {
152 let cfg = VoyageAIConfig::new("k", "voyage-3-large", Some("query".into()));
153 let v = VoyageAITextVectorizer::new(cfg);
154 let body = v.build_request(&["hello"]);
155 let json = serde_json::to_value(&body).unwrap();
156 assert_eq!(json["model"], "voyage-3-large");
157 assert_eq!(json["input"], serde_json::json!(["hello"]));
158 assert_eq!(json["input_type"], "query");
159 }
160
161 #[test]
162 fn voyageai_request_omits_none_input_type() {
163 let cfg = VoyageAIConfig::new("k", "voyage-3-large", None);
164 let v = VoyageAITextVectorizer::new(cfg);
165 let body = v.build_request(&["hello"]);
166 let json = serde_json::to_value(&body).unwrap();
167 assert!(json.get("input_type").is_none());
168 }
169
170 #[test]
171 fn voyageai_vectorizer_is_send_sync() {
172 fn assert_send_sync<T: Send + Sync>() {}
173 assert_send_sync::<VoyageAITextVectorizer>();
174 }
175}