rag-rat-core 0.2.0

Repository evidence engine for source chunks, symbols, graph edges, Git history, GitHub rationale, and source-bound memories.
Documentation
use rusqlite::{Connection, OptionalExtension, params};
use serde::Serialize;

use crate::language::Language;

#[derive(Debug, Serialize)]
pub struct SymbolHit {
    pub symbol_id: i64,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub logical_symbol_id: Option<i64>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub logical_variant_count: Option<u64>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub logical_group_reason: Option<String>,
    pub file_id: i64,
    pub path: String,
    pub file_kind: String,
    pub language: String,
    pub name: String,
    pub qualified_name: String,
    pub symbol_path: String,
    pub kind: String,
    pub start_byte: i64,
    pub end_byte: i64,
    pub signature: Option<String>,
    pub docs: Option<String>,
}

#[derive(Debug, Serialize)]
pub struct SymbolLookup {
    pub candidates: Vec<SymbolHit>,
    pub disambiguation_required: bool,
}

#[derive(Debug, Clone, Serialize)]
pub struct LogicalSymbolHit {
    pub logical_symbol_id: i64,
    pub language: String,
    pub path: String,
    pub logical_name: String,
    pub qualified_name: String,
    pub kind: String,
    pub variant_count: u64,
    pub group_reason: String,
}

#[derive(Debug, Clone, Serialize)]
pub struct LogicalSymbolMember {
    pub symbol_id: i64,
    pub cfg_expr: Option<String>,
    pub signature_hash: Option<String>,
    pub start_line: i64,
    pub end_line: i64,
}

#[derive(Debug, Clone)]
pub struct SymbolSelector {
    pub logical_symbol_id: Option<i64>,
    pub symbol_id: Option<i64>,
    pub symbol_path: Option<String>,
    pub symbol: Option<String>,
    pub language: Option<Language>,
    pub allow_ambiguous: bool,
    pub limit: u32,
}

#[derive(Debug, Serialize)]
pub struct SymbolDisambiguation {
    pub candidates: Vec<SymbolHit>,
    pub disambiguation_required: bool,
}

pub fn lookup(
    conn: &Connection,
    name: &str,
    language: Option<Language>,
    limit: u32,
) -> anyhow::Result<Vec<SymbolHit>> {
    let mut hits = lookup_name(conn, name, language, limit)?;
    enrich_symbol_hits(conn, &mut hits)?;
    Ok(hits)
}

pub fn lookup_candidates(
    conn: &Connection,
    selector: &SymbolSelector,
) -> anyhow::Result<SymbolLookup> {
    let candidates = candidates_for_selector(conn, selector)?;
    Ok(SymbolLookup {
        disambiguation_required: needs_disambiguation(&candidates, selector.allow_ambiguous),
        candidates,
    })
}

pub fn select_one(
    conn: &Connection,
    selector: &SymbolSelector,
) -> anyhow::Result<Result<Option<SymbolHit>, SymbolDisambiguation>> {
    let mut candidates = candidates_for_selector(conn, selector)?;
    if candidates.is_empty() {
        return Ok(Ok(None));
    }
    if selector.logical_symbol_id.is_some() {
        return Ok(Ok(Some(candidates.remove(0))));
    }
    if needs_disambiguation(&candidates, selector.allow_ambiguous) {
        return Ok(Err(SymbolDisambiguation { candidates, disambiguation_required: true }));
    }
    Ok(Ok(Some(candidates.remove(0))))
}

pub fn lookup_by_id(conn: &Connection, symbol_id: i64) -> anyhow::Result<Option<SymbolHit>> {
    let mut hit = conn
        .query_row(
        "
        SELECT symbols.id, files.id, files.path, files.kind, symbols.language, symbols.name, symbols.qualified_name,
               symbols.kind, symbols.start_byte, symbols.end_byte, symbols.signature, symbols.docs
        FROM symbols
        JOIN files ON files.id = symbols.file_id
        WHERE symbols.id = ?1
        ",
        [symbol_id],
        symbol_hit_row,
    )
        .optional()?;
    if let Some(hit) = hit.as_mut() {
        enrich_symbol_hit(conn, hit)?;
    }
    Ok(hit)
}

