rig/providers/
voyageai.rs1use crate::client::{EmbeddingsClient, ProviderClient};
2use crate::embeddings::EmbeddingError;
3use crate::{embeddings, impl_conversion_traits};
4use serde::Deserialize;
5use serde_json::json;
6
7const OPENAI_API_BASE_URL: &str = "https://api.voyageai.com/v1";
11
12#[derive(Clone)]
13pub struct Client {
14 base_url: String,
15 api_key: String,
16 http_client: reqwest::Client,
17}
18
19impl std::fmt::Debug for Client {
20 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 f.debug_struct("Client")
22 .field("base_url", &self.base_url)
23 .field("http_client", &self.http_client)
24 .field("api_key", &"<REDACTED>")
25 .finish()
26 }
27}
28
29impl Client {
30 pub fn new(api_key: &str) -> Self {
32 Self::from_url(api_key, OPENAI_API_BASE_URL)
33 }
34
35 pub fn from_url(api_key: &str, base_url: &str) -> Self {
37 Self {
38 base_url: base_url.to_string(),
39 api_key: api_key.to_string(),
40 http_client: reqwest::Client::builder()
41 .build()
42 .expect("OpenAI reqwest client should build"),
43 }
44 }
45
46 pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
49 self.http_client = client;
50
51 self
52 }
53
54 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
55 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
56 self.http_client.post(url).bearer_auth(&self.api_key)
57 }
58}
59
60impl_conversion_traits!(
61 AsCompletion,
62 AsTranscription,
63 AsImageGeneration,
64 AsAudioGeneration for Client
65);
66
67impl ProviderClient for Client {
68 fn from_env() -> Self {
71 let api_key = std::env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY not set");
72 Self::new(&api_key)
73 }
74}
75
76impl EmbeddingsClient for Client {
79 type EmbeddingModel = EmbeddingModel;
80 fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
81 let ndims = match model {
82 VOYAGE_CODE_2 => 1536,
83 VOYAGE_3_LARGE | VOYAGE_3_5 | VOYAGE_3_5_LITE | VOYAGE_CODE_3 | VOYAGE_FINANCE_2
84 | VOYAGE_LAW_2 => 1024,
85 _ => 0,
86 };
87 EmbeddingModel::new(self.clone(), model, ndims)
88 }
89
90 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
91 EmbeddingModel::new(self.clone(), model, ndims)
92 }
93}
94
95impl EmbeddingModel {
96 pub fn new(client: Client, model: &str, ndims: usize) -> Self {
97 Self {
98 client,
99 model: model.to_string(),
100 ndims,
101 }
102 }
103}
104
105pub const VOYAGE_3_LARGE: &str = "voyage-3-large";
110pub const VOYAGE_3_5: &str = "voyage-3.5";
112pub const VOYAGE_3_5_LITE: &str = "voyage.3-5.lite";
114pub const VOYAGE_CODE_3: &str = "voyage-code-3";
116pub const VOYAGE_FINANCE_2: &str = "voyage-finance-2";
118pub const VOYAGE_LAW_2: &str = "voyage-law-2";
120pub const VOYAGE_CODE_2: &str = "voyage-code-2";
122
123#[derive(Debug, Deserialize)]
124pub struct EmbeddingResponse {
125 pub object: String,
126 pub data: Vec<EmbeddingData>,
127 pub model: String,
128 pub usage: Usage,
129}
130
131#[derive(Clone, Debug, Deserialize)]
132pub struct Usage {
133 pub prompt_tokens: usize,
134 pub total_tokens: usize,
135}
136
137#[derive(Debug, Deserialize)]
138pub struct ApiErrorResponse {
139 pub(crate) message: String,
140}
141
142impl From<ApiErrorResponse> for EmbeddingError {
143 fn from(err: ApiErrorResponse) -> Self {
144 EmbeddingError::ProviderError(err.message)
145 }
146}
147
148#[derive(Debug, Deserialize)]
149#[serde(untagged)]
150pub(crate) enum ApiResponse<T> {
151 Ok(T),
152 Err(ApiErrorResponse),
153}
154
155impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
156 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
157 match value {
158 ApiResponse::Ok(response) => Ok(response),
159 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
160 }
161 }
162}
163
164#[derive(Debug, Deserialize)]
165pub struct EmbeddingData {
166 pub object: String,
167 pub embedding: Vec<f64>,
168 pub index: usize,
169}
170
171#[derive(Clone)]
172pub struct EmbeddingModel {
173 client: Client,
174 pub model: String,
175 ndims: usize,
176}
177
178impl embeddings::EmbeddingModel for EmbeddingModel {
179 const MAX_DOCUMENTS: usize = 1024;
180
181 fn ndims(&self) -> usize {
182 self.ndims
183 }
184
185 #[cfg_attr(feature = "worker", worker::send)]
186 async fn embed_texts(
187 &self,
188 documents: impl IntoIterator<Item = String>,
189 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
190 let documents = documents.into_iter().collect::<Vec<_>>();
191
192 let response = self
193 .client
194 .post("/embeddings")
195 .json(&json!({
196 "model": self.model,
197 "input": documents,
198 }))
199 .send()
200 .await?;
201
202 if response.status().is_success() {
203 match response.json::<ApiResponse<EmbeddingResponse>>().await? {
204 ApiResponse::Ok(response) => {
205 tracing::info!(target: "rig",
206 "VoyageAI embedding token usage: {}",
207 response.usage.total_tokens
208 );
209
210 if response.data.len() != documents.len() {
211 return Err(EmbeddingError::ResponseError(
212 "Response data length does not match input length".into(),
213 ));
214 }
215
216 Ok(response
217 .data
218 .into_iter()
219 .zip(documents.into_iter())
220 .map(|(embedding, document)| embeddings::Embedding {
221 document,
222 vec: embedding.embedding,
223 })
224 .collect())
225 }
226 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
227 }
228 } else {
229 Err(EmbeddingError::ProviderError(response.text().await?))
230 }
231 }
232}