agcodex_core/embeddings/providers/
voyage.rs1use super::super::EmbeddingError;
10use super::super::EmbeddingProvider;
11use super::super::EmbeddingVector;
12use reqwest::Client;
13use serde::Deserialize;
14use serde::Serialize;
15
16#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum VoyageInputType {
19 Document,
21 Query,
23}
24
25impl ToString for VoyageInputType {
26 fn to_string(&self) -> String {
27 match self {
28 VoyageInputType::Document => "document".to_string(),
29 VoyageInputType::Query => "query".to_string(),
30 }
31 }
32}
33
34pub struct VoyageProvider {
36 client: Client,
37 api_key: String,
38 model: String,
39 input_type: VoyageInputType,
40 api_endpoint: Option<String>,
41}
42
43impl VoyageProvider {
44 pub fn new(
46 api_key: String,
47 model: String,
48 input_type: VoyageInputType,
49 api_endpoint: Option<String>,
50 ) -> Self {
51 Self {
52 client: Client::new(),
53 api_key,
54 model,
55 input_type,
56 api_endpoint,
57 }
58 }
59
60 pub fn new_for_documents(api_key: String, model: String) -> Self {
62 Self::new(api_key, model, VoyageInputType::Document, None)
63 }
64
65 pub fn new_for_queries(api_key: String, model: String) -> Self {
67 Self::new(api_key, model, VoyageInputType::Query, None)
68 }
69
70 pub const fn input_type(&self) -> &VoyageInputType {
72 &self.input_type
73 }
74
75 pub const fn set_input_type(&mut self, input_type: VoyageInputType) {
77 self.input_type = input_type;
78 }
79}
80
81#[derive(Debug, Serialize)]
82struct VoyageRequest {
83 model: String,
84 input: Vec<String>,
85 input_type: String,
86}
87
88#[derive(Debug, Deserialize)]
89struct VoyageResponse {
90 data: Vec<VoyageEmbedding>,
91 _usage: VoyageUsage,
92}
93
94#[derive(Debug, Deserialize)]
95struct VoyageEmbedding {
96 embedding: Vec<f32>,
97 index: usize,
98}
99
100#[derive(Debug, Deserialize)]
101struct VoyageUsage {
102 _total_tokens: usize,
103}
104
105#[derive(Debug, Deserialize)]
106struct VoyageError {
107 error: VoyageErrorDetail,
108}
109
110#[derive(Debug, Deserialize)]
111struct VoyageErrorDetail {
112 message: String,
113 #[serde(rename = "type")]
114 error_type: String,
115 _code: Option<String>,
116}
117
118#[async_trait::async_trait]
119impl EmbeddingProvider for VoyageProvider {
120 fn model_id(&self) -> String {
121 format!("voyage:{}:{}", self.model, self.input_type.to_string())
122 }
123
124 fn dimensions(&self) -> usize {
125 match self.model.as_str() {
127 "voyage-3.5" => 1024,
128 "voyage-3.5-lite" => 512,
129 "voyage-3-large" => 1536,
130 "voyage-3" => 1024,
131 "voyage-2" => 1024,
132 "voyage-large-2" => 1536,
133 "voyage-code-2" => 1536,
134 "voyage-multilingual-2" => 1024,
135 _ => 1024, }
137 }
138
139 async fn embed(&self, text: &str) -> Result<EmbeddingVector, EmbeddingError> {
140 self.embed_batch(&[text.to_string()])
141 .await
142 .map(|mut vecs| vecs.pop().unwrap_or_default())
143 }
144
145 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<EmbeddingVector>, EmbeddingError> {
146 if texts.is_empty() {
147 return Ok(vec![]);
148 }
149
150 const MAX_BATCH_SIZE: usize = 128;
152 if texts.len() > MAX_BATCH_SIZE {
153 let mut all_embeddings = Vec::with_capacity(texts.len());
155 for chunk in texts.chunks(MAX_BATCH_SIZE) {
156 let chunk_embeddings = self.embed_batch_internal(chunk).await?;
157 all_embeddings.extend(chunk_embeddings);
158 }
159 return Ok(all_embeddings);
160 }
161
162 self.embed_batch_internal(texts).await
163 }
164
165 fn is_available(&self) -> bool {
166 !self.api_key.is_empty()
167 }
168}
169
170impl VoyageProvider {
171 async fn embed_batch_internal(
172 &self,
173 texts: &[String],
174 ) -> Result<Vec<EmbeddingVector>, EmbeddingError> {
175 let endpoint = self
176 .api_endpoint
177 .as_deref()
178 .unwrap_or("https://api.voyageai.com/v1/embeddings");
179
180 let request = VoyageRequest {
181 model: self.model.clone(),
182 input: texts.to_vec(),
183 input_type: self.input_type.to_string(),
184 };
185
186 let response = self
187 .client
188 .post(endpoint)
189 .header("Authorization", format!("Bearer {}", self.api_key))
190 .header("Content-Type", "application/json")
191 .json(&request)
192 .send()
193 .await
194 .map_err(|e| EmbeddingError::ApiError(format!("Request failed: {}", e)))?;
195
196 let status = response.status();
197 if !status.is_success() {
198 let error_text = response
199 .text()
200 .await
201 .unwrap_or_else(|_| "Unknown error".to_string());
202
203 if let Ok(error) = serde_json::from_str::<VoyageError>(&error_text) {
205 return Err(EmbeddingError::ApiError(format!(
206 "Voyage API error ({}): {} - {}",
207 status, error.error.error_type, error.error.message
208 )));
209 }
210
211 return Err(EmbeddingError::ApiError(format!(
212 "Voyage API error ({}): {}",
213 status, error_text
214 )));
215 }
216
217 let voyage_response: VoyageResponse = response
218 .json()
219 .await
220 .map_err(|e| EmbeddingError::ApiError(format!("Failed to parse response: {}", e)))?;
221
222 let mut embeddings = voyage_response.data;
224 embeddings.sort_by_key(|e| e.index);
225
226 let expected_dims = self.dimensions();
228 for embedding in &embeddings {
229 if embedding.embedding.len() != expected_dims {
230 return Err(EmbeddingError::DimensionMismatch {
231 expected: expected_dims,
232 actual: embedding.embedding.len(),
233 });
234 }
235 }
236
237 Ok(embeddings.into_iter().map(|e| e.embedding).collect())
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn test_model_id() {
247 let provider = VoyageProvider::new(
248 "test-key".to_string(),
249 "voyage-3.5".to_string(),
250 VoyageInputType::Document,
251 None,
252 );
253 assert_eq!(provider.model_id(), "voyage:voyage-3.5:document");
254
255 let provider_query = VoyageProvider::new(
256 "test-key".to_string(),
257 "voyage-3.5".to_string(),
258 VoyageInputType::Query,
259 None,
260 );
261 assert_eq!(provider_query.model_id(), "voyage:voyage-3.5:query");
262 }
263
264 #[test]
265 fn test_dimensions() {
266 let provider = VoyageProvider::new(
267 "test-key".to_string(),
268 "voyage-3.5".to_string(),
269 VoyageInputType::Document,
270 None,
271 );
272 assert_eq!(provider.dimensions(), 1024);
273
274 let provider_lite = VoyageProvider::new(
275 "test-key".to_string(),
276 "voyage-3.5-lite".to_string(),
277 VoyageInputType::Document,
278 None,
279 );
280 assert_eq!(provider_lite.dimensions(), 512);
281
282 let provider_3_large = VoyageProvider::new(
283 "test-key".to_string(),
284 "voyage-3-large".to_string(),
285 VoyageInputType::Document,
286 None,
287 );
288 assert_eq!(provider_3_large.dimensions(), 1536);
289
290 let provider_large = VoyageProvider::new(
291 "test-key".to_string(),
292 "voyage-large-2".to_string(),
293 VoyageInputType::Document,
294 None,
295 );
296 assert_eq!(provider_large.dimensions(), 1536);
297 }
298
299 #[test]
300 fn test_input_type() {
301 assert_eq!(VoyageInputType::Document.to_string(), "document");
302 assert_eq!(VoyageInputType::Query.to_string(), "query");
303 }
304
305 #[test]
306 fn test_convenience_constructors() {
307 let provider_doc =
308 VoyageProvider::new_for_documents("test-key".to_string(), "voyage-3.5".to_string());
309 assert_eq!(provider_doc.input_type(), &VoyageInputType::Document);
310
311 let provider_query =
312 VoyageProvider::new_for_queries("test-key".to_string(), "voyage-3.5".to_string());
313 assert_eq!(provider_query.input_type(), &VoyageInputType::Query);
314 }
315
316 #[test]
317 fn test_is_available() {
318 let provider = VoyageProvider::new(
319 "test-key".to_string(),
320 "voyage-3.5".to_string(),
321 VoyageInputType::Document,
322 None,
323 );
324 assert!(provider.is_available());
325
326 let provider_empty = VoyageProvider::new(
327 String::new(),
328 "voyage-3.5".to_string(),
329 VoyageInputType::Document,
330 None,
331 );
332 assert!(!provider_empty.is_available());
333 }
334
335 #[test]
336 fn test_set_input_type() {
337 let mut provider =
338 VoyageProvider::new_for_documents("test-key".to_string(), "voyage-3.5".to_string());
339 assert_eq!(provider.input_type(), &VoyageInputType::Document);
340
341 provider.set_input_type(VoyageInputType::Query);
342 assert_eq!(provider.input_type(), &VoyageInputType::Query);
343 }
344}