Skip to main content

difflore_core/context/embedding/
cloud.rs

1use async_trait::async_trait;
2use std::time::Duration;
3
4use crate::errors::CoreError;
5
6use super::{
7    DEFAULT_OPENAI_EMBEDDING_DIM, EMBEDDING_RETRY_DELAYS_MS, Embedder, embedding_http_client,
8    retryable_embedding_status,
9};
10
11/// Cloud-managed embedder. POSTs `{ texts: [..] }` to the cloud API's
12/// `/api/embeddings` endpoint, authenticating with the user's existing
13/// CLI session token (same `cloud-auth.db` row that powers
14/// `cloud::client::CloudClient`).
15///
16/// This is the happy path for the OSS Free tier — users don't have to
17/// manage an OpenAI key locally; the cloud forwards to its own configured
18/// embedding provider. Failures (network / 401 / 5xx) bubble up as
19/// `CoreError::Internal` so the caller can fall back to local SHA1 after
20/// retry rather than surfacing a hard error.
21pub struct CloudEmbedder {
22    base_url: String,
23    token: String,
24    model: String,
25    dim: usize,
26    client: reqwest::Client,
27}
28
29impl CloudEmbedder {
30    pub fn new(base_url: String, token: String) -> Self {
31        Self::with_model(
32            base_url,
33            token,
34            "text-embedding-3-small".to_owned(),
35            DEFAULT_OPENAI_EMBEDDING_DIM,
36        )
37    }
38
39    pub fn with_model(base_url: String, token: String, model: String, dim: usize) -> Self {
40        Self {
41            base_url,
42            token,
43            model,
44            dim,
45            client: embedding_http_client(),
46        }
47    }
48
49    pub(crate) fn endpoint(&self) -> String {
50        format!("{}/embeddings", self.base_url.trim_end_matches('/'))
51    }
52
53    async fn post_embedding(
54        &self,
55        token: &str,
56        body: &serde_json::Value,
57    ) -> Result<reqwest::Response, CoreError> {
58        self.client
59            .post(self.endpoint())
60            .bearer_auth(token)
61            .json(body)
62            .send()
63            .await
64            .map_err(|e| CoreError::Internal(format!("cloud embedding request failed: {e}")))
65    }
66
67    async fn post_embedding_with_transport_retry(
68        &self,
69        token: &str,
70        body: &serde_json::Value,
71    ) -> Result<reqwest::Response, CoreError> {
72        let mut last_error = String::new();
73        for attempt in 0..=EMBEDDING_RETRY_DELAYS_MS.len() {
74            match self.post_embedding(token, body).await {
75                Ok(resp) => return Ok(resp),
76                Err(error) => {
77                    last_error = error.to_string();
78                    if let Some(delay_ms) = EMBEDDING_RETRY_DELAYS_MS.get(attempt) {
79                        tokio::time::sleep(Duration::from_millis(*delay_ms)).await;
80                    }
81                }
82            }
83        }
84        Err(CoreError::Internal(format!(
85            "cloud embedding request failed after {} transport attempts: {last_error}",
86            EMBEDDING_RETRY_DELAYS_MS.len() + 1
87        )))
88    }
89}
90
91#[async_trait]
92impl Embedder for CloudEmbedder {
93    async fn embed(&self, text: &str) -> Result<Vec<f32>, CoreError> {
94        let single = vec![text.to_owned()];
95        let mut vectors = self.embed_batch(&single, None).await?;
96        return vectors.pop().ok_or_else(|| {
97            CoreError::Internal("cloud embedding response missing first vector".into())
98        });
99    }
100
101    async fn embed_batch(
102        &self,
103        texts: &[String],
104        rule_ids: Option<&[String]>,
105    ) -> Result<Vec<Vec<f32>>, CoreError> {
106        if texts.is_empty() {
107            return Ok(Vec::new());
108        }
109        let body = serde_json::json!({
110            "texts": texts,
111            "model": self.model,
112        });
113        let body = if let Some(rule_ids) = rule_ids {
114            let mut value = body;
115            value["rule_ids"] = serde_json::json!(rule_ids);
116            value
117        } else {
118            body
119        };
120        let mut active_token = self.token.clone();
121        let mut resp = self
122            .post_embedding_with_transport_retry(&active_token, &body)
123            .await?;
124
125        let mut status = resp.status();
126        if status == reqwest::StatusCode::UNAUTHORIZED
127            && let Some(refreshed_token) =
128                crate::cloud::client::CloudClient::refresh_saved_token().await
129        {
130            active_token = refreshed_token;
131            resp = self
132                .post_embedding_with_transport_retry(&active_token, &body)
133                .await?;
134            status = resp.status();
135        }
136        for delay_ms in EMBEDDING_RETRY_DELAYS_MS {
137            if !retryable_embedding_status(status) {
138                break;
139            }
140            tokio::time::sleep(Duration::from_millis(*delay_ms)).await;
141            resp = self
142                .post_embedding_with_transport_retry(&active_token, &body)
143                .await?;
144            status = resp.status();
145        }
146        if !status.is_success() {
147            let body_text = resp.text().await.unwrap_or_default();
148            // 409 with `embed_cap_reached` is the Free-tier rule cap. We
149            // surface it as a typed error so the caller can fall back to
150            // lexical retrieval for this single embed call AND record an
151            // activity event for doctor — `Internal(...)` would lose both
152            // signals.
153            if status.as_u16() == 409
154                && let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&body_text)
155                && parsed.get("code").and_then(|c| c.as_str()) == Some("embed_cap_reached")
156            {
157                let cap = u32::try_from(
158                    parsed
159                        .get("cap")
160                        .and_then(serde_json::Value::as_u64)
161                        .unwrap_or(0),
162                )
163                .unwrap_or(u32::MAX);
164                let used = u32::try_from(
165                    parsed
166                        .get("used")
167                        .and_then(serde_json::Value::as_u64)
168                        .unwrap_or(0),
169                )
170                .unwrap_or(u32::MAX);
171                crate::activity_stream::record(
172                    crate::activity_stream::ActivityPayload::EmbedCapReached { cap, used },
173                );
174                return Err(CoreError::EmbedCapReached { cap, used });
175            }
176            return Err(CoreError::Internal(format!(
177                "cloud embedding endpoint returned {status}; semantic recall will fall back to file-pattern and keyword matching"
178            )));
179        }
180
181        let json: serde_json::Value = resp
182            .json()
183            .await
184            .map_err(|e| CoreError::Internal(format!("cloud embedding decode error: {e}")))?;
185
186        let vectors = json
187            .get("vectors")
188            .and_then(|v| v.as_array())
189            .ok_or_else(|| CoreError::Internal("cloud embedding response missing vectors".into()))?
190            .iter()
191            .map(|vector| {
192                vector
193                    .as_array()
194                    .ok_or_else(|| {
195                        CoreError::Internal("cloud embedding vector is not an array".into())
196                    })
197                    .map(|items| {
198                        items
199                            .iter()
200                            .map(|n| n.as_f64().unwrap_or(0.0) as f32)
201                            .collect::<Vec<f32>>()
202                    })
203            })
204            .collect::<Result<Vec<Vec<f32>>, CoreError>>()?;
205        if vectors.len() != texts.len() {
206            return Err(CoreError::Internal(format!(
207                "cloud embedding response length mismatch: expected {}, got {}",
208                texts.len(),
209                vectors.len()
210            )));
211        }
212        Ok(vectors)
213    }
214
215    fn dim(&self) -> usize {
216        self.dim
217    }
218}