1use rusqlite::params;
2
3use crate::sqlite::SqliteGraphStore;
4use crate::store::GraphStore;
5use crate::types::{
6 EdgeChange, EdgeDirection, EdgeKind, GraphEdge, GraphError, GraphNode, ModuleProfile,
7 NodeChange,
8};
9
10impl GraphStore for SqliteGraphStore {
11 fn get_node(&self, hash: &str) -> Option<GraphNode> {
12 let mut stmt = self
14 .conn
15 .prepare("SELECT * FROM nodes WHERE hash = ?1")
16 .ok()?;
17 if let Ok(node) = stmt.query_row(params![hash], Self::row_to_node) {
18 return Some(self.node_with_relations(node));
19 }
20
21 let mut prev_stmt = self
23 .conn
24 .prepare(
25 "SELECT n.* FROM nodes n
26 JOIN previous_hashes ph ON ph.node_id = n.id
27 WHERE ph.hash = ?1
28 LIMIT 1",
29 )
30 .ok()?;
31 let node = prev_stmt.query_row(params![hash], Self::row_to_node).ok()?;
32 Some(self.node_with_relations(node))
33 }
34
35 fn get_node_by_id(&self, id: u64) -> Option<GraphNode> {
36 let mut stmt = self
37 .conn
38 .prepare("SELECT * FROM nodes WHERE id = ?1")
39 .ok()?;
40 let node = stmt.query_row(params![id], Self::row_to_node).ok()?;
41 Some(self.node_with_relations(node))
42 }
43
44 fn get_edges(&self, node_id: u64, direction: EdgeDirection) -> Vec<GraphEdge> {
45 let query = match direction {
46 EdgeDirection::Incoming => "SELECT * FROM edges WHERE target_id = ?1",
47 EdgeDirection::Outgoing => "SELECT * FROM edges WHERE source_id = ?1",
48 EdgeDirection::Both => "SELECT * FROM edges WHERE source_id = ?1 OR target_id = ?1",
49 };
50
51 let mut stmt = match self.conn.prepare(query) {
52 Ok(s) => s,
53 Err(e) => {
54 eprintln!("[keel] get_edges: prepare failed: {e}");
55 return Vec::new();
56 }
57 };
58 let result = match stmt.query_map(params![node_id], |row| {
59 let kind_str: String = row.get("kind")?;
60 let kind = match kind_str.as_str() {
61 "calls" => EdgeKind::Calls,
62 "imports" => EdgeKind::Imports,
63 "inherits" => EdgeKind::Inherits,
64 "contains" => EdgeKind::Contains,
65 _ => EdgeKind::Calls,
66 };
67 Ok(GraphEdge {
68 id: row.get("id")?,
69 source_id: row.get("source_id")?,
70 target_id: row.get("target_id")?,
71 kind,
72 file_path: row.get("file_path")?,
73 line: row.get("line")?,
74 confidence: row.get("confidence").unwrap_or(1.0),
75 })
76 }) {
77 Ok(rows) => rows.filter_map(|r| r.ok()).collect(),
78 Err(e) => {
79 eprintln!("[keel] get_edges: query failed: {e}");
80 Vec::new()
81 }
82 };
83 result
84 }
85
86 fn get_module_profile(&self, module_id: u64) -> Option<ModuleProfile> {
87 let mut stmt = self
88 .conn
89 .prepare("SELECT * FROM module_profiles WHERE module_id = ?1")
90 .ok()?;
91 stmt.query_row(params![module_id], |row| {
92 let prefixes: String = row.get("function_name_prefixes")?;
93 let types: String = row.get("primary_types")?;
94 let imports: String = row.get("import_sources")?;
95 let exports: String = row.get("export_targets")?;
96 let keywords: String = row.get("responsibility_keywords")?;
97 Ok(ModuleProfile {
98 module_id: row.get("module_id")?,
99 path: row.get("path")?,
100 function_count: row.get("function_count")?,
101 class_count: row.get("class_count")?,
102 line_count: row.get("line_count")?,
103 function_name_prefixes: serde_json::from_str(&prefixes).unwrap_or_default(),
104 primary_types: serde_json::from_str(&types).unwrap_or_default(),
105 import_sources: serde_json::from_str(&imports).unwrap_or_default(),
106 export_targets: serde_json::from_str(&exports).unwrap_or_default(),
107 external_endpoint_count: row.get("external_endpoint_count")?,
108 responsibility_keywords: serde_json::from_str(&keywords).unwrap_or_default(),
109 })
110 })
111 .ok()
112 }
113
114 fn get_nodes_in_file(&self, file_path: &str) -> Vec<GraphNode> {
115 let mut stmt = match self
116 .conn
117 .prepare("SELECT * FROM nodes WHERE file_path = ?1")
118 {
119 Ok(s) => s,
120 Err(e) => {
121 eprintln!("[keel] get_nodes_in_file: prepare failed: {e}");
122 return Vec::new();
123 }
124 };
125 let nodes: Vec<GraphNode> = match stmt.query_map(params![file_path], Self::row_to_node) {
126 Ok(rows) => rows.filter_map(|r| r.ok()).collect(),
127 Err(e) => {
128 eprintln!("[keel] get_nodes_in_file: query failed: {e}");
129 return Vec::new();
130 }
131 };
132 self.nodes_with_relations_batch(nodes)
134 }
135
136 fn get_all_modules(&self) -> Vec<GraphNode> {
137 let mut stmt = match self
138 .conn
139 .prepare("SELECT * FROM nodes WHERE kind = 'module'")
140 {
141 Ok(s) => s,
142 Err(e) => {
143 eprintln!("[keel] get_all_modules: prepare failed: {e}");
144 return Vec::new();
145 }
146 };
147 let nodes: Vec<GraphNode> = match stmt.query_map([], Self::row_to_node) {
148 Ok(rows) => rows.filter_map(|r| r.ok()).collect(),
149 Err(e) => {
150 eprintln!("[keel] get_all_modules: query failed: {e}");
151 return Vec::new();
152 }
153 };
154 self.nodes_with_relations_batch(nodes)
156 }
157
158 fn update_nodes(&mut self, changes: Vec<NodeChange>) -> Result<(), GraphError> {
159 let tx = self.conn.transaction()?;
160 for change in changes {
161 match change {
162 NodeChange::Add(node) => {
163 let existing: Option<String> = tx
165 .query_row(
166 "SELECT name FROM nodes WHERE hash = ?1",
167 params![node.hash],
168 |row| row.get(0),
169 )
170 .ok();
171 if let Some(existing_name) = existing {
172 if existing_name != node.name {
173 return Err(GraphError::HashCollision {
174 hash: node.hash.clone(),
175 existing: existing_name,
176 new_fn: node.name.clone(),
177 });
178 }
179 }
180 tx.execute(
182 "INSERT INTO nodes (id, hash, kind, name, signature, file_path, line_start, line_end, docstring, is_public, type_hints_present, has_docstring, module_id, package)
183 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14)
184 ON CONFLICT(hash) DO UPDATE SET
185 kind = excluded.kind,
186 name = excluded.name,
187 signature = excluded.signature,
188 file_path = excluded.file_path,
189 line_start = excluded.line_start,
190 line_end = excluded.line_end,
191 docstring = excluded.docstring,
192 is_public = excluded.is_public,
193 type_hints_present = excluded.type_hints_present,
194 has_docstring = excluded.has_docstring,
195 module_id = excluded.module_id,
196 package = excluded.package,
197 updated_at = datetime('now')",
198 params![
199 node.id,
200 node.hash,
201 node.kind.as_str(),
202 node.name,
203 node.signature,
204 node.file_path,
205 node.line_start,
206 node.line_end,
207 node.docstring,
208 node.is_public as i32,
209 node.type_hints_present as i32,
210 node.has_docstring as i32,
211 if node.module_id == 0 { None } else { Some(node.module_id) },
212 node.package,
213 ],
214 )?;
215 }
216 NodeChange::Update(node) => {
217 let existing: Option<(u64, String)> = tx
219 .query_row(
220 "SELECT id, name FROM nodes WHERE hash = ?1",
221 params![node.hash],
222 |row| Ok((row.get(0)?, row.get(1)?)),
223 )
224 .ok();
225 if let Some((existing_id, existing_name)) = existing {
226 if existing_id != node.id {
227 return Err(GraphError::HashCollision {
228 hash: node.hash.clone(),
229 existing: existing_name,
230 new_fn: node.name.clone(),
231 });
232 }
233 }
234 tx.execute(
235 "UPDATE nodes SET hash = ?1, kind = ?2, name = ?3, signature = ?4, file_path = ?5, line_start = ?6, line_end = ?7, docstring = ?8, is_public = ?9, type_hints_present = ?10, has_docstring = ?11, module_id = ?12, package = ?13, updated_at = datetime('now') WHERE id = ?14",
236 params![
237 node.hash,
238 node.kind.as_str(),
239 node.name,
240 node.signature,
241 node.file_path,
242 node.line_start,
243 node.line_end,
244 node.docstring,
245 node.is_public as i32,
246 node.type_hints_present as i32,
247 node.has_docstring as i32,
248 if node.module_id == 0 { None } else { Some(node.module_id) },
249 node.package,
250 node.id,
251 ],
252 )?;
253 }
254 NodeChange::Remove(id) => {
255 tx.execute("DELETE FROM nodes WHERE id = ?1", params![id])?;
256 }
257 }
258 }
259 tx.commit()?;
260 Ok(())
261 }
262
263 fn update_edges(&mut self, changes: Vec<EdgeChange>) -> Result<(), GraphError> {
264 let tx = self.conn.transaction()?;
265 for change in changes {
266 match change {
267 EdgeChange::Add(edge) => {
268 tx.execute(
272 "INSERT OR IGNORE INTO edges (id, source_id, target_id, kind, file_path, line, confidence) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
273 params![
274 edge.id,
275 edge.source_id,
276 edge.target_id,
277 edge.kind.as_str(),
278 edge.file_path,
279 edge.line,
280 edge.confidence,
281 ],
282 )?;
283 }
284 EdgeChange::Remove(id) => {
285 tx.execute("DELETE FROM edges WHERE id = ?1", params![id])?;
286 }
287 }
288 }
289 tx.commit()?;
290 Ok(())
291 }
292
293 fn get_previous_hashes(&self, node_id: u64) -> Vec<String> {
294 self.load_previous_hashes(node_id)
295 }
296
297 fn find_modules_by_prefix(&self, prefix: &str, exclude_file: &str) -> Vec<ModuleProfile> {
298 let pattern = format!("%\"{}\"%", prefix);
301 let mut stmt = match self.conn.prepare(
302 "SELECT mp.* FROM module_profiles mp
303 JOIN nodes n ON n.id = mp.module_id
304 WHERE n.file_path != ?1
305 AND mp.function_name_prefixes LIKE ?2",
306 ) {
307 Ok(s) => s,
308 Err(e) => {
309 eprintln!("[keel] find_modules_by_prefix: prepare failed: {e}");
310 return Vec::new();
311 }
312 };
313 let result = match stmt.query_map(params![exclude_file, pattern], |row| {
314 let prefixes: String = row.get("function_name_prefixes")?;
315 let types: String = row.get("primary_types")?;
316 let imports: String = row.get("import_sources")?;
317 let exports: String = row.get("export_targets")?;
318 let keywords: String = row.get("responsibility_keywords")?;
319 Ok(ModuleProfile {
320 module_id: row.get("module_id")?,
321 path: row.get("path")?,
322 function_count: row.get("function_count")?,
323 class_count: row.get("class_count")?,
324 line_count: row.get("line_count")?,
325 function_name_prefixes: serde_json::from_str(&prefixes).unwrap_or_default(),
326 primary_types: serde_json::from_str(&types).unwrap_or_default(),
327 import_sources: serde_json::from_str(&imports).unwrap_or_default(),
328 export_targets: serde_json::from_str(&exports).unwrap_or_default(),
329 external_endpoint_count: row.get("external_endpoint_count")?,
330 responsibility_keywords: serde_json::from_str(&keywords).unwrap_or_default(),
331 })
332 }) {
333 Ok(rows) => rows.filter_map(|r| r.ok()).collect(),
334 Err(e) => {
335 eprintln!("[keel] find_modules_by_prefix: query failed: {e}");
336 Vec::new()
337 }
338 };
339 result
340 }
341
342 fn find_nodes_by_name(&self, name: &str, kind: &str, exclude_file: &str) -> Vec<GraphNode> {
343 let sql = match (kind.is_empty(), exclude_file.is_empty()) {
345 (true, true) => "SELECT * FROM nodes WHERE name = ?1",
346 (true, false) => "SELECT * FROM nodes WHERE name = ?1 AND file_path != ?2",
347 (false, true) => "SELECT * FROM nodes WHERE name = ?1 AND kind = ?2",
348 (false, false) => {
349 "SELECT * FROM nodes WHERE name = ?1 AND kind = ?2 AND file_path != ?3"
350 }
351 };
352 let mut stmt = match self.conn.prepare(sql) {
353 Ok(s) => s,
354 Err(e) => {
355 eprintln!("[keel] find_nodes_by_name: prepare failed: {e}");
356 return Vec::new();
357 }
358 };
359 let result = match (kind.is_empty(), exclude_file.is_empty()) {
360 (true, true) => stmt.query_map(params![name], Self::row_to_node),
361 (true, false) => stmt.query_map(params![name, exclude_file], Self::row_to_node),
362 (false, true) => stmt.query_map(params![name, kind], Self::row_to_node),
363 (false, false) => stmt.query_map(params![name, kind, exclude_file], Self::row_to_node),
364 };
365 match result {
366 Ok(rows) => rows.filter_map(|r| r.ok()).collect(),
367 Err(e) => {
368 eprintln!("[keel] find_nodes_by_name: query failed: {e}");
369 Vec::new()
370 }
371 }
372 }
373}