difflore_core/context/embedding/
cloud.rs1use async_trait::async_trait;
2use std::time::Duration;
3
4use crate::errors::CoreError;
5
6use super::{
7 DEFAULT_OPENAI_EMBEDDING_DIM, EMBEDDING_RETRY_DELAYS_MS, Embedder, embedding_http_client,
8 retryable_embedding_status,
9};
10
11pub struct CloudEmbedder {
22 base_url: String,
23 token: String,
24 model: String,
25 dim: usize,
26 client: reqwest::Client,
27}
28
29impl CloudEmbedder {
30 pub fn new(base_url: String, token: String) -> Self {
31 Self::with_model(
32 base_url,
33 token,
34 "text-embedding-3-small".to_owned(),
35 DEFAULT_OPENAI_EMBEDDING_DIM,
36 )
37 }
38
39 pub fn with_model(base_url: String, token: String, model: String, dim: usize) -> Self {
40 Self {
41 base_url,
42 token,
43 model,
44 dim,
45 client: embedding_http_client(),
46 }
47 }
48
49 pub(crate) fn endpoint(&self) -> String {
50 format!("{}/embeddings", self.base_url.trim_end_matches('/'))
51 }
52
53 async fn post_embedding(
54 &self,
55 token: &str,
56 body: &serde_json::Value,
57 ) -> Result<reqwest::Response, CoreError> {
58 self.client
59 .post(self.endpoint())
60 .bearer_auth(token)
61 .json(body)
62 .send()
63 .await
64 .map_err(|e| CoreError::Internal(format!("cloud embedding request failed: {e}")))
65 }
66
67 async fn post_embedding_with_transport_retry(
68 &self,
69 token: &str,
70 body: &serde_json::Value,
71 ) -> Result<reqwest::Response, CoreError> {
72 let mut last_error = String::new();
73 for attempt in 0..=EMBEDDING_RETRY_DELAYS_MS.len() {
74 match self.post_embedding(token, body).await {
75 Ok(resp) => return Ok(resp),
76 Err(error) => {
77 last_error = error.to_string();
78 if let Some(delay_ms) = EMBEDDING_RETRY_DELAYS_MS.get(attempt) {
79 tokio::time::sleep(Duration::from_millis(*delay_ms)).await;
80 }
81 }
82 }
83 }
84 Err(CoreError::Internal(format!(
85 "cloud embedding request failed after {} transport attempts: {last_error}",
86 EMBEDDING_RETRY_DELAYS_MS.len() + 1
87 )))
88 }
89}
90
91#[async_trait]
92impl Embedder for CloudEmbedder {
93 async fn embed(&self, text: &str) -> Result<Vec<f32>, CoreError> {
94 let single = vec![text.to_owned()];
95 let mut vectors = self.embed_batch(&single, None).await?;
96 return vectors.pop().ok_or_else(|| {
97 CoreError::Internal("cloud embedding response missing first vector".into())
98 });
99 }
100
101 async fn embed_batch(
102 &self,
103 texts: &[String],
104 rule_ids: Option<&[String]>,
105 ) -> Result<Vec<Vec<f32>>, CoreError> {
106 if texts.is_empty() {
107 return Ok(Vec::new());
108 }
109 let body = serde_json::json!({
110 "texts": texts,
111 "model": self.model,
112 });
113 let body = if let Some(rule_ids) = rule_ids {
114 let mut value = body;
115 value["rule_ids"] = serde_json::json!(rule_ids);
116 value
117 } else {
118 body
119 };
120 let mut active_token = self.token.clone();
121 let mut resp = self
122 .post_embedding_with_transport_retry(&active_token, &body)
123 .await?;
124
125 let mut status = resp.status();
126 if status == reqwest::StatusCode::UNAUTHORIZED
127 && let Some(refreshed_token) =
128 crate::cloud::client::CloudClient::refresh_saved_token().await
129 {
130 active_token = refreshed_token;
131 resp = self
132 .post_embedding_with_transport_retry(&active_token, &body)
133 .await?;
134 status = resp.status();
135 }
136 for delay_ms in EMBEDDING_RETRY_DELAYS_MS {
137 if !retryable_embedding_status(status) {
138 break;
139 }
140 tokio::time::sleep(Duration::from_millis(*delay_ms)).await;
141 resp = self
142 .post_embedding_with_transport_retry(&active_token, &body)
143 .await?;
144 status = resp.status();
145 }
146 if !status.is_success() {
147 let body_text = resp.text().await.unwrap_or_default();
148 if status.as_u16() == 409
154 && let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&body_text)
155 && parsed.get("code").and_then(|c| c.as_str()) == Some("embed_cap_reached")
156 {
157 let cap = u32::try_from(
158 parsed
159 .get("cap")
160 .and_then(serde_json::Value::as_u64)
161 .unwrap_or(0),
162 )
163 .unwrap_or(u32::MAX);
164 let used = u32::try_from(
165 parsed
166 .get("used")
167 .and_then(serde_json::Value::as_u64)
168 .unwrap_or(0),
169 )
170 .unwrap_or(u32::MAX);
171 crate::activity_stream::record(
172 crate::activity_stream::ActivityPayload::EmbedCapReached { cap, used },
173 );
174 return Err(CoreError::EmbedCapReached { cap, used });
175 }
176 return Err(CoreError::Internal(format!(
177 "cloud embedding endpoint returned {status}; semantic recall will fall back to file-pattern and keyword matching"
178 )));
179 }
180
181 let json: serde_json::Value = resp
182 .json()
183 .await
184 .map_err(|e| CoreError::Internal(format!("cloud embedding decode error: {e}")))?;
185
186 let vectors = json
187 .get("vectors")
188 .and_then(|v| v.as_array())
189 .ok_or_else(|| CoreError::Internal("cloud embedding response missing vectors".into()))?
190 .iter()
191 .map(|vector| {
192 vector
193 .as_array()
194 .ok_or_else(|| {
195 CoreError::Internal("cloud embedding vector is not an array".into())
196 })
197 .map(|items| {
198 items
199 .iter()
200 .map(|n| n.as_f64().unwrap_or(0.0) as f32)
201 .collect::<Vec<f32>>()
202 })
203 })
204 .collect::<Result<Vec<Vec<f32>>, CoreError>>()?;
205 if vectors.len() != texts.len() {
206 return Err(CoreError::Internal(format!(
207 "cloud embedding response length mismatch: expected {}, got {}",
208 texts.len(),
209 vectors.len()
210 )));
211 }
212 Ok(vectors)
213 }
214
215 fn dim(&self) -> usize {
216 self.dim
217 }
218}