rig/providers/
voyageai.rs

1use crate::client::{ClientBuilderError, EmbeddingsClient, ProviderClient};
2use crate::embeddings::EmbeddingError;
3use crate::{embeddings, impl_conversion_traits};
4use serde::Deserialize;
5use serde_json::json;
6
7// ================================================================
8// Main Voyage AI Client
9// ================================================================
10const OPENAI_API_BASE_URL: &str = "https://api.voyageai.com/v1";
11
12pub struct ClientBuilder<'a> {
13    api_key: &'a str,
14    base_url: &'a str,
15    http_client: Option<reqwest::Client>,
16}
17
18impl<'a> ClientBuilder<'a> {
19    pub fn new(api_key: &'a str) -> Self {
20        Self {
21            api_key,
22            base_url: OPENAI_API_BASE_URL,
23            http_client: None,
24        }
25    }
26
27    pub fn base_url(mut self, base_url: &'a str) -> Self {
28        self.base_url = base_url;
29        self
30    }
31
32    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
33        self.http_client = Some(client);
34        self
35    }
36
37    pub fn build(self) -> Result<Client, ClientBuilderError> {
38        let http_client = if let Some(http_client) = self.http_client {
39            http_client
40        } else {
41            reqwest::Client::builder().build()?
42        };
43
44        Ok(Client {
45            base_url: self.base_url.to_string(),
46            api_key: self.api_key.to_string(),
47            http_client,
48        })
49    }
50}
51
52#[derive(Clone)]
53pub struct Client {
54    base_url: String,
55    api_key: String,
56    http_client: reqwest::Client,
57}
58
59impl std::fmt::Debug for Client {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        f.debug_struct("Client")
62            .field("base_url", &self.base_url)
63            .field("http_client", &self.http_client)
64            .field("api_key", &"<REDACTED>")
65            .finish()
66    }
67}
68
69impl Client {
70    /// Create a new Voyage AI client builder.
71    ///
72    /// # Example
73    /// ```
74    /// use rig::providers::voyageai::{ClientBuilder, self};
75    ///
76    /// // Initialize the Voyage AI client
77    /// let voyageai = Client::builder("your-voyageai-api-key")
78    ///    .build()
79    /// ```
80    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
81        ClientBuilder::new(api_key)
82    }
83
84    /// Create a new Voyage AI client. For more control, use the `builder` method.
85    ///
86    /// # Panics
87    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
88    pub fn new(api_key: &str) -> Self {
89        Self::builder(api_key)
90            .build()
91            .expect("Voyage AI client should build")
92    }
93
94    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
95        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
96        self.http_client.post(url).bearer_auth(&self.api_key)
97    }
98}
99
100impl_conversion_traits!(
101    AsCompletion,
102    AsTranscription,
103    AsImageGeneration,
104    AsAudioGeneration for Client
105);
106
107impl ProviderClient for Client {
108    /// Create a new OpenAI client from the `OPENAI_API_KEY` environment variable.
109    /// Panics if the environment variable is not set.
110    fn from_env() -> Self {
111        let api_key = std::env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY not set");
112        Self::new(&api_key)
113    }
114
115    fn from_val(input: crate::client::ProviderValue) -> Self {
116        let crate::client::ProviderValue::Simple(api_key) = input else {
117            panic!("Incorrect provider value type")
118        };
119        Self::new(&api_key)
120    }
121}
122
123/// Although the models have default embedding dimensions, there are additional alternatives for increasing and decreasing the dimensions to your requirements.
124/// See Voyage AI's documentation:  <https://docs.voyageai.com/docs/embeddings>
125impl EmbeddingsClient for Client {
126    type EmbeddingModel = EmbeddingModel;
127    fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
128        let ndims = match model {
129            VOYAGE_CODE_2 => 1536,
130            VOYAGE_3_LARGE | VOYAGE_3_5 | VOYAGE_3_5_LITE | VOYAGE_CODE_3 | VOYAGE_FINANCE_2
131            | VOYAGE_LAW_2 => 1024,
132            _ => 0,
133        };
134        EmbeddingModel::new(self.clone(), model, ndims)
135    }
136
137    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
138        EmbeddingModel::new(self.clone(), model, ndims)
139    }
140}
141
142impl EmbeddingModel {
143    pub fn new(client: Client, model: &str, ndims: usize) -> Self {
144        Self {
145            client,
146            model: model.to_string(),
147            ndims,
148        }
149    }
150}
151
152// ================================================================
153// Voyage AI Embedding API
154// ================================================================
155/// `voyage-3-large` embedding model (Voyage AI)
156pub const VOYAGE_3_LARGE: &str = "voyage-3-large";
157/// `voyage-3.5` embedding model (Voyage AI)
158pub const VOYAGE_3_5: &str = "voyage-3.5";
159/// `voyage-3.5-lite` embedding model (Voyage AI)
160pub const VOYAGE_3_5_LITE: &str = "voyage.3-5.lite";
161/// `voyage-code-3` embedding model (Voyage AI)
162pub const VOYAGE_CODE_3: &str = "voyage-code-3";
163/// `voyage-finance-2` embedding model (Voyage AI)
164pub const VOYAGE_FINANCE_2: &str = "voyage-finance-2";
165/// `voyage-law-2` embedding model (Voyage AI)
166pub const VOYAGE_LAW_2: &str = "voyage-law-2";
167/// `voyage-code-2` embedding model (Voyage AI)
168pub const VOYAGE_CODE_2: &str = "voyage-code-2";
169
170#[derive(Debug, Deserialize)]
171pub struct EmbeddingResponse {
172    pub object: String,
173    pub data: Vec<EmbeddingData>,
174    pub model: String,
175    pub usage: Usage,
176}
177
178#[derive(Clone, Debug, Deserialize)]
179pub struct Usage {
180    pub prompt_tokens: usize,
181    pub total_tokens: usize,
182}
183
184#[derive(Debug, Deserialize)]
185pub struct ApiErrorResponse {
186    pub(crate) message: String,
187}
188
189impl From<ApiErrorResponse> for EmbeddingError {
190    fn from(err: ApiErrorResponse) -> Self {
191        EmbeddingError::ProviderError(err.message)
192    }
193}
194
195#[derive(Debug, Deserialize)]
196#[serde(untagged)]
197pub(crate) enum ApiResponse<T> {
198    Ok(T),
199    Err(ApiErrorResponse),
200}
201
202impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
203    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
204        match value {
205            ApiResponse::Ok(response) => Ok(response),
206            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
207        }
208    }
209}
210
211#[derive(Debug, Deserialize)]
212pub struct EmbeddingData {
213    pub object: String,
214    pub embedding: Vec<f64>,
215    pub index: usize,
216}
217
218#[derive(Clone)]
219pub struct EmbeddingModel {
220    client: Client,
221    pub model: String,
222    ndims: usize,
223}
224
225impl embeddings::EmbeddingModel for EmbeddingModel {
226    const MAX_DOCUMENTS: usize = 1024;
227
228    fn ndims(&self) -> usize {
229        self.ndims
230    }
231
232    #[cfg_attr(feature = "worker", worker::send)]
233    async fn embed_texts(
234        &self,
235        documents: impl IntoIterator<Item = String>,
236    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
237        let documents = documents.into_iter().collect::<Vec<_>>();
238
239        let response = self
240            .client
241            .post("/embeddings")
242            .json(&json!({
243                "model": self.model,
244                "input": documents,
245            }))
246            .send()
247            .await?;
248
249        if response.status().is_success() {
250            match response.json::<ApiResponse<EmbeddingResponse>>().await? {
251                ApiResponse::Ok(response) => {
252                    tracing::info!(target: "rig",
253                        "VoyageAI embedding token usage: {}",
254                        response.usage.total_tokens
255                    );
256
257                    if response.data.len() != documents.len() {
258                        return Err(EmbeddingError::ResponseError(
259                            "Response data length does not match input length".into(),
260                        ));
261                    }
262
263                    Ok(response
264                        .data
265                        .into_iter()
266                        .zip(documents.into_iter())
267                        .map(|(embedding, document)| embeddings::Embedding {
268                            document,
269                            vec: embedding.embedding,
270                        })
271                        .collect())
272                }
273                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
274            }
275        } else {
276            Err(EmbeddingError::ProviderError(response.text().await?))
277        }
278    }
279}