rig/providers/
voyageai.rs

1use crate::client::{EmbeddingsClient, ProviderClient};
2use crate::embeddings::EmbeddingError;
3use crate::{embeddings, impl_conversion_traits};
4use serde::Deserialize;
5use serde_json::json;
6
7// ================================================================
8// Main Voyage AI Client
9// ================================================================
10const OPENAI_API_BASE_URL: &str = "https://api.voyageai.com/v1";
11
12#[derive(Clone)]
13pub struct Client {
14    base_url: String,
15    api_key: String,
16    http_client: reqwest::Client,
17}
18
19impl std::fmt::Debug for Client {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        f.debug_struct("Client")
22            .field("base_url", &self.base_url)
23            .field("http_client", &self.http_client)
24            .field("api_key", &"<REDACTED>")
25            .finish()
26    }
27}
28
29impl Client {
30    /// Create a new OpenAI client with the given API key.
31    pub fn new(api_key: &str) -> Self {
32        Self::from_url(api_key, OPENAI_API_BASE_URL)
33    }
34
35    /// Create a new OpenAI client with the given API key and base API URL.
36    pub fn from_url(api_key: &str, base_url: &str) -> Self {
37        Self {
38            base_url: base_url.to_string(),
39            api_key: api_key.to_string(),
40            http_client: reqwest::Client::builder()
41                .build()
42                .expect("OpenAI reqwest client should build"),
43        }
44    }
45
46    /// Use your own `reqwest::Client`.
47    /// The default headers will be automatically attached upon trying to make a request.
48    pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
49        self.http_client = client;
50
51        self
52    }
53
54    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
55        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
56        self.http_client.post(url).bearer_auth(&self.api_key)
57    }
58}
59
60impl_conversion_traits!(
61    AsCompletion,
62    AsTranscription,
63    AsImageGeneration,
64    AsAudioGeneration for Client
65);
66
67impl ProviderClient for Client {
68    /// Create a new OpenAI client from the `OPENAI_API_KEY` environment variable.
69    /// Panics if the environment variable is not set.
70    fn from_env() -> Self {
71        let api_key = std::env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY not set");
72        Self::new(&api_key)
73    }
74
75    fn from_val(input: crate::client::ProviderValue) -> Self {
76        let crate::client::ProviderValue::Simple(api_key) = input else {
77            panic!("Incorrect provider value type")
78        };
79        Self::new(&api_key)
80    }
81}
82
83/// Although the models have default embedding dimensions, there are additional alternatives for increasing and decreasing the dimensions to your requirements.
84/// See Voyage AI's documentation:  <https://docs.voyageai.com/docs/embeddings>
85impl EmbeddingsClient for Client {
86    type EmbeddingModel = EmbeddingModel;
87    fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
88        let ndims = match model {
89            VOYAGE_CODE_2 => 1536,
90            VOYAGE_3_LARGE | VOYAGE_3_5 | VOYAGE_3_5_LITE | VOYAGE_CODE_3 | VOYAGE_FINANCE_2
91            | VOYAGE_LAW_2 => 1024,
92            _ => 0,
93        };
94        EmbeddingModel::new(self.clone(), model, ndims)
95    }
96
97    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
98        EmbeddingModel::new(self.clone(), model, ndims)
99    }
100}
101
102impl EmbeddingModel {
103    pub fn new(client: Client, model: &str, ndims: usize) -> Self {
104        Self {
105            client,
106            model: model.to_string(),
107            ndims,
108        }
109    }
110}
111
112// ================================================================
113// Voyage AI Embedding API
114// ================================================================
115/// `voyage-3-large` embedding model (Voyage AI)
116pub const VOYAGE_3_LARGE: &str = "voyage-3-large";
117/// `voyage-3.5` embedding model (Voyage AI)
118pub const VOYAGE_3_5: &str = "voyage-3.5";
119/// `voyage-3.5-lite` embedding model (Voyage AI)
120pub const VOYAGE_3_5_LITE: &str = "voyage.3-5.lite";
121/// `voyage-code-3` embedding model (Voyage AI)
122pub const VOYAGE_CODE_3: &str = "voyage-code-3";
123/// `voyage-finance-2` embedding model (Voyage AI)
124pub const VOYAGE_FINANCE_2: &str = "voyage-finance-2";
125/// `voyage-law-2` embedding model (Voyage AI)
126pub const VOYAGE_LAW_2: &str = "voyage-law-2";
127/// `voyage-code-2` embedding model (Voyage AI)
128pub const VOYAGE_CODE_2: &str = "voyage-code-2";
129
130#[derive(Debug, Deserialize)]
131pub struct EmbeddingResponse {
132    pub object: String,
133    pub data: Vec<EmbeddingData>,
134    pub model: String,
135    pub usage: Usage,
136}
137
138#[derive(Clone, Debug, Deserialize)]
139pub struct Usage {
140    pub prompt_tokens: usize,
141    pub total_tokens: usize,
142}
143
144#[derive(Debug, Deserialize)]
145pub struct ApiErrorResponse {
146    pub(crate) message: String,
147}
148
149impl From<ApiErrorResponse> for EmbeddingError {
150    fn from(err: ApiErrorResponse) -> Self {
151        EmbeddingError::ProviderError(err.message)
152    }
153}
154
155#[derive(Debug, Deserialize)]
156#[serde(untagged)]
157pub(crate) enum ApiResponse<T> {
158    Ok(T),
159    Err(ApiErrorResponse),
160}
161
162impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
163    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
164        match value {
165            ApiResponse::Ok(response) => Ok(response),
166            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
167        }
168    }
169}
170
171#[derive(Debug, Deserialize)]
172pub struct EmbeddingData {
173    pub object: String,
174    pub embedding: Vec<f64>,
175    pub index: usize,
176}
177
178#[derive(Clone)]
179pub struct EmbeddingModel {
180    client: Client,
181    pub model: String,
182    ndims: usize,
183}
184
185impl embeddings::EmbeddingModel for EmbeddingModel {
186    const MAX_DOCUMENTS: usize = 1024;
187
188    fn ndims(&self) -> usize {
189        self.ndims
190    }
191
192    #[cfg_attr(feature = "worker", worker::send)]
193    async fn embed_texts(
194        &self,
195        documents: impl IntoIterator<Item = String>,
196    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
197        let documents = documents.into_iter().collect::<Vec<_>>();
198
199        let response = self
200            .client
201            .post("/embeddings")
202            .json(&json!({
203                "model": self.model,
204                "input": documents,
205            }))
206            .send()
207            .await?;
208
209        if response.status().is_success() {
210            match response.json::<ApiResponse<EmbeddingResponse>>().await? {
211                ApiResponse::Ok(response) => {
212                    tracing::info!(target: "rig",
213                        "VoyageAI embedding token usage: {}",
214                        response.usage.total_tokens
215                    );
216
217                    if response.data.len() != documents.len() {
218                        return Err(EmbeddingError::ResponseError(
219                            "Response data length does not match input length".into(),
220                        ));
221                    }
222
223                    Ok(response
224                        .data
225                        .into_iter()
226                        .zip(documents.into_iter())
227                        .map(|(embedding, document)| embeddings::Embedding {
228                            document,
229                            vec: embedding.embedding,
230                        })
231                        .collect())
232                }
233                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
234            }
235        } else {
236            Err(EmbeddingError::ProviderError(response.text().await?))
237        }
238    }
239}