1use crate::client::{
2 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
3 ProviderClient,
4};
5use crate::embeddings;
6use crate::embeddings::EmbeddingError;
7use crate::http_client::{self, HttpClientExt};
8use crate::rerank;
9use crate::rerank::RerankError;
10use bytes::Bytes;
11use serde::Deserialize;
12use serde_json::json;
13
14const VOYAGEAI_API_BASE_URL: &str = "https://api.voyageai.com/v1";
18
19#[derive(Debug, Default, Clone, Copy)]
20pub struct VoyageExt;
21
22#[derive(Debug, Default, Clone, Copy)]
23pub struct VoyageBuilder;
24
25type VoyageApiKey = BearerAuth;
26
27impl Provider for VoyageExt {
28 type Builder = VoyageBuilder;
29
30 const VERIFY_PATH: &'static str = "";
32}
33
34impl<H> Capabilities<H> for VoyageExt {
35 type Completion = Nothing;
36 type Embeddings = Capable<EmbeddingModel<H>>;
37 type Rerank = Capable<RerankModel<H>>;
38 type Transcription = Nothing;
39 type ModelListing = Nothing;
40 #[cfg(feature = "image")]
41 type ImageGeneration = Nothing;
42
43 #[cfg(feature = "audio")]
44 type AudioGeneration = Nothing;
45}
46
47impl DebugExt for VoyageExt {}
48
49impl ProviderBuilder for VoyageBuilder {
50 type Extension<H>
51 = VoyageExt
52 where
53 H: HttpClientExt;
54 type ApiKey = VoyageApiKey;
55
56 const BASE_URL: &'static str = VOYAGEAI_API_BASE_URL;
57
58 fn build<H>(
59 _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
60 ) -> http_client::Result<Self::Extension<H>>
61 where
62 H: HttpClientExt,
63 {
64 Ok(VoyageExt)
65 }
66}
67
68pub type Client<H = reqwest::Client> = client::Client<VoyageExt, H>;
69pub type ClientBuilder<H = crate::markers::Missing> =
70 client::ClientBuilder<VoyageBuilder, VoyageApiKey, H>;
71
72impl ProviderClient for Client {
73 type Input = String;
74 type Error = crate::client::ProviderClientError;
75
76 fn from_env() -> Result<Self, Self::Error> {
78 let api_key = crate::client::required_env_var("VOYAGE_API_KEY")?;
79 Self::new(&api_key).map_err(Into::into)
80 }
81
82 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
83 Self::new(&input).map_err(Into::into)
84 }
85}
86
87impl<T> EmbeddingModel<T> {
88 pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
89 Self {
90 client,
91 model: model.into(),
92 ndims,
93 }
94 }
95
96 pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
97 Self {
98 client,
99 model: model.into(),
100 ndims,
101 }
102 }
103}
104
105pub const VOYAGE_3_LARGE: &str = "voyage-3-large";
111pub const VOYAGE_3_5: &str = "voyage-3.5";
113pub const VOYAGE_3_5_LITE: &str = "voyage.3-5.lite";
115pub const VOYAGE_CODE_3: &str = "voyage-code-3";
117pub const VOYAGE_FINANCE_2: &str = "voyage-finance-2";
119pub const VOYAGE_LAW_2: &str = "voyage-law-2";
121pub const VOYAGE_CODE_2: &str = "voyage-code-2";
123
124pub fn model_dimensions_from_identifier(model_identifier: &str) -> Option<usize> {
125 match model_identifier {
126 "voyage-code-2" => Some(1536),
127 "voyage-3-large" | "voyage-3.5" | "voyage.3-5.lite" | "voyage-code-3"
128 | "voyage-finance-2" | "voyage-law-2" => Some(1024),
129 _ => None,
130 }
131}
132
133#[derive(Debug, Deserialize)]
134pub struct EmbeddingResponse {
135 pub object: String,
136 pub data: Vec<EmbeddingData>,
137 pub model: String,
138 pub usage: Usage,
139}
140
141#[derive(Clone, Debug, Deserialize)]
142pub struct Usage {
143 pub total_tokens: usize,
144}
145
146#[derive(Debug, Deserialize)]
147pub struct ApiErrorResponse {
148 pub(crate) message: String,
149}
150
151impl From<ApiErrorResponse> for EmbeddingError {
152 fn from(err: ApiErrorResponse) -> Self {
153 EmbeddingError::ProviderError(err.message)
154 }
155}
156
157#[derive(Debug, Deserialize)]
158#[serde(untagged)]
159pub(crate) enum ApiResponse<T> {
160 Ok(T),
161 Err(ApiErrorResponse),
162}
163
164impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
165 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
166 match value {
167 ApiResponse::Ok(response) => Ok(response),
168 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
169 }
170 }
171}
172
173#[derive(Debug, Deserialize)]
174pub struct EmbeddingData {
175 pub object: String,
176 pub embedding: Vec<f64>,
177 pub index: usize,
178}
179
180#[derive(Clone)]
181pub struct EmbeddingModel<T> {
182 client: Client<T>,
183 pub model: String,
184 ndims: usize,
185}
186
187impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
188where
189 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
190{
191 const MAX_DOCUMENTS: usize = 1024;
192
193 type Client = Client<T>;
194
195 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
196 let model = model.into();
197 let dims = dims
198 .or(model_dimensions_from_identifier(&model))
199 .unwrap_or_default();
200
201 Self::new(client.clone(), model, dims)
202 }
203
204 fn ndims(&self) -> usize {
205 self.ndims
206 }
207
208 async fn embed_texts(
209 &self,
210 documents: impl IntoIterator<Item = String>,
211 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
212 let documents: Vec<String> = documents.into_iter().collect();
213 let response = self.embed_texts_with_usage(documents).await?;
214 Ok(response.embeddings)
215 }
216
217 async fn embed_texts_with_usage(
218 &self,
219 documents: impl IntoIterator<Item = String>,
220 ) -> Result<embeddings::EmbeddingResponse, EmbeddingError> {
221 let documents: Vec<String> = documents.into_iter().collect();
222 let request = json!({
223 "model": self.model,
224 "input": documents,
225 });
226
227 let body = serde_json::to_vec(&request)?;
228
229 let req = self
230 .client
231 .post("/embeddings")?
232 .body(body)
233 .map_err(|x| EmbeddingError::HttpError(x.into()))?;
234
235 let response = self.client.send::<_, Bytes>(req).await?;
236 let status = response.status();
237 let response_body = response.into_body().into_future().await?.to_vec();
238
239 if status.is_success() {
240 match serde_json::from_slice::<ApiResponse<EmbeddingResponse>>(&response_body)? {
241 ApiResponse::Ok(response) => {
242 tracing::info!(target: "rig",
243 "VoyageAI embedding token usage: {}",
244 response.usage.total_tokens
245 );
246
247 if response.data.len() != documents.len() {
248 return Err(EmbeddingError::ResponseError(
249 "Response data length does not match input length".into(),
250 ));
251 }
252
253 let usage = crate::completion::Usage {
254 input_tokens: response.usage.total_tokens as u64,
255 output_tokens: 0,
256 total_tokens: response.usage.total_tokens as u64,
257 cached_input_tokens: 0,
258 cache_creation_input_tokens: 0,
259 tool_use_prompt_tokens: 0,
260 reasoning_tokens: 0,
261 };
262
263 let embeddings = response
264 .data
265 .into_iter()
266 .zip(documents.into_iter())
267 .map(|(embedding, document)| embeddings::Embedding {
268 document,
269 vec: embedding.embedding,
270 })
271 .collect();
272
273 Ok(embeddings::EmbeddingResponse { embeddings, usage })
274 }
275 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
276 }
277 } else {
278 Err(EmbeddingError::ProviderError(
279 String::from_utf8_lossy(&response_body).to_string(),
280 ))
281 }
282 }
283}
284
285pub const RERANK_2_5: &str = "rerank-2.5";
291pub const RERANK_2_5_LITE: &str = "rerank-2.5-lite";
293pub const RERANK_2: &str = "rerank-2";
295pub const RERANK_2_LITE: &str = "rerank-2-lite";
297pub const RERANK_1: &str = "rerank-1";
299pub const RERANK_LITE_1: &str = "rerank-lite-1";
301
302#[derive(Debug, Deserialize)]
303pub struct RerankApiResponse {
304 pub data: Vec<RerankApiData>,
305 pub model: String,
306 pub usage: RerankApiUsage,
307}
308
309#[derive(Debug, Deserialize)]
310pub struct RerankApiUsage {
311 pub total_tokens: usize,
312}
313
314#[derive(Debug, Deserialize)]
315pub struct RerankApiData {
316 pub index: usize,
317 pub relevance_score: f64,
318 #[serde(default)]
319 pub document: Option<String>,
320}
321
322impl From<ApiErrorResponse> for RerankError {
323 fn from(err: ApiErrorResponse) -> Self {
324 RerankError::ProviderError(err.message)
325 }
326}
327
328#[derive(Clone)]
329pub struct RerankModel<T = reqwest::Client> {
330 client: Client<T>,
331 pub model: String,
332 pub top_k: Option<usize>,
333 pub return_documents: bool,
334 pub truncation: Option<bool>,
335}
336
337impl<T> RerankModel<T> {
338 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
339 Self {
340 client,
341 model: model.into(),
342 top_k: None,
343 return_documents: false,
344 truncation: None,
345 }
346 }
347
348 pub fn top_k(mut self, top_k: usize) -> Self {
349 self.top_k = Some(top_k);
350 self
351 }
352
353 pub fn return_documents(mut self, return_documents: bool) -> Self {
354 self.return_documents = return_documents;
355 self
356 }
357
358 pub fn truncation(mut self, truncation: bool) -> Self {
359 self.truncation = Some(truncation);
360 self
361 }
362}
363
364impl<T> rerank::RerankModel for RerankModel<T>
365where
366 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
367{
368 const MAX_DOCUMENTS: usize = 1000;
369
370 type Client = Client<T>;
371
372 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
373 Self::new(client.clone(), model)
374 }
375
376 async fn rerank(
377 &self,
378 query: &str,
379 documents: Vec<String>,
380 ) -> Result<rerank::RerankResponse, RerankError> {
381 let mut body = json!({
382 "query": query,
383 "documents": documents,
384 "model": self.model,
385 });
386
387 let body_obj = body.as_object_mut().ok_or_else(|| {
388 RerankError::ResponseError("rerank request body must be a JSON object".into())
389 })?;
390
391 if let Some(top_k) = self.top_k {
392 body_obj.insert("top_k".to_owned(), json!(top_k));
393 }
394
395 body_obj.insert("return_documents".to_owned(), json!(self.return_documents));
396
397 if let Some(truncation) = self.truncation {
398 body_obj.insert("truncation".to_owned(), json!(truncation));
399 }
400
401 let body = serde_json::to_vec(&body)?;
402
403 let req = self
404 .client
405 .post("/rerank")?
406 .body(body)
407 .map_err(|x| RerankError::HttpError(x.into()))?;
408
409 let response = self.client.send::<_, Bytes>(req).await?;
410 let status = response.status();
411 let response_body = response.into_body().into_future().await?.to_vec();
412
413 if status.is_success() {
414 match serde_json::from_slice::<ApiResponse<RerankApiResponse>>(&response_body)? {
415 ApiResponse::Ok(response) => {
416 tracing::info!(target: "rig",
417 "VoyageAI rerank token usage: {}",
418 response.usage.total_tokens
419 );
420
421 let usage = crate::completion::Usage {
422 input_tokens: response.usage.total_tokens as u64,
423 output_tokens: 0,
424 total_tokens: response.usage.total_tokens as u64,
425 cached_input_tokens: 0,
426 cache_creation_input_tokens: 0,
427 reasoning_tokens: 0,
428 tool_use_prompt_tokens: 0,
429 };
430
431 let results = response
432 .data
433 .into_iter()
434 .map(|d| rerank::RerankResult {
435 index: d.index,
436 document: d.document,
437 relevance_score: d.relevance_score,
438 })
439 .collect();
440
441 Ok(rerank::RerankResponse {
442 results,
443 model: response.model,
444 usage,
445 })
446 }
447 ApiResponse::Err(err) => Err(RerankError::ProviderError(err.message)),
448 }
449 } else {
450 Err(RerankError::ProviderError(
451 String::from_utf8_lossy(&response_body).to_string(),
452 ))
453 }
454 }
455}
456
457#[cfg(test)]
458mod tests {
459 #[test]
460 fn test_client_initialization() {
461 let _client =
462 crate::providers::voyageai::Client::new("dummy-key").expect("Client::new() failed");
463 let _client_from_builder = crate::providers::voyageai::Client::builder()
464 .api_key("dummy-key")
465 .build()
466 .expect("Client::builder() failed");
467 }
468}