rig/providers/mistral/
embedding.rs1use serde::Deserialize;
2use serde_json::json;
3
4use crate::{
5 embeddings::{self, EmbeddingError},
6 http_client::{self, HttpClientExt},
7};
8
9use super::client::{ApiResponse, Client, Usage};
10
11pub const MISTRAL_EMBED: &str = "mistral-embed";
16pub const MAX_DOCUMENTS: usize = 1024;
17
18#[derive(Clone)]
19pub struct EmbeddingModel<T = reqwest::Client> {
20 client: Client<T>,
21 pub model: String,
22 ndims: usize,
23}
24
25impl<T> EmbeddingModel<T> {
26 pub fn new(client: Client<T>, model: &str, ndims: usize) -> Self {
27 Self {
28 client,
29 model: model.to_string(),
30 ndims,
31 }
32 }
33}
34
35impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
36where
37 T: HttpClientExt + Clone,
38{
39 const MAX_DOCUMENTS: usize = MAX_DOCUMENTS;
40 fn ndims(&self) -> usize {
41 self.ndims
42 }
43
44 #[cfg_attr(feature = "worker", worker::send)]
45 async fn embed_texts(
46 &self,
47 documents: impl IntoIterator<Item = String>,
48 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
49 let documents = documents.into_iter().collect::<Vec<_>>();
50
51 let body = serde_json::to_vec(&json!({
52 "model": self.model,
53 "input": documents
54 }))?;
55
56 let req = self
57 .client
58 .post("v1/embeddings")?
59 .header("Content-Type", "application/json")
60 .body(body)
61 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
62
63 let response = self.client.send(req).await?;
64
65 if response.status().is_success() {
66 let body: Vec<u8> = response.into_body().await?;
67 let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
68
69 match body {
70 ApiResponse::Ok(response) => {
71 tracing::debug!(target: "rig",
72 "Mistral embedding token usage: {}",
73 response.usage
74 );
75
76 if response.data.len() != documents.len() {
77 return Err(EmbeddingError::ResponseError(
78 "Response data length does not match input length".into(),
79 ));
80 }
81
82 Ok(response
83 .data
84 .into_iter()
85 .zip(documents.into_iter())
86 .map(|(embedding, document)| embeddings::Embedding {
87 document,
88 vec: embedding.embedding,
89 })
90 .collect())
91 }
92 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
93 }
94 } else {
95 let text = http_client::text(response).await?;
96 Err(EmbeddingError::ProviderError(text))
97 }
98 }
99}
100
101#[derive(Debug, Deserialize)]
102pub struct EmbeddingResponse {
103 pub id: String,
104 pub object: String,
105 pub model: String,
106 pub usage: Usage,
107 pub data: Vec<EmbeddingData>,
108}
109
110#[derive(Debug, Deserialize)]
111pub struct EmbeddingData {
112 pub object: String,
113 pub embedding: Vec<f64>,
114 pub index: usize,
115}