rig_tei/
embedding.rs

1use rig::embeddings::{self, EmbeddingError};
2use rig::http_client::{self, HttpClientExt};
3use serde::Deserialize;
4use serde_json::{Value, json};
5
6use super::client::Client;
7
8#[derive(Debug, Deserialize)]
9struct MultiEmbeddings {
10    embeddings: Vec<Vec<f32>>,
11}
12
13#[derive(Debug, Deserialize)]
14struct SingleEmbedding {
15    embeddings: Vec<f32>,
16}
17
18#[derive(Debug, Deserialize)]
19#[serde(untagged)]
20enum EmbeddingResponse {
21    Multi(MultiEmbeddings),
22    Single(SingleEmbedding),
23    Bare(Vec<Vec<f32>>),
24}
25
26#[derive(Clone)]
27pub struct EmbeddingModel<T = reqwest::Client> {
28    pub(crate) client: Client<T>,
29    pub model: String,
30    ndims: usize,
31}
32
33impl<T> EmbeddingModel<T> {
34    pub fn new(client: Client<T>, model: &str, ndims: usize) -> Self {
35        Self {
36            client,
37            model: model.to_string(),
38            ndims,
39        }
40    }
41}
42
43impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
44where
45    T: HttpClientExt + Clone + std::fmt::Debug + Send + 'static,
46{
47    const MAX_DOCUMENTS: usize = 1024;
48
49    fn ndims(&self) -> usize {
50        self.ndims
51    }
52
53    async fn embed_texts(
54        &self,
55        documents: impl IntoIterator<Item = String>,
56    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
57        let docs: Vec<String> = documents.into_iter().collect();
58
59        let inputs_value: Value = if docs.len() == 1 {
60            json!({ "inputs": docs[0] })
61        } else {
62            json!({ "inputs": docs })
63        };
64
65        let body = serde_json::to_vec(&inputs_value)?;
66
67        // Use resolved full endpoint (customizable)
68        let req = self
69            .client
70            .post_full(&self.client.endpoints.embed)
71            .header("Content-Type", "application/json")
72            .body(body)
73            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
74
75        let response = HttpClientExt::send(&self.client.http_client, req).await?;
76
77        if !response.status().is_success() {
78            let text = http_client::text(response).await?;
79            return Err(EmbeddingError::ProviderError(text));
80        }
81
82        let bytes: Vec<u8> = response.into_body().await?;
83        let parsed: EmbeddingResponse = serde_json::from_slice(&bytes).map_err(|e| {
84            EmbeddingError::ResponseError(format!("Failed to parse TEI embeddings: {e}"))
85        })?;
86
87        let embeddings: Vec<Vec<f64>> = match parsed {
88            EmbeddingResponse::Multi(m) => m
89                .embeddings
90                .into_iter()
91                .map(|v| v.into_iter().map(|x| x as f64).collect())
92                .collect(),
93            EmbeddingResponse::Single(s) => {
94                vec![s.embeddings.into_iter().map(|x| x as f64).collect()]
95            }
96            EmbeddingResponse::Bare(arr) => arr
97                .into_iter()
98                .map(|v| v.into_iter().map(|x| x as f64).collect())
99                .collect(),
100        };
101
102        if embeddings.len() != docs.len() {
103            return Err(EmbeddingError::ResponseError(
104                "Response data length does not match input length".into(),
105            ));
106        }
107
108        Ok(embeddings
109            .into_iter()
110            .zip(docs.into_iter())
111            .map(|(vec, document)| embeddings::Embedding { document, vec })
112            .collect())
113    }
114}