rig/providers/
voyageai.rs

1use crate::client::{
2    self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
3    ProviderClient,
4};
5use crate::embeddings;
6use crate::embeddings::EmbeddingError;
7use crate::http_client::{self, HttpClientExt};
8use bytes::Bytes;
9use serde::Deserialize;
10use serde_json::json;
11
12// ================================================================
13// Main Voyage AI Client
14// ================================================================
15const VOYAGEAI_API_BASE_URL: &str = "https://api.voyageai.com/v1";
16
17#[derive(Debug, Default, Clone, Copy)]
18pub struct VoyageExt;
19
20#[derive(Debug, Default, Clone, Copy)]
21pub struct VoyageBuilder;
22
23type VoyageApiKey = BearerAuth;
24
25impl Provider for VoyageExt {
26    type Builder = VoyageBuilder;
27
28    /// There is currently no way to verify a Voyage api key without consuming tokens
29    const VERIFY_PATH: &'static str = "";
30
31    fn build<H>(
32        _: &crate::client::ClientBuilder<
33            Self::Builder,
34            <Self::Builder as crate::client::ProviderBuilder>::ApiKey,
35            H,
36        >,
37    ) -> http_client::Result<Self> {
38        Ok(Self)
39    }
40}
41
42impl<H> Capabilities<H> for VoyageExt {
43    type Completion = Nothing;
44    type Embeddings = Capable<EmbeddingModel<H>>;
45    type Transcription = Nothing;
46    #[cfg(feature = "image")]
47    type ImageGeneration = Nothing;
48
49    #[cfg(feature = "audio")]
50    type AudioGeneration = Nothing;
51}
52
53impl DebugExt for VoyageExt {}
54
55impl ProviderBuilder for VoyageBuilder {
56    type Output = VoyageExt;
57    type ApiKey = VoyageApiKey;
58
59    const BASE_URL: &'static str = VOYAGEAI_API_BASE_URL;
60}
61
62pub type Client<H = reqwest::Client> = client::Client<VoyageExt, H>;
63pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<VoyageBuilder, VoyageApiKey, H>;
64
65impl ProviderClient for Client {
66    type Input = String;
67
68    /// Create a new OpenAI client from the `OPENAI_API_KEY` environment variable.
69    /// Panics if the environment variable is not set.
70    fn from_env() -> Self {
71        let api_key = std::env::var("VOYAGE_API_KEY").expect("VOYAGE_API_KEY not set");
72        Self::new(&api_key).unwrap()
73    }
74
75    fn from_val(input: Self::Input) -> Self {
76        Self::new(&input).unwrap()
77    }
78}
79
80impl<T> EmbeddingModel<T> {
81    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
82        Self {
83            client,
84            model: model.into(),
85            ndims,
86        }
87    }
88
89    pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
90        Self {
91            client,
92            model: model.into(),
93            ndims,
94        }
95    }
96}
97
98// ================================================================
99// Voyage AI Embedding API
100// ================================================================
101
102/// `voyage-3-large` embedding model (Voyage AI)
103pub const VOYAGE_3_LARGE: &str = "voyage-3-large";
104/// `voyage-3.5` embedding model (Voyage AI)
105pub const VOYAGE_3_5: &str = "voyage-3.5";
106/// `voyage-3.5-lite` embedding model (Voyage AI)
107pub const VOYAGE_3_5_LITE: &str = "voyage.3-5.lite";
108/// `voyage-code-3` embedding model (Voyage AI)
109pub const VOYAGE_CODE_3: &str = "voyage-code-3";
110/// `voyage-finance-2` embedding model (Voyage AI)
111pub const VOYAGE_FINANCE_2: &str = "voyage-finance-2";
112/// `voyage-law-2` embedding model (Voyage AI)
113pub const VOYAGE_LAW_2: &str = "voyage-law-2";
114/// `voyage-code-2` embedding model (Voyage AI)
115pub const VOYAGE_CODE_2: &str = "voyage-code-2";
116
117pub fn model_dimensions_from_identifier(model_identifier: &str) -> Option<usize> {
118    match model_identifier {
119        "voyage-code-2" => Some(1536),
120        "voyage-3-large" | "voyage-3.5" | "voyage.3-5.lite" | "voyage-code-3"
121        | "voyage-finance-2" | "voyage-law-2" => Some(1024),
122        _ => None,
123    }
124}
125
126#[derive(Debug, Deserialize)]
127pub struct EmbeddingResponse {
128    pub object: String,
129    pub data: Vec<EmbeddingData>,
130    pub model: String,
131    pub usage: Usage,
132}
133
134#[derive(Clone, Debug, Deserialize)]
135pub struct Usage {
136    pub prompt_tokens: usize,
137    pub total_tokens: usize,
138}
139
140#[derive(Debug, Deserialize)]
141pub struct ApiErrorResponse {
142    pub(crate) message: String,
143}
144
145impl From<ApiErrorResponse> for EmbeddingError {
146    fn from(err: ApiErrorResponse) -> Self {
147        EmbeddingError::ProviderError(err.message)
148    }
149}
150
151#[derive(Debug, Deserialize)]
152#[serde(untagged)]
153pub(crate) enum ApiResponse<T> {
154    Ok(T),
155    Err(ApiErrorResponse),
156}
157
158impl From<ApiResponse<EmbeddingResponse>> for Result<EmbeddingResponse, EmbeddingError> {
159    fn from(value: ApiResponse<EmbeddingResponse>) -> Self {
160        match value {
161            ApiResponse::Ok(response) => Ok(response),
162            ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
163        }
164    }
165}
166
167#[derive(Debug, Deserialize)]
168pub struct EmbeddingData {
169    pub object: String,
170    pub embedding: Vec<f64>,
171    pub index: usize,
172}
173
174#[derive(Clone)]
175pub struct EmbeddingModel<T> {
176    client: Client<T>,
177    pub model: String,
178    ndims: usize,
179}
180
181impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
182where
183    T: HttpClientExt + Clone + std::fmt::Debug + Default + 'static,
184{
185    const MAX_DOCUMENTS: usize = 1024;
186
187    type Client = Client<T>;
188
189    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
190        let model = model.into();
191        let dims = dims
192            .or(model_dimensions_from_identifier(&model))
193            .unwrap_or_default();
194
195        Self::new(client.clone(), model, dims)
196    }
197
198    fn ndims(&self) -> usize {
199        self.ndims
200    }
201
202    #[cfg_attr(feature = "worker", worker::send)]
203    async fn embed_texts(
204        &self,
205        documents: impl IntoIterator<Item = String>,
206    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
207        let documents = documents.into_iter().collect::<Vec<_>>();
208        let request = json!({
209            "model": self.model,
210            "input": documents,
211        });
212
213        let body = serde_json::to_vec(&request)?;
214
215        let req = self
216            .client
217            .post("/embeddings")?
218            .body(body)
219            .map_err(|x| EmbeddingError::HttpError(x.into()))?;
220
221        let response = self.client.send::<_, Bytes>(req).await?;
222        let status = response.status();
223        let response_body = response.into_body().into_future().await?.to_vec();
224
225        if status.is_success() {
226            match serde_json::from_slice::<ApiResponse<EmbeddingResponse>>(&response_body)? {
227                ApiResponse::Ok(response) => {
228                    tracing::info!(target: "rig",
229                        "VoyageAI embedding token usage: {}",
230                        response.usage.total_tokens
231                    );
232
233                    if response.data.len() != documents.len() {
234                        return Err(EmbeddingError::ResponseError(
235                            "Response data length does not match input length".into(),
236                        ));
237                    }
238
239                    Ok(response
240                        .data
241                        .into_iter()
242                        .zip(documents.into_iter())
243                        .map(|(embedding, document)| embeddings::Embedding {
244                            document,
245                            vec: embedding.embedding,
246                        })
247                        .collect())
248                }
249                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
250            }
251        } else {
252            Err(EmbeddingError::ProviderError(
253                String::from_utf8_lossy(&response_body).to_string(),
254            ))
255        }
256    }
257}