1use crate::client::{EmbeddingsClient, ProviderClient, VerifyClient, VerifyError};
2use crate::embeddings::EmbeddingError;
3use crate::http_client::{HttpClientExt, with_bearer_auth};
4use crate::{embeddings, http_client, impl_conversion_traits};
5use bytes::Bytes;
6use http::Method;
7use serde::Deserialize;
8use serde_json::json;
9
10const VOYAGEAI_API_BASE_URL: &str = "https://api.voyageai.com/v1";
14
15pub struct ClientBuilder<'a, T = reqwest::Client> {
16 api_key: &'a str,
17 base_url: &'a str,
18 http_client: T,
19}
20
21impl<'a, T> ClientBuilder<'a, T>
22where
23 T: Default,
24{
25 pub fn new(api_key: &'a str) -> Self {
26 Self {
27 api_key,
28 base_url: VOYAGEAI_API_BASE_URL,
29 http_client: Default::default(),
30 }
31 }
32}
33
34impl<'a, T> ClientBuilder<'a, T> {
35 pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
36 Self {
37 api_key,
38 base_url: VOYAGEAI_API_BASE_URL,
39 http_client,
40 }
41 }
42
43 pub fn base_url(mut self, base_url: &'a str) -> Self {
44 self.base_url = base_url;
45 self
46 }
47
48 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
49 ClientBuilder {
50 api_key: self.api_key,
51 base_url: self.base_url,
52 http_client,
53 }
54 }
55
56 pub fn build(self) -> Client<T> {
57 Client {
58 base_url: self.base_url.to_string(),
59 api_key: self.api_key.to_string(),
60 http_client: self.http_client,
61 }
62 }
63}
64
65#[derive(Clone)]
66pub struct Client<T = reqwest::Client> {
67 base_url: String,
68 api_key: String,
69 http_client: T,
70}
71
72impl<T> std::fmt::Debug for Client<T>
73where
74 T: std::fmt::Debug,
75{
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 f.debug_struct("Client")
78 .field("base_url", &self.base_url)
79 .field("http_client", &self.http_client)
80 .field("api_key", &"<REDACTED>")
81 .finish()
82 }
83}
84
85impl<T> Client<T> {
86 pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
87 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
88
89 let req = http_client::Request::builder()
90 .uri(url)
91 .method(Method::POST);
92
93 with_bearer_auth(req, &self.api_key)
94 }
95}
96
97impl Client<reqwest::Client> {
98 pub fn builder(api_key: &str) -> ClientBuilder<'_, reqwest::Client> {
99 ClientBuilder::new(api_key)
100 }
101
102 pub fn new(api_key: &str) -> Self {
103 Self::builder(api_key).build()
104 }
105
106 pub fn from_env() -> Self {
107 <Self as ProviderClient>::from_env()
108 }
109}
110
111impl<T> VerifyClient for Client<T>
112where
113 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
114{
115 #[cfg_attr(feature = "worker", worker::send)]
116 async fn verify(&self) -> Result<(), VerifyError> {
117 Ok(())
119 }
120}
121
122impl_conversion_traits!(
123 AsCompletion,
124 AsTranscription,
125 AsImageGeneration,
126 AsAudioGeneration for Client<T>
127);
128
129impl<T> ProviderClient for Client<T>
130where
131 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
132{
133 fn from_env() -> Self {
136 let api_key = std::env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY not set");
137 ClientBuilder::<T>::new(&api_key).build()
138 }
139
140 fn from_val(input: crate::client::ProviderValue) -> Self {
141 let crate::client::ProviderValue::Simple(api_key) = input else {
142 panic!("Incorrect provider value type")
143 };
144 ClientBuilder::<T>::new(&api_key).build()
145 }
146}
147
148impl<T> EmbeddingsClient for Client<T>
151where
152 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
153{
154 type EmbeddingModel = EmbeddingModel<T>;
155 fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
156 let ndims = match model {
157 VOYAGE_CODE_2 => 1536,
158 VOYAGE_3_LARGE | VOYAGE_3_5 | VOYAGE_3_5_LITE | VOYAGE_CODE_3 | VOYAGE_FINANCE_2
159 | VOYAGE_LAW_2 => 1024,
160 _ => 0,
161 };
162 EmbeddingModel::new(self.clone(), model, ndims)
163 }
164
165 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
166 EmbeddingModel::new(self.clone(), model, ndims)
167 }
168}
169
170impl<T> EmbeddingModel<T> {
171 pub fn new(client: Client<T>, model: &str, ndims: usize) -> Self {
172 Self {
173 client,
174 model: model.to_string(),
175 ndims,
176 }
177 }
178}
179
180pub const VOYAGE_3_LARGE: &str = "voyage-3-large";
185pub const VOYAGE_3_5: &str = "voyage-3.5";
187pub const VOYAGE_3_5_LITE: &str = "voyage.3-5.lite";
189pub const VOYAGE_CODE_3: &str = "voyage-code-3";
191pub const VOYAGE_FINANCE_2: &str = "voyage-finance-2";
193pub const VOYAGE_LAW_2: &str = "voyage-law-2";
195pub const VOYAGE_CODE_2: &str = "voyage-code-2";
197
198#[derive(Debug, Deserialize)]
199pub struct EmbeddingResponse {
200 pub object: String,
201 pub data: Vec<EmbeddingData>,
202 pub model: String,
203 pub usage: Usage,
204}
205
206#[derive(Clone, Debug, Deserialize)]
207pub struct Usage {
208 pub prompt_tokens: usize,
209 pub total_tokens: usize,
210}
211
212#[derive(Debug, Deserialize)]
213pub struct ApiErrorResponse {
214 pub(crate) message: String,
215}
216
217impl From<ApiErrorResponse> for EmbeddingError {
218 fn from(err: ApiErrorResponse) -> Self {
219 EmbeddingError::ProviderError(err.message)
220 }
221}
222
223#[derive(Debug, Deserialize)]
224#[serde(untagged)]
225pub(crate) enum ApiResponse<T> {
226 Ok(T),
227 Err(ApiErrorResponse),
228}
229
230impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
231 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
232 match value {
233 ApiResponse::Ok(response) => Ok(response),
234 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
235 }
236 }
237}
238
239#[derive(Debug, Deserialize)]
240pub struct EmbeddingData {
241 pub object: String,
242 pub embedding: Vec<f64>,
243 pub index: usize,
244}
245
246#[derive(Clone)]
247pub struct EmbeddingModel<T> {
248 client: Client<T>,
249 pub model: String,
250 ndims: usize,
251}
252
253impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
254where
255 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
256{
257 const MAX_DOCUMENTS: usize = 1024;
258
259 fn ndims(&self) -> usize {
260 self.ndims
261 }
262
263 #[cfg_attr(feature = "worker", worker::send)]
264 async fn embed_texts(
265 &self,
266 documents: impl IntoIterator<Item = String>,
267 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
268 let documents = documents.into_iter().collect::<Vec<_>>();
269 let request = json!({
270 "model": self.model,
271 "input": documents,
272 });
273
274 let body = serde_json::to_vec(&request)?;
275
276 let req = self
277 .client
278 .post("/embeddings")?
279 .body(body)
280 .map_err(|x| EmbeddingError::HttpError(x.into()))?;
281
282 let response = self.client.http_client.send::<_, Bytes>(req).await?;
283 let status = response.status();
284 let response_body = response.into_body().into_future().await?.to_vec();
285
286 if status.is_success() {
287 match serde_json::from_slice::<ApiResponse<EmbeddingResponse>>(&response_body)? {
288 ApiResponse::Ok(response) => {
289 tracing::info!(target: "rig",
290 "VoyageAI embedding token usage: {}",
291 response.usage.total_tokens
292 );
293
294 if response.data.len() != documents.len() {
295 return Err(EmbeddingError::ResponseError(
296 "Response data length does not match input length".into(),
297 ));
298 }
299
300 Ok(response
301 .data
302 .into_iter()
303 .zip(documents.into_iter())
304 .map(|(embedding, document)| embeddings::Embedding {
305 document,
306 vec: embedding.embedding,
307 })
308 .collect())
309 }
310 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
311 }
312 } else {
313 Err(EmbeddingError::ProviderError(
314 String::from_utf8_lossy(&response_body).to_string(),
315 ))
316 }
317 }
318}