1use crate::client::{EmbeddingsClient, ProviderClient, VerifyClient, VerifyError};
2use crate::embeddings::EmbeddingError;
3use crate::{embeddings, http_client, impl_conversion_traits};
4use serde::Deserialize;
5use serde_json::json;
6
7const OPENAI_API_BASE_URL: &str = "https://api.voyageai.com/v1";
11
12pub struct ClientBuilder<'a, T = reqwest::Client> {
13 api_key: &'a str,
14 base_url: &'a str,
15 http_client: T,
16}
17
18impl<'a, T> ClientBuilder<'a, T>
19where
20 T: Default,
21{
22 pub fn new(api_key: &'a str) -> Self {
23 Self {
24 api_key,
25 base_url: OPENAI_API_BASE_URL,
26 http_client: Default::default(),
27 }
28 }
29}
30
31impl<'a, T> ClientBuilder<'a, T> {
32 pub fn base_url(mut self, base_url: &'a str) -> Self {
33 self.base_url = base_url;
34 self
35 }
36
37 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
38 ClientBuilder {
39 api_key: self.api_key,
40 base_url: self.base_url,
41 http_client,
42 }
43 }
44
45 pub fn build(self) -> Client<T> {
46 Client {
47 base_url: self.base_url.to_string(),
48 api_key: self.api_key.to_string(),
49 http_client: self.http_client,
50 }
51 }
52}
53
54#[derive(Clone)]
55pub struct Client<T = reqwest::Client> {
56 base_url: String,
57 api_key: String,
58 http_client: T,
59}
60
61impl<T> std::fmt::Debug for Client<T>
62where
63 T: std::fmt::Debug,
64{
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("Client")
67 .field("base_url", &self.base_url)
68 .field("http_client", &self.http_client)
69 .field("api_key", &"<REDACTED>")
70 .finish()
71 }
72}
73
74impl<T> Client<T>
75where
76 T: Default,
77{
78 pub fn builder(api_key: &str) -> ClientBuilder<'_, T> {
89 ClientBuilder::new(api_key)
90 }
91
92 pub fn new(api_key: &str) -> Self {
97 Self::builder(api_key).build()
98 }
99}
100
101impl Client<reqwest::Client> {
102 pub(crate) fn reqwest_post(&self, path: &str) -> reqwest::RequestBuilder {
103 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
104 self.http_client.post(url).bearer_auth(&self.api_key)
105 }
106}
107
108impl VerifyClient for Client<reqwest::Client> {
109 #[cfg_attr(feature = "worker", worker::send)]
110 async fn verify(&self) -> Result<(), VerifyError> {
111 Ok(())
113 }
114}
115
116impl_conversion_traits!(
117 AsCompletion,
118 AsTranscription,
119 AsImageGeneration,
120 AsAudioGeneration for Client<T>
121);
122
123impl ProviderClient for Client<reqwest::Client> {
124 fn from_env() -> Self {
127 let api_key = std::env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY not set");
128 Self::new(&api_key)
129 }
130
131 fn from_val(input: crate::client::ProviderValue) -> Self {
132 let crate::client::ProviderValue::Simple(api_key) = input else {
133 panic!("Incorrect provider value type")
134 };
135 Self::new(&api_key)
136 }
137}
138
139impl EmbeddingsClient for Client<reqwest::Client> {
142 type EmbeddingModel = EmbeddingModel<reqwest::Client>;
143 fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
144 let ndims = match model {
145 VOYAGE_CODE_2 => 1536,
146 VOYAGE_3_LARGE | VOYAGE_3_5 | VOYAGE_3_5_LITE | VOYAGE_CODE_3 | VOYAGE_FINANCE_2
147 | VOYAGE_LAW_2 => 1024,
148 _ => 0,
149 };
150 EmbeddingModel::new(self.clone(), model, ndims)
151 }
152
153 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
154 EmbeddingModel::new(self.clone(), model, ndims)
155 }
156}
157
158impl<T> EmbeddingModel<T> {
159 pub fn new(client: Client<T>, model: &str, ndims: usize) -> Self {
160 Self {
161 client,
162 model: model.to_string(),
163 ndims,
164 }
165 }
166}
167
168pub const VOYAGE_3_LARGE: &str = "voyage-3-large";
173pub const VOYAGE_3_5: &str = "voyage-3.5";
175pub const VOYAGE_3_5_LITE: &str = "voyage.3-5.lite";
177pub const VOYAGE_CODE_3: &str = "voyage-code-3";
179pub const VOYAGE_FINANCE_2: &str = "voyage-finance-2";
181pub const VOYAGE_LAW_2: &str = "voyage-law-2";
183pub const VOYAGE_CODE_2: &str = "voyage-code-2";
185
186#[derive(Debug, Deserialize)]
187pub struct EmbeddingResponse {
188 pub object: String,
189 pub data: Vec<EmbeddingData>,
190 pub model: String,
191 pub usage: Usage,
192}
193
194#[derive(Clone, Debug, Deserialize)]
195pub struct Usage {
196 pub prompt_tokens: usize,
197 pub total_tokens: usize,
198}
199
200#[derive(Debug, Deserialize)]
201pub struct ApiErrorResponse {
202 pub(crate) message: String,
203}
204
205impl From<ApiErrorResponse> for EmbeddingError {
206 fn from(err: ApiErrorResponse) -> Self {
207 EmbeddingError::ProviderError(err.message)
208 }
209}
210
211#[derive(Debug, Deserialize)]
212#[serde(untagged)]
213pub(crate) enum ApiResponse<T> {
214 Ok(T),
215 Err(ApiErrorResponse),
216}
217
218impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
219 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
220 match value {
221 ApiResponse::Ok(response) => Ok(response),
222 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
223 }
224 }
225}
226
227#[derive(Debug, Deserialize)]
228pub struct EmbeddingData {
229 pub object: String,
230 pub embedding: Vec<f64>,
231 pub index: usize,
232}
233
234#[derive(Clone)]
235pub struct EmbeddingModel<T> {
236 client: Client<T>,
237 pub model: String,
238 ndims: usize,
239}
240
241impl embeddings::EmbeddingModel for EmbeddingModel<reqwest::Client> {
242 const MAX_DOCUMENTS: usize = 1024;
243
244 fn ndims(&self) -> usize {
245 self.ndims
246 }
247
248 #[cfg_attr(feature = "worker", worker::send)]
249 async fn embed_texts(
250 &self,
251 documents: impl IntoIterator<Item = String>,
252 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
253 let documents = documents.into_iter().collect::<Vec<_>>();
254
255 let response = self
256 .client
257 .reqwest_post("/embeddings")
258 .json(&json!({
259 "model": self.model,
260 "input": documents,
261 }))
262 .send()
263 .await
264 .map_err(|e| EmbeddingError::HttpError(http_client::Error::Instance(e.into())))?;
265
266 if response.status().is_success() {
267 match response
268 .json::<ApiResponse<EmbeddingResponse>>()
269 .await
270 .map_err(|e| EmbeddingError::HttpError(http_client::Error::Instance(e.into())))?
271 {
272 ApiResponse::Ok(response) => {
273 tracing::info!(target: "rig",
274 "VoyageAI embedding token usage: {}",
275 response.usage.total_tokens
276 );
277
278 if response.data.len() != documents.len() {
279 return Err(EmbeddingError::ResponseError(
280 "Response data length does not match input length".into(),
281 ));
282 }
283
284 Ok(response
285 .data
286 .into_iter()
287 .zip(documents.into_iter())
288 .map(|(embedding, document)| embeddings::Embedding {
289 document,
290 vec: embedding.embedding,
291 })
292 .collect())
293 }
294 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
295 }
296 } else {
297 Err(EmbeddingError::ProviderError(
298 response.text().await.map_err(|e| {
299 EmbeddingError::HttpError(http_client::Error::Instance(e.into()))
300 })?,
301 ))
302 }
303 }
304}