rig_tei/
embedding.rs

1use rig::embeddings::{self, EmbeddingError};
2use rig::http_client::{self, HttpClientExt};
3use serde::Deserialize;
4use serde_json::{Value, json};
5
6use super::client::Client;
7
8#[derive(Debug, Deserialize)]
9struct MultiEmbeddings {
10    embeddings: Vec<Vec<f32>>,
11}
12
13#[derive(Debug, Deserialize)]
14struct SingleEmbedding {
15    embeddings: Vec<f32>,
16}
17
18#[derive(Debug, Deserialize)]
19#[serde(untagged)]
20enum EmbeddingResponse {
21    Multi(MultiEmbeddings),
22    Single(SingleEmbedding),
23    Bare(Vec<Vec<f32>>),
24}
25
26#[derive(Clone)]
27pub struct EmbeddingModel<T = reqwest::Client> {
28    pub(crate) client: Client<T>,
29    pub model: String,
30    ndims: usize,
31}
32
33impl<T> EmbeddingModel<T> {
34    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
35        Self {
36            client,
37            model: model.into(),
38            ndims,
39        }
40    }
41}
42
43impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
44where
45    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
46{
47    const MAX_DOCUMENTS: usize = 1024;
48
49    type Client = Client<T>;
50
51    fn make(client: &Self::Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
52        let model = model.into();
53        let dims = ndims.unwrap_or(0);
54        Self::new(client.clone(), model, dims)
55    }
56
57    fn ndims(&self) -> usize {
58        self.ndims
59    }
60
61    async fn embed_texts(
62        &self,
63        documents: impl IntoIterator<Item = String>,
64    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
65        let docs: Vec<String> = documents.into_iter().collect();
66
67        let inputs_value: Value = if docs.len() == 1 {
68            json!({ "inputs": docs[0] })
69        } else {
70            json!({ "inputs": docs })
71        };
72
73        let body = serde_json::to_vec(&inputs_value)?;
74
75        // Use resolved full endpoint (customizable)
76        let req = self
77            .client
78            .post_full(&self.client.endpoints.embed)
79            .header("Content-Type", "application/json")
80            .body(body)
81            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
82
83        let response = HttpClientExt::send(&self.client.http_client, req).await?;
84
85        if !response.status().is_success() {
86            let text = http_client::text(response).await?;
87            return Err(EmbeddingError::ProviderError(text));
88        }
89
90        let bytes: Vec<u8> = response.into_body().await?;
91        let parsed: EmbeddingResponse = serde_json::from_slice(&bytes).map_err(|e| {
92            EmbeddingError::ResponseError(format!("Failed to parse TEI embeddings: {e}"))
93        })?;
94
95        let embeddings: Vec<Vec<f64>> = match parsed {
96            EmbeddingResponse::Multi(m) => m
97                .embeddings
98                .into_iter()
99                .map(|v| v.into_iter().map(|x| x as f64).collect())
100                .collect(),
101            EmbeddingResponse::Single(s) => {
102                vec![s.embeddings.into_iter().map(|x| x as f64).collect()]
103            }
104            EmbeddingResponse::Bare(arr) => arr
105                .into_iter()
106                .map(|v| v.into_iter().map(|x| x as f64).collect())
107                .collect(),
108        };
109
110        if embeddings.len() != docs.len() {
111            return Err(EmbeddingError::ResponseError(
112                "Response data length does not match input length".into(),
113            ));
114        }
115
116        Ok(embeddings
117            .into_iter()
118            .zip(docs.into_iter())
119            .map(|(vec, document)| embeddings::Embedding { document, vec })
120            .collect())
121    }
122}