talon-core 0.4.2

Core retrieval engine for Talon: hybrid search (BM25 + semantic + reranker), indexing, and graph-aware ranking over markdown corpora.
Documentation
//! Graph refinement for search results.

use std::collections::{BTreeMap, BTreeSet};

use rusqlite::Connection;

use crate::config::TalonConfig;
use crate::graph::{GraphRankInput, GraphSnapshot, load_graph_snapshot, rank_related};
use crate::numeric::count_u32;
use crate::search::GraphSearchDiagnostics;
use crate::search::types::{RawSearchResult, SearchScores};
use crate::search::{Direction, SearchInput, SearchMode};

use super::search::ScoredRawSearchResult;

const SEED_MIN_SCORE: f64 = 0.62;
const GRAPH_EXISTING_BLEND: f64 = 0.04;
const GRAPH_ONLY_BLEND: f64 = 0.025;
const GRAPH_ONLY_LIMIT: usize = 4;
const GRAPH_PER_COMMUNITY_LIMIT: usize = 2;
const GRAPH_RELATED_SEED_LIMIT: usize = 8;
const GRAPH_HYBRID_SEED_LIMIT: usize = 1;

pub(super) fn refine_graph_results(
    conn: &Connection,
    input: &SearchInput,
    config: Option<&TalonConfig>,
    scored: &mut Vec<ScoredRawSearchResult>,
) -> Option<GraphSearchDiagnostics> {
    if !graph_refinement_enabled(input) {
        return None;
    }
    let Ok(snapshot) = load_graph_snapshot(conn) else {
        return None;
    };
    if snapshot.nodes.is_empty() || snapshot.edges.is_empty() {
        return None;
    }

    let mut boosted_results = 0_u32;
    let mut score_contribution = 0.0_f64;
    let existing = scored
        .iter()
        .map(|result| result.raw.path.clone())
        .collect::<BTreeSet<_>>();
    let by_path = scored
        .iter()
        .enumerate()
        .map(|(index, result)| (result.raw.path.clone(), index))
        .collect::<BTreeMap<_, _>>();
    let mut graph_only = Vec::new();
    let mut graph_only_paths = BTreeSet::new();
    let mut community_counts: BTreeMap<Option<u32>, usize> = BTreeMap::new();

    let seeds = scored
        .iter()
        .take(graph_seed_limit(input))
        .filter(|result| result.raw.score >= SEED_MIN_SCORE)
        .map(|result| (result.raw.path.clone(), result.raw.score))
        .collect::<Vec<_>>();
    for (seed_path, seed_score) in seeds {
        let ranked = rank_related(
            &snapshot,
            &GraphRankInput {
                source_path: seed_path,
                direction: Direction::Both,
                depth: input.depth.clamp(1, crate::constants::RELATED_MAX_DEPTH),
                limit: 8,
                scope_priorities: scope_priorities(config),
            },
        );
        for candidate in ranked {
            let contribution = seed_score * candidate.score;
            if let Some(index) = by_path.get(&candidate.vault_path).copied() {
                let boost = contribution * GRAPH_EXISTING_BLEND;
                scored[index].raw.score += boost;
                boosted_results = boosted_results.saturating_add(1);
                score_contribution += boost;
                continue;
            }
            if graph_only.len() >= GRAPH_ONLY_LIMIT
                || existing.contains(&candidate.vault_path)
                || graph_only_paths.contains(&candidate.vault_path)
            {
                continue;
            }
            let community = snapshot
                .nodes
                .get(&candidate.vault_path)
                .and_then(|node| node.community_id);
            let count = community_counts.entry(community).or_default();
            if *count >= GRAPH_PER_COMMUNITY_LIMIT {
                continue;
            }
            let Some(raw) = graph_only_result(&snapshot, &candidate.vault_path, contribution)
            else {
                continue;
            };
            *count = count.saturating_add(1);
            graph_only_paths.insert(candidate.vault_path);
            score_contribution += raw.raw.score;
            graph_only.push(raw);
        }
    }

    let expanded_results = count_u32(graph_only.len());
    scored.extend(graph_only);
    (boosted_results > 0 || expanded_results > 0).then_some(GraphSearchDiagnostics {
        boosted_results,
        expanded_results,
        score_contribution,
    })
}

