Skip to main content

rig/providers/
voyageai.rs

1use crate::client::{
2    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
3    ProviderClient,
4};
5use crate::embeddings;
6use crate::embeddings::EmbeddingError;
7use crate::http_client::{self, HttpClientExt};
8use bytes::Bytes;
9use serde::Deserialize;
10use serde_json::json;
11
12// ================================================================
13// Main Voyage AI Client
14// ================================================================
15const VOYAGEAI_API_BASE_URL: &str = "https://api.voyageai.com/v1";
16
17#[derive(Debug, Default, Clone, Copy)]
18pub struct VoyageExt;
19
20#[derive(Debug, Default, Clone, Copy)]
21pub struct VoyageBuilder;
22
23type VoyageApiKey = BearerAuth;
24
25impl Provider for VoyageExt {
26    type Builder = VoyageBuilder;
27
28    /// There is currently no way to verify a Voyage api key without consuming tokens
29    const VERIFY_PATH: &'static str = "";
30
31    fn build<H>(
32        _: &crate::client::ClientBuilder<
33            Self::Builder,
34            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
35            H,
36        >,
37    ) -> http_client::Result<Self> {
38        Ok(Self)
39    }
40}
41
42impl<H> Capabilities<H> for VoyageExt {
43    type Completion = Nothing;
44    type Embeddings = Capable<EmbeddingModel<H>>;
45    type Transcription = Nothing;
46    type ModelListing = Nothing;
47    #[cfg(feature = "image")]
48    type ImageGeneration = Nothing;
49
50    #[cfg(feature = "audio")]
51    type AudioGeneration = Nothing;
52}
53
54impl DebugExt for VoyageExt {}
55
56impl ProviderBuilder for VoyageBuilder {
57    type Output = VoyageExt;
58    type ApiKey = VoyageApiKey;
59
60    const BASE_URL: &'static str = VOYAGEAI_API_BASE_URL;
61}
62
63pub type Client<H = reqwest::Client> = client::Client<VoyageExt, H>;
64pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<VoyageBuilder, VoyageApiKey, H>;
65
66impl ProviderClient for Client {
67    type Input = String;
68
69    /// Create a new OpenAI client from the `OPENAI_API_KEY` environment variable.
70    /// Panics if the environment variable is not set.
71    fn from_env() -> Self {
72        let api_key = std::env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY not set");
73        Self::new(&api_key).unwrap()
74    }
75
76    fn from_val(input: Self::Input) -> Self {
77        Self::new(&input).unwrap()
78    }
79}
80
81impl<T> EmbeddingModel<T> {
82    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
83        Self {
84            client,
85            model: model.into(),
86            ndims,
87        }
88    }
89
90    pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
91        Self {
92            client,
93            model: model.into(),
94            ndims,
95        }
96    }
97}
98
99// ================================================================
100// Voyage AI Embedding API
101// ================================================================
102
103/// `voyage-3-large` embedding model (Voyage AI)
104pub const VOYAGE_3_LARGE: &str = "voyage-3-large";
105/// `voyage-3.5` embedding model (Voyage AI)
106pub const VOYAGE_3_5: &str = "voyage-3.5";
107/// `voyage-3.5-lite` embedding model (Voyage AI)
108pub const VOYAGE_3_5_LITE: &str = "voyage.3-5.lite";
109/// `voyage-code-3` embedding model (Voyage AI)
110pub const VOYAGE_CODE_3: &str = "voyage-code-3";
111/// `voyage-finance-2` embedding model (Voyage AI)
112pub const VOYAGE_FINANCE_2: &str = "voyage-finance-2";
113/// `voyage-law-2` embedding model (Voyage AI)
114pub const VOYAGE_LAW_2: &str = "voyage-law-2";
115/// `voyage-code-2` embedding model (Voyage AI)
116pub const VOYAGE_CODE_2: &str = "voyage-code-2";
117
118pub fn model_dimensions_from_identifier(model_identifier: &str) -> Option<usize> {
119    match model_identifier {
120        "voyage-code-2" => Some(1536),
121        "voyage-3-large" | "voyage-3.5" | "voyage.3-5.lite" | "voyage-code-3"
122        | "voyage-finance-2" | "voyage-law-2" => Some(1024),
123        _ => None,
124    }
125}
126
127#[derive(Debug, Deserialize)]
128pub struct EmbeddingResponse {
129    pub object: String,
130    pub data: Vec<EmbeddingData>,
131    pub model: String,
132    pub usage: Usage,
133}
134
135#[derive(Clone, Debug, Deserialize)]
136pub struct Usage {
137    pub prompt_tokens: usize,
138    pub total_tokens: usize,
139}
140
141#[derive(Debug, Deserialize)]
142pub struct ApiErrorResponse {
143    pub(crate) message: String,
144}
145
146impl From<ApiErrorResponse> for EmbeddingError {
147    fn from(err: ApiErrorResponse) -> Self {
148        EmbeddingError::ProviderError(err.message)
149    }
150}
151
152#[derive(Debug, Deserialize)]
153#[serde(untagged)]
154pub(crate) enum ApiResponse<T> {
155    Ok(T),
156    Err(ApiErrorResponse),
157}
158
159impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
160    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
161        match value {
162            ApiResponse::Ok(response) => Ok(response),
163            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
164        }
165    }
166}
167
168#[derive(Debug, Deserialize)]
169pub struct EmbeddingData {
170    pub object: String,
171    pub embedding: Vec<f64>,
172    pub index: usize,
173}
174
175#[derive(Clone)]
176pub struct EmbeddingModel<T> {
177    client: Client<T>,
178    pub model: String,
179    ndims: usize,
180}
181
182impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
183where
184    T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
185{
186    const MAX_DOCUMENTS: usize = 1024;
187
188    type Client = Client<T>;
189
190    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
191        let model = model.into();
192        let dims = dims
193            .or(model_dimensions_from_identifier(&model))
194            .unwrap_or_default();
195
196        Self::new(client.clone(), model, dims)
197    }
198
199    fn ndims(&self) -> usize {
200        self.ndims
201    }
202
203    async fn embed_texts(
204        &self,
205        documents: impl IntoIterator<Item = String>,
206    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
207        let documents = documents.into_iter().collect::<Vec<_>>();
208        let request = json!({
209            "model": self.model,
210            "input": documents,
211        });
212
213        let body = serde_json::to_vec(&request)?;
214
215        let req = self
216            .client
217            .post("/embeddings")?
218            .body(body)
219            .map_err(|x| EmbeddingError::HttpError(x.into()))?;
220
221        let response = self.client.send::<_, Bytes>(req).await?;
222        let status = response.status();
223        let response_body = response.into_body().into_future().await?.to_vec();
224
225        if status.is_success() {
226            match serde_json::from_slice::<ApiResponse<EmbeddingResponse>>(&response_body)? {
227                ApiResponse::Ok(response) => {
228                    tracing::info!(target: "rig",
229                        "VoyageAI embedding token usage: {}",
230                        response.usage.total_tokens
231                    );
232
233                    if response.data.len() != documents.len() {
234                        return Err(EmbeddingError::ResponseError(
235                            "Response data length does not match input length".into(),
236                        ));
237                    }
238
239                    Ok(response
240                        .data
241                        .into_iter()
242                        .zip(documents.into_iter())
243                        .map(|(embedding, document)| embeddings::Embedding {
244                            document,
245                            vec: embedding.embedding,
246                        })
247                        .collect())
248                }
249                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
250            }
251        } else {
252            Err(EmbeddingError::ProviderError(
253                String::from_utf8_lossy(&response_body).to_string(),
254            ))
255        }
256    }
257}
258#[cfg(test)]
259mod tests {
260    #[test]
261    fn test_client_initialization() {
262        let _client: crate::providers::voyageai::Client =
263            crate::providers::voyageai::Client::new("dummy-key").expect("Client::new() failed");
264        let _client_from_builder: crate::providers::voyageai::Client =
265            crate::providers::voyageai::Client::builder()
266                .api_key("dummy-key")
267                .build()
268                .expect("Client::builder() failed");
269    }
270}