rig_volcengine/
embedding.rs1use rig::embeddings::{self, EmbeddingError};
2use rig::http_client::{self, HttpClientExt};
3use rig::providers::openai::completion::Usage;
4use serde::Deserialize;
5use serde_json::json;
6
7use super::client::Client;
8use super::types::ApiResponse;
9
10pub const TEXT_DOUBAO_EMBEDDING: &str = "Doubao-embedding";
12pub const TEXT_DOUBAO_EMBEDDING_LARGE: &str = "doubao-embedding-large";
13
14#[derive(Debug, Deserialize)]
15pub struct EmbeddingData {
16 pub object: String,
17 pub embedding: Vec<f64>,
18 pub index: usize,
19}
20
21#[derive(Debug, Deserialize)]
22pub struct EmbeddingResponse {
23 pub object: String,
24 pub data: Vec<EmbeddingData>,
25 pub model: String,
26 #[serde(default)]
27 pub usage: Option<Usage>,
28}
29
30#[derive(Clone)]
31pub struct EmbeddingModel<T = reqwest::Client> {
32 pub(crate) client: Client<T>,
33 pub model: String,
34 ndims: usize,
35}
36
37impl<T> EmbeddingModel<T> {
38 pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
39 Self {
40 client,
41 model: model.into(),
42 ndims,
43 }
44 }
45}
46
47impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
48where
49 T: HttpClientExt + Clone + std::fmt::Debug + Send + 'static,
50{
51 const MAX_DOCUMENTS: usize = 1024;
52
53 type Client = Client<T>;
54
55 fn make(client: &Self::Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
56 let model = model.into();
57 let dims = ndims.unwrap_or(0);
58 Self::new(client.clone(), model, dims)
59 }
60
61 fn ndims(&self) -> usize {
62 self.ndims
63 }
64
65 async fn embed_texts(
66 &self,
67 documents: impl IntoIterator<Item = String>,
68 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
69 let documents = documents.into_iter().collect::<Vec<_>>();
70
71 let mut body = json!({
72 "model": self.model,
73 "input": documents,
74 });
75
76 if self.ndims > 0 {
77 body["dimensions"] = json!(self.ndims);
78 }
79
80 let body = serde_json::to_vec(&body)?;
81
82 let req = self
83 .client
84 .post("/embeddings")?
85 .header("Content-Type", "application/json")
86 .body(body)
87 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
88
89 let response = HttpClientExt::send(&self.client.http_client, req).await?;
90
91 if response.status().is_success() {
92 let text = http_client::text(response).await?;
93 let parsed: ApiResponse<EmbeddingResponse> = serde_json::from_str(&text)?;
94
95 match parsed {
96 ApiResponse::Ok(response) => {
97 if let Some(ref usage) = response.usage {
98 tracing::info!(target: "rig", "Volcengine embedding token usage: {}", usage);
99 }
100
101 if response.data.len() != documents.len() {
102 return Err(EmbeddingError::ResponseError(
103 "Response data length does not match input length".into(),
104 ));
105 }
106
107 Ok(response
108 .data
109 .into_iter()
110 .zip(documents.into_iter())
111 .map(|(embedding, document)| embeddings::Embedding {
112 document,
113 vec: embedding.embedding,
114 })
115 .collect())
116 }
117 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.error.message)),
118 }
119 } else {
120 let text = http_client::text(response).await?;
121 Err(EmbeddingError::ProviderError(text))
122 }
123 }
124}