rig_volcengine/
embedding.rs

1use rig::embeddings::{self, EmbeddingError};
2use rig::http_client::{self, HttpClientExt};
3use rig::providers::openai::completion::Usage;
4use serde::Deserialize;
5use serde_json::json;
6
7use super::client::Client;
8use super::types::ApiResponse;
9
10// Model constants (aligned with original)
11pub const TEXT_DOUBAO_EMBEDDING: &str = "Doubao-embedding";
12pub const TEXT_DOUBAO_EMBEDDING_LARGE: &str = "doubao-embedding-large";
13
14#[derive(Debug, Deserialize)]
15pub struct EmbeddingData {
16    pub object: String,
17    pub embedding: Vec<f64>,
18    pub index: usize,
19}
20
21#[derive(Debug, Deserialize)]
22pub struct EmbeddingResponse {
23    pub object: String,
24    pub data: Vec<EmbeddingData>,
25    pub model: String,
26    #[serde(default)]
27    pub usage: Option<Usage>,
28}
29
30#[derive(Clone)]
31pub struct EmbeddingModel<T = reqwest::Client> {
32    pub(crate) client: Client<T>,
33    pub model: String,
34    ndims: usize,
35}
36
37impl<T> EmbeddingModel<T> {
38    pub fn new(client: Client<T>, model: &str, ndims: usize) -> Self {
39        Self {
40            client,
41            model: model.to_string(),
42            ndims,
43        }
44    }
45}
46
47impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
48where
49    T: HttpClientExt + Clone + std::fmt::Debug + Send + 'static,
50{
51    const MAX_DOCUMENTS: usize = 1024;
52
53    fn ndims(&self) -> usize {
54        self.ndims
55    }
56
57    async fn embed_texts(
58        &self,
59        documents: impl IntoIterator<Item = String>,
60    ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
61        let documents = documents.into_iter().collect::<Vec<_>>();
62
63        let mut body = json!({
64            "model": self.model,
65            "input": documents,
66        });
67
68        if self.ndims > 0 {
69            body["dimensions"] = json!(self.ndims);
70        }
71
72        let body = serde_json::to_vec(&body)?;
73
74        let req = self
75            .client
76            .post("/embeddings")?
77            .header("Content-Type", "application/json")
78            .body(body)
79            .map_err(|e| EmbeddingError::HttpError(e.into()))?;
80
81        let response = HttpClientExt::send(&self.client.http_client, req).await?;
82
83        if response.status().is_success() {
84            let text = http_client::text(response).await?;
85            let parsed: ApiResponse<EmbeddingResponse> = serde_json::from_str(&text)?;
86
87            match parsed {
88                ApiResponse::Ok(response) => {
89                    if let Some(ref usage) = response.usage {
90                        tracing::info!(target: "rig", "Volcengine embedding token usage: {}", usage);
91                    }
92
93                    if response.data.len() != documents.len() {
94                        return Err(EmbeddingError::ResponseError(
95                            "Response data length does not match input length".into(),
96                        ));
97                    }
98
99                    Ok(response
100                        .data
101                        .into_iter()
102                        .zip(documents.into_iter())
103                        .map(|(embedding, document)| embeddings::Embedding {
104                            document,
105                            vec: embedding.embedding,
106                        })
107                        .collect())
108                }
109                ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.error.message)),
110            }
111        } else {
112            let text = http_client::text(response).await?;
113            Err(EmbeddingError::ProviderError(text))
114        }
115    }
116}