1use crate::graph::EdgeType;
19use rusqlite::{Connection, Result as SqliteResult, params};
20use std::collections::HashMap;
21use std::path::Path;
22use thiserror::Error;
23
24pub const CROSS_REFS_SCHEMA_VERSION: &str = "1.1";
27
28const SCHEMA_CREATE_CROSS_REFS: &str = r#"
30CREATE TABLE IF NOT EXISTS cross_refs (
31 -- Auto-incrementing ID
32 id INTEGER PRIMARY KEY AUTOINCREMENT,
33
34 -- Source node info
35 source_id TEXT NOT NULL,
36 source_partition TEXT NOT NULL,
37
38 -- Target node info
39 target_id TEXT NOT NULL,
40 target_partition TEXT NOT NULL,
41
42 -- Edge metadata
43 edge_type TEXT NOT NULL,
44 ref_line INTEGER,
45 ident TEXT,
46
47 -- DependsOn edge metadata (v1.1)
48 version_spec TEXT,
49 is_dev_dependency INTEGER,
50
51 -- Ensure no duplicate cross-refs
52 UNIQUE(source_id, target_id, edge_type, ref_line)
53)
54"#;
55
56const SCHEMA_CREATE_CROSS_REFS_INDEXES: &str = r#"
58-- Index for finding what references a target
59CREATE INDEX IF NOT EXISTS idx_cross_refs_target ON cross_refs(target_id);
60
61-- Index for finding what a source references
62CREATE INDEX IF NOT EXISTS idx_cross_refs_source ON cross_refs(source_id);
63
64-- Index for partition-based cleanup
65CREATE INDEX IF NOT EXISTS idx_cross_refs_source_partition ON cross_refs(source_partition);
66CREATE INDEX IF NOT EXISTS idx_cross_refs_target_partition ON cross_refs(target_partition);
67"#;
68
69const SCHEMA_CREATE_METADATA: &str = r#"
71CREATE TABLE IF NOT EXISTS cross_refs_metadata (
72 key TEXT PRIMARY KEY NOT NULL,
73 value TEXT NOT NULL
74)
75"#;
76
77#[derive(Debug, Error)]
79pub enum CrossRefError {
80 #[error("SQLite error: {0}")]
81 Sqlite(#[from] rusqlite::Error),
82
83 #[error("Schema version mismatch: expected {expected}, found {found}")]
84 SchemaVersionMismatch { expected: String, found: String },
85
86 #[error("IO error: {0}")]
87 IoError(#[from] std::io::Error),
88}
89
90#[derive(Debug, Clone, PartialEq, Eq)]
92pub struct CrossRef {
93 pub source_id: String,
95 pub source_partition: String,
97 pub target_id: String,
99 pub target_partition: String,
101 pub edge_type: EdgeType,
103 pub ref_line: Option<usize>,
105 pub ident: Option<String>,
107 pub version_spec: Option<String>,
109 pub is_dev_dependency: Option<bool>,
111}
112
113impl CrossRef {
114 pub fn new(
116 source_id: String,
117 source_partition: String,
118 target_id: String,
119 target_partition: String,
120 edge_type: EdgeType,
121 ref_line: Option<usize>,
122 ident: Option<String>,
123 ) -> Self {
124 Self {
125 source_id,
126 source_partition,
127 target_id,
128 target_partition,
129 edge_type,
130 ref_line,
131 ident,
132 version_spec: None,
133 is_dev_dependency: None,
134 }
135 }
136
137 pub fn with_dependency(
139 source_id: String,
140 source_partition: String,
141 target_id: String,
142 target_partition: String,
143 ident: Option<String>,
144 version_spec: Option<String>,
145 is_dev_dependency: Option<bool>,
146 ) -> Self {
147 Self {
148 source_id,
149 source_partition,
150 target_id,
151 target_partition,
152 edge_type: EdgeType::DependsOn,
153 ref_line: None,
154 ident,
155 version_spec,
156 is_dev_dependency,
157 }
158 }
159}
160
161#[derive(Debug, Default)]
166pub struct CrossRefIndex {
167 by_target: HashMap<String, Vec<CrossRef>>,
169 by_source: HashMap<String, Vec<CrossRef>>,
171}
172
173impl CrossRefIndex {
174 pub fn new() -> Self {
176 Self::default()
177 }
178
179 pub fn add(&mut self, cross_ref: CrossRef) {
181 self.by_target
183 .entry(cross_ref.target_id.clone())
184 .or_default()
185 .push(cross_ref.clone());
186
187 self.by_source
189 .entry(cross_ref.source_id.clone())
190 .or_default()
191 .push(cross_ref);
192 }
193
194 pub fn add_all(&mut self, cross_refs: impl IntoIterator<Item = CrossRef>) {
196 for cross_ref in cross_refs {
197 self.add(cross_ref);
198 }
199 }
200
201 pub fn get_by_target(&self, target_id: &str) -> Option<&Vec<CrossRef>> {
203 self.by_target.get(target_id)
204 }
205
206 pub fn get_by_source(&self, source_id: &str) -> Option<&Vec<CrossRef>> {
208 self.by_source.get(source_id)
209 }
210
211 pub fn remove_by_source_partition(&mut self, partition: &str) {
213 let source_ids_to_remove: Vec<String> = self
215 .by_source
216 .iter()
217 .filter(|(_, refs)| refs.iter().any(|r| r.source_partition == partition))
218 .map(|(id, _)| id.clone())
219 .collect();
220
221 for source_id in &source_ids_to_remove {
223 self.by_source.remove(source_id);
224 }
225
226 for refs in self.by_target.values_mut() {
228 refs.retain(|r| r.source_partition != partition);
229 }
230
231 self.by_target.retain(|_, refs| !refs.is_empty());
233 }
234
235 pub fn remove_by_partition(&mut self, partition: &str) {
237 self.by_source.retain(|_, refs| {
239 refs.retain(|r| r.source_partition != partition && r.target_partition != partition);
240 !refs.is_empty()
241 });
242
243 self.by_target.retain(|_, refs| {
245 refs.retain(|r| r.source_partition != partition && r.target_partition != partition);
246 !refs.is_empty()
247 });
248 }
249
250 pub fn len(&self) -> usize {
252 self.by_source.values().map(|v| v.len()).sum()
253 }
254
255 pub fn is_empty(&self) -> bool {
257 self.by_source.is_empty()
258 }
259
260 pub fn source_partitions(&self) -> impl Iterator<Item = &str> {
262 self.by_source
263 .values()
264 .flat_map(|refs| refs.iter().map(|r| r.source_partition.as_str()))
265 .collect::<std::collections::HashSet<_>>()
266 .into_iter()
267 }
268
269 pub fn target_partitions(&self) -> impl Iterator<Item = &str> {
271 self.by_target
272 .values()
273 .flat_map(|refs| refs.iter().map(|r| r.target_partition.as_str()))
274 .collect::<std::collections::HashSet<_>>()
275 .into_iter()
276 }
277
278 pub fn clear(&mut self) {
280 self.by_source.clear();
281 self.by_target.clear();
282 }
283
284 pub fn iter(&self) -> impl Iterator<Item = &CrossRef> {
286 self.by_source.values().flat_map(|refs| refs.iter())
287 }
288}
289
290pub struct CrossRefStore {
296 conn: Connection,
297}
298
299impl CrossRefStore {
300 pub fn open(path: &Path) -> Result<Self, CrossRefError> {
302 let conn = Connection::open(path)?;
303 Self::configure_connection(&conn)?;
304
305 let store = Self { conn };
306
307 if let Some(version) = store.get_metadata("schema_version")? {
309 if version != CROSS_REFS_SCHEMA_VERSION {
310 return Err(CrossRefError::SchemaVersionMismatch {
311 expected: CROSS_REFS_SCHEMA_VERSION.to_string(),
312 found: version,
313 });
314 }
315 }
316
317 Ok(store)
318 }
319
320 pub fn create(path: &Path) -> Result<Self, CrossRefError> {
322 if let Some(parent) = path.parent() {
324 std::fs::create_dir_all(parent)?;
325 }
326
327 let conn = Connection::open(path)?;
328 Self::configure_connection(&conn)?;
329
330 conn.execute(SCHEMA_CREATE_CROSS_REFS, [])?;
332 conn.execute(SCHEMA_CREATE_METADATA, [])?;
333 conn.execute_batch(SCHEMA_CREATE_CROSS_REFS_INDEXES)?;
334
335 let store = Self { conn };
336
337 store.set_metadata("schema_version", CROSS_REFS_SCHEMA_VERSION)?;
339
340 Ok(store)
341 }
342
343 pub fn in_memory() -> Result<Self, CrossRefError> {
345 let conn = Connection::open_in_memory()?;
346 Self::configure_connection(&conn)?;
347
348 conn.execute(SCHEMA_CREATE_CROSS_REFS, [])?;
350 conn.execute(SCHEMA_CREATE_METADATA, [])?;
351 conn.execute_batch(SCHEMA_CREATE_CROSS_REFS_INDEXES)?;
352
353 let store = Self { conn };
354
355 store.set_metadata("schema_version", CROSS_REFS_SCHEMA_VERSION)?;
356
357 Ok(store)
358 }
359
360 fn configure_connection(conn: &Connection) -> SqliteResult<()> {
362 conn.pragma_update(None, "journal_mode", "WAL")?;
363 conn.pragma_update(None, "foreign_keys", "ON")?;
364 conn.pragma_update(None, "cache_size", -16000)?; conn.pragma_update(None, "synchronous", "NORMAL")?;
366 conn.pragma_update(None, "temp_store", "MEMORY")?;
367 Ok(())
368 }
369
370 fn get_metadata(&self, key: &str) -> Result<Option<String>, CrossRefError> {
372 let result = self
373 .conn
374 .query_row(
375 "SELECT value FROM cross_refs_metadata WHERE key = ?1",
376 [key],
377 |row| row.get(0),
378 )
379 .optional()?;
380 Ok(result)
381 }
382
383 fn set_metadata(&self, key: &str, value: &str) -> Result<(), CrossRefError> {
385 self.conn.execute(
386 "INSERT OR REPLACE INTO cross_refs_metadata (key, value) VALUES (?1, ?2)",
387 params![key, value],
388 )?;
389 Ok(())
390 }
391
392 pub fn load_all(&self) -> Result<CrossRefIndex, CrossRefError> {
394 let mut index = CrossRefIndex::new();
395
396 let mut stmt = self.conn.prepare(
397 "SELECT source_id, source_partition, target_id, target_partition, edge_type, ref_line, ident, version_spec, is_dev_dependency FROM cross_refs",
398 )?;
399
400 let rows = stmt.query_map([], |row| {
401 let edge_type_str: String = row.get(4)?;
402 let edge_type = match edge_type_str.as_str() {
403 "CONTAINS" => EdgeType::Contains,
404 "USES" => EdgeType::Uses,
405 "DEFINES" => EdgeType::Defines,
406 "DEPENDS_ON" => EdgeType::DependsOn,
407 _ => EdgeType::Uses, };
409
410 Ok(CrossRef {
411 source_id: row.get(0)?,
412 source_partition: row.get(1)?,
413 target_id: row.get(2)?,
414 target_partition: row.get(3)?,
415 edge_type,
416 ref_line: row.get::<_, Option<i64>>(5)?.map(|v| v as usize),
417 ident: row.get(6)?,
418 version_spec: row.get(7)?,
419 is_dev_dependency: row.get::<_, Option<i64>>(8)?.map(|v| v != 0),
420 })
421 })?;
422
423 for row in rows {
424 index.add(row?);
425 }
426
427 Ok(index)
428 }
429
430 pub fn save_all(&self, index: &CrossRefIndex) -> Result<(), CrossRefError> {
432 let tx = self.conn.unchecked_transaction()?;
433
434 tx.execute("DELETE FROM cross_refs", [])?;
436
437 let mut stmt = tx.prepare(
439 "INSERT INTO cross_refs (source_id, source_partition, target_id, target_partition, edge_type, ref_line, ident, version_spec, is_dev_dependency) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
440 )?;
441
442 for cross_ref in index.iter() {
443 stmt.execute(params![
444 cross_ref.source_id,
445 cross_ref.source_partition,
446 cross_ref.target_id,
447 cross_ref.target_partition,
448 cross_ref.edge_type.as_str(),
449 cross_ref.ref_line.map(|v| v as i64),
450 cross_ref.ident,
451 cross_ref.version_spec,
452 cross_ref
453 .is_dev_dependency
454 .map(|b| if b { 1i64 } else { 0i64 }),
455 ])?;
456 }
457
458 drop(stmt);
459 tx.commit()?;
460
461 Ok(())
462 }
463
464 pub fn add_refs(&self, refs: &[CrossRef]) -> Result<(), CrossRefError> {
466 let tx = self.conn.unchecked_transaction()?;
467
468 let mut stmt = tx.prepare(
469 "INSERT OR IGNORE INTO cross_refs (source_id, source_partition, target_id, target_partition, edge_type, ref_line, ident, version_spec, is_dev_dependency) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
470 )?;
471
472 for cross_ref in refs {
473 stmt.execute(params![
474 cross_ref.source_id,
475 cross_ref.source_partition,
476 cross_ref.target_id,
477 cross_ref.target_partition,
478 cross_ref.edge_type.as_str(),
479 cross_ref.ref_line.map(|v| v as i64),
480 cross_ref.ident,
481 cross_ref.version_spec,
482 cross_ref
483 .is_dev_dependency
484 .map(|b| if b { 1i64 } else { 0i64 }),
485 ])?;
486 }
487
488 drop(stmt);
489 tx.commit()?;
490
491 Ok(())
492 }
493
494 pub fn remove_refs_by_partition(&self, partition: &str) -> Result<usize, CrossRefError> {
496 let deleted = self.conn.execute(
497 "DELETE FROM cross_refs WHERE source_partition = ?1 OR target_partition = ?1",
498 [partition],
499 )?;
500 Ok(deleted)
501 }
502
503 pub fn remove_refs_by_source_partition(&self, partition: &str) -> Result<usize, CrossRefError> {
505 let deleted = self.conn.execute(
506 "DELETE FROM cross_refs WHERE source_partition = ?1",
507 [partition],
508 )?;
509 Ok(deleted)
510 }
511
512 pub fn count(&self) -> Result<usize, CrossRefError> {
514 let count: i64 = self
515 .conn
516 .query_row("SELECT COUNT(*) FROM cross_refs", [], |row| row.get(0))?;
517 Ok(count as usize)
518 }
519}
520
521use rusqlite::OptionalExtension;
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527
528 fn sample_cross_ref(n: usize) -> CrossRef {
529 CrossRef::new(
530 format!("src/mod{}.rs:func{}", n, n),
531 format!("partition_{}", n % 3),
532 format!("src/lib.rs:helper{}", n % 5),
533 "partition_lib".to_string(),
534 EdgeType::Uses,
535 Some(10 + n),
536 Some(format!("helper{}", n % 5)),
537 )
538 }
539
540 #[test]
545 fn test_index_add_and_get() {
546 let mut index = CrossRefIndex::new();
547
548 let cross_ref = sample_cross_ref(1);
549 index.add(cross_ref.clone());
550
551 let by_target = index.get_by_target("src/lib.rs:helper1").unwrap();
553 assert_eq!(by_target.len(), 1);
554 assert_eq!(by_target[0], cross_ref);
555
556 let by_source = index.get_by_source("src/mod1.rs:func1").unwrap();
558 assert_eq!(by_source.len(), 1);
559 assert_eq!(by_source[0], cross_ref);
560 }
561
562 #[test]
563 fn test_index_multiple_refs_to_same_target() {
564 let mut index = CrossRefIndex::new();
565
566 for i in 0..5 {
568 index.add(CrossRef::new(
569 format!("src/mod{}.rs:caller{}", i, i),
570 format!("partition_{}", i),
571 "src/lib.rs:shared_func".to_string(),
572 "partition_lib".to_string(),
573 EdgeType::Uses,
574 Some(10 + i),
575 Some("shared_func".to_string()),
576 ));
577 }
578
579 let refs = index.get_by_target("src/lib.rs:shared_func").unwrap();
580 assert_eq!(refs.len(), 5);
581 }
582
583 #[test]
584 fn test_index_remove_by_source_partition() {
585 let mut index = CrossRefIndex::new();
586
587 for i in 0..10 {
589 index.add(sample_cross_ref(i));
590 }
591
592 let initial_count = index.len();
593 assert!(initial_count > 0);
594
595 index.remove_by_source_partition("partition_0");
597
598 for source_refs in index.by_source.values() {
600 for r in source_refs {
601 assert_ne!(r.source_partition, "partition_0");
602 }
603 }
604 }
605
606 #[test]
607 fn test_index_remove_by_partition() {
608 let mut index = CrossRefIndex::new();
609
610 index.add(CrossRef::new(
612 "a:func".to_string(),
613 "part_a".to_string(),
614 "b:target".to_string(),
615 "part_b".to_string(),
616 EdgeType::Uses,
617 None,
618 None,
619 ));
620 index.add(CrossRef::new(
621 "c:func".to_string(),
622 "part_c".to_string(),
623 "a:target".to_string(),
624 "part_a".to_string(),
625 EdgeType::Uses,
626 None,
627 None,
628 ));
629 index.add(CrossRef::new(
630 "d:func".to_string(),
631 "part_d".to_string(),
632 "e:target".to_string(),
633 "part_e".to_string(),
634 EdgeType::Uses,
635 None,
636 None,
637 ));
638
639 assert_eq!(index.len(), 3);
640
641 index.remove_by_partition("part_a");
643
644 assert_eq!(index.len(), 1);
645 }
646
647 #[test]
648 fn test_index_len_and_is_empty() {
649 let mut index = CrossRefIndex::new();
650 assert!(index.is_empty());
651 assert_eq!(index.len(), 0);
652
653 index.add(sample_cross_ref(1));
654 assert!(!index.is_empty());
655 assert_eq!(index.len(), 1);
656
657 index.add(sample_cross_ref(2));
658 assert_eq!(index.len(), 2);
659 }
660
661 #[test]
662 fn test_index_clear() {
663 let mut index = CrossRefIndex::new();
664
665 for i in 0..5 {
666 index.add(sample_cross_ref(i));
667 }
668
669 assert!(!index.is_empty());
670 index.clear();
671 assert!(index.is_empty());
672 }
673
674 #[test]
675 fn test_index_iter() {
676 let mut index = CrossRefIndex::new();
677
678 for i in 0..5 {
679 index.add(sample_cross_ref(i));
680 }
681
682 let collected: Vec<_> = index.iter().collect();
683 assert_eq!(collected.len(), 5);
684 }
685
686 #[test]
691 fn test_store_create_and_open() {
692 let store = CrossRefStore::in_memory().unwrap();
693 assert_eq!(store.count().unwrap(), 0);
694 }
695
696 #[test]
697 fn test_store_save_and_load() {
698 let store = CrossRefStore::in_memory().unwrap();
699
700 let mut index = CrossRefIndex::new();
702 for i in 0..10 {
703 index.add(sample_cross_ref(i));
704 }
705
706 store.save_all(&index).unwrap();
708 assert_eq!(store.count().unwrap(), 10);
709
710 let loaded = store.load_all().unwrap();
712 assert_eq!(loaded.len(), 10);
713
714 for i in 0..10 {
716 let target = format!("src/lib.rs:helper{}", i % 5);
717 let refs = loaded.get_by_target(&target);
718 assert!(refs.is_some());
719 }
720 }
721
722 #[test]
723 fn test_store_add_refs() {
724 let store = CrossRefStore::in_memory().unwrap();
725
726 let refs1: Vec<_> = (0..5).map(sample_cross_ref).collect();
728 store.add_refs(&refs1).unwrap();
729 assert_eq!(store.count().unwrap(), 5);
730
731 let refs2: Vec<_> = (5..10).map(sample_cross_ref).collect();
733 store.add_refs(&refs2).unwrap();
734 assert_eq!(store.count().unwrap(), 10);
735 }
736
737 #[test]
738 fn test_store_remove_by_partition() {
739 let store = CrossRefStore::in_memory().unwrap();
740
741 let refs: Vec<_> = (0..10).map(sample_cross_ref).collect();
743 store.add_refs(&refs).unwrap();
744
745 let initial = store.count().unwrap();
746 assert_eq!(initial, 10);
747
748 let removed = store.remove_refs_by_partition("partition_lib").unwrap();
750 assert_eq!(removed, 10);
751 assert_eq!(store.count().unwrap(), 0);
752 }
753
754 #[test]
755 fn test_store_remove_by_source_partition() {
756 let store = CrossRefStore::in_memory().unwrap();
757
758 let refs: Vec<_> = (0..9).map(sample_cross_ref).collect();
760 store.add_refs(&refs).unwrap();
761 assert_eq!(store.count().unwrap(), 9);
762
763 let removed = store
765 .remove_refs_by_source_partition("partition_0")
766 .unwrap();
767 assert_eq!(removed, 3);
768 assert_eq!(store.count().unwrap(), 6);
769 }
770
771 #[test]
772 fn test_store_roundtrip_edge_types() {
773 let store = CrossRefStore::in_memory().unwrap();
774
775 let mut index = CrossRefIndex::new();
776 index.add(CrossRef::new(
777 "a:x".to_string(),
778 "p1".to_string(),
779 "b:y".to_string(),
780 "p2".to_string(),
781 EdgeType::Contains,
782 None,
783 None,
784 ));
785 index.add(CrossRef::new(
786 "c:x".to_string(),
787 "p1".to_string(),
788 "d:y".to_string(),
789 "p2".to_string(),
790 EdgeType::Uses,
791 Some(42),
792 Some("ident".to_string()),
793 ));
794 index.add(CrossRef::new(
795 "e:x".to_string(),
796 "p1".to_string(),
797 "f:y".to_string(),
798 "p2".to_string(),
799 EdgeType::Defines,
800 None,
801 None,
802 ));
803
804 store.save_all(&index).unwrap();
805 let loaded = store.load_all().unwrap();
806
807 assert_eq!(loaded.len(), 3);
808
809 let contains_refs = loaded.get_by_source("a:x").unwrap();
811 assert_eq!(contains_refs[0].edge_type, EdgeType::Contains);
812
813 let uses_refs = loaded.get_by_source("c:x").unwrap();
814 assert_eq!(uses_refs[0].edge_type, EdgeType::Uses);
815 assert_eq!(uses_refs[0].ref_line, Some(42));
816 assert_eq!(uses_refs[0].ident, Some("ident".to_string()));
817
818 let defines_refs = loaded.get_by_source("e:x").unwrap();
819 assert_eq!(defines_refs[0].edge_type, EdgeType::Defines);
820 }
821
822 #[test]
823 fn test_store_duplicate_handling() {
824 let store = CrossRefStore::in_memory().unwrap();
825
826 let cross_ref = sample_cross_ref(1);
827
828 store.add_refs(std::slice::from_ref(&cross_ref)).unwrap();
830 store.add_refs(std::slice::from_ref(&cross_ref)).unwrap();
831
832 assert_eq!(store.count().unwrap(), 1);
834 }
835}