fn candidates_for_selector(
    conn: &Connection,
    selector: &SymbolSelector,
) -> anyhow::Result<Vec<SymbolHit>> {
    if let Some(logical_symbol_id) = selector.logical_symbol_id {
        return lookup_logical_members(conn, logical_symbol_id, selector.limit);
    }
    if let Some(symbol_id) = selector.symbol_id {
        return Ok(lookup_by_id(conn, symbol_id)?.into_iter().collect());
    }
    if let Some(symbol_path) = selector.symbol_path.as_deref() {
        let mut hits = lookup_symbol_path(conn, symbol_path, selector.language, selector.limit)?;
        enrich_symbol_hits(conn, &mut hits)?;
        return Ok(hits);
    }
    let Some(symbol) = selector.symbol.as_deref() else {
        anyhow::bail!("one of symbol_id, symbol_path, or symbol is required");
    };
    let mut hits = lookup_name(conn, symbol, selector.language, selector.limit)?;
    enrich_symbol_hits(conn, &mut hits)?;
    Ok(hits)
}

fn lookup_name(
    conn: &Connection,
    name: &str,
    language: Option<Language>,
    limit: u32,
) -> anyhow::Result<Vec<SymbolHit>> {
    let mut sql = "
        SELECT symbols.id, files.id, files.path, files.kind, symbols.language, symbols.name, symbols.qualified_name,
               symbols.kind, symbols.start_byte, symbols.end_byte, symbols.signature, symbols.docs
        FROM symbols
        JOIN files ON files.id = symbols.file_id
        WHERE (symbols.name = ?1 OR symbols.qualified_name LIKE ?2)
    "
    .to_string();
    if language.is_some() {
        sql.push_str(" AND symbols.language = ?3");
    }
    sql.push_str(
        "
        ORDER BY
          CASE WHEN symbols.name = ?1 THEN 0 ELSE 1 END,
          CASE symbols.kind
            WHEN 'struct' THEN 0
            WHEN 'class' THEN 0
            WHEN 'object' THEN 0
            WHEN 'enum' THEN 1
            WHEN 'trait' THEN 1
            WHEN 'interface' THEN 1
            WHEN 'type' THEN 1
            WHEN 'function' THEN 2
            WHEN 'method' THEN 2
            WHEN 'const' THEN 3
            WHEN 'property' THEN 3
            WHEN 'static' THEN 3
            WHEN 'impl' THEN 8
            ELSE 9
          END,
          files.path,
          symbols.start_byte
        LIMIT ?
        ",
    );

    let fuzzy = format!("%{name}%");
    let mut stmt = conn.prepare(&sql)?;
    let rows = if let Some(language) = language {
        stmt.query_map(params![name, fuzzy, language.as_str(), limit], symbol_hit_row)?
    } else {
        stmt.query_map(params![name, fuzzy, limit], symbol_hit_row)?
    };

    let mut hits = Vec::new();
    for row in rows {
        hits.push(row?);
    }
    enrich_symbol_hits(conn, &mut hits)?;
    Ok(hits)
}

pub fn lookup_logical_by_id(
    conn: &Connection,
    logical_symbol_id: i64,
) -> anyhow::Result<Option<LogicalSymbolHit>> {
    conn.query_row(
        "
        SELECT id, language, path, logical_name, qualified_name, kind, variant_count, group_reason
        FROM logical_symbols
        WHERE id = ?1
        ",
        [logical_symbol_id],
        logical_symbol_hit_row,
    )
    .optional()
    .map_err(Into::into)
}

pub fn logical_for_symbol_id(
    conn: &Connection,
    symbol_id: i64,
) -> anyhow::Result<Option<LogicalSymbolHit>> {
    conn.query_row(
        "
        SELECT logical_symbols.id, logical_symbols.language, logical_symbols.path,
               logical_symbols.logical_name, logical_symbols.qualified_name, logical_symbols.kind,
               logical_symbols.variant_count, logical_symbols.group_reason
        FROM logical_symbol_members
        JOIN logical_symbols ON logical_symbols.id = logical_symbol_members.logical_symbol_id
        WHERE logical_symbol_members.symbol_id = ?1
        ",
        [symbol_id],
        logical_symbol_hit_row,
    )
    .optional()
    .map_err(Into::into)
}

pub fn logical_members(
    conn: &Connection,
    logical_symbol_id: i64,
) -> anyhow::Result<Vec<LogicalSymbolMember>> {
    let mut stmt = conn.prepare(
        "
        SELECT symbol_id, cfg_expr, signature_hash, start_line, end_line
        FROM logical_symbol_members
        WHERE logical_symbol_id = ?1
        ORDER BY start_line, symbol_id
        ",
    )?;
    let rows = stmt.query_map([logical_symbol_id], |row| {
        Ok(LogicalSymbolMember {
            symbol_id: row.get(0)?,
            cfg_expr: row.get(1)?,
            signature_hash: row.get(2)?,
            start_line: row.get(3)?,
            end_line: row.get(4)?,
        })
    })?;
    let mut members = Vec::new();
    for row in rows {
        members.push(row?);
    }
    Ok(members)
}

