Skip to main content

redis_vl/vectorizers/
mistral.rs

1//! Mistral AI embedding adapter.
2//!
3//! Enabled by the `mistral` feature flag. Mistral's embedding API is similar to
4//! OpenAI but uses `inputs` (plural) instead of `input` in the request body.
5
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8
9use super::{AsyncVectorizer, Vectorizer};
10use crate::error::Result;
11
12/// Configuration for the Mistral AI embedding provider.
13#[derive(Debug, Clone)]
14pub struct MistralConfig {
15    /// API key for Mistral.
16    pub api_key: String,
17    /// Embedding model name (default: `mistral-embed`).
18    pub model: String,
19}
20
21impl MistralConfig {
22    /// Creates a new Mistral config.
23    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
24        Self {
25            api_key: api_key.into(),
26            model: model.into(),
27        }
28    }
29
30    /// Constructs from `MISTRAL_API_KEY` environment variable.
31    pub fn from_env(model: impl Into<String>) -> Result<Self> {
32        let api_key = std::env::var("MISTRAL_API_KEY")
33            .map_err(|_| crate::error::Error::InvalidInput("MISTRAL_API_KEY not set".into()))?;
34        Ok(Self::new(api_key, model))
35    }
36}
37
38const MISTRAL_EMBED_URL: &str = "https://api.mistral.ai/v1/embeddings";
39
40/// Mistral uses `inputs` instead of `input`.
41#[derive(Debug, Serialize)]
42struct MistralEmbedRequest<'a> {
43    model: &'a str,
44    #[serde(rename = "input")]
45    inputs: Vec<&'a str>,
46}
47
48#[derive(Debug, Deserialize)]
49struct MistralEmbedResponse {
50    data: Vec<MistralEmbedDatum>,
51}
52
53#[derive(Debug, Deserialize)]
54struct MistralEmbedDatum {
55    embedding: Vec<f32>,
56}
57
58/// Mistral AI embedding adapter.
59///
60/// Uses the Mistral embeddings API. The request format is similar to OpenAI but
61/// the Python client sends the field as `inputs`; the actual Mistral REST API
62/// accepts `input` (same as OpenAI), so we use the standard field name.
63#[derive(Debug, Clone)]
64pub struct MistralAITextVectorizer {
65    config: MistralConfig,
66    client: reqwest::Client,
67    blocking_client: reqwest::blocking::Client,
68}
69
70impl MistralAITextVectorizer {
71    /// Creates a new Mistral AI adapter.
72    pub fn new(config: MistralConfig) -> Self {
73        Self {
74            config,
75            client: reqwest::Client::new(),
76            blocking_client: reqwest::blocking::Client::new(),
77        }
78    }
79
80    async fn embed_many_inner(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
81        let resp: MistralEmbedResponse = self
82            .client
83            .post(MISTRAL_EMBED_URL)
84            .bearer_auth(&self.config.api_key)
85            .json(&MistralEmbedRequest {
86                model: &self.config.model,
87                inputs: texts.to_vec(),
88            })
89            .send()
90            .await?
91            .error_for_status()?
92            .json()
93            .await?;
94        Ok(resp.data.into_iter().map(|d| d.embedding).collect())
95    }
96}
97
98impl Vectorizer for MistralAITextVectorizer {
99    fn embed(&self, text: &str) -> Result<Vec<f32>> {
100        let resp: MistralEmbedResponse = self
101            .blocking_client
102            .post(MISTRAL_EMBED_URL)
103            .bearer_auth(&self.config.api_key)
104            .json(&MistralEmbedRequest {
105                model: &self.config.model,
106                inputs: vec![text],
107            })
108            .send()?
109            .error_for_status()?
110            .json()?;
111        Ok(resp
112            .data
113            .into_iter()
114            .next()
115            .map_or_else(Vec::new, |d| d.embedding))
116    }
117}
118
119#[async_trait]
120impl AsyncVectorizer for MistralAITextVectorizer {
121    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
122        let mut v = self.embed_many_inner(&[text]).await?;
123        Ok(v.pop().unwrap_or_default())
124    }
125
126    async fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
127        self.embed_many_inner(texts).await
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134
135    #[test]
136    fn mistral_config_stores_fields() {
137        let cfg = MistralConfig::new("key", "mistral-embed");
138        assert_eq!(cfg.api_key, "key");
139        assert_eq!(cfg.model, "mistral-embed");
140    }
141
142    #[test]
143    fn mistral_request_serializes_input_field() {
144        let body = MistralEmbedRequest {
145            model: "mistral-embed",
146            inputs: vec!["hello"],
147        };
148        let json = serde_json::to_value(&body).unwrap();
149        // Mistral REST API uses "input" field name
150        assert_eq!(json["model"], "mistral-embed");
151        assert_eq!(json["input"], serde_json::json!(["hello"]));
152    }
153
154    #[test]
155    fn mistral_vectorizer_is_send_sync() {
156        fn assert_send_sync<T: Send + Sync>() {}
157        assert_send_sync::<MistralAITextVectorizer>();
158    }
159}