rig_volcengine/
embedding.rs1use rig::embeddings::{self, EmbeddingError};
2use rig::http_client::{self, HttpClientExt};
3use rig::providers::openai::completion::Usage;
4use serde::Deserialize;
5use serde_json::json;
6
7use super::client::Client;
8use super::types::ApiResponse;
9
10pub const TEXT_DOUBAO_EMBEDDING: &str = "Doubao-embedding";
12pub const TEXT_DOUBAO_EMBEDDING_LARGE: &str = "doubao-embedding-large";
13
14#[derive(Debug, Deserialize)]
15pub struct EmbeddingData {
16 pub object: String,
17 pub embedding: Vec<f64>,
18 pub index: usize,
19}
20
21#[derive(Debug, Deserialize)]
22pub struct EmbeddingResponse {
23 pub object: String,
24 pub data: Vec<EmbeddingData>,
25 pub model: String,
26 #[serde(default)]
27 pub usage: Option<Usage>,
28}
29
30#[derive(Clone)]
31pub struct EmbeddingModel<T = reqwest::Client> {
32 pub(crate) client: Client<T>,
33 pub model: String,
34 ndims: usize,
35}
36
37impl<T> EmbeddingModel<T> {
38 pub fn new(client: Client<T>, model: &str, ndims: usize) -> Self {
39 Self {
40 client,
41 model: model.to_string(),
42 ndims,
43 }
44 }
45}
46
47impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
48where
49 T: HttpClientExt + Clone + std::fmt::Debug + Send + 'static,
50{
51 const MAX_DOCUMENTS: usize = 1024;
52
53 fn ndims(&self) -> usize {
54 self.ndims
55 }
56
57 async fn embed_texts(
58 &self,
59 documents: impl IntoIterator<Item = String>,
60 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
61 let documents = documents.into_iter().collect::<Vec<_>>();
62
63 let mut body = json!({
64 "model": self.model,
65 "input": documents,
66 });
67
68 if self.ndims > 0 {
69 body["dimensions"] = json!(self.ndims);
70 }
71
72 let body = serde_json::to_vec(&body)?;
73
74 let req = self
75 .client
76 .post("/embeddings")?
77 .header("Content-Type", "application/json")
78 .body(body)
79 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
80
81 let response = HttpClientExt::send(&self.client.http_client, req).await?;
82
83 if response.status().is_success() {
84 let text = http_client::text(response).await?;
85 let parsed: ApiResponse<EmbeddingResponse> = serde_json::from_str(&text)?;
86
87 match parsed {
88 ApiResponse::Ok(response) => {
89 if let Some(ref usage) = response.usage {
90 tracing::info!(target: "rig", "Volcengine embedding token usage: {}", usage);
91 }
92
93 if response.data.len() != documents.len() {
94 return Err(EmbeddingError::ResponseError(
95 "Response data length does not match input length".into(),
96 ));
97 }
98
99 Ok(response
100 .data
101 .into_iter()
102 .zip(documents.into_iter())
103 .map(|(embedding, document)| embeddings::Embedding {
104 document,
105 vec: embedding.embedding,
106 })
107 .collect())
108 }
109 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.error.message)),
110 }
111 } else {
112 let text = http_client::text(response).await?;
113 Err(EmbeddingError::ProviderError(text))
114 }
115 }
116}