agcodex_core/embeddings/providers/
voyage.rs

1//! Voyage AI embeddings provider - completely separate from chat models
2//!
3//! Supports:
4//! - voyage-3.5 (1024 dimensions by default)
5//! - Batch processing (up to 128 inputs)
6//! - Document vs Query input types for optimized embeddings
7//! - Uses VOYAGE_API_KEY environment variable
8
9use super::super::EmbeddingError;
10use super::super::EmbeddingProvider;
11use super::super::EmbeddingVector;
12use reqwest::Client;
13use serde::Deserialize;
14use serde::Serialize;
15
16/// Input type for Voyage AI embeddings
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum VoyageInputType {
19    /// For documents being indexed/stored
20    Document,
21    /// For search queries
22    Query,
23}
24
25impl ToString for VoyageInputType {
26    fn to_string(&self) -> String {
27        match self {
28            VoyageInputType::Document => "document".to_string(),
29            VoyageInputType::Query => "query".to_string(),
30        }
31    }
32}
33
34/// Voyage AI embedding provider
35pub struct VoyageProvider {
36    client: Client,
37    api_key: String,
38    model: String,
39    input_type: VoyageInputType,
40    api_endpoint: Option<String>,
41}
42
43impl VoyageProvider {
44    /// Create a new Voyage provider
45    pub fn new(
46        api_key: String,
47        model: String,
48        input_type: VoyageInputType,
49        api_endpoint: Option<String>,
50    ) -> Self {
51        Self {
52            client: Client::new(),
53            api_key,
54            model,
55            input_type,
56            api_endpoint,
57        }
58    }
59
60    /// Create a new Voyage provider for documents
61    pub fn new_for_documents(api_key: String, model: String) -> Self {
62        Self::new(api_key, model, VoyageInputType::Document, None)
63    }
64
65    /// Create a new Voyage provider for queries
66    pub fn new_for_queries(api_key: String, model: String) -> Self {
67        Self::new(api_key, model, VoyageInputType::Query, None)
68    }
69
70    /// Get the current input type
71    pub const fn input_type(&self) -> &VoyageInputType {
72        &self.input_type
73    }
74
75    /// Set the input type
76    pub const fn set_input_type(&mut self, input_type: VoyageInputType) {
77        self.input_type = input_type;
78    }
79}
80
81#[derive(Debug, Serialize)]
82struct VoyageRequest {
83    model: String,
84    input: Vec<String>,
85    input_type: String,
86}
87
88#[derive(Debug, Deserialize)]
89struct VoyageResponse {
90    data: Vec<VoyageEmbedding>,
91    _usage: VoyageUsage,
92}
93
94#[derive(Debug, Deserialize)]
95struct VoyageEmbedding {
96    embedding: Vec<f32>,
97    index: usize,
98}
99
100#[derive(Debug, Deserialize)]
101struct VoyageUsage {
102    _total_tokens: usize,
103}
104
105#[derive(Debug, Deserialize)]
106struct VoyageError {
107    error: VoyageErrorDetail,
108}
109
110#[derive(Debug, Deserialize)]
111struct VoyageErrorDetail {
112    message: String,
113    #[serde(rename = "type")]
114    error_type: String,
115    _code: Option<String>,
116}
117
118#[async_trait::async_trait]
119impl EmbeddingProvider for VoyageProvider {
120    fn model_id(&self) -> String {
121        format!("voyage:{}:{}", self.model, self.input_type.to_string())
122    }
123
124    fn dimensions(&self) -> usize {
125        // Return model-specific dimensions
126        match self.model.as_str() {
127            "voyage-3.5" => 1024,
128            "voyage-3.5-lite" => 512,
129            "voyage-3-large" => 1536,
130            "voyage-3" => 1024,
131            "voyage-2" => 1024,
132            "voyage-large-2" => 1536,
133            "voyage-code-2" => 1536,
134            "voyage-multilingual-2" => 1024,
135            _ => 1024, // Default fallback
136        }
137    }
138
139    async fn embed(&self, text: &str) -> Result<EmbeddingVector, EmbeddingError> {
140        self.embed_batch(&[text.to_string()])
141            .await
142            .map(|mut vecs| vecs.pop().unwrap_or_default())
143    }
144
145    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<EmbeddingVector>, EmbeddingError> {
146        if texts.is_empty() {
147            return Ok(vec![]);
148        }
149
150        // Voyage AI has a limit of 128 inputs per batch
151        const MAX_BATCH_SIZE: usize = 128;
152        if texts.len() > MAX_BATCH_SIZE {
153            // Process in chunks
154            let mut all_embeddings = Vec::with_capacity(texts.len());
155            for chunk in texts.chunks(MAX_BATCH_SIZE) {
156                let chunk_embeddings = self.embed_batch_internal(chunk).await?;
157                all_embeddings.extend(chunk_embeddings);
158            }
159            return Ok(all_embeddings);
160        }
161
162        self.embed_batch_internal(texts).await
163    }
164
165    fn is_available(&self) -> bool {
166        !self.api_key.is_empty()
167    }
168}
169
170impl VoyageProvider {
171    async fn embed_batch_internal(
172        &self,
173        texts: &[String],
174    ) -> Result<Vec<EmbeddingVector>, EmbeddingError> {
175        let endpoint = self
176            .api_endpoint
177            .as_deref()
178            .unwrap_or("https://api.voyageai.com/v1/embeddings");
179
180        let request = VoyageRequest {
181            model: self.model.clone(),
182            input: texts.to_vec(),
183            input_type: self.input_type.to_string(),
184        };
185
186        let response = self
187            .client
188            .post(endpoint)
189            .header("Authorization", format!("Bearer {}", self.api_key))
190            .header("Content-Type", "application/json")
191            .json(&request)
192            .send()
193            .await
194            .map_err(|e| EmbeddingError::ApiError(format!("Request failed: {}", e)))?;
195
196        let status = response.status();
197        if !status.is_success() {
198            let error_text = response
199                .text()
200                .await
201                .unwrap_or_else(|_| "Unknown error".to_string());
202
203            // Try to parse Voyage error format
204            if let Ok(error) = serde_json::from_str::<VoyageError>(&error_text) {
205                return Err(EmbeddingError::ApiError(format!(
206                    "Voyage API error ({}): {} - {}",
207                    status, error.error.error_type, error.error.message
208                )));
209            }
210
211            return Err(EmbeddingError::ApiError(format!(
212                "Voyage API error ({}): {}",
213                status, error_text
214            )));
215        }
216
217        let voyage_response: VoyageResponse = response
218            .json()
219            .await
220            .map_err(|e| EmbeddingError::ApiError(format!("Failed to parse response: {}", e)))?;
221
222        // Sort by index to ensure correct order
223        let mut embeddings = voyage_response.data;
224        embeddings.sort_by_key(|e| e.index);
225
226        // Validate dimensions
227        let expected_dims = self.dimensions();
228        for embedding in &embeddings {
229            if embedding.embedding.len() != expected_dims {
230                return Err(EmbeddingError::DimensionMismatch {
231                    expected: expected_dims,
232                    actual: embedding.embedding.len(),
233                });
234            }
235        }
236
237        Ok(embeddings.into_iter().map(|e| e.embedding).collect())
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn test_model_id() {
247        let provider = VoyageProvider::new(
248            "test-key".to_string(),
249            "voyage-3.5".to_string(),
250            VoyageInputType::Document,
251            None,
252        );
253        assert_eq!(provider.model_id(), "voyage:voyage-3.5:document");
254
255        let provider_query = VoyageProvider::new(
256            "test-key".to_string(),
257            "voyage-3.5".to_string(),
258            VoyageInputType::Query,
259            None,
260        );
261        assert_eq!(provider_query.model_id(), "voyage:voyage-3.5:query");
262    }
263
264    #[test]
265    fn test_dimensions() {
266        let provider = VoyageProvider::new(
267            "test-key".to_string(),
268            "voyage-3.5".to_string(),
269            VoyageInputType::Document,
270            None,
271        );
272        assert_eq!(provider.dimensions(), 1024);
273
274        let provider_lite = VoyageProvider::new(
275            "test-key".to_string(),
276            "voyage-3.5-lite".to_string(),
277            VoyageInputType::Document,
278            None,
279        );
280        assert_eq!(provider_lite.dimensions(), 512);
281
282        let provider_3_large = VoyageProvider::new(
283            "test-key".to_string(),
284            "voyage-3-large".to_string(),
285            VoyageInputType::Document,
286            None,
287        );
288        assert_eq!(provider_3_large.dimensions(), 1536);
289
290        let provider_large = VoyageProvider::new(
291            "test-key".to_string(),
292            "voyage-large-2".to_string(),
293            VoyageInputType::Document,
294            None,
295        );
296        assert_eq!(provider_large.dimensions(), 1536);
297    }
298
299    #[test]
300    fn test_input_type() {
301        assert_eq!(VoyageInputType::Document.to_string(), "document");
302        assert_eq!(VoyageInputType::Query.to_string(), "query");
303    }
304
305    #[test]
306    fn test_convenience_constructors() {
307        let provider_doc =
308            VoyageProvider::new_for_documents("test-key".to_string(), "voyage-3.5".to_string());
309        assert_eq!(provider_doc.input_type(), &VoyageInputType::Document);
310
311        let provider_query =
312            VoyageProvider::new_for_queries("test-key".to_string(), "voyage-3.5".to_string());
313        assert_eq!(provider_query.input_type(), &VoyageInputType::Query);
314    }
315
316    #[test]
317    fn test_is_available() {
318        let provider = VoyageProvider::new(
319            "test-key".to_string(),
320            "voyage-3.5".to_string(),
321            VoyageInputType::Document,
322            None,
323        );
324        assert!(provider.is_available());
325
326        let provider_empty = VoyageProvider::new(
327            String::new(),
328            "voyage-3.5".to_string(),
329            VoyageInputType::Document,
330            None,
331        );
332        assert!(!provider_empty.is_available());
333    }
334
335    #[test]
336    fn test_set_input_type() {
337        let mut provider =
338            VoyageProvider::new_for_documents("test-key".to_string(), "voyage-3.5".to_string());
339        assert_eq!(provider.input_type(), &VoyageInputType::Document);
340
341        provider.set_input_type(VoyageInputType::Query);
342        assert_eq!(provider.input_type(), &VoyageInputType::Query);
343    }
344}