neurographrag 1.2.0

Local GraphRAG memory for LLMs in a single SQLite file
Documentation
use crate::cli::MemoryType;
use crate::errors::AppError;
use crate::output::{self, OutputFormat, RecallItem};
use crate::paths::AppPaths;
use crate::storage::connection::open_ro;
use crate::storage::memories;

use std::collections::HashMap;

#[derive(clap::Args)]
pub struct HybridSearchArgs {
    pub query: String,
    #[arg(short = 'k', long, default_value = "10")]
    pub k: usize,
    #[arg(long, default_value = "60")]
    pub rrf_k: u32,
    #[arg(long, default_value = "1.0")]
    pub weight_vec: f32,
    #[arg(long, default_value = "1.0")]
    pub weight_fts: f32,
    #[arg(long, value_enum)]
    pub r#type: Option<MemoryType>,
    #[arg(long)]
    pub namespace: Option<String>,
    #[arg(long)]
    pub with_graph: bool,
    #[arg(long, default_value = "2")]
    pub max_hops: u32,
    #[arg(long, default_value = "0.3")]
    pub min_weight: f64,
    #[arg(long, value_enum, default_value = "json")]
    pub format: OutputFormat,
    #[arg(long, env = "NEUROGRAPHRAG_DB_PATH")]
    pub db: Option<String>,
}

#[derive(serde::Serialize)]
pub struct HybridSearchResponse {
    pub query: String,
    pub combined_rank: Vec<RecallItem>,
    pub vec_rank: Vec<RecallItem>,
    pub fts_rank: Vec<RecallItem>,
}

pub fn run(args: HybridSearchArgs) -> Result<(), AppError> {
    let namespace = crate::namespace::resolve_namespace(args.namespace.as_deref())?;
    let paths = AppPaths::resolve(args.db.as_deref())?;

    output::emit_progress("Computing query embedding...");
    let embedder = crate::embedder::get_embedder(&paths.models)?;
    let embedding = crate::embedder::embed_query(embedder, &args.query)?;

    let conn = open_ro(&paths.db)?;

    let memory_type_str = args.r#type.map(|t| t.as_str());

    let vec_results =
        memories::knn_search(&conn, &embedding, &namespace, memory_type_str, args.k * 2)?;

    let mut vec_rank = Vec::new();
    for (memory_id, distance) in vec_results.iter() {
        if let Some(row) = memories::read_full(&conn, *memory_id)? {
            let snippet: String = row.body.chars().take(300).collect();
            vec_rank.push(RecallItem {
                memory_id: row.id,
                name: row.name,
                namespace: row.namespace,
                memory_type: row.memory_type,
                description: row.description,
                snippet,
                distance: *distance,
                source: "vector".to_string(),
            });
        }
    }

    let fts_results =
        memories::fts_search(&conn, &args.query, &namespace, memory_type_str, args.k * 2)?;

    let mut fts_rank = Vec::new();
    for row in fts_results.iter() {
        let snippet: String = row.body.chars().take(300).collect();
        fts_rank.push(RecallItem {
            memory_id: row.id,
            name: row.name.clone(),
            namespace: row.namespace.clone(),
            memory_type: row.memory_type.clone(),
            description: row.description.clone(),
            snippet,
            distance: 0.0,
            source: "fts".to_string(),
        });
    }

    let mut combined_scores: HashMap<i64, (f32, RecallItem)> = HashMap::new();
    let rrf_k = args.rrf_k as f32;

    for (rank, item) in vec_rank.iter().enumerate() {
        let score = args.weight_vec * (1.0 / (rrf_k + rank as f32 + 1.0));
        combined_scores
            .entry(item.memory_id)
            .or_insert_with(|| (score, item.clone()))
            .0 += score;
    }

    for (rank, item) in fts_rank.iter().enumerate() {
        let score = args.weight_fts * (1.0 / (rrf_k + rank as f32 + 1.0));
        let entry = combined_scores
            .entry(item.memory_id)
            .or_insert_with(|| (0.0, item.clone()));
        entry.0 += score;
    }

    let mut combined: Vec<_> = combined_scores.into_iter().collect();
    combined.sort_by(|a, b| {
        b.1 .0
            .partial_cmp(&a.1 .0)
            .unwrap_or(std::cmp::Ordering::Equal)
    });
    let combined_rank: Vec<RecallItem> = combined
        .into_iter()
        .take(args.k)
        .map(|(_, (_, item))| item)
        .collect();

    output::emit_json(&HybridSearchResponse {
        query: args.query,
        combined_rank,
        vec_rank,
        fts_rank,
    })?;

    Ok(())
}