rig/providers/mistral/
embedding.rs

1use serde::Deserialize;
2use serde_json::json;
3
4use crate::{
5    embeddings::{self, EmbeddingError},
6    http_client::{self, HttpClientExt},
7};
8
9use super::client::{ApiResponse, Client, Usage};
10
11// ================================================================
12// Mistral Embedding API
13// ================================================================
14pub const MISTRAL_EMBED: &str = "mistral-embed";
15
16pub const MAX_DOCUMENTS: usize = 1024;
17
18#[derive(Clone)]
19pub struct EmbeddingModel<T = reqwest::Client> {
20    client: Client<T>,
21    pub model: String,
22    ndims: usize,
23}
24
25impl<T> EmbeddingModel<T> {
26    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
27        Self {
28            client,
29            model: model.into(),
30            ndims,
31        }
32    }
33
34    pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
35        Self {
36            client,
37            model: model.to_string(),
38            ndims,
39        }
40    }
41}
42
43impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
44where
45    T: HttpClientExt + Clone,
46{
47    type Client = Client<T>;
48
49    const MAX_DOCUMENTS: usize = MAX_DOCUMENTS;
50
51    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
52        Self::new(client.clone(), model, dims.unwrap_or_default())
53    }
54
55    fn ndims(&self) -> usize {
56        self.ndims
57    }
58
59    #[cfg_attr(feature = "worker", worker::send)]
60    async fn embed_texts(
61        &self,
62        documents: impl IntoIterator<Item = String>,
63    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
64        let documents = documents.into_iter().collect::<Vec<_>>();
65
66        let body = serde_json::to_vec(&json!({
67            "model": self.model,
68            "input": documents
69        }))?;
70
71        let req = self
72            .client
73            .post("v1/embeddings")?
74            .header("Content-Type", "application/json")
75            .body(body)
76            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
77
78        let response = self.client.send(req).await?;
79
80        if response.status().is_success() {
81            let body: Vec<u8> = response.into_body().await?;
82            let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
83
84            match body {
85                ApiResponse::Ok(response) => {
86                    tracing::debug!(target: "rig",
87                        "Mistral embedding token usage: {}",
88                        response.usage
89                    );
90
91                    if response.data.len() != documents.len() {
92                        return Err(EmbeddingError::ResponseError(
93                            "Response data length does not match input length".into(),
94                        ));
95                    }
96
97                    Ok(response
98                        .data
99                        .into_iter()
100                        .zip(documents.into_iter())
101                        .map(|(embedding, document)| embeddings::Embedding {
102                            document,
103                            vec: embedding.embedding,
104                        })
105                        .collect())
106                }
107                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
108            }
109        } else {
110            let text = http_client::text(response).await?;
111            Err(EmbeddingError::ProviderError(text))
112        }
113    }
114}
115
116#[derive(Debug, Deserialize)]
117pub struct EmbeddingResponse {
118    pub id: String,
119    pub object: String,
120    pub model: String,
121    pub usage: Usage,
122    pub data: Vec<EmbeddingData>,
123}
124
125#[derive(Debug, Deserialize)]
126pub struct EmbeddingData {
127    pub object: String,
128    pub embedding: Vec<f64>,
129    pub index: usize,
130}