rig/providers/mistral/
embedding.rs

1use serde::Deserialize;
2use serde_json::json;
3
4use crate::{
5    embeddings::{self, EmbeddingError},
6    http_client::{self, HttpClientExt},
7};
8
9use super::client::{ApiResponse, Client, Usage};
10
11// ================================================================
12// Mistral Embedding API
13// ================================================================
14/// `mistral-embed` embedding model
15pub const MISTRAL_EMBED: &str = "mistral-embed";
16pub const MAX_DOCUMENTS: usize = 1024;
17
18#[derive(Clone)]
19pub struct EmbeddingModel<T = reqwest::Client> {
20    client: Client<T>,
21    pub model: String,
22    ndims: usize,
23}
24
25impl<T> EmbeddingModel<T> {
26    pub fn new(client: Client<T>, model: &str, ndims: usize) -> Self {
27        Self {
28            client,
29            model: model.to_string(),
30            ndims,
31        }
32    }
33}
34
35impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
36where
37    T: HttpClientExt + Clone,
38{
39    const MAX_DOCUMENTS: usize = MAX_DOCUMENTS;
40    fn ndims(&self) -> usize {
41        self.ndims
42    }
43
44    #[cfg_attr(feature = "worker", worker::send)]
45    async fn embed_texts(
46        &self,
47        documents: impl IntoIterator<Item = String>,
48    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
49        let documents = documents.into_iter().collect::<Vec<_>>();
50
51        let body = serde_json::to_vec(&json!({
52            "model": self.model,
53            "input": documents
54        }))?;
55
56        let req = self
57            .client
58            .post("v1/embeddings")?
59            .header("Content-Type", "application/json")
60            .body(body)
61            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
62
63        let response = self.client.send(req).await?;
64
65        if response.status().is_success() {
66            let body: Vec<u8> = response.into_body().await?;
67            let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
68
69            match body {
70                ApiResponse::Ok(response) => {
71                    tracing::debug!(target: "rig",
72                        "Mistral embedding token usage: {}",
73                        response.usage
74                    );
75
76                    if response.data.len() != documents.len() {
77                        return Err(EmbeddingError::ResponseError(
78                            "Response data length does not match input length".into(),
79                        ));
80                    }
81
82                    Ok(response
83                        .data
84                        .into_iter()
85                        .zip(documents.into_iter())
86                        .map(|(embedding, document)| embeddings::Embedding {
87                            document,
88                            vec: embedding.embedding,
89                        })
90                        .collect())
91                }
92                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
93            }
94        } else {
95            let text = http_client::text(response).await?;
96            Err(EmbeddingError::ProviderError(text))
97        }
98    }
99}
100
101#[derive(Debug, Deserialize)]
102pub struct EmbeddingResponse {
103    pub id: String,
104    pub object: String,
105    pub model: String,
106    pub usage: Usage,
107    pub data: Vec<EmbeddingData>,
108}
109
110#[derive(Debug, Deserialize)]
111pub struct EmbeddingData {
112    pub object: String,
113    pub embedding: Vec<f64>,
114    pub index: usize,
115}