ll-core 0.1.2

Hybrid semantic search engine: BM25 + vector + PageRank + cross-encoder reranking over SQLite, with on-device ONNX inference
use std::collections::HashMap;

pub const RRF_K: f64 = 5.0;

pub const PRF_ALPHA: f32 = 0.7;
pub const PRF_K: usize = 3;

#[derive(Debug, Clone, Copy)]
pub struct PrfParams {
    pub alpha: f32,
    pub beta: f32,
    pub k: usize,
}

impl Default for PrfParams {
    fn default() -> Self {
        Self { alpha: PRF_ALPHA, beta: 1.0 - PRF_ALPHA, k: PRF_K }
    }
}

pub struct FtsConfig {
    pub fts_table: &'static str,
    pub content_table: &'static str,
    pub items_table: &'static str,
    pub id_column: &'static str,
    pub path_column: &'static str,
    pub bm25_weights: &'static str,
}

pub const VAULT_FTS: FtsConfig = FtsConfig {
    fts_table: "notes_fts",
    content_table: "notes_content",
    items_table: "notes",
    id_column: "id",
    path_column: "path",
    bm25_weights: "10.0, 5.0, 1.0",
};

pub fn cosine(a: &[f32], b: &[f32]) -> f32 {
    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}

pub fn fts_escape(text: &str) -> String {
    text.split_whitespace()
        .filter(|t| !t.is_empty())
        .map(|t| format!("\"{}\"", t.replace('"', "\"\"")))
        .collect::<Vec<_>>()
        .join(" ")
}

pub fn fts_bm25_query(
    conn: &rusqlite::Connection,
    query: &str,
    limit: usize,
    config: &FtsConfig,
) -> Vec<(i64, String, f64)> {
    let escaped = fts_escape(query);
    if escaped.is_empty() {
        return Vec::new();
    }

    let sql = format!(
        "SELECT nc.{id}, n.{path}, bm25({fts}, {weights}) as score
         FROM {fts}
         JOIN {content} nc ON nc.{id} = {fts}.rowid
         JOIN {items} n ON n.{id} = nc.{id}
         WHERE {fts} MATCH ?1
         ORDER BY score
         LIMIT ?2",
        id = config.id_column,
        path = config.path_column,
        fts = config.fts_table,
        content = config.content_table,
        items = config.items_table,
        weights = config.bm25_weights,
    );

    let mut stmt = match conn.prepare(&sql) {
        Ok(s) => s,
        Err(_) => return Vec::new(),
    };

    let rows = match stmt.query_map(rusqlite::params![escaped, limit as i64], |row| {
        Ok((
            row.get::<_, i64>(0)?,
            row.get::<_, String>(1)?,
            row.get::<_, f64>(2)?,
        ))
    }) {
        Ok(r) => r,
        Err(_) => return Vec::new(),
    };

    rows.filter_map(|r| r.ok()).collect()
}

pub fn add_ranked_rrf<'a>(rrf_scores: &mut HashMap<String, f64>, items: impl Iterator<Item = &'a str>) {
    for (rank, path) in items.enumerate() {
        *rrf_scores.entry(path.to_string()).or_default() += 1.0 / (RRF_K + rank as f64 + 1.0);
    }
}

pub fn finalize_rrf(rrf_scores: HashMap<String, f64>, top_n: usize) -> Vec<(String, f64)> {
    let mut results: Vec<(String, f64)> = rrf_scores.into_iter().collect();
    results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
    results.truncate(top_n);
    results
}

pub fn collect_seeds(
    vec_scored: &[(String, f64)],
    fts_results: &[(i64, String, f64)],
) -> Vec<String> {
    use std::collections::HashSet;
    let mut seeds: Vec<String> = Vec::new();
    let mut seen = HashSet::new();
    for (path, _) in vec_scored.iter().take(10) {
        if seen.insert(path.clone()) {
            seeds.push(path.clone());
        }
    }
    for (_, path, _) in fts_results.iter().take(10) {
        if seen.insert(path.clone()) {
            seeds.push(path.clone());
        }
    }
    seeds
}

pub fn rocchio_prf_with(
    query_vec: &[f32],
    top_results: &[(String, f64)],
    all_embeddings: &[(i64, String, Vec<f32>)],
    params: &PrfParams,
) -> Vec<(String, f64)> {
    let dim = query_vec.len();
    let emb_map: HashMap<&str, &Vec<f32>> = all_embeddings
        .iter()
        .map(|(_, path, emb)| (path.as_str(), emb))
        .collect();

    let feedback_vecs: Vec<&Vec<f32>> = top_results
        .iter()
        .take(params.k)
        .filter_map(|(path, _)| emb_map.get(path.as_str()).copied())
        .collect();

    if feedback_vecs.is_empty() {
        return Vec::new();
    }

    let mut expanded = vec![0.0f32; dim];
    for d in 0..dim {
        let fb_mean: f32 = feedback_vecs.iter().map(|v| v[d]).sum::<f32>() / feedback_vecs.len() as f32;
        expanded[d] = params.alpha * query_vec[d] + params.beta * fb_mean;
    }

    let norm: f32 = expanded.iter().map(|x| x * x).sum::<f32>().sqrt();
    if norm > 0.0 {
        for d in 0..dim {
            expanded[d] /= norm;
        }
    }

    let mut prf_scored: Vec<(String, f64)> = all_embeddings
        .iter()
        .map(|(_, path, emb)| (path.clone(), cosine(&expanded, emb) as f64))
        .collect();
    prf_scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
    prf_scored.truncate(30);
    prf_scored
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_cosine_identical() {
        let v = vec![1.0f32, 0.0, 0.0];
        let sim = cosine(&v, &v);
        assert!((sim - 1.0).abs() < 1e-5);
    }

    #[test]
    fn test_cosine_orthogonal() {
        let a = vec![1.0f32, 0.0, 0.0];
        let b = vec![0.0f32, 1.0, 0.0];
        let sim = cosine(&a, &b);
        assert!(sim.abs() < 1e-5);
    }

    #[test]
    fn test_fts_escape() {
        assert_eq!(fts_escape("hello world"), "\"hello\" \"world\"");
        assert_eq!(fts_escape(""), "");
        assert_eq!(fts_escape("  "), "");
    }

    #[test]
    fn test_rrf_basic() {
        let mut scores = HashMap::new();
        add_ranked_rrf(&mut scores, ["a", "b", "c"].iter().copied());
        assert!(scores["a"] > scores["b"]);
        assert!(scores["b"] > scores["c"]);
    }

    #[test]
    fn test_finalize_rrf_truncates() {
        let mut scores = HashMap::new();
        for i in 0..20 {
            scores.insert(format!("doc_{}", i), 1.0 / (i + 1) as f64);
        }
        let results = finalize_rrf(scores, 5);
        assert_eq!(results.len(), 5);
    }
}