Skip to main content

rig_core/providers/
voyageai.rs

1use crate::client::{
2    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
3    ProviderClient,
4};
5use crate::embeddings;
6use crate::embeddings::EmbeddingError;
7use crate::http_client::{self, HttpClientExt};
8use crate::rerank;
9use crate::rerank::RerankError;
10use bytes::Bytes;
11use serde::Deserialize;
12use serde_json::json;
13
14// ================================================================
15// Main Voyage AI Client
16// ================================================================
17const VOYAGEAI_API_BASE_URL: &str = "https://api.voyageai.com/v1";
18
19#[derive(Debug, Default, Clone, Copy)]
20pub struct VoyageExt;
21
22#[derive(Debug, Default, Clone, Copy)]
23pub struct VoyageBuilder;
24
25type VoyageApiKey = BearerAuth;
26
27impl Provider for VoyageExt {
28    type Builder = VoyageBuilder;
29
30    /// There is currently no way to verify a Voyage api key without consuming tokens
31    const VERIFY_PATH: &'static str = "";
32}
33
34impl<H> Capabilities<H> for VoyageExt {
35    type Completion = Nothing;
36    type Embeddings = Capable<EmbeddingModel<H>>;
37    type Rerank = Capable<RerankModel<H>>;
38    type Transcription = Nothing;
39    type ModelListing = Nothing;
40    #[cfg(feature = "image")]
41    type ImageGeneration = Nothing;
42
43    #[cfg(feature = "audio")]
44    type AudioGeneration = Nothing;
45}
46
47impl DebugExt for VoyageExt {}
48
49impl ProviderBuilder for VoyageBuilder {
50    type Extension<H>
51        = VoyageExt
52    where
53        H: HttpClientExt;
54    type ApiKey = VoyageApiKey;
55
56    const BASE_URL: &'static str = VOYAGEAI_API_BASE_URL;
57
58    fn build<H>(
59        _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
60    ) -> http_client::Result<Self::Extension<H>>
61    where
62        H: HttpClientExt,
63    {
64        Ok(VoyageExt)
65    }
66}
67
68pub type Client<H = reqwest::Client> = client::Client<VoyageExt, H>;
69pub type ClientBuilder<H = crate::markers::Missing> =
70    client::ClientBuilder<VoyageBuilder, VoyageApiKey, H>;
71
72impl ProviderClient for Client {
73    type Input = String;
74    type Error = crate::client::ProviderClientError;
75
76    /// Create a new OpenAI client from the `OPENAI_API_KEY` environment variable.
77    fn from_env() -> Result<Self, Self::Error> {
78        let api_key = crate::client::required_env_var("VOYAGE_API_KEY")?;
79        Self::new(&api_key).map_err(Into::into)
80    }
81
82    fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
83        Self::new(&input).map_err(Into::into)
84    }
85}
86
87impl<T> EmbeddingModel<T> {
88    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
89        Self {
90            client,
91            model: model.into(),
92            ndims,
93        }
94    }
95
96    pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
97        Self {
98            client,
99            model: model.into(),
100            ndims,
101        }
102    }
103}
104
105// ================================================================
106// Voyage AI Embedding API
107// ================================================================
108
109/// `voyage-3-large` embedding model (Voyage AI)
110pub const VOYAGE_3_LARGE: &str = "voyage-3-large";
111/// `voyage-3.5` embedding model (Voyage AI)
112pub const VOYAGE_3_5: &str = "voyage-3.5";
113/// `voyage-3.5-lite` embedding model (Voyage AI)
114pub const VOYAGE_3_5_LITE: &str = "voyage.3-5.lite";
115/// `voyage-code-3` embedding model (Voyage AI)
116pub const VOYAGE_CODE_3: &str = "voyage-code-3";
117/// `voyage-finance-2` embedding model (Voyage AI)
118pub const VOYAGE_FINANCE_2: &str = "voyage-finance-2";
119/// `voyage-law-2` embedding model (Voyage AI)
120pub const VOYAGE_LAW_2: &str = "voyage-law-2";
121/// `voyage-code-2` embedding model (Voyage AI)
122pub const VOYAGE_CODE_2: &str = "voyage-code-2";
123
124pub fn model_dimensions_from_identifier(model_identifier: &str) -> Option<usize> {
125    match model_identifier {
126        "voyage-code-2" => Some(1536),
127        "voyage-3-large" | "voyage-3.5" | "voyage.3-5.lite" | "voyage-code-3"
128        | "voyage-finance-2" | "voyage-law-2" => Some(1024),
129        _ => None,
130    }
131}
132
133#[derive(Debug, Deserialize)]
134pub struct EmbeddingResponse {
135    pub object: String,
136    pub data: Vec<EmbeddingData>,
137    pub model: String,
138    pub usage: Usage,
139}
140
141#[derive(Clone, Debug, Deserialize)]
142pub struct Usage {
143    pub total_tokens: usize,
144}
145
146#[derive(Debug, Deserialize)]
147pub struct ApiErrorResponse {
148    pub(crate) message: String,
149}
150
151impl From<ApiErrorResponse> for EmbeddingError {
152    fn from(err: ApiErrorResponse) -> Self {
153        EmbeddingError::ProviderError(err.message)
154    }
155}
156
157#[derive(Debug, Deserialize)]
158#[serde(untagged)]
159pub(crate) enum ApiResponse<T> {
160    Ok(T),
161    Err(ApiErrorResponse),
162}
163
164impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
165    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
166        match value {
167            ApiResponse::Ok(response) => Ok(response),
168            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
169        }
170    }
171}
172
173#[derive(Debug, Deserialize)]
174pub struct EmbeddingData {
175    pub object: String,
176    pub embedding: Vec<f64>,
177    pub index: usize,
178}
179
180#[derive(Clone)]
181pub struct EmbeddingModel<T> {
182    client: Client<T>,
183    pub model: String,
184    ndims: usize,
185}
186
187impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
188where
189    T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
190{
191    const MAX_DOCUMENTS: usize = 1024;
192
193    type Client = Client<T>;
194
195    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
196        let model = model.into();
197        let dims = dims
198            .or(model_dimensions_from_identifier(&model))
199            .unwrap_or_default();
200
201        Self::new(client.clone(), model, dims)
202    }
203
204    fn ndims(&self) -> usize {
205        self.ndims
206    }
207
208    async fn embed_texts(
209        &self,
210        documents: impl IntoIterator<Item = String>,
211    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
212        let documents: Vec<String> = documents.into_iter().collect();
213        let response = self.embed_texts_with_usage(documents).await?;
214        Ok(response.embeddings)
215    }
216
217    async fn embed_texts_with_usage(
218        &self,
219        documents: impl IntoIterator<Item = String>,
220    ) -> Result<embeddings::EmbeddingResponse, EmbeddingError> {
221        let documents: Vec<String> = documents.into_iter().collect();
222        let request = json!({
223            "model": self.model,
224            "input": documents,
225        });
226
227        let body = serde_json::to_vec(&request)?;
228
229        let req = self
230            .client
231            .post("/embeddings")?
232            .body(body)
233            .map_err(|x| EmbeddingError::HttpError(x.into()))?;
234
235        let response = self.client.send::<_, Bytes>(req).await?;
236        let status = response.status();
237        let response_body = response.into_body().into_future().await?.to_vec();
238
239        if status.is_success() {
240            match serde_json::from_slice::<ApiResponse<EmbeddingResponse>>(&response_body)? {
241                ApiResponse::Ok(response) => {
242                    tracing::info!(target: "rig",
243                        "VoyageAI embedding token usage: {}",
244                        response.usage.total_tokens
245                    );
246
247                    if response.data.len() != documents.len() {
248                        return Err(EmbeddingError::ResponseError(
249                            "Response data length does not match input length".into(),
250                        ));
251                    }
252
253                    let usage = crate::completion::Usage {
254                        input_tokens: response.usage.total_tokens as u64,
255                        output_tokens: 0,
256                        total_tokens: response.usage.total_tokens as u64,
257                        cached_input_tokens: 0,
258                        cache_creation_input_tokens: 0,
259                        tool_use_prompt_tokens: 0,
260                        reasoning_tokens: 0,
261                    };
262
263                    let embeddings = response
264                        .data
265                        .into_iter()
266                        .zip(documents.into_iter())
267                        .map(|(embedding, document)| embeddings::Embedding {
268                            document,
269                            vec: embedding.embedding,
270                        })
271                        .collect();
272
273                    Ok(embeddings::EmbeddingResponse { embeddings, usage })
274                }
275                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
276            }
277        } else {
278            Err(EmbeddingError::ProviderError(
279                String::from_utf8_lossy(&response_body).to_string(),
280            ))
281        }
282    }
283}
284
285// ================================================================
286// Voyage AI Rerank API
287// ================================================================
288
289/// `rerank-2.5` reranker model (Voyage AI)
290pub const RERANK_2_5: &str = "rerank-2.5";
291/// `rerank-2.5-lite` reranker model (Voyage AI)
292pub const RERANK_2_5_LITE: &str = "rerank-2.5-lite";
293/// `rerank-2` reranker model (Voyage AI)
294pub const RERANK_2: &str = "rerank-2";
295/// `rerank-2-lite` reranker model (Voyage AI)
296pub const RERANK_2_LITE: &str = "rerank-2-lite";
297/// `rerank-1` reranker model (Voyage AI)
298pub const RERANK_1: &str = "rerank-1";
299/// `rerank-lite-1` reranker model (Voyage AI)
300pub const RERANK_LITE_1: &str = "rerank-lite-1";
301
302#[derive(Debug, Deserialize)]
303pub struct RerankApiResponse {
304    pub data: Vec<RerankApiData>,
305    pub model: String,
306    pub usage: RerankApiUsage,
307}
308
309#[derive(Debug, Deserialize)]
310pub struct RerankApiUsage {
311    pub total_tokens: usize,
312}
313
314#[derive(Debug, Deserialize)]
315pub struct RerankApiData {
316    pub index: usize,
317    pub relevance_score: f64,
318    #[serde(default)]
319    pub document: Option<String>,
320}
321
322impl From<ApiErrorResponse> for RerankError {
323    fn from(err: ApiErrorResponse) -> Self {
324        RerankError::ProviderError(err.message)
325    }
326}
327
328#[derive(Clone)]
329pub struct RerankModel<T = reqwest::Client> {
330    client: Client<T>,
331    pub model: String,
332    pub top_k: Option<usize>,
333    pub return_documents: bool,
334    pub truncation: Option<bool>,
335}
336
337impl<T> RerankModel<T> {
338    pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
339        Self {
340            client,
341            model: model.into(),
342            top_k: None,
343            return_documents: false,
344            truncation: None,
345        }
346    }
347
348    pub fn top_k(mut self, top_k: usize) -> Self {
349        self.top_k = Some(top_k);
350        self
351    }
352
353    pub fn return_documents(mut self, return_documents: bool) -> Self {
354        self.return_documents = return_documents;
355        self
356    }
357
358    pub fn truncation(mut self, truncation: bool) -> Self {
359        self.truncation = Some(truncation);
360        self
361    }
362}
363
364impl<T> rerank::RerankModel for RerankModel<T>
365where
366    T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
367{
368    const MAX_DOCUMENTS: usize = 1000;
369
370    type Client = Client<T>;
371
372    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
373        Self::new(client.clone(), model)
374    }
375
376    async fn rerank(
377        &self,
378        query: &str,
379        documents: Vec<String>,
380    ) -> Result<rerank::RerankResponse, RerankError> {
381        let mut body = json!({
382            "query": query,
383            "documents": documents,
384            "model": self.model,
385        });
386
387        let body_obj = body.as_object_mut().ok_or_else(|| {
388            RerankError::ResponseError("rerank request body must be a JSON object".into())
389        })?;
390
391        if let Some(top_k) = self.top_k {
392            body_obj.insert("top_k".to_owned(), json!(top_k));
393        }
394
395        body_obj.insert("return_documents".to_owned(), json!(self.return_documents));
396
397        if let Some(truncation) = self.truncation {
398            body_obj.insert("truncation".to_owned(), json!(truncation));
399        }
400
401        let body = serde_json::to_vec(&body)?;
402
403        let req = self
404            .client
405            .post("/rerank")?
406            .body(body)
407            .map_err(|x| RerankError::HttpError(x.into()))?;
408
409        let response = self.client.send::<_, Bytes>(req).await?;
410        let status = response.status();
411        let response_body = response.into_body().into_future().await?.to_vec();
412
413        if status.is_success() {
414            match serde_json::from_slice::<ApiResponse<RerankApiResponse>>(&response_body)? {
415                ApiResponse::Ok(response) => {
416                    tracing::info!(target: "rig",
417                        "VoyageAI rerank token usage: {}",
418                        response.usage.total_tokens
419                    );
420
421                    let usage = crate::completion::Usage {
422                        input_tokens: response.usage.total_tokens as u64,
423                        output_tokens: 0,
424                        total_tokens: response.usage.total_tokens as u64,
425                        cached_input_tokens: 0,
426                        cache_creation_input_tokens: 0,
427                        reasoning_tokens: 0,
428                        tool_use_prompt_tokens: 0,
429                    };
430
431                    let results = response
432                        .data
433                        .into_iter()
434                        .map(|d| rerank::RerankResult {
435                            index: d.index,
436                            document: d.document,
437                            relevance_score: d.relevance_score,
438                        })
439                        .collect();
440
441                    Ok(rerank::RerankResponse {
442                        results,
443                        model: response.model,
444                        usage,
445                    })
446                }
447                ApiResponse::Err(err) => Err(RerankError::ProviderError(err.message)),
448            }
449        } else {
450            Err(RerankError::ProviderError(
451                String::from_utf8_lossy(&response_body).to_string(),
452            ))
453        }
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    #[test]
460    fn test_client_initialization() {
461        let _client =
462            crate::providers::voyageai::Client::new("dummy-key").expect("Client::new() failed");
463        let _client_from_builder = crate::providers::voyageai::Client::builder()
464            .api_key("dummy-key")
465            .build()
466            .expect("Client::builder() failed");
467    }
468}