llmsdk_mistral/embedding/
model.rs1use std::sync::Arc;
8
9use async_trait::async_trait;
10use llmsdk_provider::ProviderError;
11use llmsdk_provider::embedding_model::{
12 EmbedOptions, EmbedResult, Embedding, EmbeddingModel, EmbeddingUsage,
13};
14use llmsdk_provider::shared::{RequestInfo, ResponseInfo};
15use llmsdk_provider_utils::http::{JsonRequest, post_json};
16
17use crate::PROVIDER_ID;
18use crate::config::Inner;
19
20use super::options::parse as parse_options;
21use super::wire::{EmbeddingRequest, EmbeddingResponse};
22
23const MAX_PER_CALL: u32 = 32;
25
26#[derive(Debug, Clone)]
30pub struct MistralEmbeddingModel {
31 inner: Arc<Inner>,
32 model_id: String,
33}
34
35impl MistralEmbeddingModel {
36 pub(crate) fn new(inner: Arc<Inner>, model_id: String) -> Self {
37 Self { inner, model_id }
38 }
39
40 fn endpoint(&self) -> String {
41 format!("{}/embeddings", self.inner.base_url)
42 }
43}
44
45#[async_trait]
46impl EmbeddingModel for MistralEmbeddingModel {
47 fn provider(&self) -> &str {
48 PROVIDER_ID
49 }
50
51 fn model_id(&self) -> &str {
52 &self.model_id
53 }
54
55 async fn max_embeddings_per_call(&self) -> Option<u32> {
56 Some(MAX_PER_CALL)
57 }
58
59 async fn supports_parallel_calls(&self) -> bool {
60 false
61 }
62
63 async fn do_embed(&self, options: EmbedOptions) -> Result<EmbedResult, ProviderError> {
64 let total = options.values.len();
65 if u32::try_from(total).is_ok_and(|n| n > MAX_PER_CALL) {
66 return Err(ProviderError::too_many_embedding_values(
67 MAX_PER_CALL as usize,
68 total,
69 ));
70 }
71
72 let mistral_opts = parse_options(options.provider_options.as_ref());
73
74 let request = EmbeddingRequest {
75 model: self.model_id.clone(),
76 input: options.values,
77 encoding_format: "float",
78 output_dimension: mistral_opts.output_dimension,
79 output_dtype: mistral_opts.output_dtype,
80 };
81
82 let request_body_value = serde_json::to_value(&request).ok();
83
84 let mut request_headers = self.inner.headers.clone();
85 if let Some(headers) = options.headers {
86 for (name, value) in headers {
87 request_headers.insert(name, value);
88 }
89 }
90
91 let mut http_request = JsonRequest::new(self.endpoint(), request);
92 http_request.headers = request_headers;
93
94 let response = post_json::<_, EmbeddingResponse>(&self.inner.http, http_request).await?;
95
96 let embeddings: Vec<Embedding> = response
97 .value
98 .data
99 .into_iter()
100 .map(|d| d.embedding)
101 .collect();
102 let usage = response.value.usage.map(|u| EmbeddingUsage {
103 tokens: Some(u.prompt_tokens),
104 });
105
106 Ok(EmbedResult {
107 embeddings,
108 usage,
109 provider_metadata: None,
110 request: Some(RequestInfo {
111 body: request_body_value,
112 }),
113 response: Some(ResponseInfo {
114 headers: Some(
115 response
116 .headers
117 .into_iter()
118 .map(|(k, v)| (k, Some(v)))
119 .collect(),
120 ),
121 ..ResponseInfo::default()
122 }),
123 })
124 }
125}