1use std::pin::Pin;
2use std::sync::Arc;
3
4use serde::{Deserialize, Serialize};
5
6use crate::error::{Error, Result};
7
8use super::backend::EmbeddingBackend;
9use super::config::MistralConfig;
10use super::convert::to_f32_blob;
11
12const DIMENSIONS: usize = 1024;
16
17struct Inner {
18 client: reqwest::Client,
19 api_key: String,
20 model: String,
21}
22
23pub struct MistralEmbedding(Arc<Inner>);
36
37impl Clone for MistralEmbedding {
38 fn clone(&self) -> Self {
39 Self(Arc::clone(&self.0))
40 }
41}
42
43impl MistralEmbedding {
44 pub fn new(client: reqwest::Client, config: &MistralConfig) -> Result<Self> {
50 config.validate()?;
51 Ok(Self(Arc::new(Inner {
52 client,
53 api_key: config.api_key.clone(),
54 model: config.model.clone(),
55 })))
56 }
57}
58
59impl EmbeddingBackend for MistralEmbedding {
60 fn embed(&self, input: &str) -> Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send + '_>> {
61 let input = input.to_owned();
62 Box::pin(async move {
63 const URL: &str = concat!("https://api.mistral.ai", "/v1/embeddings");
64 let body = Request {
65 input: &input,
66 model: &self.0.model,
67 };
68
69 let resp = self
70 .0
71 .client
72 .post(URL)
73 .bearer_auth(&self.0.api_key)
74 .json(&body)
75 .send()
76 .await
77 .map_err(|e| Error::internal("mistral embeddings request failed").chain(e))?;
78
79 if !resp.status().is_success() {
80 let status = resp.status();
81 let text = resp.text().await.unwrap_or_default();
82 return Err(Error::internal(format!(
83 "mistral embedding error: {status}: {text}"
84 )));
85 }
86
87 let parsed: Response = resp.json().await.map_err(|e| {
88 Error::internal("failed to parse mistral embedding response").chain(e)
89 })?;
90
91 let values = parsed
92 .data
93 .into_iter()
94 .next()
95 .ok_or_else(|| Error::internal("mistral returned empty embedding data"))?
96 .embedding;
97
98 Ok(to_f32_blob(&values))
99 })
100 }
101
102 fn dimensions(&self) -> usize {
103 DIMENSIONS
104 }
105
106 fn model_name(&self) -> &str {
107 &self.0.model
108 }
109}
110
111#[derive(Serialize)]
112struct Request<'a> {
113 input: &'a str,
114 model: &'a str,
115}
116
117#[derive(Deserialize)]
118struct Response {
119 data: Vec<EmbeddingData>,
120}
121
122#[derive(Deserialize)]
123struct EmbeddingData {
124 embedding: Vec<f32>,
125}