earl 0.5.2

AI-safe CLI for AI agents
use anyhow::{Context, Result};
use reqwest::Client;
use serde::Deserialize;
use serde_json::json;

use crate::config::RemoteSearchConfig;
use crate::secrets::SecretManager;
use crate::secrets::store::require_secret;

use super::cosine_similarity;
use super::index::{SearchDocument, SearchHit};

#[derive(Debug, Deserialize)]
struct OpenAiEmbeddingResponse {
    data: Vec<OpenAiEmbeddingItem>,
}

#[derive(Debug, Deserialize)]
struct OpenAiEmbeddingItem {
    embedding: Vec<f32>,
}

#[derive(Debug, Deserialize)]
struct RerankResponse {
    data: Vec<RerankItem>,
}

#[derive(Debug, Deserialize)]
struct RerankItem {
    index: usize,
    score: f32,
}

pub async fn search_remote(
    query: &str,
    documents: &[SearchDocument],
    cfg: &RemoteSearchConfig,
    secrets: &SecretManager,
) -> Result<Option<Vec<SearchHit>>> {
    if !cfg.enabled {
        return Ok(None);
    }

    let base_url = match &cfg.base_url {
        Some(url) => url.trim_end_matches('/').to_string(),
        None => return Ok(None),
    };

    let api_key = match &cfg.api_key_secret {
        Some(secret_key) => require_secret(secrets.store(), secrets.resolvers(), secret_key)?,
        None => return Ok(None),
    };

    let client = Client::builder()
        .timeout(std::time::Duration::from_millis(cfg.timeout_ms))
        .redirect(reqwest::redirect::Policy::none())
        .build()
        .context("failed building remote search client")?;

    let mut texts = Vec::with_capacity(documents.len() + 1);
    texts.push(query.to_string());
    texts.extend(documents.iter().map(|d| d.text.clone()));

    let embeddings_url = format!("{base_url}{}", cfg.embeddings_path);
    let embed_response = client
        .post(&embeddings_url)
        .bearer_auth(&api_key)
        .json(&json!({
            "model": "text-embedding-3-small",
            "input": texts,
        }))
        .send()
        .await
        .with_context(|| format!("remote embeddings request failed: {embeddings_url}"))?;

    if !embed_response.status().is_success() {
        return Ok(None);
    }

    let embeddings = embed_response
        .json::<OpenAiEmbeddingResponse>()
        .await
        .context("failed decoding remote embeddings response")?;

    if embeddings.data.len() != documents.len() + 1 {
        return Ok(None);
    }

    let query_embedding = &embeddings.data[0].embedding;
    let mut scored: Vec<(usize, f32)> = embeddings
        .data
        .iter()
        .enumerate()
        .skip(1)
        .map(|(idx, emb)| (idx - 1, cosine_similarity(query_embedding, &emb.embedding)))
        .collect();

    scored.sort_by(|a, b| b.1.total_cmp(&a.1));

    let rerank_url = format!("{base_url}{}", cfg.rerank_path);
    let rerank_docs: Vec<String> = scored
        .iter()
        .take(40)
        .map(|(idx, _)| documents[*idx].text.clone())
        .collect();

    let reranked = client
        .post(&rerank_url)
        .bearer_auth(&api_key)
        .json(&json!({
            "query": query,
            "documents": rerank_docs,
            "top_n": 10,
        }))
        .send()
        .await;

    let mut hits = Vec::new();

    if let Ok(resp) = reranked
        && resp.status().is_success()
        && let Ok(parsed) = resp.json::<RerankResponse>().await
    {
        let top_indices: Vec<usize> = scored.iter().take(40).map(|(idx, _)| *idx).collect();
        for item in parsed.data.into_iter().take(10) {
            if let Some(doc_idx) = top_indices.get(item.index) {
                let doc = &documents[*doc_idx];
                hits.push(SearchHit {
                    key: doc.key.clone(),
                    score: item.score,
                    summary: doc.summary.clone(),
                });
            }
        }
        if !hits.is_empty() {
            return Ok(Some(hits));
        }
    }

    for (idx, score) in scored.into_iter().take(10) {
        let doc = &documents[idx];
        hits.push(SearchHit {
            key: doc.key.clone(),
            score,
            summary: doc.summary.clone(),
        });
    }

    Ok(Some(hits))
}