difflore-core 0.1.0

Core library for the difflore CLI — rule store, retrieval, MCP server, hooks, cloud sync. Not intended for direct use; depend on `difflore-cli` instead.
Documentation
use async_trait::async_trait;
use std::time::Duration;

use crate::errors::CoreError;

use super::{
    DEFAULT_OPENAI_EMBEDDING_DIM, EMBEDDING_RETRY_DELAYS_MS, Embedder, embedding_http_client,
    retryable_embedding_status,
};

/// Cloud-managed embedder. POSTs `{ texts: [..] }` to the cloud API's
/// `/api/embeddings` endpoint, authenticating with the user's existing
/// CLI session token (same `cloud-auth.db` row that powers
/// `cloud::client::CloudClient`).
///
/// This is the happy path for the OSS Free tier — users don't have to
/// manage an OpenAI key locally; the cloud forwards to its own configured
/// embedding provider. Failures (network / 401 / 5xx) bubble up as
/// `CoreError::Internal` so the caller can fall back to local SHA1 after
/// retry rather than surfacing a hard error.
pub struct CloudEmbedder {
    base_url: String,
    token: String,
    model: String,
    dim: usize,
    client: reqwest::Client,
}

impl CloudEmbedder {
    pub fn new(base_url: String, token: String) -> Self {
        Self::with_model(
            base_url,
            token,
            "text-embedding-3-small".to_owned(),
            DEFAULT_OPENAI_EMBEDDING_DIM,
        )
    }

    pub fn with_model(base_url: String, token: String, model: String, dim: usize) -> Self {
        Self {
            base_url,
            token,
            model,
            dim,
            client: embedding_http_client(),
        }
    }

    pub(crate) fn endpoint(&self) -> String {
        format!("{}/embeddings", self.base_url.trim_end_matches('/'))
    }

    async fn post_embedding(
        &self,
        token: &str,
        body: &serde_json::Value,
    ) -> Result<reqwest::Response, CoreError> {
        self.client
            .post(self.endpoint())
            .bearer_auth(token)
            .json(body)
            .send()
            .await
            .map_err(|e| CoreError::Internal(format!("cloud embedding request failed: {e}")))
    }

    async fn post_embedding_with_transport_retry(
        &self,
        token: &str,
        body: &serde_json::Value,
    ) -> Result<reqwest::Response, CoreError> {
        let mut last_error = String::new();
        for attempt in 0..=EMBEDDING_RETRY_DELAYS_MS.len() {
            match self.post_embedding(token, body).await {
                Ok(resp) => return Ok(resp),
                Err(error) => {
                    last_error = error.to_string();
                    if let Some(delay_ms) = EMBEDDING_RETRY_DELAYS_MS.get(attempt) {
                        tokio::time::sleep(Duration::from_millis(*delay_ms)).await;
                    }
                }
            }
        }
        Err(CoreError::Internal(format!(
            "cloud embedding request failed after {} transport attempts: {last_error}",
            EMBEDDING_RETRY_DELAYS_MS.len() + 1
        )))
    }
}

#[async_trait]
impl Embedder for CloudEmbedder {
    async fn embed(&self, text: &str) -> Result<Vec<f32>, CoreError> {
        let single = vec![text.to_owned()];
        let mut vectors = self.embed_batch(&single, None).await?;
        return vectors.pop().ok_or_else(|| {
            CoreError::Internal("cloud embedding response missing first vector".into())
        });
    }

    async fn embed_batch(
        &self,
        texts: &[String],
        rule_ids: Option<&[String]>,
    ) -> Result<Vec<Vec<f32>>, CoreError> {
        if texts.is_empty() {
            return Ok(Vec::new());
        }
        let body = serde_json::json!({
            "texts": texts,
            "model": self.model,
        });
        let body = if let Some(rule_ids) = rule_ids {
            let mut value = body;
            value["rule_ids"] = serde_json::json!(rule_ids);
            value
        } else {
            body
        };
        let mut active_token = self.token.clone();
        let mut resp = self
            .post_embedding_with_transport_retry(&active_token, &body)
            .await?;

        let mut status = resp.status();
        if status == reqwest::StatusCode::UNAUTHORIZED
            && let Some(refreshed_token) =
                crate::cloud::client::CloudClient::refresh_saved_token().await
        {
            active_token = refreshed_token;
            resp = self
                .post_embedding_with_transport_retry(&active_token, &body)
                .await?;
            status = resp.status();
        }
        for delay_ms in EMBEDDING_RETRY_DELAYS_MS {
            if !retryable_embedding_status(status) {
                break;
            }
            tokio::time::sleep(Duration::from_millis(*delay_ms)).await;
            resp = self
                .post_embedding_with_transport_retry(&active_token, &body)
                .await?;
            status = resp.status();
        }
        if !status.is_success() {
            let body_text = resp.text().await.unwrap_or_default();
            // 409 with `embed_cap_reached` is the Free-tier rule cap. We
            // surface it as a typed error so the caller can fall back to
            // lexical retrieval for this single embed call AND record an
            // activity event for doctor — `Internal(...)` would lose both
            // signals.
            if status.as_u16() == 409
                && let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&body_text)
                && parsed.get("code").and_then(|c| c.as_str()) == Some("embed_cap_reached")
            {
                let cap = u32::try_from(
                    parsed
                        .get("cap")
                        .and_then(serde_json::Value::as_u64)
                        .unwrap_or(0),
                )
                .unwrap_or(u32::MAX);
                let used = u32::try_from(
                    parsed
                        .get("used")
                        .and_then(serde_json::Value::as_u64)
                        .unwrap_or(0),
                )
                .unwrap_or(u32::MAX);
                crate::activity_stream::record(
                    crate::activity_stream::ActivityPayload::EmbedCapReached { cap, used },
                );
                return Err(CoreError::EmbedCapReached { cap, used });
            }
            return Err(CoreError::Internal(format!(
                "cloud embedding endpoint returned {status}; semantic recall will fall back to file-pattern and keyword matching"
            )));
        }

        let json: serde_json::Value = resp
            .json()
            .await
            .map_err(|e| CoreError::Internal(format!("cloud embedding decode error: {e}")))?;

        let vectors = json
            .get("vectors")
            .and_then(|v| v.as_array())
            .ok_or_else(|| CoreError::Internal("cloud embedding response missing vectors".into()))?
            .iter()
            .map(|vector| {
                vector
                    .as_array()
                    .ok_or_else(|| {
                        CoreError::Internal("cloud embedding vector is not an array".into())
                    })
                    .map(|items| {
                        items
                            .iter()
                            .map(|n| n.as_f64().unwrap_or(0.0) as f32)
                            .collect::<Vec<f32>>()
                    })
            })
            .collect::<Result<Vec<Vec<f32>>, CoreError>>()?;
        if vectors.len() != texts.len() {
            return Err(CoreError::Internal(format!(
                "cloud embedding response length mismatch: expected {}, got {}",
                texts.len(),
                vectors.len()
            )));
        }
        Ok(vectors)
    }

    fn dim(&self) -> usize {
        self.dim
    }
}