Skip to main content

rig/providers/
voyageai.rs

1use crate::client::{
2    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
3    ProviderClient,
4};
5use crate::embeddings;
6use crate::embeddings::EmbeddingError;
7use crate::http_client::{self, HttpClientExt};
8use bytes::Bytes;
9use serde::Deserialize;
10use serde_json::json;
11
12// ================================================================
13// Main Voyage AI Client
14// ================================================================
15const VOYAGEAI_API_BASE_URL: &str = "https://api.voyageai.com/v1";
16
17#[derive(Debug, Default, Clone, Copy)]
18pub struct VoyageExt;
19
20#[derive(Debug, Default, Clone, Copy)]
21pub struct VoyageBuilder;
22
23type VoyageApiKey = BearerAuth;
24
25impl Provider for VoyageExt {
26    type Builder = VoyageBuilder;
27
28    /// There is currently no way to verify a Voyage api key without consuming tokens
29    const VERIFY_PATH: &'static str = "";
30}
31
32impl<H> Capabilities<H> for VoyageExt {
33    type Completion = Nothing;
34    type Embeddings = Capable<EmbeddingModel<H>>;
35    type Transcription = Nothing;
36    type ModelListing = Nothing;
37    #[cfg(feature = "image")]
38    type ImageGeneration = Nothing;
39
40    #[cfg(feature = "audio")]
41    type AudioGeneration = Nothing;
42}
43
44impl DebugExt for VoyageExt {}
45
46impl ProviderBuilder for VoyageBuilder {
47    type Extension<H>
48        = VoyageExt
49    where
50        H: HttpClientExt;
51    type ApiKey = VoyageApiKey;
52
53    const BASE_URL: &'static str = VOYAGEAI_API_BASE_URL;
54
55    fn build<H>(
56        _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
57    ) -> http_client::Result<Self::Extension<H>>
58    where
59        H: HttpClientExt,
60    {
61        Ok(VoyageExt)
62    }
63}
64
65pub type Client<H = reqwest::Client> = client::Client<VoyageExt, H>;
66pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<VoyageBuilder, VoyageApiKey, H>;
67
68impl ProviderClient for Client {
69    type Input = String;
70
71    /// Create a new OpenAI client from the `OPENAI_API_KEY` environment variable.
72    /// Panics if the environment variable is not set.
73    fn from_env() -> Self {
74        let api_key = std::env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY not set");
75        Self::new(&api_key).unwrap()
76    }
77
78    fn from_val(input: Self::Input) -> Self {
79        Self::new(&input).unwrap()
80    }
81}
82
83impl<T> EmbeddingModel<T> {
84    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
85        Self {
86            client,
87            model: model.into(),
88            ndims,
89        }
90    }
91
92    pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
93        Self {
94            client,
95            model: model.into(),
96            ndims,
97        }
98    }
99}
100
101// ================================================================
102// Voyage AI Embedding API
103// ================================================================
104
105/// `voyage-3-large` embedding model (Voyage AI)
106pub const VOYAGE_3_LARGE: &str = "voyage-3-large";
107/// `voyage-3.5` embedding model (Voyage AI)
108pub const VOYAGE_3_5: &str = "voyage-3.5";
109/// `voyage-3.5-lite` embedding model (Voyage AI)
110pub const VOYAGE_3_5_LITE: &str = "voyage.3-5.lite";
111/// `voyage-code-3` embedding model (Voyage AI)
112pub const VOYAGE_CODE_3: &str = "voyage-code-3";
113/// `voyage-finance-2` embedding model (Voyage AI)
114pub const VOYAGE_FINANCE_2: &str = "voyage-finance-2";
115/// `voyage-law-2` embedding model (Voyage AI)
116pub const VOYAGE_LAW_2: &str = "voyage-law-2";
117/// `voyage-code-2` embedding model (Voyage AI)
118pub const VOYAGE_CODE_2: &str = "voyage-code-2";
119
120pub fn model_dimensions_from_identifier(model_identifier: &str) -> Option<usize> {
121    match model_identifier {
122        "voyage-code-2" => Some(1536),
123        "voyage-3-large" | "voyage-3.5" | "voyage.3-5.lite" | "voyage-code-3"
124        | "voyage-finance-2" | "voyage-law-2" => Some(1024),
125        _ => None,
126    }
127}
128
129#[derive(Debug, Deserialize)]
130pub struct EmbeddingResponse {
131    pub object: String,
132    pub data: Vec<EmbeddingData>,
133    pub model: String,
134    pub usage: Usage,
135}
136
137#[derive(Clone, Debug, Deserialize)]
138pub struct Usage {
139    pub total_tokens: usize,
140}
141
142#[derive(Debug, Deserialize)]
143pub struct ApiErrorResponse {
144    pub(crate) message: String,
145}
146
147impl From<ApiErrorResponse> for EmbeddingError {
148    fn from(err: ApiErrorResponse) -> Self {
149        EmbeddingError::ProviderError(err.message)
150    }
151}
152
153#[derive(Debug, Deserialize)]
154#[serde(untagged)]
155pub(crate) enum ApiResponse<T> {
156    Ok(T),
157    Err(ApiErrorResponse),
158}
159
160impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
161    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
162        match value {
163            ApiResponse::Ok(response) => Ok(response),
164            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
165        }
166    }
167}
168
169#[derive(Debug, Deserialize)]
170pub struct EmbeddingData {
171    pub object: String,
172    pub embedding: Vec<f64>,
173    pub index: usize,
174}
175
176#[derive(Clone)]
177pub struct EmbeddingModel<T> {
178    client: Client<T>,
179    pub model: String,
180    ndims: usize,
181}
182
183impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
184where
185    T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
186{
187    const MAX_DOCUMENTS: usize = 1024;
188
189    type Client = Client<T>;
190
191    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
192        let model = model.into();
193        let dims = dims
194            .or(model_dimensions_from_identifier(&model))
195            .unwrap_or_default();
196
197        Self::new(client.clone(), model, dims)
198    }
199
200    fn ndims(&self) -> usize {
201        self.ndims
202    }
203
204    async fn embed_texts(
205        &self,
206        documents: impl IntoIterator<Item = String>,
207    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
208        let documents = documents.into_iter().collect::<Vec<_>>();
209        let request = json!({
210            "model": self.model,
211            "input": documents,
212        });
213
214        let body = serde_json::to_vec(&request)?;
215
216        let req = self
217            .client
218            .post("/embeddings")?
219            .body(body)
220            .map_err(|x| EmbeddingError::HttpError(x.into()))?;
221
222        let response = self.client.send::<_, Bytes>(req).await?;
223        let status = response.status();
224        let response_body = response.into_body().into_future().await?.to_vec();
225
226        if status.is_success() {
227            match serde_json::from_slice::<ApiResponse<EmbeddingResponse>>(&response_body)? {
228                ApiResponse::Ok(response) => {
229                    tracing::info!(target: "rig",
230                        "VoyageAI embedding token usage: {}",
231                        response.usage.total_tokens
232                    );
233
234                    if response.data.len() != documents.len() {
235                        return Err(EmbeddingError::ResponseError(
236                            "Response data length does not match input length".into(),
237                        ));
238                    }
239
240                    Ok(response
241                        .data
242                        .into_iter()
243                        .zip(documents.into_iter())
244                        .map(|(embedding, document)| embeddings::Embedding {
245                            document,
246                            vec: embedding.embedding,
247                        })
248                        .collect())
249                }
250                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
251            }
252        } else {
253            Err(EmbeddingError::ProviderError(
254                String::from_utf8_lossy(&response_body).to_string(),
255            ))
256        }
257    }
258}
259#[cfg(test)]
260mod tests {
261    #[test]
262    fn test_client_initialization() {
263        let _client =
264            crate::providers::voyageai::Client::new("dummy-key").expect("Client::new() failed");
265        let _client_from_builder = crate::providers::voyageai::Client::builder()
266            .api_key("dummy-key")
267            .build()
268            .expect("Client::builder() failed");
269    }
270}