rig/providers/
voyageai.rs1use crate::client::{
2 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
3 ProviderClient,
4};
5use crate::embeddings;
6use crate::embeddings::EmbeddingError;
7use crate::http_client::{self, HttpClientExt};
8use bytes::Bytes;
9use serde::Deserialize;
10use serde_json::json;
11
12const VOYAGEAI_API_BASE_URL: &str = "https://api.voyageai.com/v1";
16
17#[derive(Debug, Default, Clone, Copy)]
18pub struct VoyageExt;
19
20#[derive(Debug, Default, Clone, Copy)]
21pub struct VoyageBuilder;
22
23type VoyageApiKey = BearerAuth;
24
25impl Provider for VoyageExt {
26 type Builder = VoyageBuilder;
27
28 const VERIFY_PATH: &'static str = "";
30}
31
32impl<H> Capabilities<H> for VoyageExt {
33 type Completion = Nothing;
34 type Embeddings = Capable<EmbeddingModel<H>>;
35 type Transcription = Nothing;
36 type ModelListing = Nothing;
37 #[cfg(feature = "image")]
38 type ImageGeneration = Nothing;
39
40 #[cfg(feature = "audio")]
41 type AudioGeneration = Nothing;
42}
43
44impl DebugExt for VoyageExt {}
45
46impl ProviderBuilder for VoyageBuilder {
47 type Extension<H>
48 = VoyageExt
49 where
50 H: HttpClientExt;
51 type ApiKey = VoyageApiKey;
52
53 const BASE_URL: &'static str = VOYAGEAI_API_BASE_URL;
54
55 fn build<H>(
56 _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
57 ) -> http_client::Result<Self::Extension<H>>
58 where
59 H: HttpClientExt,
60 {
61 Ok(VoyageExt)
62 }
63}
64
65pub type Client<H = reqwest::Client> = client::Client<VoyageExt, H>;
66pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<VoyageBuilder, VoyageApiKey, H>;
67
68impl ProviderClient for Client {
69 type Input = String;
70
71 fn from_env() -> Self {
74 let api_key = std::env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY not set");
75 Self::new(&api_key).unwrap()
76 }
77
78 fn from_val(input: Self::Input) -> Self {
79 Self::new(&input).unwrap()
80 }
81}
82
83impl<T> EmbeddingModel<T> {
84 pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
85 Self {
86 client,
87 model: model.into(),
88 ndims,
89 }
90 }
91
92 pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
93 Self {
94 client,
95 model: model.into(),
96 ndims,
97 }
98 }
99}
100
101pub const VOYAGE_3_LARGE: &str = "voyage-3-large";
107pub const VOYAGE_3_5: &str = "voyage-3.5";
109pub const VOYAGE_3_5_LITE: &str = "voyage.3-5.lite";
111pub const VOYAGE_CODE_3: &str = "voyage-code-3";
113pub const VOYAGE_FINANCE_2: &str = "voyage-finance-2";
115pub const VOYAGE_LAW_2: &str = "voyage-law-2";
117pub const VOYAGE_CODE_2: &str = "voyage-code-2";
119
120pub fn model_dimensions_from_identifier(model_identifier: &str) -> Option<usize> {
121 match model_identifier {
122 "voyage-code-2" => Some(1536),
123 "voyage-3-large" | "voyage-3.5" | "voyage.3-5.lite" | "voyage-code-3"
124 | "voyage-finance-2" | "voyage-law-2" => Some(1024),
125 _ => None,
126 }
127}
128
129#[derive(Debug, Deserialize)]
130pub struct EmbeddingResponse {
131 pub object: String,
132 pub data: Vec<EmbeddingData>,
133 pub model: String,
134 pub usage: Usage,
135}
136
137#[derive(Clone, Debug, Deserialize)]
138pub struct Usage {
139 pub prompt_tokens: usize,
140 pub total_tokens: usize,
141}
142
143#[derive(Debug, Deserialize)]
144pub struct ApiErrorResponse {
145 pub(crate) message: String,
146}
147
148impl From<ApiErrorResponse> for EmbeddingError {
149 fn from(err: ApiErrorResponse) -> Self {
150 EmbeddingError::ProviderError(err.message)
151 }
152}
153
154#[derive(Debug, Deserialize)]
155#[serde(untagged)]
156pub(crate) enum ApiResponse<T> {
157 Ok(T),
158 Err(ApiErrorResponse),
159}
160
161impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
162 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
163 match value {
164 ApiResponse::Ok(response) => Ok(response),
165 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
166 }
167 }
168}
169
170#[derive(Debug, Deserialize)]
171pub struct EmbeddingData {
172 pub object: String,
173 pub embedding: Vec<f64>,
174 pub index: usize,
175}
176
177#[derive(Clone)]
178pub struct EmbeddingModel<T> {
179 client: Client<T>,
180 pub model: String,
181 ndims: usize,
182}
183
184impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
185where
186 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
187{
188 const MAX_DOCUMENTS: usize = 1024;
189
190 type Client = Client<T>;
191
192 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
193 let model = model.into();
194 let dims = dims
195 .or(model_dimensions_from_identifier(&model))
196 .unwrap_or_default();
197
198 Self::new(client.clone(), model, dims)
199 }
200
201 fn ndims(&self) -> usize {
202 self.ndims
203 }
204
205 async fn embed_texts(
206 &self,
207 documents: impl IntoIterator<Item = String>,
208 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
209 let documents = documents.into_iter().collect::<Vec<_>>();
210 let request = json!({
211 "model": self.model,
212 "input": documents,
213 });
214
215 let body = serde_json::to_vec(&request)?;
216
217 let req = self
218 .client
219 .post("/embeddings")?
220 .body(body)
221 .map_err(|x| EmbeddingError::HttpError(x.into()))?;
222
223 let response = self.client.send::<_, Bytes>(req).await?;
224 let status = response.status();
225 let response_body = response.into_body().into_future().await?.to_vec();
226
227 if status.is_success() {
228 match serde_json::from_slice::<ApiResponse<EmbeddingResponse>>(&response_body)? {
229 ApiResponse::Ok(response) => {
230 tracing::info!(target: "rig",
231 "VoyageAI embedding token usage: {}",
232 response.usage.total_tokens
233 );
234
235 if response.data.len() != documents.len() {
236 return Err(EmbeddingError::ResponseError(
237 "Response data length does not match input length".into(),
238 ));
239 }
240
241 Ok(response
242 .data
243 .into_iter()
244 .zip(documents.into_iter())
245 .map(|(embedding, document)| embeddings::Embedding {
246 document,
247 vec: embedding.embedding,
248 })
249 .collect())
250 }
251 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
252 }
253 } else {
254 Err(EmbeddingError::ProviderError(
255 String::from_utf8_lossy(&response_body).to_string(),
256 ))
257 }
258 }
259}
260#[cfg(test)]
261mod tests {
262 #[test]
263 fn test_client_initialization() {
264 let _client =
265 crate::providers::voyageai::Client::new("dummy-key").expect("Client::new() failed");
266 let _client_from_builder = crate::providers::voyageai::Client::builder()
267 .api_key("dummy-key")
268 .build()
269 .expect("Client::builder() failed");
270 }
271}