Skip to main content

modo/embed/
mistral.rs

1use std::pin::Pin;
2use std::sync::Arc;
3
4use serde::{Deserialize, Serialize};
5
6use crate::error::{Error, Result};
7
8use super::backend::EmbeddingBackend;
9use super::config::MistralConfig;
10use super::convert::to_f32_blob;
11
12/// Fixed output dimensions for `mistral-embed`. The Mistral API does not
13/// accept a `dimensions` parameter — all models return 1024-dimensional
14/// vectors.
15const DIMENSIONS: usize = 1024;
16
17struct Inner {
18    client: reqwest::Client,
19    api_key: String,
20    model: String,
21}
22
23/// Mistral embedding provider.
24///
25/// Calls `POST https://api.mistral.ai/v1/embeddings` and returns a
26/// little-endian f32 blob.
27///
28/// # Example
29///
30/// ```rust,ignore
31/// let client = reqwest::Client::new();
32/// let provider = MistralEmbedding::new(client, &config)?;
33/// let embedder = EmbeddingProvider::new(provider);
34/// ```
35pub struct MistralEmbedding(Arc<Inner>);
36
37impl Clone for MistralEmbedding {
38    fn clone(&self) -> Self {
39        Self(Arc::clone(&self.0))
40    }
41}
42
43impl MistralEmbedding {
44    /// Create from config. Validates config at construction.
45    ///
46    /// # Errors
47    ///
48    /// Returns `Error::bad_request` if config validation fails.
49    pub fn new(client: reqwest::Client, config: &MistralConfig) -> Result<Self> {
50        config.validate()?;
51        Ok(Self(Arc::new(Inner {
52            client,
53            api_key: config.api_key.clone(),
54            model: config.model.clone(),
55        })))
56    }
57}
58
59impl EmbeddingBackend for MistralEmbedding {
60    fn embed(&self, input: &str) -> Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send + '_>> {
61        let input = input.to_owned();
62        Box::pin(async move {
63            const URL: &str = concat!("https://api.mistral.ai", "/v1/embeddings");
64            let body = Request {
65                input: &input,
66                model: &self.0.model,
67            };
68
69            let resp = self
70                .0
71                .client
72                .post(URL)
73                .bearer_auth(&self.0.api_key)
74                .json(&body)
75                .send()
76                .await
77                .map_err(|e| Error::internal("mistral embeddings request failed").chain(e))?;
78
79            if !resp.status().is_success() {
80                let status = resp.status();
81                let text = resp.text().await.unwrap_or_default();
82                return Err(Error::internal(format!(
83                    "mistral embedding error: {status}: {text}"
84                )));
85            }
86
87            let parsed: Response = resp.json().await.map_err(|e| {
88                Error::internal("failed to parse mistral embedding response").chain(e)
89            })?;
90
91            let values = parsed
92                .data
93                .into_iter()
94                .next()
95                .ok_or_else(|| Error::internal("mistral returned empty embedding data"))?
96                .embedding;
97
98            Ok(to_f32_blob(&values))
99        })
100    }
101
102    fn dimensions(&self) -> usize {
103        DIMENSIONS
104    }
105
106    fn model_name(&self) -> &str {
107        &self.0.model
108    }
109}
110
111#[derive(Serialize)]
112struct Request<'a> {
113    input: &'a str,
114    model: &'a str,
115}
116
117#[derive(Deserialize)]
118struct Response {
119    data: Vec<EmbeddingData>,
120}
121
122#[derive(Deserialize)]
123struct EmbeddingData {
124    embedding: Vec<f32>,
125}