rig/providers/
voyageai.rs1use crate::client::{
2 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
3 ProviderClient,
4};
5use crate::embeddings;
6use crate::embeddings::EmbeddingError;
7use crate::http_client::{self, HttpClientExt};
8use bytes::Bytes;
9use serde::Deserialize;
10use serde_json::json;
11
12const VOYAGEAI_API_BASE_URL: &str = "https://api.voyageai.com/v1";
16
17#[derive(Debug, Default, Clone, Copy)]
18pub struct VoyageExt;
19
20#[derive(Debug, Default, Clone, Copy)]
21pub struct VoyageBuilder;
22
23type VoyageApiKey = BearerAuth;
24
25impl Provider for VoyageExt {
26 type Builder = VoyageBuilder;
27
28 const VERIFY_PATH: &'static str = "";
30}
31
32impl<H> Capabilities<H> for VoyageExt {
33 type Completion = Nothing;
34 type Embeddings = Capable<EmbeddingModel<H>>;
35 type Transcription = Nothing;
36 type ModelListing = Nothing;
37 #[cfg(feature = "image")]
38 type ImageGeneration = Nothing;
39
40 #[cfg(feature = "audio")]
41 type AudioGeneration = Nothing;
42}
43
44impl DebugExt for VoyageExt {}
45
46impl ProviderBuilder for VoyageBuilder {
47 type Extension<H>
48 = VoyageExt
49 where
50 H: HttpClientExt;
51 type ApiKey = VoyageApiKey;
52
53 const BASE_URL: &'static str = VOYAGEAI_API_BASE_URL;
54
55 fn build<H>(
56 _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
57 ) -> http_client::Result<Self::Extension<H>>
58 where
59 H: HttpClientExt,
60 {
61 Ok(VoyageExt)
62 }
63}
64
65pub type Client<H = reqwest::Client> = client::Client<VoyageExt, H>;
66pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<VoyageBuilder, VoyageApiKey, H>;
67
68impl ProviderClient for Client {
69 type Input = String;
70
71 fn from_env() -> Self {
74 let api_key = std::env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY not set");
75 Self::new(&api_key).unwrap()
76 }
77
78 fn from_val(input: Self::Input) -> Self {
79 Self::new(&input).unwrap()
80 }
81}
82
83impl<T> EmbeddingModel<T> {
84 pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
85 Self {
86 client,
87 model: model.into(),
88 ndims,
89 }
90 }
91
92 pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
93 Self {
94 client,
95 model: model.into(),
96 ndims,
97 }
98 }
99}
100
101pub const VOYAGE_3_LARGE: &str = "voyage-3-large";
107pub const VOYAGE_3_5: &str = "voyage-3.5";
109pub const VOYAGE_3_5_LITE: &str = "voyage.3-5.lite";
111pub const VOYAGE_CODE_3: &str = "voyage-code-3";
113pub const VOYAGE_FINANCE_2: &str = "voyage-finance-2";
115pub const VOYAGE_LAW_2: &str = "voyage-law-2";
117pub const VOYAGE_CODE_2: &str = "voyage-code-2";
119
120pub fn model_dimensions_from_identifier(model_identifier: &str) -> Option<usize> {
121 match model_identifier {
122 "voyage-code-2" => Some(1536),
123 "voyage-3-large" | "voyage-3.5" | "voyage.3-5.lite" | "voyage-code-3"
124 | "voyage-finance-2" | "voyage-law-2" => Some(1024),
125 _ => None,
126 }
127}
128
129#[derive(Debug, Deserialize)]
130pub struct EmbeddingResponse {
131 pub object: String,
132 pub data: Vec<EmbeddingData>,
133 pub model: String,
134 pub usage: Usage,
135}
136
137#[derive(Clone, Debug, Deserialize)]
138pub struct Usage {
139 pub total_tokens: usize,
140}
141
142#[derive(Debug, Deserialize)]
143pub struct ApiErrorResponse {
144 pub(crate) message: String,
145}
146
147impl From<ApiErrorResponse> for EmbeddingError {
148 fn from(err: ApiErrorResponse) -> Self {
149 EmbeddingError::ProviderError(err.message)
150 }
151}
152
153#[derive(Debug, Deserialize)]
154#[serde(untagged)]
155pub(crate) enum ApiResponse<T> {
156 Ok(T),
157 Err(ApiErrorResponse),
158}
159
160impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
161 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
162 match value {
163 ApiResponse::Ok(response) => Ok(response),
164 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
165 }
166 }
167}
168
169#[derive(Debug, Deserialize)]
170pub struct EmbeddingData {
171 pub object: String,
172 pub embedding: Vec<f64>,
173 pub index: usize,
174}
175
176#[derive(Clone)]
177pub struct EmbeddingModel<T> {
178 client: Client<T>,
179 pub model: String,
180 ndims: usize,
181}
182
183impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
184where
185 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
186{
187 const MAX_DOCUMENTS: usize = 1024;
188
189 type Client = Client<T>;
190
191 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
192 let model = model.into();
193 let dims = dims
194 .or(model_dimensions_from_identifier(&model))
195 .unwrap_or_default();
196
197 Self::new(client.clone(), model, dims)
198 }
199
200 fn ndims(&self) -> usize {
201 self.ndims
202 }
203
204 async fn embed_texts(
205 &self,
206 documents: impl IntoIterator<Item = String>,
207 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
208 let documents = documents.into_iter().collect::<Vec<_>>();
209 let request = json!({
210 "model": self.model,
211 "input": documents,
212 });
213
214 let body = serde_json::to_vec(&request)?;
215
216 let req = self
217 .client
218 .post("/embeddings")?
219 .body(body)
220 .map_err(|x| EmbeddingError::HttpError(x.into()))?;
221
222 let response = self.client.send::<_, Bytes>(req).await?;
223 let status = response.status();
224 let response_body = response.into_body().into_future().await?.to_vec();
225
226 if status.is_success() {
227 match serde_json::from_slice::<ApiResponse<EmbeddingResponse>>(&response_body)? {
228 ApiResponse::Ok(response) => {
229 tracing::info!(target: "rig",
230 "VoyageAI embedding token usage: {}",
231 response.usage.total_tokens
232 );
233
234 if response.data.len() != documents.len() {
235 return Err(EmbeddingError::ResponseError(
236 "Response data length does not match input length".into(),
237 ));
238 }
239
240 Ok(response
241 .data
242 .into_iter()
243 .zip(documents.into_iter())
244 .map(|(embedding, document)| embeddings::Embedding {
245 document,
246 vec: embedding.embedding,
247 })
248 .collect())
249 }
250 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
251 }
252 } else {
253 Err(EmbeddingError::ProviderError(
254 String::from_utf8_lossy(&response_body).to_string(),
255 ))
256 }
257 }
258}
259#[cfg(test)]
260mod tests {
261 #[test]
262 fn test_client_initialization() {
263 let _client =
264 crate::providers::voyageai::Client::new("dummy-key").expect("Client::new() failed");
265 let _client_from_builder = crate::providers::voyageai::Client::builder()
266 .api_key("dummy-key")
267 .build()
268 .expect("Client::builder() failed");
269 }
270}