const fn graph_refinement_enabled(input: &SearchInput) -> bool {
    input.related || matches!(input.mode, SearchMode::Hybrid)
}

const fn graph_seed_limit(input: &SearchInput) -> usize {
    if input.related {
        GRAPH_RELATED_SEED_LIMIT
    } else {
        GRAPH_HYBRID_SEED_LIMIT
    }
}

fn scope_priorities(
    config: Option<&TalonConfig>,
) -> BTreeMap<String, crate::config::ScopePriority> {
    config
        .map(|cfg| {
            cfg.scopes
                .iter()
                .map(|(name, scope)| (name.clone(), scope.priority))
                .collect()
        })
        .unwrap_or_default()
}

fn graph_only_result(
    snapshot: &GraphSnapshot,
    path: &str,
    contribution: f64,
) -> Option<ScoredRawSearchResult> {
    let node = snapshot.nodes.get(path)?;
    let score = (contribution * GRAPH_ONLY_BLEND).min(0.72);
    Some(ScoredRawSearchResult {
        raw: RawSearchResult {
            path: path.to_string(),
            title: node.title.clone(),
            tags: node.tags.clone(),
            aliases: node.aliases.clone(),
            snippet: String::new(),
            score,
            scores: SearchScores {
                hybrid: Some(score),
                ..SearchScores::default()
            },
            semantic_heading: None,
            semantic_char_start: None,
            semantic_char_end: None,
        },
        raw_score: score,
    })
}

#[cfg(test)]
mod tests {
    use rusqlite::{Connection, params};

    use crate::indexing::migrations::run_migrations;
    use crate::query::search::ScoredRawSearchResult;
    use crate::search::input::{SearchInput, SearchMode};
    use crate::search::types::{RawSearchResult, SearchScores};

    use super::refine_graph_results;

    #[test]
    fn graph_refinement_adds_bounded_graph_only_candidate() -> Result<(), Box<dyn std::error::Error>>
    {
        let mut conn = Connection::open_in_memory()?;
        run_migrations(&mut conn)?;
        insert_graph_node(&conn, "Seed.md")?;
        insert_graph_node(&conn, "Neighbor.md")?;
        insert_graph_edge(&conn, "Seed.md", "Neighbor.md", 2)?;
        let mut scored = vec![scored("Seed.md", 0.9)];

        let input = SearchInput {
            related: true,
            ..SearchInput::default()
        };
        let diagnostics = refine_graph_results(&conn, &input, None, &mut scored)
            .ok_or_else(|| std::io::Error::other("graph refinement should report expansion"))?;

        assert_eq!(scored.len(), 2);
        assert_eq!(scored[1].raw.path, "Neighbor.md");
        assert_eq!(diagnostics.expanded_results, 1);
        assert_eq!(diagnostics.boosted_results, 0);
        assert!(diagnostics.score_contribution > 0.0);
        Ok(())
    }

    #[test]
    fn graph_refinement_boosts_existing_neighbor() -> Result<(), Box<dyn std::error::Error>> {
        let mut conn = Connection::open_in_memory()?;
        run_migrations(&mut conn)?;
        insert_graph_node(&conn, "Seed.md")?;
        insert_graph_node(&conn, "Neighbor.md")?;
        insert_graph_edge(&conn, "Seed.md", "Neighbor.md", 2)?;
        let mut scored = vec![scored("Seed.md", 0.9), scored("Neighbor.md", 0.2)];

        let input = SearchInput {
            related: true,
            ..SearchInput::default()
        };
        let diagnostics = refine_graph_results(&conn, &input, None, &mut scored)
            .ok_or_else(|| std::io::Error::other("graph refinement should report boost"))?;

        assert!(scored[1].raw.score > 0.2);
        assert_eq!(diagnostics.expanded_results, 0);
        assert_eq!(diagnostics.boosted_results, 1);
        assert!(diagnostics.score_contribution > 0.0);
        Ok(())
    }

