Skip to main content

rag_rat_core/query/
graph_meta.rs

1use std::collections::BTreeSet;
2
3use rusqlite::{Connection, OptionalExtension, params};
4use serde::Serialize;
5
6use crate::{query::ReadChunk, search::lexical::SearchHit};
7
8const CALL_EDGE_KINDS: &[&str] = &["calls_name", "constructs", "uses_macro"];
9const FULL_GRAPH_NOTE: &str = "Call graph is tree-sitter/syntactic, not compiler-resolved.";
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum GraphMetaMode {
13    None,
14    Compact,
15    Full,
16}
17
18impl GraphMetaMode {
19    pub fn parse(value: &str) -> anyhow::Result<Self> {
20        match value {
21            "none" | "false" => Ok(Self::None),
22            "compact" | "true" => Ok(Self::Compact),
23            "full" => Ok(Self::Full),
24            other => anyhow::bail!(
25                "unknown graph metadata mode `{other}`; expected none, compact, or full"
26            ),
27        }
28    }
29}
30
31#[derive(Debug, Clone, Serialize)]
32pub struct GraphEvidence {
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub symbol: Option<GraphSymbol>,
35    pub caller_count: u64,
36    pub callee_count: u64,
37    #[serde(skip_serializing_if = "Vec::is_empty")]
38    pub top_callers: Vec<CallerEvidence>,
39    #[serde(skip_serializing_if = "Vec::is_empty")]
40    pub top_callees: Vec<CalleeEvidence>,
41    #[serde(skip_serializing_if = "Vec::is_empty")]
42    pub callers: Vec<CallerEvidence>,
43    #[serde(skip_serializing_if = "Vec::is_empty")]
44    pub callees: Vec<CalleeEvidence>,
45    #[serde(skip_serializing_if = "Vec::is_empty")]
46    pub imports: Vec<ImportEvidence>,
47    #[serde(skip_serializing_if = "Vec::is_empty")]
48    pub referenced_types: Vec<TypeEvidence>,
49    pub truncated: GraphTruncation,
50    #[serde(skip_serializing_if = "Vec::is_empty")]
51    pub notes: Vec<String>,
52}
53
54#[derive(Debug, Clone, Serialize)]
55pub struct GraphSymbol {
56    pub id: i64,
57    pub name: String,
58    pub qualified_name: String,
59    pub kind: String,
60    pub symbol_path: String,
61}
62
63#[derive(Debug, Clone, Serialize)]
64pub struct CallerEvidence {
65    pub symbol_path: String,
66    pub path: String,
67    pub line: i64,
68    pub callsite: CallsiteEvidence,
69    pub edge_kind: String,
70    pub confidence: String,
71}
72
73#[derive(Debug, Clone, Serialize)]
74pub struct CalleeEvidence {
75    pub target: String,
76    #[serde(skip_serializing_if = "Option::is_none")]
77    pub resolved_symbol_path: Option<String>,
78    #[serde(skip_serializing_if = "Option::is_none")]
79    pub path: Option<String>,
80    #[serde(skip_serializing_if = "Option::is_none")]
81    pub line: Option<i64>,
82    pub callsite: CallsiteEvidence,
83    pub edge_kind: String,
84    pub confidence: String,
85}
86
87#[derive(Debug, Clone, Serialize)]
88pub struct CallsiteEvidence {
89    pub path: String,
90    pub line: i64,
91    pub span: [i64; 2],
92}
93
94#[derive(Debug, Clone, Serialize)]
95pub struct ImportEvidence {
96    pub target: String,
97    pub confidence: String,
98}
99
100#[derive(Debug, Clone, Serialize)]
101pub struct TypeEvidence {
102    pub name: String,
103    pub confidence: String,
104}
105
106#[derive(Debug, Clone, Default, Serialize)]
107pub struct GraphTruncation {
108    pub callers: bool,
109    pub callees: bool,
110    #[serde(skip_serializing_if = "is_false")]
111    pub imports: bool,
112    #[serde(skip_serializing_if = "is_false")]
113    pub referenced_types: bool,
114}
115
116pub fn attach_to_search_hits(
117    conn: &Connection,
118    hits: &mut [SearchHit],
119    mode: GraphMetaMode,
120    limit: u32,
121) -> anyhow::Result<()> {
122    if mode == GraphMetaMode::None {
123        return Ok(());
124    }
125    let limit = limit.max(1);
126    for hit in hits {
127        hit.graph = evidence_for_chunk(conn, hit.chunk_id, mode, limit)?;
128    }
129    Ok(())
130}
131
132pub fn attach_to_read_chunk(
133    conn: &Connection,
134    chunk: &mut ReadChunk,
135    mode: GraphMetaMode,
136    limit: u32,
137) -> anyhow::Result<()> {
138    if mode == GraphMetaMode::None {
139        return Ok(());
140    }
141    chunk.graph = evidence_for_chunk(conn, chunk.chunk_id, mode, limit.max(1))?;
142    Ok(())
143}
144
145fn evidence_for_chunk(
146    conn: &Connection,
147    chunk_id: i64,
148    mode: GraphMetaMode,
149    limit: u32,
150) -> anyhow::Result<Option<GraphEvidence>> {
151    let Some(symbol) = primary_symbol(conn, chunk_id)? else {
152        return Ok(None);
153    };
154    let caller_count = count_callers(conn, &symbol)?;
155    let callee_count = count_callees(conn, symbol.id)?;
156    let mut evidence = GraphEvidence {
157        symbol: (mode == GraphMetaMode::Full).then(|| symbol.public.clone()),
158        caller_count,
159        callee_count,
160        top_callers: Vec::new(),
161        top_callees: Vec::new(),
162        callers: Vec::new(),
163        callees: Vec::new(),
164        imports: Vec::new(),
165        referenced_types: Vec::new(),
166        truncated: GraphTruncation::default(),
167        notes: Vec::new(),
168    };
169    let callers = callers(conn, &symbol, limit)?;
170    let callees = callees(conn, symbol.id, limit)?;
171    evidence.truncated.callers = caller_count > u64::try_from(callers.len()).unwrap_or(u64::MAX);
172    evidence.truncated.callees = callee_count > u64::try_from(callees.len()).unwrap_or(u64::MAX);
173    if mode == GraphMetaMode::Full {
174        evidence.callers = callers;
175        evidence.callees = callees;
176        evidence.imports = imports(conn, chunk_id, limit)?;
177        evidence.referenced_types = referenced_types(conn, symbol.id, limit)?;
178        evidence.truncated.imports =
179            count_imports(conn, chunk_id)? > u64::try_from(evidence.imports.len()).unwrap_or(0);
180        evidence.truncated.referenced_types = count_referenced_types(conn, symbol.id)?
181            > u64::try_from(evidence.referenced_types.len()).unwrap_or(0);
182        evidence.notes.push(FULL_GRAPH_NOTE.to_string());
183    } else {
184        evidence.top_callers = callers;
185        evidence.top_callees = callees;
186    }
187    Ok(Some(evidence))
188}
189
190#[derive(Debug, Clone)]
191struct PrimarySymbol {
192    id: i64,
193    name: String,
194    public: GraphSymbol,
195}
196
197fn primary_symbol(conn: &Connection, chunk_id: i64) -> anyhow::Result<Option<PrimarySymbol>> {
198    Ok(conn
199        .query_row(
200            "
201            SELECT symbols.id, symbols.name, symbols.qualified_name, symbols.kind, files.path
202            FROM chunks
203            JOIN symbols ON symbols.file_id = chunks.file_id
204             AND symbols.start_byte < chunks.end_byte
205             AND symbols.end_byte > chunks.start_byte
206            JOIN files ON files.id = symbols.file_id
207            WHERE chunks.id = ?1
208            ORDER BY
209              CASE symbols.kind
210                WHEN 'function' THEN 0
211                WHEN 'method' THEN 1
212                WHEN 'class' THEN 2
213                WHEN 'struct' THEN 3
214                ELSE 9
215              END,
216              symbols.start_byte ASC
217            LIMIT 1
218            ",
219            [chunk_id],
220            |row| {
221                let id = row.get(0)?;
222                let name: String = row.get(1)?;
223                let qualified_name: String = row.get(2)?;
224                let kind = row.get(3)?;
225                let path: String = row.get(4)?;
226                Ok(PrimarySymbol {
227                    id,
228                    name: name.clone(),
229                    public: GraphSymbol {
230                        id,
231                        name,
232                        qualified_name: qualified_name.clone(),
233                        kind,
234                        symbol_path: symbol_path(&path, &qualified_name),
235                    },
236                })
237            },
238        )
239        .optional()?)
240}
241
242fn count_callers(conn: &Connection, symbol: &PrimarySymbol) -> anyhow::Result<u64> {
243    let count = conn.query_row(
244        "
245        SELECT COUNT(DISTINCT COALESCE(from_symbol_id, -id))
246        FROM edges
247        WHERE edge_kind IN ('calls_name', 'constructs', 'uses_macro')
248          AND (to_symbol_id = ?1 OR (to_symbol_id IS NULL AND to_name = ?2))
249        ",
250        params![symbol.id, symbol.name],
251        |row| row.get::<_, i64>(0),
252    )?;
253    Ok(u64::try_from(count).unwrap_or(0))
254}
255
256fn count_callees(conn: &Connection, symbol_id: i64) -> anyhow::Result<u64> {
257    count_edges_for_symbol(conn, symbol_id, CALL_EDGE_KINDS)
258}
259
260fn count_imports(conn: &Connection, chunk_id: i64) -> anyhow::Result<u64> {
261    count_edges_for_chunk(conn, chunk_id, &["imports"])
262}
263
264fn count_referenced_types(conn: &Connection, symbol_id: i64) -> anyhow::Result<u64> {
265    count_edges_for_symbol(conn, symbol_id, &["references_type", "implements", "extends"])
266}
267
268fn count_edges_for_symbol(
269    conn: &Connection,
270    symbol_id: i64,
271    edge_kinds: &[&str],
272) -> anyhow::Result<u64> {
273    let count = conn.query_row(
274        &format!(
275            "
276        SELECT COUNT(DISTINCT COALESCE(CAST(to_symbol_id AS TEXT), to_name))
277        FROM edges
278            WHERE from_symbol_id = ?1
279              AND edge_kind IN ({})
280            ",
281            quoted(edge_kinds),
282        ),
283        [symbol_id],
284        |row| row.get::<_, i64>(0),
285    )?;
286    Ok(u64::try_from(count).unwrap_or(0))
287}
288
289fn count_edges_for_chunk(
290    conn: &Connection,
291    chunk_id: i64,
292    edge_kinds: &[&str],
293) -> anyhow::Result<u64> {
294    let count = conn.query_row(
295        &format!(
296            "
297            SELECT COUNT(*)
298            FROM edges
299            JOIN chunks ON chunks.file_id = edges.source_file_id
300            WHERE chunks.id = ?1
301              AND edges.from_symbol_id IS NULL
302              AND edges.edge_kind IN ({})
303            ",
304            quoted(edge_kinds),
305        ),
306        [chunk_id],
307        |row| row.get::<_, i64>(0),
308    )?;
309    Ok(u64::try_from(count).unwrap_or(0))
310}
311
312fn callers(
313    conn: &Connection,
314    symbol: &PrimarySymbol,
315    limit: u32,
316) -> anyhow::Result<Vec<CallerEvidence>> {
317    let mut stmt = conn.prepare(
318        "
319        SELECT DISTINCT
320               source_files.path,
321               COALESCE(source_symbols.qualified_name, edges.from_name, source_files.path),
322               COALESCE(NULLIF(edges.source_start_line, 0), source_chunks.start_line, 1),
323               COALESCE(NULLIF(edges.source_end_line, 0), NULLIF(edges.source_start_line, 0), source_chunks.start_line, 1),
324               edges.edge_kind,
325               edges.confidence
326        FROM edges
327        JOIN files source_files ON source_files.id = edges.source_file_id
328        LEFT JOIN symbols source_symbols ON source_symbols.id = edges.from_symbol_id
329        LEFT JOIN chunks source_chunks ON source_chunks.file_id = edges.source_file_id
330          AND source_symbols.start_byte >= source_chunks.start_byte
331          AND source_symbols.start_byte < source_chunks.end_byte
332        WHERE edges.edge_kind IN ('calls_name', 'constructs', 'uses_macro')
333          AND (edges.to_symbol_id = ?1 OR (edges.to_symbol_id IS NULL AND edges.to_name = ?2))
334        ORDER BY
335          CASE edges.confidence
336            WHEN 'Exact' THEN 0
337            WHEN 'Syntactic' THEN 1
338            WHEN 'NameOnly' THEN 2
339            ELSE 3
340          END,
341          source_files.path,
342          source_chunks.start_line
343        LIMIT ?3
344        ",
345    )?;
346    let rows = stmt.query_map(params![symbol.id, symbol.name, expanded_limit(limit)], |row| {
347        let path: String = row.get(0)?;
348        let qualified_name: String = row.get(1)?;
349        let source_start_line = row.get(2)?;
350        let source_end_line = row.get(3)?;
351        Ok(CallerEvidence {
352            symbol_path: symbol_path(&path, &qualified_name),
353            path: path.clone(),
354            line: source_start_line,
355            callsite: CallsiteEvidence {
356                path,
357                line: source_start_line,
358                span: [source_start_line, source_end_line],
359            },
360            edge_kind: row.get(4)?,
361            confidence: confidence(row.get::<_, String>(5)?.as_str()).to_string(),
362        })
363    })?;
364    let mut seen = BTreeSet::new();
365    let mut callers = collect_rows(rows)?
366        .into_iter()
367        .filter(|caller| seen.insert((caller.symbol_path.clone(), caller.edge_kind.clone())))
368        .collect::<Vec<_>>();
369    callers.truncate(usize::try_from(limit).unwrap_or(usize::MAX));
370    Ok(callers)
371}
372
373fn callees(conn: &Connection, symbol_id: i64, limit: u32) -> anyhow::Result<Vec<CalleeEvidence>> {
374    let mut stmt = conn.prepare(
375        "
376        SELECT DISTINCT
377               edges.to_name,
378               target_files.path,
379               target_symbols.qualified_name,
380               COALESCE(edges.target_start_line, target_chunks.start_line),
381               source_files.path,
382               COALESCE(NULLIF(edges.source_start_line, 0), source_chunks.start_line, 1),
383               COALESCE(NULLIF(edges.source_end_line, 0), NULLIF(edges.source_start_line, 0), source_chunks.start_line, 1),
384               edges.edge_kind,
385               edges.confidence
386        FROM edges
387        JOIN files source_files ON source_files.id = edges.source_file_id
388        LEFT JOIN symbols target_symbols ON target_symbols.id = edges.to_symbol_id
389        LEFT JOIN files target_files ON target_files.id = target_symbols.file_id
390        LEFT JOIN chunks target_chunks ON target_chunks.file_id = target_symbols.file_id
391          AND target_symbols.start_byte >= target_chunks.start_byte
392          AND target_symbols.start_byte < target_chunks.end_byte
393        LEFT JOIN symbols source_symbols ON source_symbols.id = edges.from_symbol_id
394        LEFT JOIN chunks source_chunks ON source_chunks.file_id = edges.source_file_id
395          AND source_symbols.start_byte >= source_chunks.start_byte
396          AND source_symbols.start_byte < source_chunks.end_byte
397        WHERE edges.from_symbol_id = ?1
398          AND edges.edge_kind IN ('calls_name', 'constructs', 'uses_macro')
399        ORDER BY
400          CASE edges.confidence
401            WHEN 'Exact' THEN 0
402            WHEN 'Syntactic' THEN 1
403            WHEN 'NameOnly' THEN 2
404            ELSE 3
405          END,
406          source_chunks.start_line,
407          edges.to_name
408        LIMIT ?2
409        ",
410    )?;
411    let rows = stmt.query_map(params![symbol_id, expanded_limit(limit)], |row| {
412        let target: String = row.get(0)?;
413        let path: Option<String> = row.get(1)?;
414        let qualified_name: Option<String> = row.get(2)?;
415        let callsite_path: String = row.get(4)?;
416        let callsite_start_line = row.get(5)?;
417        let callsite_end_line = row.get(6)?;
418        Ok(CalleeEvidence {
419            target,
420            resolved_symbol_path: path
421                .as_ref()
422                .zip(qualified_name.as_ref())
423                .map(|(path, qualified_name)| symbol_path(path, qualified_name)),
424            path,
425            line: row.get(3)?,
426            callsite: CallsiteEvidence {
427                path: callsite_path,
428                line: callsite_start_line,
429                span: [callsite_start_line, callsite_end_line],
430            },
431            edge_kind: row.get(7)?,
432            confidence: confidence(row.get::<_, String>(8)?.as_str()).to_string(),
433        })
434    })?;
435    let mut seen = BTreeSet::new();
436    let mut callees = collect_rows(rows)?
437        .into_iter()
438        .filter(|callee| {
439            seen.insert((
440                callee.target.clone(),
441                callee.resolved_symbol_path.clone(),
442                callee.edge_kind.clone(),
443            ))
444        })
445        .collect::<Vec<_>>();
446    callees.truncate(usize::try_from(limit).unwrap_or(usize::MAX));
447    Ok(callees)
448}
449
450fn imports(conn: &Connection, chunk_id: i64, limit: u32) -> anyhow::Result<Vec<ImportEvidence>> {
451    let mut stmt = conn.prepare(
452        "
453        SELECT edges.to_name, edges.confidence
454        FROM edges
455        JOIN chunks ON chunks.file_id = edges.source_file_id
456        WHERE chunks.id = ?1
457          AND edges.from_symbol_id IS NULL
458          AND edges.edge_kind = 'imports'
459        ORDER BY edges.to_name
460        LIMIT ?2
461        ",
462    )?;
463    let rows = stmt.query_map(params![chunk_id, i64::from(limit)], |row| {
464        Ok(ImportEvidence {
465            target: row.get(0)?,
466            confidence: confidence(row.get::<_, String>(1)?.as_str()).to_string(),
467        })
468    })?;
469    collect_rows(rows)
470}
471
472fn referenced_types(
473    conn: &Connection,
474    symbol_id: i64,
475    limit: u32,
476) -> anyhow::Result<Vec<TypeEvidence>> {
477    let mut stmt = conn.prepare(
478        "
479        SELECT DISTINCT edges.to_name, edges.confidence
480        FROM edges
481        WHERE edges.from_symbol_id = ?1
482          AND edges.edge_kind IN ('references_type', 'implements', 'extends')
483        ORDER BY
484          CASE edges.confidence
485            WHEN 'Exact' THEN 0
486            WHEN 'Syntactic' THEN 1
487            WHEN 'NameOnly' THEN 2
488            ELSE 3
489          END,
490          edges.to_name
491        LIMIT ?2
492        ",
493    )?;
494    let rows = stmt.query_map(params![symbol_id, i64::from(limit)], |row| {
495        Ok(TypeEvidence {
496            name: row.get(0)?,
497            confidence: confidence(row.get::<_, String>(1)?.as_str()).to_string(),
498        })
499    })?;
500    collect_rows(rows)
501}
502
503fn symbol_path(path: &str, qualified_name: &str) -> String {
504    if qualified_name == path || qualified_name.starts_with(&format!("{path}::")) {
505        return qualified_name.to_string();
506    }
507    format!("{path}::{qualified_name}")
508}
509
510fn confidence(value: &str) -> &'static str {
511    match value {
512        "Exact" => "exact",
513        "Syntactic" => "syntactic",
514        "NameOnly" => "name_only",
515        "Ambiguous" => "ambiguous",
516        _ => "name_only",
517    }
518}
519
520fn quoted(values: &[&str]) -> String {
521    values.iter().map(|value| format!("'{value}'")).collect::<Vec<_>>().join(", ")
522}
523
524fn expanded_limit(limit: u32) -> i64 {
525    i64::from(limit.max(1)).saturating_mul(4)
526}
527
528fn is_false(value: &bool) -> bool {
529    !*value
530}
531
532fn collect_rows<T>(
533    rows: rusqlite::MappedRows<'_, impl FnMut(&rusqlite::Row<'_>) -> rusqlite::Result<T>>,
534) -> anyhow::Result<Vec<T>> {
535    let mut out = Vec::new();
536    for row in rows {
537        out.push(row?);
538    }
539    Ok(out)
540}