rig/providers/mistral/
embedding.rs1use serde::Deserialize;
2use serde_json::json;
3
4use crate::{
5 embeddings::{self, EmbeddingError},
6 http_client::{self, HttpClientExt},
7};
8
9use super::client::{ApiResponse, Client, Usage};
10
11pub const MISTRAL_EMBED: &str = "mistral-embed";
15
16pub const MAX_DOCUMENTS: usize = 1024;
17
18#[derive(Clone)]
19pub struct EmbeddingModel<T = reqwest::Client> {
20 client: Client<T>,
21 pub model: String,
22 ndims: usize,
23}
24
25impl<T> EmbeddingModel<T> {
26 pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
27 Self {
28 client,
29 model: model.into(),
30 ndims,
31 }
32 }
33
34 pub fn with_model(client: Client<T>, model: &str, ndims: usize) -> Self {
35 Self {
36 client,
37 model: model.to_string(),
38 ndims,
39 }
40 }
41}
42
43impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
44where
45 T: HttpClientExt + Clone,
46{
47 type Client = Client<T>;
48
49 const MAX_DOCUMENTS: usize = MAX_DOCUMENTS;
50
51 fn make(client: &Self::Client, model: impl Into<String>, dims: Option<usize>) -> Self {
52 Self::new(client.clone(), model, dims.unwrap_or_default())
53 }
54
55 fn ndims(&self) -> usize {
56 self.ndims
57 }
58
59 #[cfg_attr(feature = "worker", worker::send)]
60 async fn embed_texts(
61 &self,
62 documents: impl IntoIterator<Item = String>,
63 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
64 let documents = documents.into_iter().collect::<Vec<_>>();
65
66 let body = serde_json::to_vec(&json!({
67 "model": self.model,
68 "input": documents
69 }))?;
70
71 let req = self
72 .client
73 .post("v1/embeddings")?
74 .header("Content-Type", "application/json")
75 .body(body)
76 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
77
78 let response = self.client.send(req).await?;
79
80 if response.status().is_success() {
81 let body: Vec<u8> = response.into_body().await?;
82 let body: ApiResponse<EmbeddingResponse> = serde_json::from_slice(&body)?;
83
84 match body {
85 ApiResponse::Ok(response) => {
86 tracing::debug!(target: "rig",
87 "Mistral embedding token usage: {}",
88 response.usage
89 );
90
91 if response.data.len() != documents.len() {
92 return Err(EmbeddingError::ResponseError(
93 "Response data length does not match input length".into(),
94 ));
95 }
96
97 Ok(response
98 .data
99 .into_iter()
100 .zip(documents.into_iter())
101 .map(|(embedding, document)| embeddings::Embedding {
102 document,
103 vec: embedding.embedding,
104 })
105 .collect())
106 }
107 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.message)),
108 }
109 } else {
110 let text = http_client::text(response).await?;
111 Err(EmbeddingError::ProviderError(text))
112 }
113 }
114}
115
116#[derive(Debug, Deserialize)]
117pub struct EmbeddingResponse {
118 pub id: String,
119 pub object: String,
120 pub model: String,
121 pub usage: Usage,
122 pub data: Vec<EmbeddingData>,
123}
124
125#[derive(Debug, Deserialize)]
126pub struct EmbeddingData {
127 pub object: String,
128 pub embedding: Vec<f64>,
129 pub index: usize,
130}