rig/providers/
voyageai.rs1use crate::client::{EmbeddingsClient, ProviderClient};
2use crate::embeddings::EmbeddingError;
3use crate::{embeddings, impl_conversion_traits};
4use serde::Deserialize;
5use serde_json::json;
6
7const OPENAI_API_BASE_URL: &str = "https://api.voyageai.com/v1";
11
12#[derive(Clone)]
13pub struct Client {
14 base_url: String,
15 api_key: String,
16 http_client: reqwest::Client,
17}
18
19impl std::fmt::Debug for Client {
20 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 f.debug_struct("Client")
22 .field("base_url", &self.base_url)
23 .field("http_client", &self.http_client)
24 .field("api_key", &"<REDACTED>")
25 .finish()
26 }
27}
28
29impl Client {
30 pub fn new(api_key: &str) -> Self {
32 Self::from_url(api_key, OPENAI_API_BASE_URL)
33 }
34
35 pub fn from_url(api_key: &str, base_url: &str) -> Self {
37 Self {
38 base_url: base_url.to_string(),
39 api_key: api_key.to_string(),
40 http_client: reqwest::Client::builder()
41 .build()
42 .expect("OpenAI reqwest client should build"),
43 }
44 }
45
46 pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
49 self.http_client = client;
50
51 self
52 }
53
54 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
55 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
56 self.http_client.post(url).bearer_auth(&self.api_key)
57 }
58}
59
60impl_conversion_traits!(
61 AsCompletion,
62 AsTranscription,
63 AsImageGeneration,
64 AsAudioGeneration for Client
65);
66
67impl ProviderClient for Client {
68 fn from_env() -> Self {
71 let api_key = std::env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY not set");
72 Self::new(&api_key)
73 }
74
75 fn from_val(input: crate::client::ProviderValue) -> Self {
76 let crate::client::ProviderValue::Simple(api_key) = input else {
77 panic!("Incorrect provider value type")
78 };
79 Self::new(&api_key)
80 }
81}
82
83impl EmbeddingsClient for Client {
86 type EmbeddingModel = EmbeddingModel;
87 fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
88 let ndims = match model {
89 VOYAGE_CODE_2 => 1536,
90 VOYAGE_3_LARGE | VOYAGE_3_5 | VOYAGE_3_5_LITE | VOYAGE_CODE_3 | VOYAGE_FINANCE_2
91 | VOYAGE_LAW_2 => 1024,
92 _ => 0,
93 };
94 EmbeddingModel::new(self.clone(), model, ndims)
95 }
96
97 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
98 EmbeddingModel::new(self.clone(), model, ndims)
99 }
100}
101
102impl EmbeddingModel {
103 pub fn new(client: Client, model: &str, ndims: usize) -> Self {
104 Self {
105 client,
106 model: model.to_string(),
107 ndims,
108 }
109 }
110}
111
112pub const VOYAGE_3_LARGE: &str = "voyage-3-large";
117pub const VOYAGE_3_5: &str = "voyage-3.5";
119pub const VOYAGE_3_5_LITE: &str = "voyage.3-5.lite";
121pub const VOYAGE_CODE_3: &str = "voyage-code-3";
123pub const VOYAGE_FINANCE_2: &str = "voyage-finance-2";
125pub const VOYAGE_LAW_2: &str = "voyage-law-2";
127pub const VOYAGE_CODE_2: &str = "voyage-code-2";
129
130#[derive(Debug, Deserialize)]
131pub struct EmbeddingResponse {
132 pub object: String,
133 pub data: Vec<EmbeddingData>,
134 pub model: String,
135 pub usage: Usage,
136}
137
138#[derive(Clone, Debug, Deserialize)]
139pub struct Usage {
140 pub prompt_tokens: usize,
141 pub total_tokens: usize,
142}
143
144#[derive(Debug, Deserialize)]
145pub struct ApiErrorResponse {
146 pub(crate) message: String,
147}
148
149impl From<ApiErrorResponse> for EmbeddingError {
150 fn from(err: ApiErrorResponse) -> Self {
151 EmbeddingError::ProviderError(err.message)
152 }
153}
154
155#[derive(Debug, Deserialize)]
156#[serde(untagged)]
157pub(crate) enum ApiResponse<T> {
158 Ok(T),
159 Err(ApiErrorResponse),
160}
161
162impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
163 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
164 match value {
165 ApiResponse::Ok(response) => Ok(response),
166 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
167 }
168 }
169}
170
171#[derive(Debug, Deserialize)]
172pub struct EmbeddingData {
173 pub object: String,
174 pub embedding: Vec<f64>,
175 pub index: usize,
176}
177
178#[derive(Clone)]
179pub struct EmbeddingModel {
180 client: Client,
181 pub model: String,
182 ndims: usize,
183}
184
185impl embeddings::EmbeddingModel for EmbeddingModel {
186 const MAX_DOCUMENTS: usize = 1024;
187
188 fn ndims(&self) -> usize {
189 self.ndims
190 }
191
192 #[cfg_attr(feature = "worker", worker::send)]
193 async fn embed_texts(
194 &self,
195 documents: impl IntoIterator<Item = String>,
196 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
197 let documents = documents.into_iter().collect::<Vec<_>>();
198
199 let response = self
200 .client
201 .post("/embeddings")
202 .json(&json!({
203 "model": self.model,
204 "input": documents,
205 }))
206 .send()
207 .await?;
208
209 if response.status().is_success() {
210 match response.json::<ApiResponse<EmbeddingResponse>>().await? {
211 ApiResponse::Ok(response) => {
212 tracing::info!(target: "rig",
213 "VoyageAI embedding token usage: {}",
214 response.usage.total_tokens
215 );
216
217 if response.data.len() != documents.len() {
218 return Err(EmbeddingError::ResponseError(
219 "Response data length does not match input length".into(),
220 ));
221 }
222
223 Ok(response
224 .data
225 .into_iter()
226 .zip(documents.into_iter())
227 .map(|(embedding, document)| embeddings::Embedding {
228 document,
229 vec: embedding.embedding,
230 })
231 .collect())
232 }
233 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
234 }
235 } else {
236 Err(EmbeddingError::ProviderError(response.text().await?))
237 }
238 }
239}