rig/providers/mistral/
embedding.rs

1use serde::Deserialize;
2use serde_json::json;
3
4use crate::{
5    embeddings::{self, EmbeddingError},
6    http_client::{self, HttpClientExt},
7};
8
9use super::client::{ApiResponse, Client, Usage};
10
11// ================================================================
12// Mistral Embedding API
13// ================================================================
14pub const MISTRAL_EMBED: &str = "mistral-embed";
15
16pub const MAX_DOCUMENTS: usize = 1024;
17
18#[derive(Clone)]
19pub struct EmbeddingModel<T = reqwest::Client> {
20    client: Client<T>,
21    pub model: String,
22    ndims: usize,
23}
24
25impl<T> EmbeddingModel<T> {
26    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
27        Self {
28            client,
29            model: model.into(),
30            ndims,
31        }
32    }
33
34    pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
35        Self {
36            client,
37            model: model.to_string(),
38            ndims,
39        }
40    }
41}
42
43impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
44where
45    T: HttpClientExt + Clone + 'static,
46{
47    type Client = Client<T>;
48
49    const MAX_DOCUMENTS: usize = MAX_DOCUMENTS;
50
51    fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
52        Self::new(client.clone(), model, dims.unwrap_or_default())
53    }
54
55    fn ndims(&self) -> usize {
56        self.ndims
57    }
58
59    async fn embed_texts(
60        &self,
61        documents: impl IntoIterator<Item = String>,
62    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
63        let documents = documents.into_iter().collect::<Vec<_>>();
64
65        let body = serde_json::to_vec(&json!({
66            "model": self.model,
67            "input": documents
68        }))?;
69
70        let req = self
71            .client
72            .post("v1/embeddings")?
73            .header("Content-Type", "application/json")
74            .body(body)
75            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
76
77        let response = self.client.send(req).await?;
78
79        if response.status().is_success() {
80            let body: Vec<u8> = response.into_body().await?;
81            let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
82
83            match body {
84                ApiResponse::Ok(response) => {
85                    tracing::debug!(target: "rig",
86                        "Mistral embedding token usage: {}",
87                        response.usage
88                    );
89
90                    if response.data.len() != documents.len() {
91                        return Err(EmbeddingError::ResponseError(
92                            "Response data length does not match input length".into(),
93                        ));
94                    }
95
96                    Ok(response
97                        .data
98                        .into_iter()
99                        .zip(documents.into_iter())
100                        .map(|(embedding, document)| embeddings::Embedding {
101                            document,
102                            vec: embedding.embedding,
103                        })
104                        .collect())
105                }
106                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
107            }
108        } else {
109            let text = http_client::text(response).await?;
110            Err(EmbeddingError::ProviderError(text))
111        }
112    }
113}
114
115#[derive(Debug, Deserialize)]
116pub struct EmbeddingResponse {
117    pub id: String,
118    pub object: String,
119    pub model: String,
120    pub usage: Usage,
121    pub data: Vec<EmbeddingData>,
122}
123
124#[derive(Debug, Deserialize)]
125pub struct EmbeddingData {
126    pub object: String,
127    pub embedding: Vec<f64>,
128    pub index: usize,
129}