rig/providers/
voyageai.rs

1use crate::client::{EmbeddingsClient, ProviderClient};
2use crate::embeddings::EmbeddingError;
3use crate::{embeddings, impl_conversion_traits};
4use serde::Deserialize;
5use serde_json::json;
6
7// ================================================================
8// Main Voyage AI Client
9// ================================================================
10const OPENAI_API_BASE_URL: &str = "https://api.voyageai.com/v1";
11
12#[derive(Clone)]
13pub struct Client {
14    base_url: String,
15    api_key: String,
16    http_client: reqwest::Client,
17}
18
19impl std::fmt::Debug for Client {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        f.debug_struct("Client")
22            .field("base_url", &self.base_url)
23            .field("http_client", &self.http_client)
24            .field("api_key", &"<REDACTED>")
25            .finish()
26    }
27}
28
29impl Client {
30    /// Create a new OpenAI client with the given API key.
31    pub fn new(api_key: &str) -> Self {
32        Self::from_url(api_key, OPENAI_API_BASE_URL)
33    }
34
35    /// Create a new OpenAI client with the given API key and base API URL.
36    pub fn from_url(api_key: &str, base_url: &str) -> Self {
37        Self {
38            base_url: base_url.to_string(),
39            api_key: api_key.to_string(),
40            http_client: reqwest::Client::builder()
41                .build()
42                .expect("OpenAI reqwest client should build"),
43        }
44    }
45
46    /// Use your own `reqwest::Client`.
47    /// The default headers will be automatically attached upon trying to make a request.
48    pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
49        self.http_client = client;
50
51        self
52    }
53
54    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
55        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
56        self.http_client.post(url).bearer_auth(&self.api_key)
57    }
58}
59
60impl_conversion_traits!(
61    AsCompletion,
62    AsTranscription,
63    AsImageGeneration,
64    AsAudioGeneration for Client
65);
66
67impl ProviderClient for Client {
68    /// Create a new OpenAI client from the `OPENAI_API_KEY` environment variable.
69    /// Panics if the environment variable is not set.
70    fn from_env() -> Self {
71        let api_key = std::env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY not set");
72        Self::new(&api_key)
73    }
74}
75
76/// Although the models have default embedding dimensions, there are additional alternatives for increasing and decreasing the dimensions to your requirements.
77/// See Voyage AI's documentation:  <https://docs.voyageai.com/docs/embeddings>
78impl EmbeddingsClient for Client {
79    type EmbeddingModel = EmbeddingModel;
80    fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
81        let ndims = match model {
82            VOYAGE_CODE_2 => 1536,
83            VOYAGE_3_LARGE | VOYAGE_3_5 | VOYAGE_3_5_LITE | VOYAGE_CODE_3 | VOYAGE_FINANCE_2
84            | VOYAGE_LAW_2 => 1024,
85            _ => 0,
86        };
87        EmbeddingModel::new(self.clone(), model, ndims)
88    }
89
90    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
91        EmbeddingModel::new(self.clone(), model, ndims)
92    }
93}
94
95impl EmbeddingModel {
96    pub fn new(client: Client, model: &str, ndims: usize) -> Self {
97        Self {
98            client,
99            model: model.to_string(),
100            ndims,
101        }
102    }
103}
104
105// ================================================================
106// Voyage AI Embedding API
107// ================================================================
108/// `voyage-3-large` embedding model (Voyage AI)
109pub const VOYAGE_3_LARGE: &str = "voyage-3-large";
110/// `voyage-3.5` embedding model (Voyage AI)
111pub const VOYAGE_3_5: &str = "voyage-3.5";
112/// `voyage-3.5-lite` embedding model (Voyage AI)
113pub const VOYAGE_3_5_LITE: &str = "voyage.3-5.lite";
114/// `voyage-code-3` embedding model (Voyage AI)
115pub const VOYAGE_CODE_3: &str = "voyage-code-3";
116/// `voyage-finance-2` embedding model (Voyage AI)
117pub const VOYAGE_FINANCE_2: &str = "voyage-finance-2";
118/// `voyage-law-2` embedding model (Voyage AI)
119pub const VOYAGE_LAW_2: &str = "voyage-law-2";
120/// `voyage-code-2` embedding model (Voyage AI)
121pub const VOYAGE_CODE_2: &str = "voyage-code-2";
122
123#[derive(Debug, Deserialize)]
124pub struct EmbeddingResponse {
125    pub object: String,
126    pub data: Vec<EmbeddingData>,
127    pub model: String,
128    pub usage: Usage,
129}
130
131#[derive(Clone, Debug, Deserialize)]
132pub struct Usage {
133    pub prompt_tokens: usize,
134    pub total_tokens: usize,
135}
136
137#[derive(Debug, Deserialize)]
138pub struct ApiErrorResponse {
139    pub(crate) message: String,
140}
141
142impl From<ApiErrorResponse> for EmbeddingError {
143    fn from(err: ApiErrorResponse) -> Self {
144        EmbeddingError::ProviderError(err.message)
145    }
146}
147
148#[derive(Debug, Deserialize)]
149#[serde(untagged)]
150pub(crate) enum ApiResponse<T> {
151    Ok(T),
152    Err(ApiErrorResponse),
153}
154
155impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
156    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
157        match value {
158            ApiResponse::Ok(response) => Ok(response),
159            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
160        }
161    }
162}
163
164#[derive(Debug, Deserialize)]
165pub struct EmbeddingData {
166    pub object: String,
167    pub embedding: Vec<f64>,
168    pub index: usize,
169}
170
171#[derive(Clone)]
172pub struct EmbeddingModel {
173    client: Client,
174    pub model: String,
175    ndims: usize,
176}
177
178impl embeddings::EmbeddingModel for EmbeddingModel {
179    const MAX_DOCUMENTS: usize = 1024;
180
181    fn ndims(&self) -> usize {
182        self.ndims
183    }
184
185    #[cfg_attr(feature = "worker", worker::send)]
186    async fn embed_texts(
187        &self,
188        documents: impl IntoIterator<Item = String>,
189    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
190        let documents = documents.into_iter().collect::<Vec<_>>();
191
192        let response = self
193            .client
194            .post("/embeddings")
195            .json(&json!({
196                "model": self.model,
197                "input": documents,
198            }))
199            .send()
200            .await?;
201
202        if response.status().is_success() {
203            match response.json::<ApiResponse<EmbeddingResponse>>().await? {
204                ApiResponse::Ok(response) => {
205                    tracing::info!(target: "rig",
206                        "VoyageAI embedding token usage: {}",
207                        response.usage.total_tokens
208                    );
209
210                    if response.data.len() != documents.len() {
211                        return Err(EmbeddingError::ResponseError(
212                            "Response data length does not match input length".into(),
213                        ));
214                    }
215
216                    Ok(response
217                        .data
218                        .into_iter()
219                        .zip(documents.into_iter())
220                        .map(|(embedding, document)| embeddings::Embedding {
221                            document,
222                            vec: embedding.embedding,
223                        })
224                        .collect())
225                }
226                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
227            }
228        } else {
229            Err(EmbeddingError::ProviderError(response.text().await?))
230        }
231    }
232}