rig/providers/mistral/
embedding.rs1use serde::Deserialize;
2use serde_json::json;
3
4use crate::{
5 embeddings::{self, EmbeddingError},
6 http_client::{self, HttpClientExt},
7};
8
9use super::client::{ApiResponse, Client, Usage};
10
11pub const MISTRAL_EMBED: &str = "mistral-embed";
15
16pub const MAX_DOCUMENTS: usize = 1024;
17
18#[derive(Clone)]
19pub struct EmbeddingModel<T = reqwest::Client> {
20 client: Client<T>,
21 pub model: String,
22 ndims: usize,
23}
24
25impl<T> EmbeddingModel<T> {
26 pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
27 Self {
28 client,
29 model: model.into(),
30 ndims,
31 }
32 }
33
34 pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
35 Self {
36 client,
37 model: model.to_string(),
38 ndims,
39 }
40 }
41}
42
43impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
44where
45 T: HttpClientExt + Clone + 'static,
46{
47 type Client = Client<T>;
48
49 const MAX_DOCUMENTS: usize = MAX_DOCUMENTS;
50
51 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
52 Self::new(client.clone(), model, dims.unwrap_or_default())
53 }
54
55 fn ndims(&self) -> usize {
56 self.ndims
57 }
58
59 async fn embed_texts(
60 &self,
61 documents: impl IntoIterator<Item = String>,
62 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
63 let documents = documents.into_iter().collect::<Vec<_>>();
64
65 let body = serde_json::to_vec(&json!({
66 "model": self.model,
67 "input": documents
68 }))?;
69
70 let req = self
71 .client
72 .post("v1/embeddings")?
73 .header("Content-Type", "application/json")
74 .body(body)
75 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
76
77 let response = self.client.send(req).await?;
78
79 if response.status().is_success() {
80 let body: Vec<u8> = response.into_body().await?;
81 let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
82
83 match body {
84 ApiResponse::Ok(response) => {
85 tracing::debug!(target: "rig",
86 "Mistral embedding token usage: {}",
87 response.usage
88 );
89
90 if response.data.len() != documents.len() {
91 return Err(EmbeddingError::ResponseError(
92 "Response data length does not match input length".into(),
93 ));
94 }
95
96 Ok(response
97 .data
98 .into_iter()
99 .zip(documents.into_iter())
100 .map(|(embedding, document)| embeddings::Embedding {
101 document,
102 vec: embedding.embedding,
103 })
104 .collect())
105 }
106 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
107 }
108 } else {
109 let text = http_client::text(response).await?;
110 Err(EmbeddingError::ProviderError(text))
111 }
112 }
113}
114
115#[derive(Debug, Deserialize)]
116pub struct EmbeddingResponse {
117 pub id: String,
118 pub object: String,
119 pub model: String,
120 pub usage: Usage,
121 pub data: Vec<EmbeddingData>,
122}
123
124#[derive(Debug, Deserialize)]
125pub struct EmbeddingData {
126 pub object: String,
127 pub embedding: Vec<f64>,
128 pub index: usize,
129}