rig/providers/
voyageai.rs

1use crate::client::{
2    ClientBuilderError, EmbeddingsClient, ProviderClient, VerifyClient, VerifyError,
3};
4use crate::embeddings::EmbeddingError;
5use crate::{embeddings, impl_conversion_traits};
6use serde::Deserialize;
7use serde_json::json;
8
9// ================================================================
10// Main Voyage AI Client
11// ================================================================
12const OPENAI_API_BASE_URL: &str = "https://api.voyageai.com/v1";
13
14pub struct ClientBuilder<'a> {
15    api_key: &'a str,
16    base_url: &'a str,
17    http_client: Option<reqwest::Client>,
18}
19
20impl<'a> ClientBuilder<'a> {
21    pub fn new(api_key: &'a str) -> Self {
22        Self {
23            api_key,
24            base_url: OPENAI_API_BASE_URL,
25            http_client: None,
26        }
27    }
28
29    pub fn base_url(mut self, base_url: &'a str) -> Self {
30        self.base_url = base_url;
31        self
32    }
33
34    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
35        self.http_client = Some(client);
36        self
37    }
38
39    pub fn build(self) -> Result<Client, ClientBuilderError> {
40        let http_client = if let Some(http_client) = self.http_client {
41            http_client
42        } else {
43            reqwest::Client::builder().build()?
44        };
45
46        Ok(Client {
47            base_url: self.base_url.to_string(),
48            api_key: self.api_key.to_string(),
49            http_client,
50        })
51    }
52}
53
54#[derive(Clone)]
55pub struct Client {
56    base_url: String,
57    api_key: String,
58    http_client: reqwest::Client,
59}
60
61impl std::fmt::Debug for Client {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        f.debug_struct("Client")
64            .field("base_url", &self.base_url)
65            .field("http_client", &self.http_client)
66            .field("api_key", &"<REDACTED>")
67            .finish()
68    }
69}
70
71impl Client {
72    /// Create a new Voyage AI client builder.
73    ///
74    /// # Example
75    /// ```
76    /// use rig::providers::voyageai::{ClientBuilder, self};
77    ///
78    /// // Initialize the Voyage AI client
79    /// let voyageai = Client::builder("your-voyageai-api-key")
80    ///    .build()
81    /// ```
82    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
83        ClientBuilder::new(api_key)
84    }
85
86    /// Create a new Voyage AI client. For more control, use the `builder` method.
87    ///
88    /// # Panics
89    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
90    pub fn new(api_key: &str) -> Self {
91        Self::builder(api_key)
92            .build()
93            .expect("Voyage AI client should build")
94    }
95
96    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
97        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
98        self.http_client.post(url).bearer_auth(&self.api_key)
99    }
100}
101
102impl VerifyClient for Client {
103    #[cfg_attr(feature = "worker", worker::send)]
104    async fn verify(&self) -> Result<(), VerifyError> {
105        // No API endpoint to verify the API key
106        Ok(())
107    }
108}
109
110impl_conversion_traits!(
111    AsCompletion,
112    AsTranscription,
113    AsImageGeneration,
114    AsAudioGeneration for Client
115);
116
117impl ProviderClient for Client {
118    /// Create a new OpenAI client from the `OPENAI_API_KEY` environment variable.
119    /// Panics if the environment variable is not set.
120    fn from_env() -> Self {
121        let api_key = std::env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY not set");
122        Self::new(&api_key)
123    }
124
125    fn from_val(input: crate::client::ProviderValue) -> Self {
126        let crate::client::ProviderValue::Simple(api_key) = input else {
127            panic!("Incorrect provider value type")
128        };
129        Self::new(&api_key)
130    }
131}
132
133/// Although the models have default embedding dimensions, there are additional alternatives for increasing and decreasing the dimensions to your requirements.
134/// See Voyage AI's documentation:  <https://docs.voyageai.com/docs/embeddings>
135impl EmbeddingsClient for Client {
136    type EmbeddingModel = EmbeddingModel;
137    fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
138        let ndims = match model {
139            VOYAGE_CODE_2 => 1536,
140            VOYAGE_3_LARGE | VOYAGE_3_5 | VOYAGE_3_5_LITE | VOYAGE_CODE_3 | VOYAGE_FINANCE_2
141            | VOYAGE_LAW_2 => 1024,
142            _ => 0,
143        };
144        EmbeddingModel::new(self.clone(), model, ndims)
145    }
146
147    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
148        EmbeddingModel::new(self.clone(), model, ndims)
149    }
150}
151
152impl EmbeddingModel {
153    pub fn new(client: Client, model: &str, ndims: usize) -> Self {
154        Self {
155            client,
156            model: model.to_string(),
157            ndims,
158        }
159    }
160}
161
162// ================================================================
163// Voyage AI Embedding API
164// ================================================================
165/// `voyage-3-large` embedding model (Voyage AI)
166pub const VOYAGE_3_LARGE: &str = "voyage-3-large";
167/// `voyage-3.5` embedding model (Voyage AI)
168pub const VOYAGE_3_5: &str = "voyage-3.5";
169/// `voyage-3.5-lite` embedding model (Voyage AI)
170pub const VOYAGE_3_5_LITE: &str = "voyage.3-5.lite";
171/// `voyage-code-3` embedding model (Voyage AI)
172pub const VOYAGE_CODE_3: &str = "voyage-code-3";
173/// `voyage-finance-2` embedding model (Voyage AI)
174pub const VOYAGE_FINANCE_2: &str = "voyage-finance-2";
175/// `voyage-law-2` embedding model (Voyage AI)
176pub const VOYAGE_LAW_2: &str = "voyage-law-2";
177/// `voyage-code-2` embedding model (Voyage AI)
178pub const VOYAGE_CODE_2: &str = "voyage-code-2";
179
180#[derive(Debug, Deserialize)]
181pub struct EmbeddingResponse {
182    pub object: String,
183    pub data: Vec<EmbeddingData>,
184    pub model: String,
185    pub usage: Usage,
186}
187
188#[derive(Clone, Debug, Deserialize)]
189pub struct Usage {
190    pub prompt_tokens: usize,
191    pub total_tokens: usize,
192}
193
194#[derive(Debug, Deserialize)]
195pub struct ApiErrorResponse {
196    pub(crate) message: String,
197}
198
199impl From<ApiErrorResponse> for EmbeddingError {
200    fn from(err: ApiErrorResponse) -> Self {
201        EmbeddingError::ProviderError(err.message)
202    }
203}
204
205#[derive(Debug, Deserialize)]
206#[serde(untagged)]
207pub(crate) enum ApiResponse<T> {
208    Ok(T),
209    Err(ApiErrorResponse),
210}
211
212impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
213    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
214        match value {
215            ApiResponse::Ok(response) => Ok(response),
216            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
217        }
218    }
219}
220
221#[derive(Debug, Deserialize)]
222pub struct EmbeddingData {
223    pub object: String,
224    pub embedding: Vec<f64>,
225    pub index: usize,
226}
227
228#[derive(Clone)]
229pub struct EmbeddingModel {
230    client: Client,
231    pub model: String,
232    ndims: usize,
233}
234
235impl embeddings::EmbeddingModel for EmbeddingModel {
236    const MAX_DOCUMENTS: usize = 1024;
237
238    fn ndims(&self) -> usize {
239        self.ndims
240    }
241
242    #[cfg_attr(feature = "worker", worker::send)]
243    async fn embed_texts(
244        &self,
245        documents: impl IntoIterator<Item = String>,
246    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
247        let documents = documents.into_iter().collect::<Vec<_>>();
248
249        let response = self
250            .client
251            .post("/embeddings")
252            .json(&json!({
253                "model": self.model,
254                "input": documents,
255            }))
256            .send()
257            .await?;
258
259        if response.status().is_success() {
260            match response.json::<ApiResponse<EmbeddingResponse>>().await? {
261                ApiResponse::Ok(response) => {
262                    tracing::info!(target: "rig",
263                        "VoyageAI embedding token usage: {}",
264                        response.usage.total_tokens
265                    );
266
267                    if response.data.len() != documents.len() {
268                        return Err(EmbeddingError::ResponseError(
269                            "Response data length does not match input length".into(),
270                        ));
271                    }
272
273                    Ok(response
274                        .data
275                        .into_iter()
276                        .zip(documents.into_iter())
277                        .map(|(embedding, document)| embeddings::Embedding {
278                            document,
279                            vec: embedding.embedding,
280                        })
281                        .collect())
282                }
283                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
284            }
285        } else {
286            Err(EmbeddingError::ProviderError(response.text().await?))
287        }
288    }
289}