rig/providers/
voyageai.rs1use crate::client::{
2 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
3 ProviderClient,
4};
5use crate::embeddings;
6use crate::embeddings::EmbeddingError;
7use crate::http_client::{self, HttpClientExt};
8use bytes::Bytes;
9use serde::Deserialize;
10use serde_json::json;
11
12const VOYAGEAI_API_BASE_URL: &str = "https://api.voyageai.com/v1";
16
17#[derive(Debug, Default, Clone, Copy)]
18pub struct VoyageExt;
19
20#[derive(Debug, Default, Clone, Copy)]
21pub struct VoyageBuilder;
22
23type VoyageApiKey = BearerAuth;
24
25impl Provider for VoyageExt {
26 type Builder = VoyageBuilder;
27
28 const VERIFY_PATH: &'static str = "";
30
31 fn build<H>(
32 _: &crate::client::ClientBuilder<
33 Self::Builder,
34 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
35 H,
36 >,
37 ) -> http_client::Result<Self> {
38 Ok(Self)
39 }
40}
41
42impl<H> Capabilities<H> for VoyageExt {
43 type Completion = Nothing;
44 type Embeddings = Capable<EmbeddingModel<H>>;
45 type Transcription = Nothing;
46 type ModelListing = Nothing;
47 #[cfg(feature = "image")]
48 type ImageGeneration = Nothing;
49
50 #[cfg(feature = "audio")]
51 type AudioGeneration = Nothing;
52}
53
54impl DebugExt for VoyageExt {}
55
56impl ProviderBuilder for VoyageBuilder {
57 type Output = VoyageExt;
58 type ApiKey = VoyageApiKey;
59
60 const BASE_URL: &'static str = VOYAGEAI_API_BASE_URL;
61}
62
63pub type Client<H = reqwest::Client> = client::Client<VoyageExt, H>;
64pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<VoyageBuilder, VoyageApiKey, H>;
65
66impl ProviderClient for Client {
67 type Input = String;
68
69 fn from_env() -> Self {
72 let api_key = std::env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY not set");
73 Self::new(&api_key).unwrap()
74 }
75
76 fn from_val(input: Self::Input) -> Self {
77 Self::new(&input).unwrap()
78 }
79}
80
81impl<T> EmbeddingModel<T> {
82 pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
83 Self {
84 client,
85 model: model.into(),
86 ndims,
87 }
88 }
89
90 pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
91 Self {
92 client,
93 model: model.into(),
94 ndims,
95 }
96 }
97}
98
99pub const VOYAGE_3_LARGE: &str = "voyage-3-large";
105pub const VOYAGE_3_5: &str = "voyage-3.5";
107pub const VOYAGE_3_5_LITE: &str = "voyage.3-5.lite";
109pub const VOYAGE_CODE_3: &str = "voyage-code-3";
111pub const VOYAGE_FINANCE_2: &str = "voyage-finance-2";
113pub const VOYAGE_LAW_2: &str = "voyage-law-2";
115pub const VOYAGE_CODE_2: &str = "voyage-code-2";
117
118pub fn model_dimensions_from_identifier(model_identifier: &str) -> Option<usize> {
119 match model_identifier {
120 "voyage-code-2" => Some(1536),
121 "voyage-3-large" | "voyage-3.5" | "voyage.3-5.lite" | "voyage-code-3"
122 | "voyage-finance-2" | "voyage-law-2" => Some(1024),
123 _ => None,
124 }
125}
126
127#[derive(Debug, Deserialize)]
128pub struct EmbeddingResponse {
129 pub object: String,
130 pub data: Vec<EmbeddingData>,
131 pub model: String,
132 pub usage: Usage,
133}
134
135#[derive(Clone, Debug, Deserialize)]
136pub struct Usage {
137 pub prompt_tokens: usize,
138 pub total_tokens: usize,
139}
140
141#[derive(Debug, Deserialize)]
142pub struct ApiErrorResponse {
143 pub(crate) message: String,
144}
145
146impl From<ApiErrorResponse> for EmbeddingError {
147 fn from(err: ApiErrorResponse) -> Self {
148 EmbeddingError::ProviderError(err.message)
149 }
150}
151
152#[derive(Debug, Deserialize)]
153#[serde(untagged)]
154pub(crate) enum ApiResponse<T> {
155 Ok(T),
156 Err(ApiErrorResponse),
157}
158
159impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
160 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
161 match value {
162 ApiResponse::Ok(response) => Ok(response),
163 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
164 }
165 }
166}
167
168#[derive(Debug, Deserialize)]
169pub struct EmbeddingData {
170 pub object: String,
171 pub embedding: Vec<f64>,
172 pub index: usize,
173}
174
175#[derive(Clone)]
176pub struct EmbeddingModel<T> {
177 client: Client<T>,
178 pub model: String,
179 ndims: usize,
180}
181
182impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
183where
184 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
185{
186 const MAX_DOCUMENTS: usize = 1024;
187
188 type Client = Client<T>;
189
190 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
191 let model = model.into();
192 let dims = dims
193 .or(model_dimensions_from_identifier(&model))
194 .unwrap_or_default();
195
196 Self::new(client.clone(), model, dims)
197 }
198
199 fn ndims(&self) -> usize {
200 self.ndims
201 }
202
203 async fn embed_texts(
204 &self,
205 documents: impl IntoIterator<Item = String>,
206 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
207 let documents = documents.into_iter().collect::<Vec<_>>();
208 let request = json!({
209 "model": self.model,
210 "input": documents,
211 });
212
213 let body = serde_json::to_vec(&request)?;
214
215 let req = self
216 .client
217 .post("/embeddings")?
218 .body(body)
219 .map_err(|x| EmbeddingError::HttpError(x.into()))?;
220
221 let response = self.client.send::<_, Bytes>(req).await?;
222 let status = response.status();
223 let response_body = response.into_body().into_future().await?.to_vec();
224
225 if status.is_success() {
226 match serde_json::from_slice::<ApiResponse<EmbeddingResponse>>(&response_body)? {
227 ApiResponse::Ok(response) => {
228 tracing::info!(target: "rig",
229 "VoyageAI embedding token usage: {}",
230 response.usage.total_tokens
231 );
232
233 if response.data.len() != documents.len() {
234 return Err(EmbeddingError::ResponseError(
235 "Response data length does not match input length".into(),
236 ));
237 }
238
239 Ok(response
240 .data
241 .into_iter()
242 .zip(documents.into_iter())
243 .map(|(embedding, document)| embeddings::Embedding {
244 document,
245 vec: embedding.embedding,
246 })
247 .collect())
248 }
249 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
250 }
251 } else {
252 Err(EmbeddingError::ProviderError(
253 String::from_utf8_lossy(&response_body).to_string(),
254 ))
255 }
256 }
257}
258#[cfg(test)]
259mod tests {
260 #[test]
261 fn test_client_initialization() {
262 let _client: crate::providers::voyageai::Client =
263 crate::providers::voyageai::Client::new("dummy-key").expect("Client::new() failed");
264 let _client_from_builder: crate::providers::voyageai::Client =
265 crate::providers::voyageai::Client::builder()
266 .api_key("dummy-key")
267 .build()
268 .expect("Client::builder() failed");
269 }
270}