Skip to main content

rig/providers/mistral/
embedding.rs

1use serde::Deserialize;
2use serde_json::json;
3
4use crate::{
5    embeddings::{self, EmbeddingError},
6    http_client::{self, HttpClientExt},
7};
8
9use super::client::{ApiResponse, Client, Usage};
10
11// ================================================================
12// Mistral Embedding API
13// ================================================================
14pub const MISTRAL_EMBED: &str = "mistral-embed";
15
16pub const MAX_DOCUMENTS: usize = 1024;
17
18#[derive(Clone)]
19pub struct EmbeddingModel<T = reqwest::Client> {
20    client: Client<T>,
21    pub model: String,
22    ndims: usize,
23}
24
25impl<T> EmbeddingModel<T> {
26    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
27        Self {
28            client,
29            model: model.into(),
30            ndims,
31        }
32    }
33
34    pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
35        Self {
36            client,
37            model: model.to_string(),
38            ndims,
39        }
40    }
41}
42
43impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
44where
45    T: HttpClientExt + Clone + 'static,
46{
47    type Client = Client<T>;
48
49    const MAX_DOCUMENTS: usize = MAX_DOCUMENTS;
50
51    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
52        Self::new(client.clone(), model, dims.unwrap_or_default())
53    }
54
55    fn ndims(&self) -> usize {
56        self.ndims
57    }
58
59    async fn embed_texts(
60        &self,
61        documents: impl IntoIterator<Item = String>,
62    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
63        let documents = documents.into_iter().collect::<Vec<_>>();
64
65        let body = serde_json::to_vec(&json!({
66            "model": self.model,
67            "input": documents
68        }))?;
69
70        let req = self
71            .client
72            .post("v1/embeddings")?
73            .header("Content-Type", "application/json")
74            .body(body)
75            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
76
77        let response = self.client.send(req).await?;
78
79        if response.status().is_success() {
80            let body: Vec<u8> = response.into_body().await?;
81            let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
82
83            match body {
84                ApiResponse::Ok(response) => {
85                    tracing::debug!(target: "rig",
86                        "Mistral embedding token usage: {}",
87                        response.usage
88                    );
89
90                    if response.data.len() != documents.len() {
91                        return Err(EmbeddingError::ResponseError(
92                            "Response data length does not match input length".into(),
93                        ));
94                    }
95
96                    Ok(response
97                        .data
98                        .into_iter()
99                        .zip(documents.into_iter())
100                        .map(|(embedding, document)| embeddings::Embedding {
101                            document,
102                            vec: embedding
103                                .embedding
104                                .into_iter()
105                                .filter_map(|n| n.as_f64())
106                                .collect(),
107                        })
108                        .collect())
109                }
110                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
111            }
112        } else {
113            let text = http_client::text(response).await?;
114            Err(EmbeddingError::ProviderError(text))
115        }
116    }
117}
118
119#[derive(Debug, Deserialize)]
120pub struct EmbeddingResponse {
121    pub id: String,
122    pub object: String,
123    pub model: String,
124    pub usage: Usage,
125    pub data: Vec<EmbeddingData>,
126}
127
128#[derive(Debug, Deserialize)]
129pub struct EmbeddingData {
130    pub object: String,
131    pub embedding: Vec<serde_json::Number>,
132    pub index: usize,
133}