Skip to main content

argyph_embed/
voyage.rs

1use std::time::Duration;
2
3use serde::{Deserialize, Serialize};
4use tracing;
5
6use crate::api_key::ApiKey;
7use crate::config::EmbedConfig;
8use crate::error::{EmbedError, Result};
9
10const VOYAGE_BASE_URL: &str = "https://api.voyageai.com";
11const MAX_TOKENS_PER_TEXT: usize = 16384;
12const MAX_BATCH_SIZE: usize = 128;
13const DEFAULT_MODEL: &str = "voyage-code-2";
14
15#[derive(Serialize)]
16struct VoyageEmbedRequest<'a> {
17    model: &'a str,
18    input: &'a [String],
19    input_type: &'a str,
20}
21
22#[derive(Deserialize)]
23struct VoyageEmbedResponse {
24    data: Vec<VoyageEmbeddingData>,
25}
26
27#[derive(Deserialize)]
28struct VoyageEmbeddingData {
29    embedding: Vec<f32>,
30}
31
32#[derive(Deserialize)]
33struct VoyageErrorResponse {
34    detail: Option<String>,
35}
36
37pub struct VoyageEmbedder {
38    api_key: ApiKey,
39    client: reqwest::Client,
40    config: EmbedConfig,
41    model: String,
42}
43
44impl VoyageEmbedder {
45    pub fn new(config: EmbedConfig) -> Result<Self> {
46        let api_key = ApiKey::from_env("VOYAGE_API_KEY")?;
47        Self::with_api_key(config, api_key)
48    }
49
50    pub fn with_api_key(config: EmbedConfig, api_key: ApiKey) -> Result<Self> {
51        let client = crate::http::build_client(&config)
52            .map_err(|e| EmbedError::Config(format!("failed to build HTTP client: {e}")))?;
53        Ok(Self {
54            api_key,
55            client,
56            config,
57            model: DEFAULT_MODEL.to_string(),
58        })
59    }
60
61    fn dimension_for_model(model: &str) -> usize {
62        match model {
63            "voyage-code-3" => 1024,
64            "voyage-large-2" => 512,
65            _ => 1536,
66        }
67    }
68
69    fn base_url(&self) -> &str {
70        self.config.base_url.as_deref().unwrap_or(VOYAGE_BASE_URL)
71    }
72
73    fn truncate_text(text: &str) -> String {
74        let words: Vec<&str> = text.split_whitespace().collect();
75        if words.len() <= MAX_TOKENS_PER_TEXT {
76            text.to_string()
77        } else {
78            words[..MAX_TOKENS_PER_TEXT].join(" ")
79        }
80    }
81}
82
83#[async_trait::async_trait]
84impl crate::Embedder for VoyageEmbedder {
85    fn dimension(&self) -> usize {
86        Self::dimension_for_model(&self.model)
87    }
88
89    fn model_id(&self) -> &str {
90        &self.model
91    }
92
93    async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
94        if texts.is_empty() {
95            return Err(EmbedError::EmptyInput);
96        }
97
98        if self.config.batch_size > MAX_BATCH_SIZE {
99            return Err(EmbedError::BatchTooLarge {
100                batch_size: self.config.batch_size,
101                max_batch_size: MAX_BATCH_SIZE,
102            });
103        }
104
105        let truncated: Vec<String> = texts.iter().map(|t| Self::truncate_text(t)).collect();
106        let url = format!("{}/v1/embeddings", self.base_url());
107
108        let mut all_embeddings: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
109
110        for (batch_idx, chunk) in truncated.chunks(self.config.batch_size).enumerate() {
111            let batch: Vec<String> = chunk.to_vec();
112            let batch_start = batch_idx * self.config.batch_size;
113
114            tracing::debug!(
115                model = %self.model,
116                batch_index = batch_idx,
117                batch_size = batch.len(),
118                url = %url,
119                input_type = "document",
120                "sending embedding request"
121            );
122
123            let response_data = self.send_with_retry(&url, &batch, "document").await?;
124
125            for (i, data) in response_data.into_iter().enumerate() {
126                let global_idx = batch_start + i;
127                if global_idx < all_embeddings.len() {
128                    all_embeddings[global_idx] = Some(data.embedding);
129                }
130            }
131
132            tracing::info!(
133                model = %self.model,
134                batch_index = batch_idx,
135                batch_size = batch.len(),
136                "batch embedding completed"
137            );
138        }
139
140        all_embeddings
141            .into_iter()
142            .collect::<Option<Vec<_>>>()
143            .ok_or_else(|| EmbedError::InvalidResponse("missing embeddings in response".into()))
144    }
145
146    async fn embed_query(&self, query: &str) -> Result<Vec<f32>> {
147        if query.is_empty() {
148            return Err(EmbedError::EmptyInput);
149        }
150
151        let truncated = Self::truncate_text(query);
152        let url = format!("{}/v1/embeddings", self.base_url());
153        let batch = vec![truncated];
154
155        tracing::debug!(
156            model = %self.model,
157            url = %url,
158            input_type = "query",
159            "sending query embedding request"
160        );
161
162        let mut response_data = self.send_with_retry(&url, &batch, "query").await?;
163
164        response_data
165            .pop()
166            .map(|d| d.embedding)
167            .ok_or_else(|| EmbedError::InvalidResponse("empty response for query embedding".into()))
168    }
169}
170
171impl VoyageEmbedder {
172    async fn send_with_retry(
173        &self,
174        url: &str,
175        batch: &[String],
176        input_type: &str,
177    ) -> Result<Vec<VoyageEmbeddingData>> {
178        let request_body = VoyageEmbedRequest {
179            model: &self.model,
180            input: batch,
181            input_type,
182        };
183
184        let mut last_error: Option<EmbedError> = None;
185
186        for attempt in 0..=self.config.max_retries {
187            if attempt > 0 {
188                let delay = self.config.base_delay * 2u32.pow(attempt - 1);
189                tokio::time::sleep(delay).await;
190            }
191
192            let response = self
193                .client
194                .post(url)
195                .bearer_auth(&*self.api_key)
196                .json(&request_body)
197                .send()
198                .await;
199
200            match response {
201                Ok(resp) => {
202                    let status = resp.status();
203
204                    if status.is_success() {
205                        match resp.json::<VoyageEmbedResponse>().await {
206                            Ok(parsed) => return Ok(parsed.data),
207                            Err(e) => {
208                                last_error = Some(EmbedError::InvalidResponse(format!(
209                                    "failed to parse response: {e}"
210                                )));
211                                break;
212                            }
213                        }
214                    }
215
216                    if status.as_u16() == 429 {
217                        let retry_after = resp
218                            .headers()
219                            .get("retry-after")
220                            .and_then(|v| v.to_str().ok())
221                            .and_then(|v| v.parse::<u64>().ok())
222                            .map(Duration::from_secs);
223                        return Err(EmbedError::RateLimited { retry_after });
224                    }
225
226                    if status.as_u16() == 401 || status.as_u16() == 403 {
227                        let body = resp.text().await.unwrap_or_default();
228                        return Err(EmbedError::Auth(body));
229                    }
230
231                    let body_text = resp.text().await.unwrap_or_default();
232                    let detail = serde_json::from_str::<VoyageErrorResponse>(&body_text)
233                        .ok()
234                        .and_then(|e| e.detail);
235                    let error_msg = if let Some(d) = detail {
236                        format!("HTTP {}: {}", status.as_u16(), d)
237                    } else {
238                        format!("HTTP {}: {}", status.as_u16(), body_text)
239                    };
240                    last_error = Some(EmbedError::Http(error_msg));
241                }
242                Err(e) => {
243                    last_error = Some(EmbedError::Http(e.to_string()));
244                }
245            }
246        }
247
248        Err(last_error.unwrap_or_else(|| EmbedError::Http("unknown error".into())))
249    }
250}
251
252#[cfg(test)]
253#[allow(clippy::unwrap_used)]
254mod tests {
255    use super::*;
256    use crate::api_key::ApiKey;
257    use crate::config::EmbedConfig;
258    use crate::Embedder;
259    use serde_json::json;
260    use wiremock::matchers::{method, path};
261    use wiremock::{Mock, MockServer, ResponseTemplate};
262
263    fn test_config(base_url: String) -> EmbedConfig {
264        EmbedConfig {
265            base_url: Some(base_url),
266            ..EmbedConfig::default()
267        }
268    }
269
270    fn test_config_batch64(base_url: String) -> EmbedConfig {
271        EmbedConfig {
272            base_url: Some(base_url),
273            batch_size: 64,
274            ..EmbedConfig::default()
275        }
276    }
277
278    fn test_api_key() -> ApiKey {
279        ApiKey::from("vp-test-key")
280    }
281
282    fn make_voyage_response(embeddings: Vec<Vec<f32>>) -> serde_json::Value {
283        let data: Vec<_> = embeddings
284            .into_iter()
285            .map(|embedding| {
286                json!({
287                    "object": "embedding",
288                    "embedding": embedding,
289                })
290            })
291            .collect();
292
293        json!({
294            "object": "list",
295            "data": data,
296            "model": "voyage-code-2",
297        })
298    }
299
300    #[tokio::test]
301    async fn happy_path_returns_correct_vectors() {
302        let mock_server = MockServer::start().await;
303        let expected = vec![vec![0.1_f32, 0.2, 0.3], vec![0.4, 0.5, 0.6]];
304
305        Mock::given(method("POST"))
306            .and(path("/v1/embeddings"))
307            .respond_with(
308                ResponseTemplate::new(200).set_body_json(make_voyage_response(expected.clone())),
309            )
310            .expect(1)
311            .mount(&mock_server)
312            .await;
313
314        let config = test_config(mock_server.uri());
315        let embedder = VoyageEmbedder::with_api_key(config, test_api_key()).unwrap();
316
317        let texts: Vec<String> = vec!["hello".into(), "world".into()];
318        let result = embedder.embed(&texts).await.unwrap();
319
320        assert_eq!(result.len(), 2);
321        assert_eq!(result[0], vec![0.1_f32, 0.2, 0.3]);
322        assert_eq!(result[1], vec![0.4, 0.5, 0.6]);
323    }
324
325    #[tokio::test]
326    async fn auth_failure_401_returns_auth_error() {
327        let mock_server = MockServer::start().await;
328
329        Mock::given(method("POST"))
330            .and(path("/v1/embeddings"))
331            .respond_with(ResponseTemplate::new(401).set_body_string("invalid api key"))
332            .expect(1)
333            .mount(&mock_server)
334            .await;
335
336        let config = test_config(mock_server.uri());
337        let embedder = VoyageEmbedder::with_api_key(config, test_api_key()).unwrap();
338
339        let texts: Vec<String> = vec!["hello".into()];
340        let result = embedder.embed(&texts).await;
341
342        assert!(result.is_err());
343        match result.unwrap_err() {
344            EmbedError::Auth(_) => {}
345            other => panic!("expected Auth error, got: {other:?}"),
346        }
347    }
348
349    #[tokio::test]
350    async fn rate_limit_429_returns_rate_limited_error() {
351        let mock_server = MockServer::start().await;
352
353        Mock::given(method("POST"))
354            .and(path("/v1/embeddings"))
355            .respond_with(
356                ResponseTemplate::new(429)
357                    .set_body_string("rate limited")
358                    .insert_header("retry-after", "42"),
359            )
360            .expect(1)
361            .mount(&mock_server)
362            .await;
363
364        let config = test_config(mock_server.uri());
365        let embedder = VoyageEmbedder::with_api_key(config, test_api_key()).unwrap();
366
367        let texts: Vec<String> = vec!["hello".into()];
368        let result = embedder.embed(&texts).await;
369
370        assert!(result.is_err());
371        match result.unwrap_err() {
372            EmbedError::RateLimited { retry_after } => {
373                assert_eq!(retry_after, Some(Duration::from_secs(42)));
374            }
375            other => panic!("expected RateLimited error, got: {other:?}"),
376        }
377    }
378
379    #[tokio::test]
380    async fn batching_splits_200_texts_into_4_batches() {
381        let mock_server = MockServer::start().await;
382
383        let generate_response = |count: usize| -> serde_json::Value {
384            let embeddings: Vec<Vec<f32>> = (0..count).map(|_| vec![0.1, 0.2, 0.3]).collect();
385            make_voyage_response(embeddings)
386        };
387
388        Mock::given(method("POST"))
389            .and(path("/v1/embeddings"))
390            .respond_with(move |req: &wiremock::Request| {
391                let body: serde_json::Value = serde_json::from_slice(&req.body).unwrap_or_default();
392                let input_len = body["input"].as_array().map(|a| a.len()).unwrap_or(0);
393                let resp = generate_response(input_len);
394                ResponseTemplate::new(200).set_body_json(resp)
395            })
396            .expect(4)
397            .mount(&mock_server)
398            .await;
399
400        let config = test_config_batch64(mock_server.uri());
401        let embedder = VoyageEmbedder::with_api_key(config, test_api_key()).unwrap();
402
403        let texts: Vec<String> = (0..200).map(|i| format!("text {i}")).collect();
404        let result = embedder.embed(&texts).await.unwrap();
405
406        assert_eq!(result.len(), 200);
407        for embedding in &result {
408            assert_eq!(embedding, &vec![0.1_f32, 0.2, 0.3]);
409        }
410    }
411
412    #[tokio::test]
413    async fn empty_input_returns_empty_input_error() {
414        let mock_server = MockServer::start().await;
415        let config = test_config(mock_server.uri());
416        let embedder = VoyageEmbedder::with_api_key(config, test_api_key()).unwrap();
417
418        let texts: Vec<String> = vec![];
419        let result = embedder.embed(&texts).await;
420
421        assert!(result.is_err());
422        match result.unwrap_err() {
423            EmbedError::EmptyInput => {}
424            other => panic!("expected EmptyInput error, got: {other:?}"),
425        }
426    }
427
428    #[tokio::test]
429    async fn embed_query_uses_input_type_query() {
430        let mock_server = MockServer::start().await;
431        let expected = vec![0.1_f32, 0.2, 0.3];
432        let response_value = make_voyage_response(vec![expected.clone()]);
433
434        Mock::given(method("POST"))
435            .and(path("/v1/embeddings"))
436            .respond_with(move |req: &wiremock::Request| {
437                let body: serde_json::Value = serde_json::from_slice(&req.body).unwrap_or_default();
438                let input_type = body["input_type"].as_str().unwrap_or("");
439                assert_eq!(
440                    input_type, "query",
441                    "embed_query must send input_type: query"
442                );
443                ResponseTemplate::new(200).set_body_json(response_value.clone())
444            })
445            .expect(1)
446            .mount(&mock_server)
447            .await;
448
449        let config = test_config(mock_server.uri());
450        let embedder = VoyageEmbedder::with_api_key(config, test_api_key()).unwrap();
451
452        let result = embedder.embed_query("hello").await.unwrap();
453        assert_eq!(result, expected);
454    }
455
456    #[cfg(feature = "live-providers")]
457    #[tokio::test]
458    async fn voyage_live_smoke() {
459        if std::env::var("VOYAGE_API_KEY").is_err() {
460            return;
461        }
462        let config = EmbedConfig::default();
463        let embedder = VoyageEmbedder::new(config).unwrap();
464
465        assert_eq!(embedder.dimension(), 1536);
466        assert_eq!(embedder.model_id(), "voyage-code-2");
467
468        let texts: Vec<String> = vec!["hello world".into(), "goodbye world".into()];
469        let embeddings = embedder.embed(&texts).await.unwrap();
470
471        assert_eq!(embeddings.len(), 2);
472        for embedding in &embeddings {
473            assert_eq!(embedding.len(), 1536);
474            let sum: f32 = embedding.iter().sum();
475            assert!(sum != 0.0, "embedding should not be all zeros");
476        }
477    }
478}