1use rig::embeddings::{self, EmbeddingError};
2use rig::http_client::{self, HttpClientExt};
3use serde::Deserialize;
4use serde_json::{Value, json};
5
6use super::client::Client;
7
8#[derive(Debug, Deserialize)]
9struct MultiEmbeddings {
10 embeddings: Vec<Vec<f32>>,
11}
12
13#[derive(Debug, Deserialize)]
14struct SingleEmbedding {
15 embeddings: Vec<f32>,
16}
17
18#[derive(Debug, Deserialize)]
19#[serde(untagged)]
20enum EmbeddingResponse {
21 Multi(MultiEmbeddings),
22 Single(SingleEmbedding),
23 Bare(Vec<Vec<f32>>),
24}
25
26#[derive(Clone)]
27pub struct EmbeddingModel<T = reqwest::Client> {
28 pub(crate) client: Client<T>,
29 pub model: String,
30 ndims: usize,
31}
32
33impl<T> EmbeddingModel<T> {
34 pub fn new(client: Client<T>, model: &str, ndims: usize) -> Self {
35 Self {
36 client,
37 model: model.to_string(),
38 ndims,
39 }
40 }
41}
42
43impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
44where
45 T: HttpClientExt + Clone + std::fmt::Debug + Send + 'static,
46{
47 const MAX_DOCUMENTS: usize = 1024;
48
49 fn ndims(&self) -> usize {
50 self.ndims
51 }
52
53 async fn embed_texts(
54 &self,
55 documents: impl IntoIterator<Item = String>,
56 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
57 let docs: Vec<String> = documents.into_iter().collect();
58
59 let inputs_value: Value = if docs.len() == 1 {
60 json!({ "inputs": docs[0] })
61 } else {
62 json!({ "inputs": docs })
63 };
64
65 let body = serde_json::to_vec(&inputs_value)?;
66
67 let req = self
69 .client
70 .post_full(&self.client.endpoints.embed)
71 .header("Content-Type", "application/json")
72 .body(body)
73 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
74
75 let response = HttpClientExt::send(&self.client.http_client, req).await?;
76
77 if !response.status().is_success() {
78 let text = http_client::text(response).await?;
79 return Err(EmbeddingError::ProviderError(text));
80 }
81
82 let bytes: Vec<u8> = response.into_body().await?;
83 let parsed: EmbeddingResponse = serde_json::from_slice(&bytes).map_err(|e| {
84 EmbeddingError::ResponseError(format!("Failed to parse TEI embeddings: {e}"))
85 })?;
86
87 let embeddings: Vec<Vec<f64>> = match parsed {
88 EmbeddingResponse::Multi(m) => m
89 .embeddings
90 .into_iter()
91 .map(|v| v.into_iter().map(|x| x as f64).collect())
92 .collect(),
93 EmbeddingResponse::Single(s) => {
94 vec![s.embeddings.into_iter().map(|x| x as f64).collect()]
95 }
96 EmbeddingResponse::Bare(arr) => arr
97 .into_iter()
98 .map(|v| v.into_iter().map(|x| x as f64).collect())
99 .collect(),
100 };
101
102 if embeddings.len() != docs.len() {
103 return Err(EmbeddingError::ResponseError(
104 "Response data length does not match input length".into(),
105 ));
106 }
107
108 Ok(embeddings
109 .into_iter()
110 .zip(docs.into_iter())
111 .map(|(vec, document)| embeddings::Embedding { document, vec })
112 .collect())
113 }
114}