rig/providers/mistral/
embedding.rs

1use serde::Deserialize;
2use serde_json::json;
3
4use crate::embeddings::{self, EmbeddingError};
5
6use super::client::{ApiResponse, Client, Usage};
7
8// ================================================================
9// Mistral Embedding API
10// ================================================================
11/// `mistral-embed` embedding model
12pub const MISTRAL_EMBED: &str = "mistral-embed";
13pub const MAX_DOCUMENTS: usize = 1024;
14
15#[derive(Clone)]
16pub struct EmbeddingModel {
17    client: Client,
18    pub model: String,
19    ndims: usize,
20}
21
22impl EmbeddingModel {
23    pub fn new(client: Client, model: &str, ndims: usize) -> Self {
24        Self {
25            client,
26            model: model.to_string(),
27            ndims,
28        }
29    }
30}
31
32impl embeddings::EmbeddingModel for EmbeddingModel {
33    const MAX_DOCUMENTS: usize = MAX_DOCUMENTS;
34    fn ndims(&self) -> usize {
35        self.ndims
36    }
37
38    #[cfg_attr(feature = "worker", worker::send)]
39    async fn embed_texts(
40        &self,
41        documents: impl IntoIterator<Item = String>,
42    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
43        let documents = documents.into_iter().collect::<Vec<_>>();
44
45        let response = self
46            .client
47            .post("v1/embeddings")
48            .json(&json!({
49                "model": self.model,
50                "input": documents,
51            }))
52            .send()
53            .await?;
54
55        if response.status().is_success() {
56            match response.json::<ApiResponse<EmbeddingResponse>>().await? {
57                ApiResponse::Ok(response) => {
58                    tracing::debug!(target: "rig",
59                        "Mistral embedding token usage: {}",
60                        response.usage
61                    );
62
63                    if response.data.len() != documents.len() {
64                        return Err(EmbeddingError::ResponseError(
65                            "Response data length does not match input length".into(),
66                        ));
67                    }
68
69                    Ok(response
70                        .data
71                        .into_iter()
72                        .zip(documents.into_iter())
73                        .map(|(embedding, document)| embeddings::Embedding {
74                            document,
75                            vec: embedding.embedding,
76                        })
77                        .collect())
78                }
79                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
80            }
81        } else {
82            Err(EmbeddingError::ProviderError(response.text().await?))
83        }
84    }
85}
86
87#[derive(Debug, Deserialize)]
88pub struct EmbeddingResponse {
89    pub id: String,
90    pub object: String,
91    pub model: String,
92    pub usage: Usage,
93    pub data: Vec<EmbeddingData>,
94}
95
96#[derive(Debug, Deserialize)]
97pub struct EmbeddingData {
98    pub object: String,
99    pub embedding: Vec<f64>,
100    pub index: usize,
101}