1use crate::client::{
2 ClientBuilderError, EmbeddingsClient, ProviderClient, VerifyClient, VerifyError,
3};
4use crate::embeddings::EmbeddingError;
5use crate::{embeddings, impl_conversion_traits};
6use serde::Deserialize;
7use serde_json::json;
8
9const OPENAI_API_BASE_URL: &str = "https://api.voyageai.com/v1";
13
14pub struct ClientBuilder<'a> {
15 api_key: &'a str,
16 base_url: &'a str,
17 http_client: Option<reqwest::Client>,
18}
19
20impl<'a> ClientBuilder<'a> {
21 pub fn new(api_key: &'a str) -> Self {
22 Self {
23 api_key,
24 base_url: OPENAI_API_BASE_URL,
25 http_client: None,
26 }
27 }
28
29 pub fn base_url(mut self, base_url: &'a str) -> Self {
30 self.base_url = base_url;
31 self
32 }
33
34 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
35 self.http_client = Some(client);
36 self
37 }
38
39 pub fn build(self) -> Result<Client, ClientBuilderError> {
40 let http_client = if let Some(http_client) = self.http_client {
41 http_client
42 } else {
43 reqwest::Client::builder().build()?
44 };
45
46 Ok(Client {
47 base_url: self.base_url.to_string(),
48 api_key: self.api_key.to_string(),
49 http_client,
50 })
51 }
52}
53
54#[derive(Clone)]
55pub struct Client {
56 base_url: String,
57 api_key: String,
58 http_client: reqwest::Client,
59}
60
61impl std::fmt::Debug for Client {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 f.debug_struct("Client")
64 .field("base_url", &self.base_url)
65 .field("http_client", &self.http_client)
66 .field("api_key", &"<REDACTED>")
67 .finish()
68 }
69}
70
71impl Client {
72 pub fn builder(api_key: &str) -> ClientBuilder<'_> {
83 ClientBuilder::new(api_key)
84 }
85
86 pub fn new(api_key: &str) -> Self {
91 Self::builder(api_key)
92 .build()
93 .expect("Voyage AI client should build")
94 }
95
96 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
97 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
98 self.http_client.post(url).bearer_auth(&self.api_key)
99 }
100}
101
102impl VerifyClient for Client {
103 #[cfg_attr(feature = "worker", worker::send)]
104 async fn verify(&self) -> Result<(), VerifyError> {
105 Ok(())
107 }
108}
109
110impl_conversion_traits!(
111 AsCompletion,
112 AsTranscription,
113 AsImageGeneration,
114 AsAudioGeneration for Client
115);
116
117impl ProviderClient for Client {
118 fn from_env() -> Self {
121 let api_key = std::env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY not set");
122 Self::new(&api_key)
123 }
124
125 fn from_val(input: crate::client::ProviderValue) -> Self {
126 let crate::client::ProviderValue::Simple(api_key) = input else {
127 panic!("Incorrect provider value type")
128 };
129 Self::new(&api_key)
130 }
131}
132
133impl EmbeddingsClient for Client {
136 type EmbeddingModel = EmbeddingModel;
137 fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
138 let ndims = match model {
139 VOYAGE_CODE_2 => 1536,
140 VOYAGE_3_LARGE | VOYAGE_3_5 | VOYAGE_3_5_LITE | VOYAGE_CODE_3 | VOYAGE_FINANCE_2
141 | VOYAGE_LAW_2 => 1024,
142 _ => 0,
143 };
144 EmbeddingModel::new(self.clone(), model, ndims)
145 }
146
147 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
148 EmbeddingModel::new(self.clone(), model, ndims)
149 }
150}
151
152impl EmbeddingModel {
153 pub fn new(client: Client, model: &str, ndims: usize) -> Self {
154 Self {
155 client,
156 model: model.to_string(),
157 ndims,
158 }
159 }
160}
161
162pub const VOYAGE_3_LARGE: &str = "voyage-3-large";
167pub const VOYAGE_3_5: &str = "voyage-3.5";
169pub const VOYAGE_3_5_LITE: &str = "voyage.3-5.lite";
171pub const VOYAGE_CODE_3: &str = "voyage-code-3";
173pub const VOYAGE_FINANCE_2: &str = "voyage-finance-2";
175pub const VOYAGE_LAW_2: &str = "voyage-law-2";
177pub const VOYAGE_CODE_2: &str = "voyage-code-2";
179
180#[derive(Debug, Deserialize)]
181pub struct EmbeddingResponse {
182 pub object: String,
183 pub data: Vec<EmbeddingData>,
184 pub model: String,
185 pub usage: Usage,
186}
187
188#[derive(Clone, Debug, Deserialize)]
189pub struct Usage {
190 pub prompt_tokens: usize,
191 pub total_tokens: usize,
192}
193
194#[derive(Debug, Deserialize)]
195pub struct ApiErrorResponse {
196 pub(crate) message: String,
197}
198
199impl From<ApiErrorResponse> for EmbeddingError {
200 fn from(err: ApiErrorResponse) -> Self {
201 EmbeddingError::ProviderError(err.message)
202 }
203}
204
205#[derive(Debug, Deserialize)]
206#[serde(untagged)]
207pub(crate) enum ApiResponse<T> {
208 Ok(T),
209 Err(ApiErrorResponse),
210}
211
212impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
213 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
214 match value {
215 ApiResponse::Ok(response) => Ok(response),
216 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
217 }
218 }
219}
220
221#[derive(Debug, Deserialize)]
222pub struct EmbeddingData {
223 pub object: String,
224 pub embedding: Vec<f64>,
225 pub index: usize,
226}
227
228#[derive(Clone)]
229pub struct EmbeddingModel {
230 client: Client,
231 pub model: String,
232 ndims: usize,
233}
234
235impl embeddings::EmbeddingModel for EmbeddingModel {
236 const MAX_DOCUMENTS: usize = 1024;
237
238 fn ndims(&self) -> usize {
239 self.ndims
240 }
241
242 #[cfg_attr(feature = "worker", worker::send)]
243 async fn embed_texts(
244 &self,
245 documents: impl IntoIterator<Item = String>,
246 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
247 let documents = documents.into_iter().collect::<Vec<_>>();
248
249 let response = self
250 .client
251 .post("/embeddings")
252 .json(&json!({
253 "model": self.model,
254 "input": documents,
255 }))
256 .send()
257 .await?;
258
259 if response.status().is_success() {
260 match response.json::<ApiResponse<EmbeddingResponse>>().await? {
261 ApiResponse::Ok(response) => {
262 tracing::info!(target: "rig",
263 "VoyageAI embedding token usage: {}",
264 response.usage.total_tokens
265 );
266
267 if response.data.len() != documents.len() {
268 return Err(EmbeddingError::ResponseError(
269 "Response data length does not match input length".into(),
270 ));
271 }
272
273 Ok(response
274 .data
275 .into_iter()
276 .zip(documents.into_iter())
277 .map(|(embedding, document)| embeddings::Embedding {
278 document,
279 vec: embedding.embedding,
280 })
281 .collect())
282 }
283 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
284 }
285 } else {
286 Err(EmbeddingError::ProviderError(response.text().await?))
287 }
288 }
289}