1use rig::embeddings::{self, EmbeddingError};
2use rig::http_client::{self, HttpClientExt};
3use serde::Deserialize;
4use serde_json::{Value, json};
5
6use super::client::Client;
7
8#[derive(Debug, Deserialize)]
9struct MultiEmbeddings {
10 embeddings: Vec<Vec<f32>>,
11}
12
13#[derive(Debug, Deserialize)]
14struct SingleEmbedding {
15 embeddings: Vec<f32>,
16}
17
18#[derive(Debug, Deserialize)]
19#[serde(untagged)]
20enum EmbeddingResponse {
21 Multi(MultiEmbeddings),
22 Single(SingleEmbedding),
23 Bare(Vec<Vec<f32>>),
24}
25
26#[derive(Clone)]
27pub struct EmbeddingModel<T = reqwest::Client> {
28 pub(crate) client: Client<T>,
29 pub model: String,
30 ndims: usize,
31}
32
33impl<T> EmbeddingModel<T> {
34 pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
35 Self {
36 client,
37 model: model.into(),
38 ndims,
39 }
40 }
41}
42
43impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
44where
45 T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
46{
47 const MAX_DOCUMENTS: usize = 1024;
48
49 type Client = Client<T>;
50
51 fn make(client: &Self::Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
52 let model = model.into();
53 let dims = ndims.unwrap_or(0);
54 Self::new(client.clone(), model, dims)
55 }
56
57 fn ndims(&self) -> usize {
58 self.ndims
59 }
60
61 async fn embed_texts(
62 &self,
63 documents: impl IntoIterator<Item = String>,
64 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
65 let docs: Vec<String> = documents.into_iter().collect();
66
67 let inputs_value: Value = if docs.len() == 1 {
68 json!({ "inputs": docs[0] })
69 } else {
70 json!({ "inputs": docs })
71 };
72
73 let body = serde_json::to_vec(&inputs_value)?;
74
75 let req = self
77 .client
78 .post_full(&self.client.endpoints.embed)
79 .header("Content-Type", "application/json")
80 .body(body)
81 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
82
83 let response = HttpClientExt::send(&self.client.http_client, req).await?;
84
85 if !response.status().is_success() {
86 let text = http_client::text(response).await?;
87 return Err(EmbeddingError::ProviderError(text));
88 }
89
90 let bytes: Vec<u8> = response.into_body().await?;
91 let parsed: EmbeddingResponse = serde_json::from_slice(&bytes).map_err(|e| {
92 EmbeddingError::ResponseError(format!("Failed to parse TEI embeddings: {e}"))
93 })?;
94
95 let embeddings: Vec<Vec<f64>> = match parsed {
96 EmbeddingResponse::Multi(m) => m
97 .embeddings
98 .into_iter()
99 .map(|v| v.into_iter().map(|x| x as f64).collect())
100 .collect(),
101 EmbeddingResponse::Single(s) => {
102 vec![s.embeddings.into_iter().map(|x| x as f64).collect()]
103 }
104 EmbeddingResponse::Bare(arr) => arr
105 .into_iter()
106 .map(|v| v.into_iter().map(|x| x as f64).collect())
107 .collect(),
108 };
109
110 if embeddings.len() != docs.len() {
111 return Err(EmbeddingError::ResponseError(
112 "Response data length does not match input length".into(),
113 ));
114 }
115
116 Ok(embeddings
117 .into_iter()
118 .zip(docs.into_iter())
119 .map(|(vec, document)| embeddings::Embedding { document, vec })
120 .collect())
121 }
122}