Skip to main content

keel_core/
sqlite_queries.rs

1use rusqlite::params;
2
3use crate::sqlite::SqliteGraphStore;
4use crate::store::GraphStore;
5use crate::types::{
6    EdgeChange, EdgeDirection, EdgeKind, GraphEdge, GraphError, GraphNode, ModuleProfile,
7    NodeChange,
8};
9
10impl GraphStore for SqliteGraphStore {
11    fn get_node(&self, hash: &str) -> Option<GraphNode> {
12        // Try direct hash lookup first
13        let mut stmt = self
14            .conn
15            .prepare("SELECT * FROM nodes WHERE hash = ?1")
16            .ok()?;
17        if let Ok(node) = stmt.query_row(params![hash], Self::row_to_node) {
18            return Some(self.node_with_relations(node));
19        }
20
21        // Fall back to previous_hashes table for renamed/updated nodes
22        let mut prev_stmt = self
23            .conn
24            .prepare(
25                "SELECT n.* FROM nodes n
26                 JOIN previous_hashes ph ON ph.node_id = n.id
27                 WHERE ph.hash = ?1
28                 LIMIT 1",
29            )
30            .ok()?;
31        let node = prev_stmt.query_row(params![hash], Self::row_to_node).ok()?;
32        Some(self.node_with_relations(node))
33    }
34
35    fn get_node_by_id(&self, id: u64) -> Option<GraphNode> {
36        let mut stmt = self
37            .conn
38            .prepare("SELECT * FROM nodes WHERE id = ?1")
39            .ok()?;
40        let node = stmt.query_row(params![id], Self::row_to_node).ok()?;
41        Some(self.node_with_relations(node))
42    }
43
44    fn get_edges(&self, node_id: u64, direction: EdgeDirection) -> Vec<GraphEdge> {
45        let query = match direction {
46            EdgeDirection::Incoming => "SELECT * FROM edges WHERE target_id = ?1",
47            EdgeDirection::Outgoing => "SELECT * FROM edges WHERE source_id = ?1",
48            EdgeDirection::Both => "SELECT * FROM edges WHERE source_id = ?1 OR target_id = ?1",
49        };
50
51        let mut stmt = match self.conn.prepare(query) {
52            Ok(s) => s,
53            Err(e) => {
54                eprintln!("[keel] get_edges: prepare failed: {e}");
55                return Vec::new();
56            }
57        };
58        let result = match stmt.query_map(params![node_id], |row| {
59            let kind_str: String = row.get("kind")?;
60            let kind = match kind_str.as_str() {
61                "calls" => EdgeKind::Calls,
62                "imports" => EdgeKind::Imports,
63                "inherits" => EdgeKind::Inherits,
64                "contains" => EdgeKind::Contains,
65                _ => EdgeKind::Calls,
66            };
67            Ok(GraphEdge {
68                id: row.get("id")?,
69                source_id: row.get("source_id")?,
70                target_id: row.get("target_id")?,
71                kind,
72                file_path: row.get("file_path")?,
73                line: row.get("line")?,
74                confidence: row.get("confidence").unwrap_or(1.0),
75            })
76        }) {
77            Ok(rows) => rows.filter_map(|r| r.ok()).collect(),
78            Err(e) => {
79                eprintln!("[keel] get_edges: query failed: {e}");
80                Vec::new()
81            }
82        };
83        result
84    }
85
86    fn get_module_profile(&self, module_id: u64) -> Option<ModuleProfile> {
87        let mut stmt = self
88            .conn
89            .prepare("SELECT * FROM module_profiles WHERE module_id = ?1")
90            .ok()?;
91        stmt.query_row(params![module_id], |row| {
92            let prefixes: String = row.get("function_name_prefixes")?;
93            let types: String = row.get("primary_types")?;
94            let imports: String = row.get("import_sources")?;
95            let exports: String = row.get("export_targets")?;
96            let keywords: String = row.get("responsibility_keywords")?;
97            Ok(ModuleProfile {
98                module_id: row.get("module_id")?,
99                path: row.get("path")?,
100                function_count: row.get("function_count")?,
101                class_count: row.get("class_count")?,
102                line_count: row.get("line_count")?,
103                function_name_prefixes: serde_json::from_str(&prefixes).unwrap_or_default(),
104                primary_types: serde_json::from_str(&types).unwrap_or_default(),
105                import_sources: serde_json::from_str(&imports).unwrap_or_default(),
106                export_targets: serde_json::from_str(&exports).unwrap_or_default(),
107                external_endpoint_count: row.get("external_endpoint_count")?,
108                responsibility_keywords: serde_json::from_str(&keywords).unwrap_or_default(),
109            })
110        })
111        .ok()
112    }
113
114    fn get_nodes_in_file(&self, file_path: &str) -> Vec<GraphNode> {
115        let mut stmt = match self
116            .conn
117            .prepare("SELECT * FROM nodes WHERE file_path = ?1")
118        {
119            Ok(s) => s,
120            Err(e) => {
121                eprintln!("[keel] get_nodes_in_file: prepare failed: {e}");
122                return Vec::new();
123            }
124        };
125        let nodes: Vec<GraphNode> = match stmt.query_map(params![file_path], Self::row_to_node) {
126            Ok(rows) => rows.filter_map(|r| r.ok()).collect(),
127            Err(e) => {
128                eprintln!("[keel] get_nodes_in_file: query failed: {e}");
129                return Vec::new();
130            }
131        };
132        // Batch-load relations: 2 queries total instead of 2*N
133        self.nodes_with_relations_batch(nodes)
134    }
135
136    fn get_all_modules(&self) -> Vec<GraphNode> {
137        let mut stmt = match self
138            .conn
139            .prepare("SELECT * FROM nodes WHERE kind = 'module'")
140        {
141            Ok(s) => s,
142            Err(e) => {
143                eprintln!("[keel] get_all_modules: prepare failed: {e}");
144                return Vec::new();
145            }
146        };
147        let nodes: Vec<GraphNode> = match stmt.query_map([], Self::row_to_node) {
148            Ok(rows) => rows.filter_map(|r| r.ok()).collect(),
149            Err(e) => {
150                eprintln!("[keel] get_all_modules: query failed: {e}");
151                return Vec::new();
152            }
153        };
154        // Batch-load relations: 2 queries total instead of 2*N
155        self.nodes_with_relations_batch(nodes)
156    }
157
158    fn update_nodes(&mut self, changes: Vec<NodeChange>) -> Result<(), GraphError> {
159        let tx = self.conn.transaction()?;
160        for change in changes {
161            match change {
162                NodeChange::Add(node) => {
163                    // Check for hash collision (different function, same hash)
164                    let existing: Option<String> = tx
165                        .query_row(
166                            "SELECT name FROM nodes WHERE hash = ?1",
167                            params![node.hash],
168                            |row| row.get(0),
169                        )
170                        .ok();
171                    if let Some(existing_name) = existing {
172                        if existing_name != node.name {
173                            return Err(GraphError::HashCollision {
174                                hash: node.hash.clone(),
175                                existing: existing_name,
176                                new_fn: node.name.clone(),
177                            });
178                        }
179                    }
180                    // UPSERT to handle re-map without cascade-deleting related rows
181                    tx.execute(
182                        "INSERT INTO nodes (id, hash, kind, name, signature, file_path, line_start, line_end, docstring, is_public, type_hints_present, has_docstring, module_id, package)
183                         VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14)
184                         ON CONFLICT(hash) DO UPDATE SET
185                            kind = excluded.kind,
186                            name = excluded.name,
187                            signature = excluded.signature,
188                            file_path = excluded.file_path,
189                            line_start = excluded.line_start,
190                            line_end = excluded.line_end,
191                            docstring = excluded.docstring,
192                            is_public = excluded.is_public,
193                            type_hints_present = excluded.type_hints_present,
194                            has_docstring = excluded.has_docstring,
195                            module_id = excluded.module_id,
196                            package = excluded.package,
197                            updated_at = datetime('now')",
198                        params![
199                            node.id,
200                            node.hash,
201                            node.kind.as_str(),
202                            node.name,
203                            node.signature,
204                            node.file_path,
205                            node.line_start,
206                            node.line_end,
207                            node.docstring,
208                            node.is_public as i32,
209                            node.type_hints_present as i32,
210                            node.has_docstring as i32,
211                            if node.module_id == 0 { None } else { Some(node.module_id) },
212                            node.package,
213                        ],
214                    )?;
215                }
216                NodeChange::Update(node) => {
217                    // Check for hash collision (different node, same hash)
218                    let existing: Option<(u64, String)> = tx
219                        .query_row(
220                            "SELECT id, name FROM nodes WHERE hash = ?1",
221                            params![node.hash],
222                            |row| Ok((row.get(0)?, row.get(1)?)),
223                        )
224                        .ok();
225                    if let Some((existing_id, existing_name)) = existing {
226                        if existing_id != node.id {
227                            return Err(GraphError::HashCollision {
228                                hash: node.hash.clone(),
229                                existing: existing_name,
230                                new_fn: node.name.clone(),
231                            });
232                        }
233                    }
234                    tx.execute(
235                        "UPDATE nodes SET hash = ?1, kind = ?2, name = ?3, signature = ?4, file_path = ?5, line_start = ?6, line_end = ?7, docstring = ?8, is_public = ?9, type_hints_present = ?10, has_docstring = ?11, module_id = ?12, package = ?13, updated_at = datetime('now') WHERE id = ?14",
236                        params![
237                            node.hash,
238                            node.kind.as_str(),
239                            node.name,
240                            node.signature,
241                            node.file_path,
242                            node.line_start,
243                            node.line_end,
244                            node.docstring,
245                            node.is_public as i32,
246                            node.type_hints_present as i32,
247                            node.has_docstring as i32,
248                            if node.module_id == 0 { None } else { Some(node.module_id) },
249                            node.package,
250                            node.id,
251                        ],
252                    )?;
253                }
254                NodeChange::Remove(id) => {
255                    tx.execute("DELETE FROM nodes WHERE id = ?1", params![id])?;
256                }
257            }
258        }
259        tx.commit()?;
260        Ok(())
261    }
262
263    fn update_edges(&mut self, changes: Vec<EdgeChange>) -> Result<(), GraphError> {
264        let tx = self.conn.transaction()?;
265        for change in changes {
266            match change {
267                EdgeChange::Add(edge) => {
268                    // INSERT OR IGNORE handles UNIQUE constraint violations
269                    // (duplicate edges). FK violations are prevented by the caller
270                    // filtering edges to valid node IDs and disabling FK pragma.
271                    tx.execute(
272                        "INSERT OR IGNORE INTO edges (id, source_id, target_id, kind, file_path, line, confidence) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
273                        params![
274                            edge.id,
275                            edge.source_id,
276                            edge.target_id,
277                            edge.kind.as_str(),
278                            edge.file_path,
279                            edge.line,
280                            edge.confidence,
281                        ],
282                    )?;
283                }
284                EdgeChange::Remove(id) => {
285                    tx.execute("DELETE FROM edges WHERE id = ?1", params![id])?;
286                }
287            }
288        }
289        tx.commit()?;
290        Ok(())
291    }
292
293    fn get_previous_hashes(&self, node_id: u64) -> Vec<String> {
294        self.load_previous_hashes(node_id)
295    }
296
297    fn find_modules_by_prefix(&self, prefix: &str, exclude_file: &str) -> Vec<ModuleProfile> {
298        // Search module_profiles whose function_name_prefixes JSON array contains the prefix.
299        // The LIKE pattern matches the prefix as a quoted JSON string element.
300        let pattern = format!("%\"{}\"%", prefix);
301        let mut stmt = match self.conn.prepare(
302            "SELECT mp.* FROM module_profiles mp
303             JOIN nodes n ON n.id = mp.module_id
304             WHERE n.file_path != ?1
305             AND mp.function_name_prefixes LIKE ?2",
306        ) {
307            Ok(s) => s,
308            Err(e) => {
309                eprintln!("[keel] find_modules_by_prefix: prepare failed: {e}");
310                return Vec::new();
311            }
312        };
313        let result = match stmt.query_map(params![exclude_file, pattern], |row| {
314            let prefixes: String = row.get("function_name_prefixes")?;
315            let types: String = row.get("primary_types")?;
316            let imports: String = row.get("import_sources")?;
317            let exports: String = row.get("export_targets")?;
318            let keywords: String = row.get("responsibility_keywords")?;
319            Ok(ModuleProfile {
320                module_id: row.get("module_id")?,
321                path: row.get("path")?,
322                function_count: row.get("function_count")?,
323                class_count: row.get("class_count")?,
324                line_count: row.get("line_count")?,
325                function_name_prefixes: serde_json::from_str(&prefixes).unwrap_or_default(),
326                primary_types: serde_json::from_str(&types).unwrap_or_default(),
327                import_sources: serde_json::from_str(&imports).unwrap_or_default(),
328                export_targets: serde_json::from_str(&exports).unwrap_or_default(),
329                external_endpoint_count: row.get("external_endpoint_count")?,
330                responsibility_keywords: serde_json::from_str(&keywords).unwrap_or_default(),
331            })
332        }) {
333            Ok(rows) => rows.filter_map(|r| r.ok()).collect(),
334            Err(e) => {
335                eprintln!("[keel] find_modules_by_prefix: query failed: {e}");
336                Vec::new()
337            }
338        };
339        result
340    }
341
342    fn find_nodes_by_name(&self, name: &str, kind: &str, exclude_file: &str) -> Vec<GraphNode> {
343        // Empty kind/exclude_file act as wildcards (match any)
344        let sql = match (kind.is_empty(), exclude_file.is_empty()) {
345            (true, true) => "SELECT * FROM nodes WHERE name = ?1",
346            (true, false) => "SELECT * FROM nodes WHERE name = ?1 AND file_path != ?2",
347            (false, true) => "SELECT * FROM nodes WHERE name = ?1 AND kind = ?2",
348            (false, false) => {
349                "SELECT * FROM nodes WHERE name = ?1 AND kind = ?2 AND file_path != ?3"
350            }
351        };
352        let mut stmt = match self.conn.prepare(sql) {
353            Ok(s) => s,
354            Err(e) => {
355                eprintln!("[keel] find_nodes_by_name: prepare failed: {e}");
356                return Vec::new();
357            }
358        };
359        let result = match (kind.is_empty(), exclude_file.is_empty()) {
360            (true, true) => stmt.query_map(params![name], Self::row_to_node),
361            (true, false) => stmt.query_map(params![name, exclude_file], Self::row_to_node),
362            (false, true) => stmt.query_map(params![name, kind], Self::row_to_node),
363            (false, false) => stmt.query_map(params![name, kind, exclude_file], Self::row_to_node),
364        };
365        match result {
366            Ok(rows) => rows.filter_map(|r| r.ok()).collect(),
367            Err(e) => {
368                eprintln!("[keel] find_nodes_by_name: query failed: {e}");
369                Vec::new()
370            }
371        }
372    }
373}