1use crate::client::{
2 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
3 ProviderClient,
4};
5use crate::embeddings;
6use crate::embeddings::EmbeddingError;
7use crate::http_client::{self, HttpClientExt};
8use bytes::Bytes;
9use serde::Deserialize;
10use serde_json::json;
11
12const VOYAGEAI_API_BASE_URL: &str = "https://api.voyageai.com/v1";
16
17#[derive(Debug, Default, Clone, Copy)]
18pub struct VoyageExt;
19
20#[derive(Debug, Default, Clone, Copy)]
21pub struct VoyageBuilder;
22
23type VoyageApiKey = BearerAuth;
24
25impl Provider for VoyageExt {
26 type Builder = VoyageBuilder;
27
28 const VERIFY_PATH: &'static str = "";
30}
31
32impl<H> Capabilities<H> for VoyageExt {
33 type Completion = Nothing;
34 type Embeddings = Capable<EmbeddingModel<H>>;
35 type Transcription = Nothing;
36 type ModelListing = Nothing;
37 #[cfg(feature = "image")]
38 type ImageGeneration = Nothing;
39
40 #[cfg(feature = "audio")]
41 type AudioGeneration = Nothing;
42}
43
44impl DebugExt for VoyageExt {}
45
46impl ProviderBuilder for VoyageBuilder {
47 type Extension<H>
48 = VoyageExt
49 where
50 H: HttpClientExt;
51 type ApiKey = VoyageApiKey;
52
53 const BASE_URL: &'static str = VOYAGEAI_API_BASE_URL;
54
55 fn build<H>(
56 _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
57 ) -> http_client::Result<Self::Extension<H>>
58 where
59 H: HttpClientExt,
60 {
61 Ok(VoyageExt)
62 }
63}
64
65pub type Client<H = reqwest::Client> = client::Client<VoyageExt, H>;
66pub type ClientBuilder<H = crate::markers::Missing> =
67 client::ClientBuilder<VoyageBuilder, VoyageApiKey, H>;
68
69impl ProviderClient for Client {
70 type Input = String;
71 type Error = crate::client::ProviderClientError;
72
73 fn from_env() -> Result<Self, Self::Error> {
75 let api_key = crate::client::required_env_var("VOYAGE_API_KEY")?;
76 Self::new(&api_key).map_err(Into::into)
77 }
78
79 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
80 Self::new(&input).map_err(Into::into)
81 }
82}
83
84impl<T> EmbeddingModel<T> {
85 pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
86 Self {
87 client,
88 model: model.into(),
89 ndims,
90 }
91 }
92
93 pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
94 Self {
95 client,
96 model: model.into(),
97 ndims,
98 }
99 }
100}
101
102pub const VOYAGE_3_LARGE: &str = "voyage-3-large";
108pub const VOYAGE_3_5: &str = "voyage-3.5";
110pub const VOYAGE_3_5_LITE: &str = "voyage.3-5.lite";
112pub const VOYAGE_CODE_3: &str = "voyage-code-3";
114pub const VOYAGE_FINANCE_2: &str = "voyage-finance-2";
116pub const VOYAGE_LAW_2: &str = "voyage-law-2";
118pub const VOYAGE_CODE_2: &str = "voyage-code-2";
120
121pub fn model_dimensions_from_identifier(model_identifier: &str) -> Option<usize> {
122 match model_identifier {
123 "voyage-code-2" => Some(1536),
124 "voyage-3-large" | "voyage-3.5" | "voyage.3-5.lite" | "voyage-code-3"
125 | "voyage-finance-2" | "voyage-law-2" => Some(1024),
126 _ => None,
127 }
128}
129
130#[derive(Debug, Deserialize)]
131pub struct EmbeddingResponse {
132 pub object: String,
133 pub data: Vec<EmbeddingData>,
134 pub model: String,
135 pub usage: Usage,
136}
137
138#[derive(Clone, Debug, Deserialize)]
139pub struct Usage {
140 pub total_tokens: usize,
141}
142
143#[derive(Debug, Deserialize)]
144pub struct ApiErrorResponse {
145 pub(crate) message: String,
146}
147
148impl From<ApiErrorResponse> for EmbeddingError {
149 fn from(err: ApiErrorResponse) -> Self {
150 EmbeddingError::ProviderError(err.message)
151 }
152}
153
154#[derive(Debug, Deserialize)]
155#[serde(untagged)]
156pub(crate) enum ApiResponse<T> {
157 Ok(T),
158 Err(ApiErrorResponse),
159}
160
161impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
162 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
163 match value {
164 ApiResponse::Ok(response) => Ok(response),
165 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
166 }
167 }
168}
169
170#[derive(Debug, Deserialize)]
171pub struct EmbeddingData {
172 pub object: String,
173 pub embedding: Vec<f64>,
174 pub index: usize,
175}
176
177#[derive(Clone)]
178pub struct EmbeddingModel<T> {
179 client: Client<T>,
180 pub model: String,
181 ndims: usize,
182}
183
184impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
185where
186 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
187{
188 const MAX_DOCUMENTS: usize = 1024;
189
190 type Client = Client<T>;
191
192 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
193 let model = model.into();
194 let dims = dims
195 .or(model_dimensions_from_identifier(&model))
196 .unwrap_or_default();
197
198 Self::new(client.clone(), model, dims)
199 }
200
201 fn ndims(&self) -> usize {
202 self.ndims
203 }
204
205 async fn embed_texts(
206 &self,
207 documents: impl IntoIterator<Item = String>,
208 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
209 let documents: Vec<String> = documents.into_iter().collect();
210 let response = self.embed_texts_with_usage(documents).await?;
211 Ok(response.embeddings)
212 }
213
214 async fn embed_texts_with_usage(
215 &self,
216 documents: impl IntoIterator<Item = String>,
217 ) -> Result<embeddings::EmbeddingResponse, EmbeddingError> {
218 let documents: Vec<String> = documents.into_iter().collect();
219 let request = json!({
220 "model": self.model,
221 "input": documents,
222 });
223
224 let body = serde_json::to_vec(&request)?;
225
226 let req = self
227 .client
228 .post("/embeddings")?
229 .body(body)
230 .map_err(|x| EmbeddingError::HttpError(x.into()))?;
231
232 let response = self.client.send::<_, Bytes>(req).await?;
233 let status = response.status();
234 let response_body = response.into_body().into_future().await?.to_vec();
235
236 if status.is_success() {
237 match serde_json::from_slice::<ApiResponse<EmbeddingResponse>>(&response_body)? {
238 ApiResponse::Ok(response) => {
239 tracing::info!(target: "rig",
240 "VoyageAI embedding token usage: {}",
241 response.usage.total_tokens
242 );
243
244 if response.data.len() != documents.len() {
245 return Err(EmbeddingError::ResponseError(
246 "Response data length does not match input length".into(),
247 ));
248 }
249
250 let usage = crate::completion::Usage {
251 input_tokens: response.usage.total_tokens as u64,
252 output_tokens: 0,
253 total_tokens: response.usage.total_tokens as u64,
254 cached_input_tokens: 0,
255 cache_creation_input_tokens: 0,
256 tool_use_prompt_tokens: 0,
257 reasoning_tokens: 0,
258 };
259
260 let embeddings = response
261 .data
262 .into_iter()
263 .zip(documents.into_iter())
264 .map(|(embedding, document)| embeddings::Embedding {
265 document,
266 vec: embedding.embedding,
267 })
268 .collect();
269
270 Ok(embeddings::EmbeddingResponse { embeddings, usage })
271 }
272 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
273 }
274 } else {
275 Err(EmbeddingError::ProviderError(
276 String::from_utf8_lossy(&response_body).to_string(),
277 ))
278 }
279 }
280}
281#[cfg(test)]
282mod tests {
283 #[test]
284 fn test_client_initialization() {
285 let _client =
286 crate::providers::voyageai::Client::new("dummy-key").expect("Client::new() failed");
287 let _client_from_builder = crate::providers::voyageai::Client::builder()
288 .api_key("dummy-key")
289 .build()
290 .expect("Client::builder() failed");
291 }
292}