rig/providers/
voyageai.rs1use crate::client::{ClientBuilderError, EmbeddingsClient, ProviderClient};
2use crate::embeddings::EmbeddingError;
3use crate::{embeddings, impl_conversion_traits};
4use serde::Deserialize;
5use serde_json::json;
6
7const OPENAI_API_BASE_URL: &str = "https://api.voyageai.com/v1";
11
12pub struct ClientBuilder<'a> {
13 api_key: &'a str,
14 base_url: &'a str,
15 http_client: Option<reqwest::Client>,
16}
17
18impl<'a> ClientBuilder<'a> {
19 pub fn new(api_key: &'a str) -> Self {
20 Self {
21 api_key,
22 base_url: OPENAI_API_BASE_URL,
23 http_client: None,
24 }
25 }
26
27 pub fn base_url(mut self, base_url: &'a str) -> Self {
28 self.base_url = base_url;
29 self
30 }
31
32 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
33 self.http_client = Some(client);
34 self
35 }
36
37 pub fn build(self) -> Result<Client, ClientBuilderError> {
38 let http_client = if let Some(http_client) = self.http_client {
39 http_client
40 } else {
41 reqwest::Client::builder().build()?
42 };
43
44 Ok(Client {
45 base_url: self.base_url.to_string(),
46 api_key: self.api_key.to_string(),
47 http_client,
48 })
49 }
50}
51
52#[derive(Clone)]
53pub struct Client {
54 base_url: String,
55 api_key: String,
56 http_client: reqwest::Client,
57}
58
59impl std::fmt::Debug for Client {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 f.debug_struct("Client")
62 .field("base_url", &self.base_url)
63 .field("http_client", &self.http_client)
64 .field("api_key", &"<REDACTED>")
65 .finish()
66 }
67}
68
69impl Client {
70 pub fn builder(api_key: &str) -> ClientBuilder<'_> {
81 ClientBuilder::new(api_key)
82 }
83
84 pub fn new(api_key: &str) -> Self {
89 Self::builder(api_key)
90 .build()
91 .expect("Voyage AI client should build")
92 }
93
94 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
95 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
96 self.http_client.post(url).bearer_auth(&self.api_key)
97 }
98}
99
100impl_conversion_traits!(
101 AsCompletion,
102 AsTranscription,
103 AsImageGeneration,
104 AsAudioGeneration for Client
105);
106
107impl ProviderClient for Client {
108 fn from_env() -> Self {
111 let api_key = std::env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY not set");
112 Self::new(&api_key)
113 }
114
115 fn from_val(input: crate::client::ProviderValue) -> Self {
116 let crate::client::ProviderValue::Simple(api_key) = input else {
117 panic!("Incorrect provider value type")
118 };
119 Self::new(&api_key)
120 }
121}
122
123impl EmbeddingsClient for Client {
126 type EmbeddingModel = EmbeddingModel;
127 fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
128 let ndims = match model {
129 VOYAGE_CODE_2 => 1536,
130 VOYAGE_3_LARGE | VOYAGE_3_5 | VOYAGE_3_5_LITE | VOYAGE_CODE_3 | VOYAGE_FINANCE_2
131 | VOYAGE_LAW_2 => 1024,
132 _ => 0,
133 };
134 EmbeddingModel::new(self.clone(), model, ndims)
135 }
136
137 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
138 EmbeddingModel::new(self.clone(), model, ndims)
139 }
140}
141
142impl EmbeddingModel {
143 pub fn new(client: Client, model: &str, ndims: usize) -> Self {
144 Self {
145 client,
146 model: model.to_string(),
147 ndims,
148 }
149 }
150}
151
152pub const VOYAGE_3_LARGE: &str = "voyage-3-large";
157pub const VOYAGE_3_5: &str = "voyage-3.5";
159pub const VOYAGE_3_5_LITE: &str = "voyage.3-5.lite";
161pub const VOYAGE_CODE_3: &str = "voyage-code-3";
163pub const VOYAGE_FINANCE_2: &str = "voyage-finance-2";
165pub const VOYAGE_LAW_2: &str = "voyage-law-2";
167pub const VOYAGE_CODE_2: &str = "voyage-code-2";
169
170#[derive(Debug, Deserialize)]
171pub struct EmbeddingResponse {
172 pub object: String,
173 pub data: Vec<EmbeddingData>,
174 pub model: String,
175 pub usage: Usage,
176}
177
178#[derive(Clone, Debug, Deserialize)]
179pub struct Usage {
180 pub prompt_tokens: usize,
181 pub total_tokens: usize,
182}
183
184#[derive(Debug, Deserialize)]
185pub struct ApiErrorResponse {
186 pub(crate) message: String,
187}
188
189impl From<ApiErrorResponse> for EmbeddingError {
190 fn from(err: ApiErrorResponse) -> Self {
191 EmbeddingError::ProviderError(err.message)
192 }
193}
194
195#[derive(Debug, Deserialize)]
196#[serde(untagged)]
197pub(crate) enum ApiResponse<T> {
198 Ok(T),
199 Err(ApiErrorResponse),
200}
201
202impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
203 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
204 match value {
205 ApiResponse::Ok(response) => Ok(response),
206 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
207 }
208 }
209}
210
211#[derive(Debug, Deserialize)]
212pub struct EmbeddingData {
213 pub object: String,
214 pub embedding: Vec<f64>,
215 pub index: usize,
216}
217
218#[derive(Clone)]
219pub struct EmbeddingModel {
220 client: Client,
221 pub model: String,
222 ndims: usize,
223}
224
225impl embeddings::EmbeddingModel for EmbeddingModel {
226 const MAX_DOCUMENTS: usize = 1024;
227
228 fn ndims(&self) -> usize {
229 self.ndims
230 }
231
232 #[cfg_attr(feature = "worker", worker::send)]
233 async fn embed_texts(
234 &self,
235 documents: impl IntoIterator<Item = String>,
236 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
237 let documents = documents.into_iter().collect::<Vec<_>>();
238
239 let response = self
240 .client
241 .post("/embeddings")
242 .json(&json!({
243 "model": self.model,
244 "input": documents,
245 }))
246 .send()
247 .await?;
248
249 if response.status().is_success() {
250 match response.json::<ApiResponse<EmbeddingResponse>>().await? {
251 ApiResponse::Ok(response) => {
252 tracing::info!(target: "rig",
253 "VoyageAI embedding token usage: {}",
254 response.usage.total_tokens
255 );
256
257 if response.data.len() != documents.len() {
258 return Err(EmbeddingError::ResponseError(
259 "Response data length does not match input length".into(),
260 ));
261 }
262
263 Ok(response
264 .data
265 .into_iter()
266 .zip(documents.into_iter())
267 .map(|(embedding, document)| embeddings::Embedding {
268 document,
269 vec: embedding.embedding,
270 })
271 .collect())
272 }
273 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
274 }
275 } else {
276 Err(EmbeddingError::ProviderError(response.text().await?))
277 }
278 }
279}