Skip to main content

rig/providers/
voyageai.rs

1use crate::client::{
2    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
3    ProviderClient,
4};
5use crate::embeddings;
6use crate::embeddings::EmbeddingError;
7use crate::http_client::{self, HttpClientExt};
8use bytes::Bytes;
9use serde::Deserialize;
10use serde_json::json;
11
12// ================================================================
13// Main Voyage AI Client
14// ================================================================
15const VOYAGEAI_API_BASE_URL: &str = "https://api.voyageai.com/v1";
16
17#[derive(Debug, Default, Clone, Copy)]
18pub struct VoyageExt;
19
20#[derive(Debug, Default, Clone, Copy)]
21pub struct VoyageBuilder;
22
23type VoyageApiKey = BearerAuth;
24
25impl Provider for VoyageExt {
26    type Builder = VoyageBuilder;
27
28    /// There is currently no way to verify a Voyage api key without consuming tokens
29    const VERIFY_PATH: &'static str = "";
30}
31
32impl<H> Capabilities<H> for VoyageExt {
33    type Completion = Nothing;
34    type Embeddings = Capable<EmbeddingModel<H>>;
35    type Transcription = Nothing;
36    type ModelListing = Nothing;
37    #[cfg(feature = "image")]
38    type ImageGeneration = Nothing;
39
40    #[cfg(feature = "audio")]
41    type AudioGeneration = Nothing;
42}
43
44impl DebugExt for VoyageExt {}
45
46impl ProviderBuilder for VoyageBuilder {
47    type Extension<H>
48        = VoyageExt
49    where
50        H: HttpClientExt;
51    type ApiKey = VoyageApiKey;
52
53    const BASE_URL: &'static str = VOYAGEAI_API_BASE_URL;
54
55    fn build<H>(
56        _builder: &crate::client::ClientBuilder<Self, Self::ApiKey, H>,
57    ) -> http_client::Result<Self::Extension<H>>
58    where
59        H: HttpClientExt,
60    {
61        Ok(VoyageExt)
62    }
63}
64
65pub type Client<H = reqwest::Client> = client::Client<VoyageExt, H>;
66pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<VoyageBuilder, VoyageApiKey, H>;
67
68impl ProviderClient for Client {
69    type Input = String;
70
71    /// Create a new OpenAI client from the `OPENAI_API_KEY` environment variable.
72    /// Panics if the environment variable is not set.
73    fn from_env() -> Self {
74        let api_key = std::env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY not set");
75        Self::new(&api_key).unwrap()
76    }
77
78    fn from_val(input: Self::Input) -> Self {
79        Self::new(&input).unwrap()
80    }
81}
82
83impl<T> EmbeddingModel<T> {
84    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
85        Self {
86            client,
87            model: model.into(),
88            ndims,
89        }
90    }
91
92    pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
93        Self {
94            client,
95            model: model.into(),
96            ndims,
97        }
98    }
99}
100
101// ================================================================
102// Voyage AI Embedding API
103// ================================================================
104
105/// `voyage-3-large` embedding model (Voyage AI)
106pub const VOYAGE_3_LARGE: &str = "voyage-3-large";
107/// `voyage-3.5` embedding model (Voyage AI)
108pub const VOYAGE_3_5: &str = "voyage-3.5";
109/// `voyage-3.5-lite` embedding model (Voyage AI)
110pub const VOYAGE_3_5_LITE: &str = "voyage.3-5.lite";
111/// `voyage-code-3` embedding model (Voyage AI)
112pub const VOYAGE_CODE_3: &str = "voyage-code-3";
113/// `voyage-finance-2` embedding model (Voyage AI)
114pub const VOYAGE_FINANCE_2: &str = "voyage-finance-2";
115/// `voyage-law-2` embedding model (Voyage AI)
116pub const VOYAGE_LAW_2: &str = "voyage-law-2";
117/// `voyage-code-2` embedding model (Voyage AI)
118pub const VOYAGE_CODE_2: &str = "voyage-code-2";
119
120pub fn model_dimensions_from_identifier(model_identifier: &str) -> Option<usize> {
121    match model_identifier {
122        "voyage-code-2" => Some(1536),
123        "voyage-3-large" | "voyage-3.5" | "voyage.3-5.lite" | "voyage-code-3"
124        | "voyage-finance-2" | "voyage-law-2" => Some(1024),
125        _ => None,
126    }
127}
128
129#[derive(Debug, Deserialize)]
130pub struct EmbeddingResponse {
131    pub object: String,
132    pub data: Vec<EmbeddingData>,
133    pub model: String,
134    pub usage: Usage,
135}
136
137#[derive(Clone, Debug, Deserialize)]
138pub struct Usage {
139    pub prompt_tokens: usize,
140    pub total_tokens: usize,
141}
142
143#[derive(Debug, Deserialize)]
144pub struct ApiErrorResponse {
145    pub(crate) message: String,
146}
147
148impl From<ApiErrorResponse> for EmbeddingError {
149    fn from(err: ApiErrorResponse) -> Self {
150        EmbeddingError::ProviderError(err.message)
151    }
152}
153
154#[derive(Debug, Deserialize)]
155#[serde(untagged)]
156pub(crate) enum ApiResponse<T> {
157    Ok(T),
158    Err(ApiErrorResponse),
159}
160
161impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
162    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
163        match value {
164            ApiResponse::Ok(response) => Ok(response),
165            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
166        }
167    }
168}
169
170#[derive(Debug, Deserialize)]
171pub struct EmbeddingData {
172    pub object: String,
173    pub embedding: Vec<f64>,
174    pub index: usize,
175}
176
177#[derive(Clone)]
178pub struct EmbeddingModel<T> {
179    client: Client<T>,
180    pub model: String,
181    ndims: usize,
182}
183
184impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
185where
186    T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
187{
188    const MAX_DOCUMENTS: usize = 1024;
189
190    type Client = Client<T>;
191
192    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
193        let model = model.into();
194        let dims = dims
195            .or(model_dimensions_from_identifier(&model))
196            .unwrap_or_default();
197
198        Self::new(client.clone(), model, dims)
199    }
200
201    fn ndims(&self) -> usize {
202        self.ndims
203    }
204
205    async fn embed_texts(
206        &self,
207        documents: impl IntoIterator<Item = String>,
208    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
209        let documents = documents.into_iter().collect::<Vec<_>>();
210        let request = json!({
211            "model": self.model,
212            "input": documents,
213        });
214
215        let body = serde_json::to_vec(&request)?;
216
217        let req = self
218            .client
219            .post("/embeddings")?
220            .body(body)
221            .map_err(|x| EmbeddingError::HttpError(x.into()))?;
222
223        let response = self.client.send::<_, Bytes>(req).await?;
224        let status = response.status();
225        let response_body = response.into_body().into_future().await?.to_vec();
226
227        if status.is_success() {
228            match serde_json::from_slice::<ApiResponse<EmbeddingResponse>>(&response_body)? {
229                ApiResponse::Ok(response) => {
230                    tracing::info!(target: "rig",
231                        "VoyageAI embedding token usage: {}",
232                        response.usage.total_tokens
233                    );
234
235                    if response.data.len() != documents.len() {
236                        return Err(EmbeddingError::ResponseError(
237                            "Response data length does not match input length".into(),
238                        ));
239                    }
240
241                    Ok(response
242                        .data
243                        .into_iter()
244                        .zip(documents.into_iter())
245                        .map(|(embedding, document)| embeddings::Embedding {
246                            document,
247                            vec: embedding.embedding,
248                        })
249                        .collect())
250                }
251                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
252            }
253        } else {
254            Err(EmbeddingError::ProviderError(
255                String::from_utf8_lossy(&response_body).to_string(),
256            ))
257        }
258    }
259}
260#[cfg(test)]
261mod tests {
262    #[test]
263    fn test_client_initialization() {
264        let _client =
265            crate::providers::voyageai::Client::new("dummy-key").expect("Client::new() failed");
266        let _client_from_builder = crate::providers::voyageai::Client::builder()
267            .api_key("dummy-key")
268            .build()
269            .expect("Client::builder() failed");
270    }
271}