1use std::time::Duration;
2
3use serde::{Deserialize, Serialize};
4use tracing;
5
6use crate::api_key::ApiKey;
7use crate::config::EmbedConfig;
8use crate::error::{EmbedError, Result};
9
10const VOYAGE_BASE_URL: &str = "https://api.voyageai.com";
11const MAX_TOKENS_PER_TEXT: usize = 16384;
12const MAX_BATCH_SIZE: usize = 128;
13const DEFAULT_MODEL: &str = "voyage-code-2";
14
15#[derive(Serialize)]
16struct VoyageEmbedRequest<'a> {
17 model: &'a str,
18 input: &'a [String],
19 input_type: &'a str,
20}
21
22#[derive(Deserialize)]
23struct VoyageEmbedResponse {
24 data: Vec<VoyageEmbeddingData>,
25}
26
27#[derive(Deserialize)]
28struct VoyageEmbeddingData {
29 embedding: Vec<f32>,
30}
31
32#[derive(Deserialize)]
33struct VoyageErrorResponse {
34 detail: Option<String>,
35}
36
37pub struct VoyageEmbedder {
38 api_key: ApiKey,
39 client: reqwest::Client,
40 config: EmbedConfig,
41 model: String,
42}
43
44impl VoyageEmbedder {
45 pub fn new(config: EmbedConfig) -> Result<Self> {
46 let api_key = ApiKey::from_env("VOYAGE_API_KEY")?;
47 Self::with_api_key(config, api_key)
48 }
49
50 pub fn with_api_key(config: EmbedConfig, api_key: ApiKey) -> Result<Self> {
51 let client = crate::http::build_client(&config)
52 .map_err(|e| EmbedError::Config(format!("failed to build HTTP client: {e}")))?;
53 Ok(Self {
54 api_key,
55 client,
56 config,
57 model: DEFAULT_MODEL.to_string(),
58 })
59 }
60
61 fn dimension_for_model(model: &str) -> usize {
62 match model {
63 "voyage-code-3" => 1024,
64 "voyage-large-2" => 512,
65 _ => 1536,
66 }
67 }
68
69 fn base_url(&self) -> &str {
70 self.config.base_url.as_deref().unwrap_or(VOYAGE_BASE_URL)
71 }
72
73 fn truncate_text(text: &str) -> String {
74 let words: Vec<&str> = text.split_whitespace().collect();
75 if words.len() <= MAX_TOKENS_PER_TEXT {
76 text.to_string()
77 } else {
78 words[..MAX_TOKENS_PER_TEXT].join(" ")
79 }
80 }
81}
82
83#[async_trait::async_trait]
84impl crate::Embedder for VoyageEmbedder {
85 fn dimension(&self) -> usize {
86 Self::dimension_for_model(&self.model)
87 }
88
89 fn model_id(&self) -> &str {
90 &self.model
91 }
92
93 async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
94 if texts.is_empty() {
95 return Err(EmbedError::EmptyInput);
96 }
97
98 if self.config.batch_size > MAX_BATCH_SIZE {
99 return Err(EmbedError::BatchTooLarge {
100 batch_size: self.config.batch_size,
101 max_batch_size: MAX_BATCH_SIZE,
102 });
103 }
104
105 let truncated: Vec<String> = texts.iter().map(|t| Self::truncate_text(t)).collect();
106 let url = format!("{}/v1/embeddings", self.base_url());
107
108 let mut all_embeddings: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
109
110 for (batch_idx, chunk) in truncated.chunks(self.config.batch_size).enumerate() {
111 let batch: Vec<String> = chunk.to_vec();
112 let batch_start = batch_idx * self.config.batch_size;
113
114 tracing::debug!(
115 model = %self.model,
116 batch_index = batch_idx,
117 batch_size = batch.len(),
118 url = %url,
119 input_type = "document",
120 "sending embedding request"
121 );
122
123 let response_data = self.send_with_retry(&url, &batch, "document").await?;
124
125 for (i, data) in response_data.into_iter().enumerate() {
126 let global_idx = batch_start + i;
127 if global_idx < all_embeddings.len() {
128 all_embeddings[global_idx] = Some(data.embedding);
129 }
130 }
131
132 tracing::info!(
133 model = %self.model,
134 batch_index = batch_idx,
135 batch_size = batch.len(),
136 "batch embedding completed"
137 );
138 }
139
140 all_embeddings
141 .into_iter()
142 .collect::<Option<Vec<_>>>()
143 .ok_or_else(|| EmbedError::InvalidResponse("missing embeddings in response".into()))
144 }
145
146 async fn embed_query(&self, query: &str) -> Result<Vec<f32>> {
147 if query.is_empty() {
148 return Err(EmbedError::EmptyInput);
149 }
150
151 let truncated = Self::truncate_text(query);
152 let url = format!("{}/v1/embeddings", self.base_url());
153 let batch = vec![truncated];
154
155 tracing::debug!(
156 model = %self.model,
157 url = %url,
158 input_type = "query",
159 "sending query embedding request"
160 );
161
162 let mut response_data = self.send_with_retry(&url, &batch, "query").await?;
163
164 response_data
165 .pop()
166 .map(|d| d.embedding)
167 .ok_or_else(|| EmbedError::InvalidResponse("empty response for query embedding".into()))
168 }
169}
170
171impl VoyageEmbedder {
172 async fn send_with_retry(
173 &self,
174 url: &str,
175 batch: &[String],
176 input_type: &str,
177 ) -> Result<Vec<VoyageEmbeddingData>> {
178 let request_body = VoyageEmbedRequest {
179 model: &self.model,
180 input: batch,
181 input_type,
182 };
183
184 let mut last_error: Option<EmbedError> = None;
185
186 for attempt in 0..=self.config.max_retries {
187 if attempt > 0 {
188 let delay = self.config.base_delay * 2u32.pow(attempt - 1);
189 tokio::time::sleep(delay).await;
190 }
191
192 let response = self
193 .client
194 .post(url)
195 .bearer_auth(&*self.api_key)
196 .json(&request_body)
197 .send()
198 .await;
199
200 match response {
201 Ok(resp) => {
202 let status = resp.status();
203
204 if status.is_success() {
205 match resp.json::<VoyageEmbedResponse>().await {
206 Ok(parsed) => return Ok(parsed.data),
207 Err(e) => {
208 last_error = Some(EmbedError::InvalidResponse(format!(
209 "failed to parse response: {e}"
210 )));
211 break;
212 }
213 }
214 }
215
216 if status.as_u16() == 429 {
217 let retry_after = resp
218 .headers()
219 .get("retry-after")
220 .and_then(|v| v.to_str().ok())
221 .and_then(|v| v.parse::<u64>().ok())
222 .map(Duration::from_secs);
223 return Err(EmbedError::RateLimited { retry_after });
224 }
225
226 if status.as_u16() == 401 || status.as_u16() == 403 {
227 let body = resp.text().await.unwrap_or_default();
228 return Err(EmbedError::Auth(body));
229 }
230
231 let body_text = resp.text().await.unwrap_or_default();
232 let detail = serde_json::from_str::<VoyageErrorResponse>(&body_text)
233 .ok()
234 .and_then(|e| e.detail);
235 let error_msg = if let Some(d) = detail {
236 format!("HTTP {}: {}", status.as_u16(), d)
237 } else {
238 format!("HTTP {}: {}", status.as_u16(), body_text)
239 };
240 last_error = Some(EmbedError::Http(error_msg));
241 }
242 Err(e) => {
243 last_error = Some(EmbedError::Http(e.to_string()));
244 }
245 }
246 }
247
248 Err(last_error.unwrap_or_else(|| EmbedError::Http("unknown error".into())))
249 }
250}
251
252#[cfg(test)]
253#[allow(clippy::unwrap_used)]
254mod tests {
255 use super::*;
256 use crate::api_key::ApiKey;
257 use crate::config::EmbedConfig;
258 use crate::Embedder;
259 use serde_json::json;
260 use wiremock::matchers::{method, path};
261 use wiremock::{Mock, MockServer, ResponseTemplate};
262
263 fn test_config(base_url: String) -> EmbedConfig {
264 EmbedConfig {
265 base_url: Some(base_url),
266 ..EmbedConfig::default()
267 }
268 }
269
270 fn test_config_batch64(base_url: String) -> EmbedConfig {
271 EmbedConfig {
272 base_url: Some(base_url),
273 batch_size: 64,
274 ..EmbedConfig::default()
275 }
276 }
277
278 fn test_api_key() -> ApiKey {
279 ApiKey::from("vp-test-key")
280 }
281
282 fn make_voyage_response(embeddings: Vec<Vec<f32>>) -> serde_json::Value {
283 let data: Vec<_> = embeddings
284 .into_iter()
285 .map(|embedding| {
286 json!({
287 "object": "embedding",
288 "embedding": embedding,
289 })
290 })
291 .collect();
292
293 json!({
294 "object": "list",
295 "data": data,
296 "model": "voyage-code-2",
297 })
298 }
299
300 #[tokio::test]
301 async fn happy_path_returns_correct_vectors() {
302 let mock_server = MockServer::start().await;
303 let expected = vec![vec![0.1_f32, 0.2, 0.3], vec![0.4, 0.5, 0.6]];
304
305 Mock::given(method("POST"))
306 .and(path("/v1/embeddings"))
307 .respond_with(
308 ResponseTemplate::new(200).set_body_json(make_voyage_response(expected.clone())),
309 )
310 .expect(1)
311 .mount(&mock_server)
312 .await;
313
314 let config = test_config(mock_server.uri());
315 let embedder = VoyageEmbedder::with_api_key(config, test_api_key()).unwrap();
316
317 let texts: Vec<String> = vec!["hello".into(), "world".into()];
318 let result = embedder.embed(&texts).await.unwrap();
319
320 assert_eq!(result.len(), 2);
321 assert_eq!(result[0], vec![0.1_f32, 0.2, 0.3]);
322 assert_eq!(result[1], vec![0.4, 0.5, 0.6]);
323 }
324
325 #[tokio::test]
326 async fn auth_failure_401_returns_auth_error() {
327 let mock_server = MockServer::start().await;
328
329 Mock::given(method("POST"))
330 .and(path("/v1/embeddings"))
331 .respond_with(ResponseTemplate::new(401).set_body_string("invalid api key"))
332 .expect(1)
333 .mount(&mock_server)
334 .await;
335
336 let config = test_config(mock_server.uri());
337 let embedder = VoyageEmbedder::with_api_key(config, test_api_key()).unwrap();
338
339 let texts: Vec<String> = vec!["hello".into()];
340 let result = embedder.embed(&texts).await;
341
342 assert!(result.is_err());
343 match result.unwrap_err() {
344 EmbedError::Auth(_) => {}
345 other => panic!("expected Auth error, got: {other:?}"),
346 }
347 }
348
349 #[tokio::test]
350 async fn rate_limit_429_returns_rate_limited_error() {
351 let mock_server = MockServer::start().await;
352
353 Mock::given(method("POST"))
354 .and(path("/v1/embeddings"))
355 .respond_with(
356 ResponseTemplate::new(429)
357 .set_body_string("rate limited")
358 .insert_header("retry-after", "42"),
359 )
360 .expect(1)
361 .mount(&mock_server)
362 .await;
363
364 let config = test_config(mock_server.uri());
365 let embedder = VoyageEmbedder::with_api_key(config, test_api_key()).unwrap();
366
367 let texts: Vec<String> = vec!["hello".into()];
368 let result = embedder.embed(&texts).await;
369
370 assert!(result.is_err());
371 match result.unwrap_err() {
372 EmbedError::RateLimited { retry_after } => {
373 assert_eq!(retry_after, Some(Duration::from_secs(42)));
374 }
375 other => panic!("expected RateLimited error, got: {other:?}"),
376 }
377 }
378
379 #[tokio::test]
380 async fn batching_splits_200_texts_into_4_batches() {
381 let mock_server = MockServer::start().await;
382
383 let generate_response = |count: usize| -> serde_json::Value {
384 let embeddings: Vec<Vec<f32>> = (0..count).map(|_| vec![0.1, 0.2, 0.3]).collect();
385 make_voyage_response(embeddings)
386 };
387
388 Mock::given(method("POST"))
389 .and(path("/v1/embeddings"))
390 .respond_with(move |req: &wiremock::Request| {
391 let body: serde_json::Value = serde_json::from_slice(&req.body).unwrap_or_default();
392 let input_len = body["input"].as_array().map(|a| a.len()).unwrap_or(0);
393 let resp = generate_response(input_len);
394 ResponseTemplate::new(200).set_body_json(resp)
395 })
396 .expect(4)
397 .mount(&mock_server)
398 .await;
399
400 let config = test_config_batch64(mock_server.uri());
401 let embedder = VoyageEmbedder::with_api_key(config, test_api_key()).unwrap();
402
403 let texts: Vec<String> = (0..200).map(|i| format!("text {i}")).collect();
404 let result = embedder.embed(&texts).await.unwrap();
405
406 assert_eq!(result.len(), 200);
407 for embedding in &result {
408 assert_eq!(embedding, &vec![0.1_f32, 0.2, 0.3]);
409 }
410 }
411
412 #[tokio::test]
413 async fn empty_input_returns_empty_input_error() {
414 let mock_server = MockServer::start().await;
415 let config = test_config(mock_server.uri());
416 let embedder = VoyageEmbedder::with_api_key(config, test_api_key()).unwrap();
417
418 let texts: Vec<String> = vec![];
419 let result = embedder.embed(&texts).await;
420
421 assert!(result.is_err());
422 match result.unwrap_err() {
423 EmbedError::EmptyInput => {}
424 other => panic!("expected EmptyInput error, got: {other:?}"),
425 }
426 }
427
428 #[tokio::test]
429 async fn embed_query_uses_input_type_query() {
430 let mock_server = MockServer::start().await;
431 let expected = vec![0.1_f32, 0.2, 0.3];
432 let response_value = make_voyage_response(vec![expected.clone()]);
433
434 Mock::given(method("POST"))
435 .and(path("/v1/embeddings"))
436 .respond_with(move |req: &wiremock::Request| {
437 let body: serde_json::Value = serde_json::from_slice(&req.body).unwrap_or_default();
438 let input_type = body["input_type"].as_str().unwrap_or("");
439 assert_eq!(
440 input_type, "query",
441 "embed_query must send input_type: query"
442 );
443 ResponseTemplate::new(200).set_body_json(response_value.clone())
444 })
445 .expect(1)
446 .mount(&mock_server)
447 .await;
448
449 let config = test_config(mock_server.uri());
450 let embedder = VoyageEmbedder::with_api_key(config, test_api_key()).unwrap();
451
452 let result = embedder.embed_query("hello").await.unwrap();
453 assert_eq!(result, expected);
454 }
455
456 #[cfg(feature = "live-providers")]
457 #[tokio::test]
458 async fn voyage_live_smoke() {
459 if std::env::var("VOYAGE_API_KEY").is_err() {
460 return;
461 }
462 let config = EmbedConfig::default();
463 let embedder = VoyageEmbedder::new(config).unwrap();
464
465 assert_eq!(embedder.dimension(), 1536);
466 assert_eq!(embedder.model_id(), "voyage-code-2");
467
468 let texts: Vec<String> = vec!["hello world".into(), "goodbye world".into()];
469 let embeddings = embedder.embed(&texts).await.unwrap();
470
471 assert_eq!(embeddings.len(), 2);
472 for embedding in &embeddings {
473 assert_eq!(embedding.len(), 1536);
474 let sum: f32 = embedding.iter().sum();
475 assert!(sum != 0.0, "embedding should not be all zeros");
476 }
477 }
478}