redis_vl/vectorizers/
mistral.rs1use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8
9use super::{AsyncVectorizer, Vectorizer};
10use crate::error::Result;
11
12#[derive(Debug, Clone)]
14pub struct MistralConfig {
15 pub api_key: String,
17 pub model: String,
19}
20
21impl MistralConfig {
22 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
24 Self {
25 api_key: api_key.into(),
26 model: model.into(),
27 }
28 }
29
30 pub fn from_env(model: impl Into<String>) -> Result<Self> {
32 let api_key = std::env::var("MISTRAL_API_KEY")
33 .map_err(|_| crate::error::Error::InvalidInput("MISTRAL_API_KEY not set".into()))?;
34 Ok(Self::new(api_key, model))
35 }
36}
37
38const MISTRAL_EMBED_URL: &str = "https://api.mistral.ai/v1/embeddings";
39
40#[derive(Debug, Serialize)]
42struct MistralEmbedRequest<'a> {
43 model: &'a str,
44 #[serde(rename = "input")]
45 inputs: Vec<&'a str>,
46}
47
48#[derive(Debug, Deserialize)]
49struct MistralEmbedResponse {
50 data: Vec<MistralEmbedDatum>,
51}
52
53#[derive(Debug, Deserialize)]
54struct MistralEmbedDatum {
55 embedding: Vec<f32>,
56}
57
58#[derive(Debug, Clone)]
64pub struct MistralAITextVectorizer {
65 config: MistralConfig,
66 client: reqwest::Client,
67 blocking_client: reqwest::blocking::Client,
68}
69
70impl MistralAITextVectorizer {
71 pub fn new(config: MistralConfig) -> Self {
73 Self {
74 config,
75 client: reqwest::Client::new(),
76 blocking_client: reqwest::blocking::Client::new(),
77 }
78 }
79
80 async fn embed_many_inner(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
81 let resp: MistralEmbedResponse = self
82 .client
83 .post(MISTRAL_EMBED_URL)
84 .bearer_auth(&self.config.api_key)
85 .json(&MistralEmbedRequest {
86 model: &self.config.model,
87 inputs: texts.to_vec(),
88 })
89 .send()
90 .await?
91 .error_for_status()?
92 .json()
93 .await?;
94 Ok(resp.data.into_iter().map(|d| d.embedding).collect())
95 }
96}
97
98impl Vectorizer for MistralAITextVectorizer {
99 fn embed(&self, text: &str) -> Result<Vec<f32>> {
100 let resp: MistralEmbedResponse = self
101 .blocking_client
102 .post(MISTRAL_EMBED_URL)
103 .bearer_auth(&self.config.api_key)
104 .json(&MistralEmbedRequest {
105 model: &self.config.model,
106 inputs: vec![text],
107 })
108 .send()?
109 .error_for_status()?
110 .json()?;
111 Ok(resp
112 .data
113 .into_iter()
114 .next()
115 .map_or_else(Vec::new, |d| d.embedding))
116 }
117}
118
119#[async_trait]
120impl AsyncVectorizer for MistralAITextVectorizer {
121 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
122 let mut v = self.embed_many_inner(&[text]).await?;
123 Ok(v.pop().unwrap_or_default())
124 }
125
126 async fn embed_many(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
127 self.embed_many_inner(texts).await
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134
135 #[test]
136 fn mistral_config_stores_fields() {
137 let cfg = MistralConfig::new("key", "mistral-embed");
138 assert_eq!(cfg.api_key, "key");
139 assert_eq!(cfg.model, "mistral-embed");
140 }
141
142 #[test]
143 fn mistral_request_serializes_input_field() {
144 let body = MistralEmbedRequest {
145 model: "mistral-embed",
146 inputs: vec!["hello"],
147 };
148 let json = serde_json::to_value(&body).unwrap();
149 assert_eq!(json["model"], "mistral-embed");
151 assert_eq!(json["input"], serde_json::json!(["hello"]));
152 }
153
154 #[test]
155 fn mistral_vectorizer_is_send_sync() {
156 fn assert_send_sync<T: Send + Sync>() {}
157 assert_send_sync::<MistralAITextVectorizer>();
158 }
159}