Skip to main content

llmsdk_mistral/embedding/
model.rs

1//! [`EmbeddingModel`] implementation for Mistral.
2//!
3//! Mirrors `mistral-embedding-model.ts`. Entry: [`MistralEmbeddingModel::new`]
4//! via [`crate::Mistral::embedding`].
5// Rust guideline compliant 2026-05-25
6
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use llmsdk_provider::ProviderError;
11use llmsdk_provider::embedding_model::{
12    EmbedOptions, EmbedResult, Embedding, EmbeddingModel, EmbeddingUsage,
13};
14use llmsdk_provider::shared::{RequestInfo, ResponseInfo};
15use llmsdk_provider_utils::http::{JsonRequest, post_json};
16
17use crate::PROVIDER_ID;
18use crate::config::Inner;
19
20use super::options::parse as parse_options;
21use super::wire::{EmbeddingRequest, EmbeddingResponse};
22
23/// `maxEmbeddingsPerCall` reported by Mistral (matches ai-sdk).
24const MAX_PER_CALL: u32 = 32;
25
26/// Mistral Embeddings model handle.
27///
28/// Cheap to clone — shares the provider's HTTP client and auth state.
29#[derive(Debug, Clone)]
30pub struct MistralEmbeddingModel {
31    inner: Arc<Inner>,
32    model_id: String,
33}
34
35impl MistralEmbeddingModel {
36    pub(crate) fn new(inner: Arc<Inner>, model_id: String) -> Self {
37        Self { inner, model_id }
38    }
39
40    fn endpoint(&self) -> String {
41        format!("{}/embeddings", self.inner.base_url)
42    }
43}
44
45#[async_trait]
46impl EmbeddingModel for MistralEmbeddingModel {
47    fn provider(&self) -> &str {
48        PROVIDER_ID
49    }
50
51    fn model_id(&self) -> &str {
52        &self.model_id
53    }
54
55    async fn max_embeddings_per_call(&self) -> Option<u32> {
56        Some(MAX_PER_CALL)
57    }
58
59    async fn supports_parallel_calls(&self) -> bool {
60        false
61    }
62
63    async fn do_embed(&self, options: EmbedOptions) -> Result<EmbedResult, ProviderError> {
64        let total = options.values.len();
65        if u32::try_from(total).is_ok_and(|n| n > MAX_PER_CALL) {
66            return Err(ProviderError::too_many_embedding_values(
67                MAX_PER_CALL as usize,
68                total,
69            ));
70        }
71
72        let mistral_opts = parse_options(options.provider_options.as_ref());
73
74        let request = EmbeddingRequest {
75            model: self.model_id.clone(),
76            input: options.values,
77            encoding_format: "float",
78            output_dimension: mistral_opts.output_dimension,
79            output_dtype: mistral_opts.output_dtype,
80        };
81
82        let request_body_value = serde_json::to_value(&request).ok();
83
84        let mut request_headers = self.inner.headers.clone();
85        if let Some(headers) = options.headers {
86            for (name, value) in headers {
87                request_headers.insert(name, value);
88            }
89        }
90
91        let mut http_request = JsonRequest::new(self.endpoint(), request);
92        http_request.headers = request_headers;
93
94        let response = post_json::<_, EmbeddingResponse>(&self.inner.http, http_request).await?;
95
96        let embeddings: Vec<Embedding> = response
97            .value
98            .data
99            .into_iter()
100            .map(|d| d.embedding)
101            .collect();
102        let usage = response.value.usage.map(|u| EmbeddingUsage {
103            tokens: Some(u.prompt_tokens),
104        });
105
106        Ok(EmbedResult {
107            embeddings,
108            usage,
109            provider_metadata: None,
110            request: Some(RequestInfo {
111                body: request_body_value,
112            }),
113            response: Some(ResponseInfo {
114                headers: Some(
115                    response
116                        .headers
117                        .into_iter()
118                        .map(|(k, v)| (k, Some(v)))
119                        .collect(),
120                ),
121                ..ResponseInfo::default()
122            }),
123        })
124    }
125}