Skip to main content

keel_core/
sqlite_queries.rs

1use rusqlite::params;
2
3use crate::sqlite::SqliteGraphStore;
4use crate::store::GraphStore;
5use crate::types::{
6    EdgeChange, EdgeDirection, EdgeKind, GraphEdge, GraphError, GraphNode, ModuleProfile,
7    NodeChange,
8};
9
10impl GraphStore for SqliteGraphStore {
11    fn get_node(&self, hash: &str) -> Option<GraphNode> {
12        // Try direct hash lookup first
13        let mut stmt = self
14            .conn
15            .prepare("SELECT * FROM nodes WHERE hash = ?1")
16            .ok()?;
17        if let Ok(node) = stmt.query_row(params![hash], Self::row_to_node) {
18            return Some(self.node_with_relations(node));
19        }
20
21        // Fall back to previous_hashes table for renamed/updated nodes
22        let mut prev_stmt = self
23            .conn
24            .prepare(
25                "SELECT n.* FROM nodes n
26                 JOIN previous_hashes ph ON ph.node_id = n.id
27                 WHERE ph.hash = ?1
28                 LIMIT 1",
29            )
30            .ok()?;
31        let node = prev_stmt
32            .query_row(params![hash], Self::row_to_node)
33            .ok()?;
34        Some(self.node_with_relations(node))
35    }
36
37    fn get_node_by_id(&self, id: u64) -> Option<GraphNode> {
38        let mut stmt = self
39            .conn
40            .prepare("SELECT * FROM nodes WHERE id = ?1")
41            .ok()?;
42        let node = stmt
43            .query_row(params![id], Self::row_to_node)
44            .ok()?;
45        Some(self.node_with_relations(node))
46    }
47
48    fn get_edges(&self, node_id: u64, direction: EdgeDirection) -> Vec<GraphEdge> {
49        let query = match direction {
50            EdgeDirection::Incoming => "SELECT * FROM edges WHERE target_id = ?1",
51            EdgeDirection::Outgoing => "SELECT * FROM edges WHERE source_id = ?1",
52            EdgeDirection::Both => {
53                "SELECT * FROM edges WHERE source_id = ?1 OR target_id = ?1"
54            }
55        };
56
57        let mut stmt = match self.conn.prepare(query) {
58            Ok(s) => s,
59            Err(e) => {
60                eprintln!("[keel] get_edges: prepare failed: {e}");
61                return Vec::new();
62            }
63        };
64        let result = match stmt.query_map(params![node_id], |row| {
65            let kind_str: String = row.get("kind")?;
66            let kind = match kind_str.as_str() {
67                "calls" => EdgeKind::Calls,
68                "imports" => EdgeKind::Imports,
69                "inherits" => EdgeKind::Inherits,
70                "contains" => EdgeKind::Contains,
71                _ => EdgeKind::Calls,
72            };
73            Ok(GraphEdge {
74                id: row.get("id")?,
75                source_id: row.get("source_id")?,
76                target_id: row.get("target_id")?,
77                kind,
78                file_path: row.get("file_path")?,
79                line: row.get("line")?,
80                confidence: row.get("confidence").unwrap_or(1.0),
81            })
82        }) {
83            Ok(rows) => rows.filter_map(|r| r.ok()).collect(),
84            Err(e) => {
85                eprintln!("[keel] get_edges: query failed: {e}");
86                Vec::new()
87            }
88        };
89        result
90    }
91
92    fn get_module_profile(&self, module_id: u64) -> Option<ModuleProfile> {
93        let mut stmt = self
94            .conn
95            .prepare("SELECT * FROM module_profiles WHERE module_id = ?1")
96            .ok()?;
97        stmt.query_row(params![module_id], |row| {
98            let prefixes: String = row.get("function_name_prefixes")?;
99            let types: String = row.get("primary_types")?;
100            let imports: String = row.get("import_sources")?;
101            let exports: String = row.get("export_targets")?;
102            let keywords: String = row.get("responsibility_keywords")?;
103            Ok(ModuleProfile {
104                module_id: row.get("module_id")?,
105                path: row.get("path")?,
106                function_count: row.get("function_count")?,
107                class_count: row.get("class_count")?,
108                line_count: row.get("line_count")?,
109                function_name_prefixes: serde_json::from_str(&prefixes).unwrap_or_default(),
110                primary_types: serde_json::from_str(&types).unwrap_or_default(),
111                import_sources: serde_json::from_str(&imports).unwrap_or_default(),
112                export_targets: serde_json::from_str(&exports).unwrap_or_default(),
113                external_endpoint_count: row.get("external_endpoint_count")?,
114                responsibility_keywords: serde_json::from_str(&keywords).unwrap_or_default(),
115            })
116        })
117        .ok()
118    }
119
120    fn get_nodes_in_file(&self, file_path: &str) -> Vec<GraphNode> {
121        let mut stmt = match self
122            .conn
123            .prepare("SELECT * FROM nodes WHERE file_path = ?1")
124        {
125            Ok(s) => s,
126            Err(e) => {
127                eprintln!("[keel] get_nodes_in_file: prepare failed: {e}");
128                return Vec::new();
129            }
130        };
131        let nodes: Vec<GraphNode> = match stmt.query_map(params![file_path], Self::row_to_node) {
132            Ok(rows) => rows.filter_map(|r| r.ok()).collect(),
133            Err(e) => {
134                eprintln!("[keel] get_nodes_in_file: query failed: {e}");
135                return Vec::new();
136            }
137        };
138        // Batch-load relations: 2 queries total instead of 2*N
139        self.nodes_with_relations_batch(nodes)
140    }
141
142    fn get_all_modules(&self) -> Vec<GraphNode> {
143        let mut stmt = match self
144            .conn
145            .prepare("SELECT * FROM nodes WHERE kind = 'module'")
146        {
147            Ok(s) => s,
148            Err(e) => {
149                eprintln!("[keel] get_all_modules: prepare failed: {e}");
150                return Vec::new();
151            }
152        };
153        let nodes: Vec<GraphNode> = match stmt.query_map([], Self::row_to_node) {
154            Ok(rows) => rows.filter_map(|r| r.ok()).collect(),
155            Err(e) => {
156                eprintln!("[keel] get_all_modules: query failed: {e}");
157                return Vec::new();
158            }
159        };
160        // Batch-load relations: 2 queries total instead of 2*N
161        self.nodes_with_relations_batch(nodes)
162    }
163
164    fn update_nodes(&mut self, changes: Vec<NodeChange>) -> Result<(), GraphError> {
165        let tx = self.conn.transaction()?;
166        for change in changes {
167            match change {
168                NodeChange::Add(node) => {
169                    // Check for hash collision (different function, same hash)
170                    let existing: Option<String> = tx
171                        .query_row(
172                            "SELECT name FROM nodes WHERE hash = ?1",
173                            params![node.hash],
174                            |row| row.get(0),
175                        )
176                        .ok();
177                    if let Some(existing_name) = existing {
178                        if existing_name != node.name {
179                            return Err(GraphError::HashCollision {
180                                hash: node.hash.clone(),
181                                existing: existing_name,
182                                new_fn: node.name.clone(),
183                            });
184                        }
185                    }
186                    // UPSERT to handle re-map without cascade-deleting related rows
187                    tx.execute(
188                        "INSERT INTO nodes (id, hash, kind, name, signature, file_path, line_start, line_end, docstring, is_public, type_hints_present, has_docstring, module_id, package)
189                         VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14)
190                         ON CONFLICT(hash) DO UPDATE SET
191                            kind = excluded.kind,
192                            name = excluded.name,
193                            signature = excluded.signature,
194                            file_path = excluded.file_path,
195                            line_start = excluded.line_start,
196                            line_end = excluded.line_end,
197                            docstring = excluded.docstring,
198                            is_public = excluded.is_public,
199                            type_hints_present = excluded.type_hints_present,
200                            has_docstring = excluded.has_docstring,
201                            module_id = excluded.module_id,
202                            package = excluded.package,
203                            updated_at = datetime('now')",
204                        params![
205                            node.id,
206                            node.hash,
207                            node.kind.as_str(),
208                            node.name,
209                            node.signature,
210                            node.file_path,
211                            node.line_start,
212                            node.line_end,
213                            node.docstring,
214                            node.is_public as i32,
215                            node.type_hints_present as i32,
216                            node.has_docstring as i32,
217                            if node.module_id == 0 { None } else { Some(node.module_id) },
218                            node.package,
219                        ],
220                    )?;
221                }
222                NodeChange::Update(node) => {
223                    // Check for hash collision (different node, same hash)
224                    let existing: Option<(u64, String)> = tx
225                        .query_row(
226                            "SELECT id, name FROM nodes WHERE hash = ?1",
227                            params![node.hash],
228                            |row| Ok((row.get(0)?, row.get(1)?)),
229                        )
230                        .ok();
231                    if let Some((existing_id, existing_name)) = existing {
232                        if existing_id != node.id {
233                            return Err(GraphError::HashCollision {
234                                hash: node.hash.clone(),
235                                existing: existing_name,
236                                new_fn: node.name.clone(),
237                            });
238                        }
239                    }
240                    tx.execute(
241                        "UPDATE nodes SET hash = ?1, kind = ?2, name = ?3, signature = ?4, file_path = ?5, line_start = ?6, line_end = ?7, docstring = ?8, is_public = ?9, type_hints_present = ?10, has_docstring = ?11, module_id = ?12, package = ?13, updated_at = datetime('now') WHERE id = ?14",
242                        params![
243                            node.hash,
244                            node.kind.as_str(),
245                            node.name,
246                            node.signature,
247                            node.file_path,
248                            node.line_start,
249                            node.line_end,
250                            node.docstring,
251                            node.is_public as i32,
252                            node.type_hints_present as i32,
253                            node.has_docstring as i32,
254                            if node.module_id == 0 { None } else { Some(node.module_id) },
255                            node.package,
256                            node.id,
257                        ],
258                    )?;
259                }
260                NodeChange::Remove(id) => {
261                    tx.execute("DELETE FROM nodes WHERE id = ?1", params![id])?;
262                }
263            }
264        }
265        tx.commit()?;
266        Ok(())
267    }
268
269    fn update_edges(&mut self, changes: Vec<EdgeChange>) -> Result<(), GraphError> {
270        let tx = self.conn.transaction()?;
271        for change in changes {
272            match change {
273                EdgeChange::Add(edge) => {
274                    // INSERT OR IGNORE handles UNIQUE constraint violations
275                    // (duplicate edges). FK violations are prevented by the caller
276                    // filtering edges to valid node IDs and disabling FK pragma.
277                    tx.execute(
278                        "INSERT OR IGNORE INTO edges (id, source_id, target_id, kind, file_path, line, confidence) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
279                        params![
280                            edge.id,
281                            edge.source_id,
282                            edge.target_id,
283                            edge.kind.as_str(),
284                            edge.file_path,
285                            edge.line,
286                            edge.confidence,
287                        ],
288                    )?;
289                }
290                EdgeChange::Remove(id) => {
291                    tx.execute("DELETE FROM edges WHERE id = ?1", params![id])?;
292                }
293            }
294        }
295        tx.commit()?;
296        Ok(())
297    }
298
299    fn get_previous_hashes(&self, node_id: u64) -> Vec<String> {
300        self.load_previous_hashes(node_id)
301    }
302
303    fn find_modules_by_prefix(&self, prefix: &str, exclude_file: &str) -> Vec<ModuleProfile> {
304        // Search module_profiles whose function_name_prefixes JSON array contains the prefix.
305        // The LIKE pattern matches the prefix as a quoted JSON string element.
306        let pattern = format!("%\"{}\"%" , prefix);
307        let mut stmt = match self.conn.prepare(
308            "SELECT mp.* FROM module_profiles mp
309             JOIN nodes n ON n.id = mp.module_id
310             WHERE n.file_path != ?1
311             AND mp.function_name_prefixes LIKE ?2",
312        ) {
313            Ok(s) => s,
314            Err(e) => {
315                eprintln!("[keel] find_modules_by_prefix: prepare failed: {e}");
316                return Vec::new();
317            }
318        };
319        let result = match stmt.query_map(params![exclude_file, pattern], |row| {
320            let prefixes: String = row.get("function_name_prefixes")?;
321            let types: String = row.get("primary_types")?;
322            let imports: String = row.get("import_sources")?;
323            let exports: String = row.get("export_targets")?;
324            let keywords: String = row.get("responsibility_keywords")?;
325            Ok(ModuleProfile {
326                module_id: row.get("module_id")?,
327                path: row.get("path")?,
328                function_count: row.get("function_count")?,
329                class_count: row.get("class_count")?,
330                line_count: row.get("line_count")?,
331                function_name_prefixes: serde_json::from_str(&prefixes).unwrap_or_default(),
332                primary_types: serde_json::from_str(&types).unwrap_or_default(),
333                import_sources: serde_json::from_str(&imports).unwrap_or_default(),
334                export_targets: serde_json::from_str(&exports).unwrap_or_default(),
335                external_endpoint_count: row.get("external_endpoint_count")?,
336                responsibility_keywords: serde_json::from_str(&keywords).unwrap_or_default(),
337            })
338        }) {
339            Ok(rows) => rows.filter_map(|r| r.ok()).collect(),
340            Err(e) => {
341                eprintln!("[keel] find_modules_by_prefix: query failed: {e}");
342                Vec::new()
343            }
344        };
345        result
346    }
347
348    fn find_nodes_by_name(&self, name: &str, kind: &str, exclude_file: &str) -> Vec<GraphNode> {
349        // Empty kind/exclude_file act as wildcards (match any)
350        let sql = match (kind.is_empty(), exclude_file.is_empty()) {
351            (true, true) => "SELECT * FROM nodes WHERE name = ?1",
352            (true, false) => "SELECT * FROM nodes WHERE name = ?1 AND file_path != ?2",
353            (false, true) => "SELECT * FROM nodes WHERE name = ?1 AND kind = ?2",
354            (false, false) => {
355                "SELECT * FROM nodes WHERE name = ?1 AND kind = ?2 AND file_path != ?3"
356            }
357        };
358        let mut stmt = match self.conn.prepare(sql) {
359            Ok(s) => s,
360            Err(e) => {
361                eprintln!("[keel] find_nodes_by_name: prepare failed: {e}");
362                return Vec::new();
363            }
364        };
365        let result = match (kind.is_empty(), exclude_file.is_empty()) {
366            (true, true) => stmt.query_map(params![name], Self::row_to_node),
367            (true, false) => stmt.query_map(params![name, exclude_file], Self::row_to_node),
368            (false, true) => stmt.query_map(params![name, kind], Self::row_to_node),
369            (false, false) => stmt.query_map(params![name, kind, exclude_file], Self::row_to_node),
370        };
371        match result {
372            Ok(rows) => rows.filter_map(|r| r.ok()).collect(),
373            Err(e) => {
374                eprintln!("[keel] find_nodes_by_name: query failed: {e}");
375                Vec::new()
376            }
377        }
378    }
379}