fn lookup_logical_members(
    conn: &Connection,
    logical_symbol_id: i64,
    limit: u32,
) -> anyhow::Result<Vec<SymbolHit>> {
    let mut stmt = conn.prepare(
        "
        SELECT symbols.id, files.id, files.path, files.kind, symbols.language, symbols.name,
               symbols.qualified_name, symbols.kind, symbols.start_byte, symbols.end_byte,
               symbols.signature, symbols.docs
        FROM logical_symbol_members
        JOIN symbols ON symbols.id = logical_symbol_members.symbol_id
        JOIN files ON files.id = symbols.file_id
        WHERE logical_symbol_members.logical_symbol_id = ?1
        ORDER BY symbols.start_byte, symbols.id
        LIMIT ?2
        ",
    )?;
    let rows = stmt.query_map(params![logical_symbol_id, limit], symbol_hit_row)?;
    let mut hits = Vec::new();
    for row in rows {
        hits.push(row?);
    }
    Ok(hits)
}

fn lookup_symbol_path(
    conn: &Connection,
    symbol_path: &str,
    language: Option<Language>,
    limit: u32,
) -> anyhow::Result<Vec<SymbolHit>> {
    let mut sql = "
        SELECT symbols.id, files.id, files.path, files.kind, symbols.language, symbols.name, symbols.qualified_name,
               symbols.kind, symbols.start_byte, symbols.end_byte, symbols.signature, symbols.docs
        FROM symbols
        JOIN files ON files.id = symbols.file_id
        WHERE symbols.qualified_name = ?1
    "
    .to_string();
    if language.is_some() {
        sql.push_str(" AND symbols.language = ?2");
    }
    sql.push_str(" ORDER BY files.path, symbols.start_byte LIMIT ?");

    let mut stmt = conn.prepare(&sql)?;
    let rows = if let Some(language) = language {
        stmt.query_map(params![symbol_path, language.as_str(), limit], symbol_hit_row)?
    } else {
        stmt.query_map(params![symbol_path, limit], symbol_hit_row)?
    };

    let mut hits = Vec::new();
    for row in rows {
        hits.push(row?);
    }
    Ok(hits)
}

fn needs_disambiguation(candidates: &[SymbolHit], allow_ambiguous: bool) -> bool {
    !allow_ambiguous && candidates.len() > 1
}

fn symbol_hit_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<SymbolHit> {
    let qualified_name = row.get(6)?;
    Ok(SymbolHit {
        symbol_id: row.get(0)?,
        logical_symbol_id: None,
        logical_variant_count: None,
        logical_group_reason: None,
        file_id: row.get(1)?,
        path: row.get(2)?,
        file_kind: row.get(3)?,
        language: row.get(4)?,
        name: row.get(5)?,
        symbol_path: qualified_name,
        qualified_name: row.get(6)?,
        kind: row.get(7)?,
        start_byte: row.get(8)?,
        end_byte: row.get(9)?,
        signature: row.get(10)?,
        docs: row.get(11)?,
    })
}

fn logical_symbol_hit_row(row: &rusqlite::Row<'_>) -> rusqlite::Result<LogicalSymbolHit> {
    let variant_count = u64::try_from(row.get::<_, i64>(6)?).unwrap_or(0);
    Ok(LogicalSymbolHit {
        logical_symbol_id: row.get(0)?,
        language: row.get(1)?,
        path: row.get(2)?,
        logical_name: row.get(3)?,
        qualified_name: row.get(4)?,
        kind: row.get(5)?,
        variant_count,
        group_reason: row.get(7)?,
    })
}

fn enrich_symbol_hits(conn: &Connection, hits: &mut [SymbolHit]) -> anyhow::Result<()> {
    for hit in hits {
        enrich_symbol_hit(conn, hit)?;
    }
    Ok(())
}

fn enrich_symbol_hit(conn: &Connection, hit: &mut SymbolHit) -> anyhow::Result<()> {
    if let Some(logical) = logical_for_symbol_id(conn, hit.symbol_id)? {
        hit.logical_symbol_id = Some(logical.logical_symbol_id);
        hit.logical_variant_count = Some(logical.variant_count);
        hit.logical_group_reason = Some(logical.group_reason);
    }
    Ok(())
}