rig_bailian/
embedding.rs

1use rig::embeddings::{self, EmbeddingError};
2use rig::http_client::{self, HttpClientExt};
3use rig::providers::openai::completion::Usage;
4use serde::Deserialize;
5use serde_json::json;
6
7use super::client::Client;
8use super::types::ApiResponse;
9
10// Model constants (aligned with original)
11pub const TEXT_EMBEDDING_V4: &str = "text-embedding-v4";
12
13#[derive(Debug, Deserialize)]
14pub struct EmbeddingData {
15    pub object: String,
16    pub embedding: Vec<f64>,
17    pub index: usize,
18}
19
20#[derive(Debug, Deserialize)]
21pub struct EmbeddingResponse {
22    pub object: String,
23    pub data: Vec<EmbeddingData>,
24    pub model: String,
25    #[serde(default)]
26    pub usage: Option<Usage>,
27}
28
29#[derive(Clone)]
30pub struct EmbeddingModel<T = reqwest::Client> {
31    pub(crate) client: Client<T>,
32    pub model: String,
33    ndims: usize,
34}
35
36impl<T> EmbeddingModel<T> {
37    pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
38        Self {
39            client,
40            model: model.into(),
41            ndims,
42        }
43    }
44}
45
46impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
47where
48    T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
49{
50    const MAX_DOCUMENTS: usize = 1024;
51
52    type Client = Client<T>;
53
54    fn make(client: &Self::Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
55        let model = model.into();
56        let dims = ndims.unwrap_or(0);
57        Self::new(client.clone(), model, dims)
58    }
59
60    fn ndims(&self) -> usize {
61        self.ndims
62    }
63
64    async fn embed_texts(
65        &self,
66        documents: impl IntoIterator<Item = String>,
67    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
68        let documents = documents.into_iter().collect::<Vec<_>>();
69
70        let mut body = json!({
71            "model": self.model,
72            "input": documents,
73        });
74
75        if self.ndims > 0 {
76            body["dimensions"] = json!(self.ndims);
77        }
78
79        let body = serde_json::to_vec(&body)?;
80
81        let req = self
82            .client
83            .post("/embeddings")?
84            .header("Content-Type", "application/json")
85            .body(body)
86            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
87
88        let response = HttpClientExt::send(&self.client.http_client, req).await?;
89
90        if response.status().is_success() {
91            let text = http_client::text(response).await?;
92            let parsed: ApiResponse<EmbeddingResponse> = serde_json::from_str(&text)?;
93
94            match parsed {
95                ApiResponse::Ok(response) => {
96                    if let Some(ref usage) = response.usage {
97                        tracing::info!(target: "rig", "Bailian embedding token usage: {}", usage);
98                    }
99
100                    if response.data.len() != documents.len() {
101                        return Err(EmbeddingError::ResponseError(
102                            "Response data length does not match input length".into(),
103                        ));
104                    }
105
106                    Ok(response
107                        .data
108                        .into_iter()
109                        .zip(documents.into_iter())
110                        .map(|(embedding, document)| embeddings::Embedding {
111                            document,
112                            vec: embedding.embedding,
113                        })
114                        .collect())
115                }
116                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.error.message)),
117            }
118        } else {
119            let text = http_client::text(response).await?;
120            Err(EmbeddingError::ProviderError(text))
121        }
122    }
123}