1use rig::embeddings::{self, EmbeddingError};
2use rig::http_client::{self, HttpClientExt};
3use rig::providers::openai::completion::Usage;
4use serde::Deserialize;
5use serde_json::json;
6
7use super::client::Client;
8use super::types::ApiResponse;
9
10pub const TEXT_EMBEDDING_V4: &str = "text-embedding-v4";
12
13#[derive(Debug, Deserialize)]
14pub struct EmbeddingData {
15 pub object: String,
16 pub embedding: Vec<f64>,
17 pub index: usize,
18}
19
20#[derive(Debug, Deserialize)]
21pub struct EmbeddingResponse {
22 pub object: String,
23 pub data: Vec<EmbeddingData>,
24 pub model: String,
25 #[serde(default)]
26 pub usage: Option<Usage>,
27}
28
29#[derive(Clone)]
30pub struct EmbeddingModel<T = reqwest::Client> {
31 pub(crate) client: Client<T>,
32 pub model: String,
33 ndims: usize,
34}
35
36impl<T> EmbeddingModel<T> {
37 pub fn new(client: Client<T>, model: impl Into<String>, ndims: usize) -> Self {
38 Self {
39 client,
40 model: model.into(),
41 ndims,
42 }
43 }
44}
45
46impl<T> embeddings::EmbeddingModel for EmbeddingModel<T>
47where
48 T: HttpClientExt + Clone + std::fmt::Debug + Default + Send + 'static,
49{
50 const MAX_DOCUMENTS: usize = 1024;
51
52 type Client = Client<T>;
53
54 fn make(client: &Self::Client, model: impl Into<String>, ndims: Option<usize>) -> Self {
55 let model = model.into();
56 let dims = ndims.unwrap_or(0);
57 Self::new(client.clone(), model, dims)
58 }
59
60 fn ndims(&self) -> usize {
61 self.ndims
62 }
63
64 async fn embed_texts(
65 &self,
66 documents: impl IntoIterator<Item = String>,
67 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
68 let documents = documents.into_iter().collect::<Vec<_>>();
69
70 let mut body = json!({
71 "model": self.model,
72 "input": documents,
73 });
74
75 if self.ndims > 0 {
76 body["dimensions"] = json!(self.ndims);
77 }
78
79 let body = serde_json::to_vec(&body)?;
80
81 let req = self
82 .client
83 .post("/embeddings")?
84 .header("Content-Type", "application/json")
85 .body(body)
86 .map_err(|e| EmbeddingError::HttpError(e.into()))?;
87
88 let response = HttpClientExt::send(&self.client.http_client, req).await?;
89
90 if response.status().is_success() {
91 let text = http_client::text(response).await?;
92 let parsed: ApiResponse<EmbeddingResponse> = serde_json::from_str(&text)?;
93
94 match parsed {
95 ApiResponse::Ok(response) => {
96 if let Some(ref usage) = response.usage {
97 tracing::info!(target: "rig", "Bailian embedding token usage: {}", usage);
98 }
99
100 if response.data.len() != documents.len() {
101 return Err(EmbeddingError::ResponseError(
102 "Response data length does not match input length".into(),
103 ));
104 }
105
106 Ok(response
107 .data
108 .into_iter()
109 .zip(documents.into_iter())
110 .map(|(embedding, document)| embeddings::Embedding {
111 document,
112 vec: embedding.embedding,
113 })
114 .collect())
115 }
116 ApiResponse::Err(err) => Err(EmbeddingError::ProviderError(err.error.message)),
117 }
118 } else {
119 let text = http_client::text(response).await?;
120 Err(EmbeddingError::ProviderError(text))
121 }
122 }
123}