earl 0.5.2

AI-safe CLI for AI agents
use anyhow::{Context, Result};
use fastembed::{
    EmbeddingModel, RerankInitOptions, RerankerModel, TextEmbedding, TextInitOptions, TextRerank,
};

use crate::config::SearchConfig;

use super::cosine_similarity;
use super::index::{SearchDocument, SearchHit};

pub fn search_local(
    query: &str,
    documents: &[SearchDocument],
    cfg: &SearchConfig,
) -> Result<Vec<SearchHit>> {
    if documents.is_empty() {
        return Ok(Vec::new());
    }

    let embedding_model: EmbeddingModel = cfg
        .local
        .embedding_model
        .clone()
        .try_into()
        .map_err(|err: String| anyhow::anyhow!(err))
        .with_context(|| format!("invalid embedding model `{}`", cfg.local.embedding_model))?;

    let reranker_model: RerankerModel = cfg
        .local
        .reranker_model
        .clone()
        .try_into()
        .map_err(|err: String| anyhow::anyhow!(err))
        .with_context(|| format!("invalid reranker model `{}`", cfg.local.reranker_model))?;

    let mut embedder = TextEmbedding::try_new(
        TextInitOptions::new(embedding_model).with_show_download_progress(false),
    )
    .context("failed to initialize local embedding model")?;

    let mut reranker = TextRerank::try_new(
        RerankInitOptions::new(reranker_model).with_show_download_progress(false),
    )
    .context("failed to initialize local reranker model")?;

    let doc_texts: Vec<String> = documents.iter().map(|d| d.text.clone()).collect();
    let query_text = vec![query.to_string()];

    let query_embedding = embedder
        .embed(query_text, None)
        .context("failed generating query embedding")?
        .into_iter()
        .next()
        .ok_or_else(|| anyhow::anyhow!("embedding model returned empty query embedding"))?;

    let doc_embeddings = embedder
        .embed(doc_texts.clone(), None)
        .context("failed generating document embeddings")?;

    let mut scored: Vec<(usize, f32)> = doc_embeddings
        .iter()
        .enumerate()
        .map(|(idx, emb)| (idx, cosine_similarity(&query_embedding, emb)))
        .collect();

    scored.sort_by(|a, b| b.1.total_cmp(&a.1));
    let top_k = cfg.top_k.min(scored.len());
    let top_indices: Vec<usize> = scored.into_iter().take(top_k).map(|(idx, _)| idx).collect();

    let rerank_docs: Vec<String> = top_indices
        .iter()
        .map(|idx| doc_texts[*idx].clone())
        .collect();

    let reranked = reranker
        .rerank(query.to_string(), rerank_docs, false, None)
        .context("local reranking failed")?;

    let mut hits = Vec::new();
    for rr in reranked.into_iter().take(cfg.rerank_k) {
        if let Some(doc_idx) = top_indices.get(rr.index) {
            let doc = &documents[*doc_idx];
            hits.push(SearchHit {
                key: doc.key.clone(),
                score: rr.score,
                summary: doc.summary.clone(),
            });
        }
    }

    Ok(hits)
}