rig/providers/
voyageai.rs1use crate::client::{
2 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
3 ProviderClient,
4};
5use crate::embeddings;
6use crate::embeddings::EmbeddingError;
7use crate::http_client::{self, HttpClientExt};
8use bytes::Bytes;
9use serde::Deserialize;
10use serde_json::json;
11
12const VOYAGEAI_API_BASE_URL: &str = "https://api.voyageai.com/v1";
16
17#[derive(Debug, Default, Clone, Copy)]
18pub struct VoyageExt;
19
20#[derive(Debug, Default, Clone, Copy)]
21pub struct VoyageBuilder;
22
23type VoyageApiKey = BearerAuth;
24
25impl Provider for VoyageExt {
26 type Builder = VoyageBuilder;
27
28 const VERIFY_PATH: &'static str = "";
30
31 fn build<H>(
32 _: &crate::client::ClientBuilder<
33 Self::Builder,
34 <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
35 H,
36 >,
37 ) -> http_client::Result<Self> {
38 Ok(Self)
39 }
40}
41
42impl<H> Capabilities<H> for VoyageExt {
43 type Completion = Nothing;
44 type Embeddings = Capable<EmbeddingModel<H>>;
45 type Transcription = Nothing;
46 #[cfg(feature = "image")]
47 type ImageGeneration = Nothing;
48
49 #[cfg(feature = "audio")]
50 type AudioGeneration = Nothing;
51}
52
53impl DebugExt for VoyageExt {}
54
55impl ProviderBuilder for VoyageBuilder {
56 type Output = VoyageExt;
57 type ApiKey = VoyageApiKey;
58
59 const BASE_URL: &'static str = VOYAGEAI_API_BASE_URL;
60}
61
62pub type Client<H = reqwest::Client> = client::Client<VoyageExt, H>;
63pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<VoyageBuilder, VoyageApiKey, H>;
64
65impl ProviderClient for Client {
66 type Input = String;
67
68 fn from_env() -> Self {
71 let api_key = std::env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY not set");
72 Self::new(&api_key).unwrap()
73 }
74
75 fn from_val(input: Self::Input) -> Self {
76 Self::new(&input).unwrap()
77 }
78}
79
80impl<T> EmbeddingModel<T> {
81 pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
82 Self {
83 client,
84 model: model.into(),
85 ndims,
86 }
87 }
88
89 pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
90 Self {
91 client,
92 model: model.into(),
93 ndims,
94 }
95 }
96}
97
98pub const VOYAGE_3_LARGE: &str = "voyage-3-large";
104pub const VOYAGE_3_5: &str = "voyage-3.5";
106pub const VOYAGE_3_5_LITE: &str = "voyage.3-5.lite";
108pub const VOYAGE_CODE_3: &str = "voyage-code-3";
110pub const VOYAGE_FINANCE_2: &str = "voyage-finance-2";
112pub const VOYAGE_LAW_2: &str = "voyage-law-2";
114pub const VOYAGE_CODE_2: &str = "voyage-code-2";
116
117pub fn model_dimensions_from_identifier(model_identifier: &str) -> Option<usize> {
118 match model_identifier {
119 "voyage-code-2" => Some(1536),
120 "voyage-3-large" | "voyage-3.5" | "voyage.3-5.lite" | "voyage-code-3"
121 | "voyage-finance-2" | "voyage-law-2" => Some(1024),
122 _ => None,
123 }
124}
125
126#[derive(Debug, Deserialize)]
127pub struct EmbeddingResponse {
128 pub object: String,
129 pub data: Vec<EmbeddingData>,
130 pub model: String,
131 pub usage: Usage,
132}
133
134#[derive(Clone, Debug, Deserialize)]
135pub struct Usage {
136 pub prompt_tokens: usize,
137 pub total_tokens: usize,
138}
139
140#[derive(Debug, Deserialize)]
141pub struct ApiErrorResponse {
142 pub(crate) message: String,
143}
144
145impl From<ApiErrorResponse> for EmbeddingError {
146 fn from(err: ApiErrorResponse) -> Self {
147 EmbeddingError::ProviderError(err.message)
148 }
149}
150
151#[derive(Debug, Deserialize)]
152#[serde(untagged)]
153pub(crate) enum ApiResponse<T> {
154 Ok(T),
155 Err(ApiErrorResponse),
156}
157
158impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
159 fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
160 match value {
161 ApiResponse::Ok(response) => Ok(response),
162 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
163 }
164 }
165}
166
167#[derive(Debug, Deserialize)]
168pub struct EmbeddingData {
169 pub object: String,
170 pub embedding: Vec<f64>,
171 pub index: usize,
172}
173
174#[derive(Clone)]
175pub struct EmbeddingModel<T> {
176 client: Client<T>,
177 pub model: String,
178 ndims: usize,
179}
180
181impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
182where
183 T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
184{
185 const MAX_DOCUMENTS: usize = 1024;
186
187 type Client = Client<T>;
188
189 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
190 let model = model.into();
191 let dims = dims
192 .or(model_dimensions_from_identifier(&model))
193 .unwrap_or_default();
194
195 Self::new(client.clone(), model, dims)
196 }
197
198 fn ndims(&self) -> usize {
199 self.ndims
200 }
201
202 #[cfg_attr(feature = "worker", worker::send)]
203 async fn embed_texts(
204 &self,
205 documents: impl IntoIterator<Item = String>,
206 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
207 let documents = documents.into_iter().collect::<Vec<_>>();
208 let request = json!({
209 "model": self.model,
210 "input": documents,
211 });
212
213 let body = serde_json::to_vec(&request)?;
214
215 let req = self
216 .client
217 .post("/embeddings")?
218 .body(body)
219 .map_err(|x| EmbeddingError::HttpError(x.into()))?;
220
221 let response = self.client.send::<_, Bytes>(req).await?;
222 let status = response.status();
223 let response_body = response.into_body().into_future().await?.to_vec();
224
225 if status.is_success() {
226 match serde_json::from_slice::<ApiResponse<EmbeddingResponse>>(&response_body)? {
227 ApiResponse::Ok(response) => {
228 tracing::info!(target: "rig",
229 "VoyageAI embedding token usage: {}",
230 response.usage.total_tokens
231 );
232
233 if response.data.len() != documents.len() {
234 return Err(EmbeddingError::ResponseError(
235 "Response data length does not match input length".into(),
236 ));
237 }
238
239 Ok(response
240 .data
241 .into_iter()
242 .zip(documents.into_iter())
243 .map(|(embedding, document)| embeddings::Embedding {
244 document,
245 vec: embedding.embedding,
246 })
247 .collect())
248 }
249 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
250 }
251 } else {
252 Err(EmbeddingError::ProviderError(
253 String::from_utf8_lossy(&response_body).to_string(),
254 ))
255 }
256 }
257}