1use std::path::Path;
7
8use rusqlite::{Connection, params};
9use tracing::debug;
10
11use crate::error::{Result, SedimentError};
12
13#[derive(Debug, Clone)]
15pub struct Edge {
16 pub target_id: String,
17 pub rel_type: String,
18 pub strength: f64,
19}
20
21#[derive(Debug, Clone)]
23pub struct CoAccessEdge {
24 pub target_id: String,
25 pub count: i64,
26}
27
28#[derive(Debug, Clone)]
30pub struct ConnectionInfo {
31 pub target_id: String,
32 pub rel_type: String,
33 pub strength: f64,
34 pub count: Option<i64>,
35}
36
37pub struct GraphStore {
39 conn: Connection,
40}
41
42impl GraphStore {
43 pub fn open(path: &Path) -> Result<Self> {
46 let conn = Connection::open(path).map_err(|e| {
47 SedimentError::Database(format!("Failed to open graph database: {}", e))
48 })?;
49
50 if let Err(e) = conn.execute_batch("PRAGMA journal_mode=WAL; PRAGMA busy_timeout=5000;") {
51 tracing::warn!("Failed to set SQLite PRAGMAs (graph): {}", e);
52 }
53
54 conn.execute_batch(
55 "CREATE TABLE IF NOT EXISTS graph_nodes (
56 id TEXT PRIMARY KEY,
57 project_id TEXT NOT NULL DEFAULT '',
58 created_at INTEGER NOT NULL
59 );
60
61 CREATE TABLE IF NOT EXISTS graph_edges (
62 from_id TEXT NOT NULL,
63 to_id TEXT NOT NULL,
64 edge_type TEXT NOT NULL,
65 strength REAL NOT NULL DEFAULT 0.0,
66 rel_type TEXT NOT NULL DEFAULT '',
67 count INTEGER NOT NULL DEFAULT 0,
68 last_at INTEGER NOT NULL DEFAULT 0,
69 created_at INTEGER NOT NULL,
70 UNIQUE(from_id, to_id, edge_type)
71 );
72
73 CREATE INDEX IF NOT EXISTS idx_edges_from ON graph_edges(from_id);
74 CREATE INDEX IF NOT EXISTS idx_edges_to ON graph_edges(to_id);",
75 )
76 .map_err(|e| SedimentError::Database(format!("Failed to create graph tables: {}", e)))?;
77
78 Ok(Self { conn })
79 }
80
81 pub fn add_node(&self, id: &str, project_id: Option<&str>, created_at: i64) -> Result<()> {
83 let pid = project_id.unwrap_or("");
84
85 self.conn
86 .execute(
87 "INSERT OR IGNORE INTO graph_nodes (id, project_id, created_at) VALUES (?1, ?2, ?3)",
88 params![id, pid, created_at],
89 )
90 .map_err(|e| SedimentError::Database(format!("Failed to add node: {}", e)))?;
91
92 debug!("Added graph node: {}", id);
93 Ok(())
94 }
95
96 pub fn ensure_node_exists(
98 &self,
99 id: &str,
100 project_id: Option<&str>,
101 created_at: i64,
102 ) -> Result<()> {
103 self.add_node(id, project_id, created_at)
104 }
105
106 pub fn remove_node(&self, id: &str) -> Result<()> {
108 self.conn
111 .execute(
112 "DELETE FROM graph_edges WHERE from_id = ?1 OR (to_id = ?1 AND edge_type != 'supersedes')",
113 params![id],
114 )
115 .map_err(|e| SedimentError::Database(format!("Failed to remove edges: {}", e)))?;
116
117 self.conn
118 .execute("DELETE FROM graph_nodes WHERE id = ?1", params![id])
119 .map_err(|e| SedimentError::Database(format!("Failed to remove node: {}", e)))?;
120
121 debug!("Removed graph node: {}", id);
122 Ok(())
123 }
124
125 pub fn add_related_edge(
127 &self,
128 from_id: &str,
129 to_id: &str,
130 strength: f64,
131 rel_type: &str,
132 ) -> Result<()> {
133 let now = chrono::Utc::now().timestamp();
134
135 self.conn
136 .execute(
137 "INSERT OR IGNORE INTO graph_edges (from_id, to_id, edge_type, strength, rel_type, created_at)
138 VALUES (?1, ?2, 'related', ?3, ?4, ?5)",
139 params![from_id, to_id, strength, rel_type, now],
140 )
141 .map_err(|e| SedimentError::Database(format!("Failed to add related edge: {}", e)))?;
142
143 debug!(
144 "Added RELATED edge: {} -> {} ({})",
145 from_id, to_id, rel_type
146 );
147 Ok(())
148 }
149
150 pub fn add_supersedes_edge(&self, new_id: &str, old_id: &str) -> Result<()> {
152 let now = chrono::Utc::now().timestamp();
153
154 self.conn
155 .execute(
156 "INSERT OR IGNORE INTO graph_edges (from_id, to_id, edge_type, strength, created_at)
157 VALUES (?1, ?2, 'supersedes', 1.0, ?3)",
158 params![new_id, old_id, now],
159 )
160 .map_err(|e| SedimentError::Database(format!("Failed to add supersedes edge: {}", e)))?;
161
162 debug!("Added SUPERSEDES edge: {} -> {}", new_id, old_id);
163 Ok(())
164 }
165
166 pub fn get_neighbors(
173 &self,
174 ids: &[&str],
175 min_strength: f64,
176 ) -> Result<Vec<(String, String, f64)>> {
177 if ids.is_empty() {
178 return Ok(Vec::new());
179 }
180
181 let placeholders: Vec<String> = (1..=ids.len()).map(|i| format!("?{}", i)).collect();
182 let ph = placeholders.join(",");
183 let strength_idx = ids.len() + 1;
184
185 let sql = format!(
186 "SELECT
187 CASE WHEN from_id IN ({ph}) THEN to_id ELSE from_id END AS neighbor,
188 CASE WHEN edge_type = 'related' THEN rel_type ELSE 'supersedes' END AS rtype,
189 strength
190 FROM graph_edges
191 WHERE (from_id IN ({ph}) OR to_id IN ({ph}))
192 AND edge_type IN ('related', 'supersedes')
193 AND strength >= ?{strength_idx}
194 LIMIT 100"
195 );
196
197 let mut stmt = self.conn.prepare(&sql).map_err(|e| {
198 SedimentError::Database(format!("Failed to prepare neighbors query: {}", e))
199 })?;
200
201 let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
202 for id in ids {
203 param_values.push(Box::new(id.to_string()));
204 }
205 param_values.push(Box::new(min_strength));
206
207 let params_ref: Vec<&dyn rusqlite::types::ToSql> =
208 param_values.iter().map(|b| b.as_ref()).collect();
209
210 let rows = stmt
211 .query_map(params_ref.as_slice(), |row| {
212 Ok((
213 row.get::<_, String>(0)?,
214 row.get::<_, String>(1)?,
215 row.get::<_, f64>(2)?,
216 ))
217 })
218 .map_err(|e| SedimentError::Database(format!("Failed to query neighbors: {}", e)))?;
219
220 let input_set: std::collections::HashSet<&str> = ids.iter().copied().collect();
222 let mut results = Vec::new();
223 for row in rows {
224 let r = row
225 .map_err(|e| SedimentError::Database(format!("Failed to read neighbor: {}", e)))?;
226 if !input_set.contains(r.0.as_str()) {
227 results.push(r);
228 }
229 }
230
231 Ok(results)
232 }
233
234 pub fn record_co_access(&self, item_ids: &[String]) -> Result<()> {
237 if item_ids.len() < 2 {
238 return Ok(());
239 }
240
241 let item_ids = if item_ids.len() > 3 {
245 &item_ids[..3]
246 } else {
247 item_ids
248 };
249
250 let now = chrono::Utc::now().timestamp();
251
252 for i in 0..item_ids.len() {
253 for j in (i + 1)..item_ids.len() {
254 let (a, b) = if item_ids[i] <= item_ids[j] {
257 (&item_ids[i], &item_ids[j])
258 } else {
259 (&item_ids[j], &item_ids[i])
260 };
261
262 self.conn
263 .execute(
264 "INSERT INTO graph_edges (from_id, to_id, edge_type, count, last_at, created_at)
265 VALUES (?1, ?2, 'co_accessed', 1, ?3, ?3)
266 ON CONFLICT(from_id, to_id, edge_type)
267 DO UPDATE SET count = count + 1, last_at = ?3",
268 params![a, b, now],
269 )
270 .map_err(|e| {
271 SedimentError::Database(format!("Failed to record co-access: {}", e))
272 })?;
273 }
274 }
275
276 Ok(())
277 }
278
279 pub fn get_co_accessed(&self, ids: &[&str], min_count: i64) -> Result<Vec<(String, i64)>> {
282 if ids.is_empty() {
283 return Ok(Vec::new());
284 }
285
286 let placeholders: Vec<String> = (1..=ids.len()).map(|i| format!("?{}", i)).collect();
287 let ph = placeholders.join(",");
288 let min_idx = ids.len() + 1;
289
290 let sql = format!(
291 "SELECT
292 CASE WHEN from_id IN ({ph}) THEN to_id ELSE from_id END AS neighbor,
293 count
294 FROM graph_edges
295 WHERE (from_id IN ({ph}) OR to_id IN ({ph}))
296 AND edge_type = 'co_accessed'
297 AND count >= ?{min_idx}
298 ORDER BY count DESC"
299 );
300
301 let mut stmt = self.conn.prepare(&sql).map_err(|e| {
302 SedimentError::Database(format!("Failed to prepare co-access query: {}", e))
303 })?;
304
305 let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
306 for id in ids {
307 param_values.push(Box::new(id.to_string()));
308 }
309 param_values.push(Box::new(min_count));
310
311 let params_ref: Vec<&dyn rusqlite::types::ToSql> =
312 param_values.iter().map(|b| b.as_ref()).collect();
313
314 let rows = stmt
315 .query_map(params_ref.as_slice(), |row| {
316 Ok((row.get::<_, String>(0)?, row.get::<_, i64>(1)?))
317 })
318 .map_err(|e| SedimentError::Database(format!("Failed to query co-access: {}", e)))?;
319
320 let mut results = Vec::new();
321 for row in rows {
322 let r = row
323 .map_err(|e| SedimentError::Database(format!("Failed to read co-access: {}", e)))?;
324 results.push(r);
325 }
326
327 results.sort_by(|a, b| b.1.cmp(&a.1));
329 let mut seen = std::collections::HashSet::new();
330 results.retain(|(id, _)| seen.insert(id.clone()));
331
332 Ok(results)
333 }
334
335 pub fn transfer_edges(&self, from_id: &str, to_id: &str) -> Result<()> {
337 let mut stmt = self
339 .conn
340 .prepare(
341 "SELECT from_id, to_id, strength, rel_type, created_at
342 FROM graph_edges
343 WHERE (from_id = ?1 OR to_id = ?1)
344 AND edge_type = 'related'
345 AND from_id != ?2 AND to_id != ?2",
346 )
347 .map_err(|e| {
348 SedimentError::Database(format!("Failed to prepare transfer query: {}", e))
349 })?;
350
351 let edges: Vec<(String, f64, String, i64)> = stmt
352 .query_map(params![from_id, to_id], |row| {
353 let fid: String = row.get(0)?;
354 let tid: String = row.get(1)?;
355 let neighbor = if fid == from_id { tid } else { fid };
356 Ok((neighbor, row.get(2)?, row.get(3)?, row.get(4)?))
357 })
358 .map_err(|e| {
359 SedimentError::Database(format!("Failed to query edges for transfer: {}", e))
360 })?
361 .filter_map(|r| match r {
362 Ok(v) => Some(v),
363 Err(e) => {
364 tracing::warn!("transfer_edges: failed to read row: {}", e);
365 None
366 }
367 })
368 .collect();
369
370 for (neighbor, strength, rel_type, _) in &edges {
372 if let Err(e) = self.add_related_edge(to_id, neighbor, *strength, rel_type) {
373 tracing::warn!("transfer edge to {} failed: {}", neighbor, e);
374 }
375 }
376
377 Ok(())
378 }
379
380 pub fn detect_clusters(&self) -> Result<Vec<(String, String, String)>> {
383 let mut stmt = self
384 .conn
385 .prepare(
386 "WITH biedges AS (
387 SELECT from_id AS a, to_id AS b FROM graph_edges WHERE edge_type = 'related'
388 UNION ALL
389 SELECT to_id AS a, from_id AS b FROM graph_edges WHERE edge_type = 'related'
390 )
391 SELECT DISTINCT e1.a, e1.b, e2.b
392 FROM biedges e1
393 JOIN biedges e2 ON e1.b = e2.a
394 JOIN biedges e3 ON e2.b = e3.a AND e3.b = e1.a
395 WHERE e1.a < e1.b AND e1.b < e2.b
396 LIMIT 50",
397 )
398 .map_err(|e| SedimentError::Database(format!("Failed to detect clusters: {}", e)))?;
399
400 let rows = stmt
401 .query_map([], |row| {
402 Ok((
403 row.get::<_, String>(0)?,
404 row.get::<_, String>(1)?,
405 row.get::<_, String>(2)?,
406 ))
407 })
408 .map_err(|e| SedimentError::Database(format!("Failed to read clusters: {}", e)))?;
409
410 let mut clusters = Vec::new();
411 for r in rows.flatten() {
412 clusters.push(r);
413 }
414
415 Ok(clusters)
416 }
417
418 pub fn get_full_connections(&self, item_id: &str) -> Result<Vec<ConnectionInfo>> {
420 let mut stmt = self
421 .conn
422 .prepare(
423 "SELECT
424 CASE WHEN from_id = ?1 THEN to_id ELSE from_id END AS neighbor,
425 edge_type,
426 strength,
427 rel_type,
428 count
429 FROM graph_edges
430 WHERE from_id = ?1 OR to_id = ?1",
431 )
432 .map_err(|e| {
433 SedimentError::Database(format!("Failed to prepare connections query: {}", e))
434 })?;
435
436 let rows = stmt
437 .query_map(params![item_id], |row| {
438 let neighbor: String = row.get(0)?;
439 let edge_type: String = row.get(1)?;
440 let strength: f64 = row.get(2)?;
441 let rel_type_val: String = row.get(3)?;
442 let count: i64 = row.get(4)?;
443
444 let display_type = match edge_type.as_str() {
445 "related" => rel_type_val.clone(),
446 "supersedes" => "supersedes".to_string(),
447 "co_accessed" => "co_accessed".to_string(),
448 _ => edge_type.clone(),
449 };
450
451 Ok(ConnectionInfo {
452 target_id: neighbor,
453 rel_type: display_type,
454 strength,
455 count: if edge_type == "co_accessed" {
456 Some(count)
457 } else {
458 None
459 },
460 })
461 })
462 .map_err(|e| SedimentError::Database(format!("Failed to query connections: {}", e)))?;
463
464 let mut connections = Vec::new();
465 for row in rows {
466 let r = row.map_err(|e| {
467 SedimentError::Database(format!("Failed to read connection: {}", e))
468 })?;
469 connections.push(r);
470 }
471
472 Ok(connections)
473 }
474
475 pub fn get_edge_count(&self, item_id: &str) -> Result<u32> {
477 let count: i64 = self
478 .conn
479 .query_row(
480 "SELECT COUNT(*) FROM graph_edges WHERE from_id = ?1 OR to_id = ?1",
481 params![item_id],
482 |row| row.get(0),
483 )
484 .map_err(|e| SedimentError::Database(format!("Failed to count edges: {}", e)))?;
485
486 Ok(count as u32)
487 }
488}
489
490#[cfg(test)]
491mod tests {
492 use super::*;
493 use tempfile::NamedTempFile;
494
495 fn open_test_graph() -> GraphStore {
496 let tmp = NamedTempFile::new().unwrap();
497 GraphStore::open(tmp.path()).unwrap()
498 }
499
500 #[test]
501 fn test_get_neighbors_excludes_input_ids() {
502 let graph = open_test_graph();
504 let now = chrono::Utc::now().timestamp();
505 graph.add_node("A", Some("proj"), now).unwrap();
506 graph.add_node("B", Some("proj"), now).unwrap();
507 graph.add_node("C", Some("proj"), now).unwrap();
508
509 graph.add_related_edge("A", "B", 0.9, "test").unwrap();
511 graph.add_related_edge("B", "C", 0.9, "test").unwrap();
512
513 let neighbors = graph.get_neighbors(&["A", "B"], 0.0).unwrap();
515 let neighbor_ids: Vec<&str> = neighbors.iter().map(|(id, _, _)| id.as_str()).collect();
516 assert!(neighbor_ids.contains(&"C"));
517 assert!(!neighbor_ids.contains(&"A"));
518 assert!(!neighbor_ids.contains(&"B"));
519 }
520
521 #[test]
522 fn test_co_access_normalized_direction() {
523 let graph = open_test_graph();
525 let now = chrono::Utc::now().timestamp();
526 graph.add_node("Z", Some("proj"), now).unwrap();
527 graph.add_node("A", Some("proj"), now).unwrap();
528
529 graph
531 .record_co_access(&["Z".to_string(), "A".to_string()])
532 .unwrap();
533 graph
535 .record_co_access(&["A".to_string(), "Z".to_string()])
536 .unwrap();
537
538 let count: i64 = graph
540 .conn
541 .query_row(
542 "SELECT COUNT(*) FROM graph_edges WHERE edge_type = 'co_accessed'",
543 [],
544 |row| row.get(0),
545 )
546 .unwrap();
547 assert_eq!(count, 1, "Should have exactly 1 co-access edge");
548
549 let edge_count: i64 = graph
550 .conn
551 .query_row(
552 "SELECT count FROM graph_edges WHERE edge_type = 'co_accessed'",
553 [],
554 |row| row.get(0),
555 )
556 .unwrap();
557 assert_eq!(edge_count, 2, "Edge count should be 2 (incremented twice)");
558 }
559
560 #[test]
561 fn test_transfer_edges_preserves_relationships() {
562 let graph = open_test_graph();
564 let now = chrono::Utc::now().timestamp();
565 graph.add_node("old", Some("proj"), now).unwrap();
566 graph.add_node("new", Some("proj"), now).unwrap();
567 graph.add_node("friend", Some("proj"), now).unwrap();
568
569 graph
570 .add_related_edge("old", "friend", 0.9, "test")
571 .unwrap();
572
573 graph.transfer_edges("old", "new").unwrap();
575
576 let neighbors = graph.get_neighbors(&["new"], 0.0).unwrap();
578 assert!(
579 !neighbors.is_empty(),
580 "New node should have inherited edges"
581 );
582 let neighbor_ids: Vec<&str> = neighbors.iter().map(|(id, _, _)| id.as_str()).collect();
583 assert!(neighbor_ids.contains(&"friend"));
584 }
585
586 #[test]
587 fn test_remove_node_preserves_incoming_supersedes() {
588 let graph = open_test_graph();
590 let now = chrono::Utc::now().timestamp();
591 graph.add_node("new", Some("proj"), now).unwrap();
592 graph.add_node("old", Some("proj"), now).unwrap();
593
594 graph.add_supersedes_edge("new", "old").unwrap();
596
597 graph.remove_node("old").unwrap();
599
600 let connections = graph.get_full_connections("new").unwrap();
602 assert_eq!(connections.len(), 1, "SUPERSEDES edge should be preserved");
603 assert_eq!(connections[0].target_id, "old");
604 assert_eq!(connections[0].rel_type, "supersedes");
605 }
606
607 #[test]
608 fn test_get_neighbors_bounded() {
609 let graph = open_test_graph();
611 let now = chrono::Utc::now().timestamp();
612 graph.add_node("center", Some("proj"), now).unwrap();
613
614 for i in 0..150 {
615 let id = format!("n{}", i);
616 graph.add_node(&id, Some("proj"), now).unwrap();
617 graph.add_related_edge("center", &id, 0.9, "test").unwrap();
618 }
619
620 let neighbors = graph.get_neighbors(&["center"], 0.0).unwrap();
621 assert!(
622 neighbors.len() <= 100,
623 "get_neighbors should return at most 100, got {}",
624 neighbors.len()
625 );
626 }
627}