Skip to main content

dk_engine/graph/
callgraph.rs

1use dk_core::{CallEdge, CallKind, RepoId, SymbolId};
2use sqlx::postgres::PgPool;
3use uuid::Uuid;
4
5/// Intermediate row type for mapping between database rows and `CallEdge`.
6#[derive(sqlx::FromRow)]
7struct CallEdgeRow {
8    id: Uuid,
9    repo_id: Uuid,
10    caller_id: Uuid,
11    callee_id: Uuid,
12    kind: String,
13}
14
15impl CallEdgeRow {
16    fn into_call_edge(self) -> CallEdge {
17        CallEdge {
18            id: self.id,
19            repo_id: self.repo_id,
20            caller: self.caller_id,
21            callee: self.callee_id,
22            kind: parse_call_kind(&self.kind),
23        }
24    }
25}
26
27fn parse_call_kind(s: &str) -> CallKind {
28    match s {
29        "direct_call" => CallKind::DirectCall,
30        "method_call" => CallKind::MethodCall,
31        "import" => CallKind::Import,
32        "implements" => CallKind::Implements,
33        "inherits" => CallKind::Inherits,
34        "macro_invocation" => CallKind::MacroInvocation,
35        other => {
36            tracing::warn!("Unknown call kind: {other}, defaulting to DirectCall");
37            CallKind::DirectCall
38        }
39    }
40}
41
42/// PostgreSQL-backed store for call graph edges.
43#[derive(Clone)]
44pub struct CallGraphStore {
45    pool: PgPool,
46}
47
48impl CallGraphStore {
49    /// Create a new `CallGraphStore` backed by the given connection pool.
50    pub fn new(pool: PgPool) -> Self {
51        Self { pool }
52    }
53
54    /// Insert a call edge. Uses `ON CONFLICT DO NOTHING` so repeated
55    /// insertion of the same edge is idempotent.
56    pub async fn insert_edge(&self, edge: &CallEdge) -> dk_core::Result<()> {
57        let kind_str = edge.kind.to_string();
58
59        sqlx::query(
60            r#"
61            INSERT INTO call_edges (id, repo_id, caller_id, callee_id, kind)
62            VALUES ($1, $2, $3, $4, $5)
63            ON CONFLICT (repo_id, caller_id, callee_id, kind) DO NOTHING
64            "#,
65        )
66        .bind(edge.id)
67        .bind(edge.repo_id)
68        .bind(edge.caller)
69        .bind(edge.callee)
70        .bind(&kind_str)
71        .execute(&self.pool)
72        .await?;
73
74        Ok(())
75    }
76
77    /// Find all edges where the given symbol is the callee (i.e. who calls this symbol).
78    pub async fn find_callers(&self, symbol_id: SymbolId) -> dk_core::Result<Vec<CallEdge>> {
79        let rows = sqlx::query_as::<_, CallEdgeRow>(
80            r#"
81            SELECT id, repo_id, caller_id, callee_id, kind
82            FROM call_edges
83            WHERE callee_id = $1
84            ORDER BY caller_id
85            "#,
86        )
87        .bind(symbol_id)
88        .fetch_all(&self.pool)
89        .await?;
90
91        Ok(rows.into_iter().map(CallEdgeRow::into_call_edge).collect())
92    }
93
94    /// Find all edges where the given symbol is the caller (i.e. what does this symbol call).
95    pub async fn find_callees(&self, symbol_id: SymbolId) -> dk_core::Result<Vec<CallEdge>> {
96        let rows = sqlx::query_as::<_, CallEdgeRow>(
97            r#"
98            SELECT id, repo_id, caller_id, callee_id, kind
99            FROM call_edges
100            WHERE caller_id = $1
101            ORDER BY callee_id
102            "#,
103        )
104        .bind(symbol_id)
105        .fetch_all(&self.pool)
106        .await?;
107
108        Ok(rows.into_iter().map(CallEdgeRow::into_call_edge).collect())
109    }
110
111    /// Delete all call edges where any involved symbol is in the given file.
112    /// This deletes edges where the file's symbols appear as either caller OR
113    /// callee, which is required before deleting the symbols themselves —
114    /// otherwise the `call_edges_callee_id_fkey` FK constraint blocks the
115    /// symbol deletion.
116    /// Returns the total number of rows deleted.
117    pub async fn delete_edges_for_file(
118        &self,
119        repo_id: RepoId,
120        file_path: &str,
121    ) -> dk_core::Result<u64> {
122        let result = sqlx::query(
123            r#"
124            DELETE FROM call_edges ce
125            USING symbols s
126            WHERE (ce.caller_id = s.id OR ce.callee_id = s.id)
127              AND s.repo_id = $1
128              AND s.file_path = $2
129            "#,
130        )
131        .bind(repo_id)
132        .bind(file_path)
133        .execute(&self.pool)
134        .await?;
135
136        Ok(result.rows_affected())
137    }
138}