rig/providers/mistral/
embedding.rs1use serde::Deserialize;
2use serde_json::json;
3
4use crate::embeddings::{self, EmbeddingError};
5
6use super::client::{ApiResponse, Client, Usage};
7
8pub const MISTRAL_EMBED: &str = "mistral-embed";
13pub const MAX_DOCUMENTS: usize = 1024;
14
15#[derive(Clone)]
16pub struct EmbeddingModel {
17 client: Client,
18 pub model: String,
19 ndims: usize,
20}
21
22impl EmbeddingModel {
23 pub fn new(client: Client, model: &str, ndims: usize) -> Self {
24 Self {
25 client,
26 model: model.to_string(),
27 ndims,
28 }
29 }
30}
31
32impl embeddings::EmbeddingModel for EmbeddingModel {
33 const MAX_DOCUMENTS: usize = MAX_DOCUMENTS;
34 fn ndims(&self) -> usize {
35 self.ndims
36 }
37
38 #[cfg_attr(feature = "worker", worker::send)]
39 async fn embed_texts(
40 &self,
41 documents: impl IntoIterator<Item = String>,
42 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
43 let documents = documents.into_iter().collect::<Vec<_>>();
44
45 let response = self
46 .client
47 .post("v1/embeddings")
48 .json(&json!({
49 "model": self.model,
50 "input": documents,
51 }))
52 .send()
53 .await?;
54
55 if response.status().is_success() {
56 match response.json::<ApiResponse<EmbeddingResponse>>().await? {
57 ApiResponse::Ok(response) => {
58 tracing::debug!(target: "rig",
59 "Mistral embedding token usage: {}",
60 response.usage
61 );
62
63 if response.data.len() != documents.len() {
64 return Err(EmbeddingError::ResponseError(
65 "Response data length does not match input length".into(),
66 ));
67 }
68
69 Ok(response
70 .data
71 .into_iter()
72 .zip(documents.into_iter())
73 .map(|(embedding, document)| embeddings::Embedding {
74 document,
75 vec: embedding.embedding,
76 })
77 .collect())
78 }
79 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
80 }
81 } else {
82 Err(EmbeddingError::ProviderError(response.text().await?))
83 }
84 }
85}
86
87#[derive(Debug, Deserialize)]
88pub struct EmbeddingResponse {
89 pub id: String,
90 pub object: String,
91 pub model: String,
92 pub usage: Usage,
93 pub data: Vec<EmbeddingData>,
94}
95
96#[derive(Debug, Deserialize)]
97pub struct EmbeddingData {
98 pub object: String,
99 pub embedding: Vec<f64>,
100 pub index: usize,
101}