    #[test]
    fn graph_refinement_uses_snapshot_in_fast_hybrid_mode() -> Result<(), Box<dyn std::error::Error>>
    {
        let mut conn = Connection::open_in_memory()?;
        run_migrations(&mut conn)?;
        insert_graph_node(&conn, "Seed.md")?;
        insert_graph_node(&conn, "Neighbor.md")?;
        insert_graph_edge(&conn, "Seed.md", "Neighbor.md", 2)?;
        let input = SearchInput {
            fast: true,
            mode: SearchMode::Hybrid,
            ..SearchInput::default()
        };
        let mut scored = vec![scored("Seed.md", 0.9)];

        let diagnostics = refine_graph_results(&conn, &input, None, &mut scored)
            .ok_or_else(|| std::io::Error::other("fast hybrid should still use persisted graph"))?;

        assert_eq!(scored.len(), 2);
        assert_eq!(diagnostics.expanded_results, 1);
        Ok(())
    }

    #[test]
    fn plain_hybrid_graph_refinement_uses_only_top_seed() -> Result<(), Box<dyn std::error::Error>>
    {
        let mut conn = Connection::open_in_memory()?;
        run_migrations(&mut conn)?;
        insert_graph_node(&conn, "Top.md")?;
        insert_graph_node(&conn, "Weak.md")?;
        insert_graph_node(&conn, "WeakNeighbor.md")?;
        insert_graph_edge(&conn, "Weak.md", "WeakNeighbor.md", 2)?;
        let input = SearchInput {
            mode: SearchMode::Hybrid,
            ..SearchInput::default()
        };
        let mut scored = vec![scored("Top.md", 0.9), scored("Weak.md", 0.8)];

        let diagnostics = refine_graph_results(&conn, &input, None, &mut scored);

        assert!(diagnostics.is_none());
        assert_eq!(scored.len(), 2);
        assert!(
            !scored
                .iter()
                .any(|result| result.raw.path == "WeakNeighbor.md")
        );
        Ok(())
    }

    #[test]
    fn graph_expansion_does_not_displace_strong_retrieval_hit()
    -> Result<(), Box<dyn std::error::Error>> {
        let mut conn = Connection::open_in_memory()?;
        run_migrations(&mut conn)?;
        insert_graph_node(&conn, "Strong.md")?;
        insert_graph_node(&conn, "Seed.md")?;
        insert_graph_node(&conn, "GraphOnly.md")?;
        insert_graph_edge(&conn, "Seed.md", "GraphOnly.md", 4)?;
        let mut scored = vec![scored("Strong.md", 0.95), scored("Seed.md", 0.9)];

        let input = SearchInput {
            related: true,
            ..SearchInput::default()
        };
        refine_graph_results(&conn, &input, None, &mut scored)
            .ok_or_else(|| std::io::Error::other("graph refinement should add candidate"))?;
        scored.sort_by(|a, b| b.raw.score.total_cmp(&a.raw.score));

        assert_eq!(scored[0].raw.path, "Strong.md");
        assert!(
            scored
                .iter()
                .any(|result| result.raw.path == "GraphOnly.md")
        );
        Ok(())
    }

    fn scored(path: &str, score: f64) -> ScoredRawSearchResult {
        ScoredRawSearchResult {
            raw: RawSearchResult {
                path: path.into(),
                title: path.into(),
                tags: Vec::new(),
                aliases: Vec::new(),
                snippet: String::new(),
                score,
                scores: SearchScores {
                    hybrid: Some(score),
                    ..SearchScores::default()
                },
                semantic_heading: None,
                semantic_char_start: None,
                semantic_char_end: None,
            },
            raw_score: score,
        }
    }

    fn insert_graph_node(conn: &Connection, path: &str) -> Result<(), rusqlite::Error> {
        conn.execute(
            "INSERT INTO graph_nodes (
               vault_path, title, aliases, tags, scope, note_type, sources,
               outgoing_degree, backlink_degree, total_degree, structural,
               community_id, community_cohesion, community_neighbor_count, bridge_weight
             ) VALUES (?1, ?1, '[]', '[]', '', NULL, '[]', 0, 0, 0, 0, NULL, 0.0, 0, 0.0)",
            params![path],
        )?;
        Ok(())
    }

    fn insert_graph_edge(
        conn: &Connection,
        from_path: &str,
        to_path: &str,
        weight: u32,
    ) -> Result<(), rusqlite::Error> {
        conn.execute(
            "INSERT INTO graph_edges (from_path, to_path, link_text, weight)
             VALUES (?1, ?2, ?2, ?3)",
            params![from_path, to_path, weight],
        )?;
        Ok(())
    }
}