1use std::path::{Path, PathBuf};
33
34use crate::error::{ForgeError, Result};
35use crate::types::{Symbol, SymbolId, Reference, SymbolKind, ReferenceKind, Language, Location};
36
37pub use sqlitegraph::backend::{NodeSpec, EdgeSpec};
39pub use sqlitegraph::graph::{GraphEntity, SqliteGraph};
40pub use sqlitegraph::config::{BackendKind as SqliteGraphBackendKind, GraphConfig, open_graph};
41
42#[derive(Clone, Copy, Debug, PartialEq, Eq)]
46pub enum BackendKind {
47 SQLite,
49 NativeV3,
51}
52
53impl Default for BackendKind {
54 fn default() -> Self {
55 Self::SQLite
56 }
57}
58
59impl std::fmt::Display for BackendKind {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 match self {
62 Self::SQLite => write!(f, "SQLite"),
63 Self::NativeV3 => write!(f, "NativeV3"),
64 }
65 }
66}
67
68impl BackendKind {
69 #[cfg(test)] fn to_sqlitegraph_kind(&self) -> SqliteGraphBackendKind {
72 match self {
73 Self::SQLite => SqliteGraphBackendKind::SQLite,
74 Self::NativeV3 => SqliteGraphBackendKind::Native,
75 }
76 }
77
78 pub fn file_extension(&self) -> &str {
80 match self {
81 Self::SQLite => "db",
82 Self::NativeV3 => "v3",
83 }
84 }
85
86 pub fn default_filename(&self) -> &str {
88 match self {
89 Self::SQLite => "graph.db",
90 Self::NativeV3 => "graph.v3",
91 }
92 }
93}
94
95pub struct UnifiedGraphStore {
101 pub codebase_path: PathBuf,
103 pub db_path: PathBuf,
105 pub backend_kind: BackendKind,
107 references: std::sync::Mutex<Vec<StoredReference>>,
109}
110
111#[derive(Clone, Debug)]
113struct StoredReference {
114 to_symbol: String,
115 kind: ReferenceKind,
116 file_path: PathBuf,
117 line_number: usize,
118}
119
120impl Clone for UnifiedGraphStore {
121 fn clone(&self) -> Self {
122 Self {
123 codebase_path: self.codebase_path.clone(),
124 db_path: self.db_path.clone(),
125 backend_kind: self.backend_kind,
126 references: std::sync::Mutex::new(self.references.lock().unwrap().clone()),
127 }
128 }
129}
130
131impl std::fmt::Debug for UnifiedGraphStore {
132 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 f.debug_struct("UnifiedGraphStore")
134 .field("codebase_path", &self.codebase_path)
135 .field("db_path", &self.db_path)
136 .field("backend_kind", &self.backend_kind)
137 .field("connected", &self.is_connected())
138 .finish()
139 }
140}
141
142impl UnifiedGraphStore {
143 pub async fn open(codebase_path: impl AsRef<Path>, backend_kind: BackendKind) -> Result<Self> {
154 let codebase = codebase_path.as_ref();
155 let db_path = codebase.join(".forge").join(backend_kind.default_filename());
156
157 if let Some(parent) = db_path.parent() {
159 tokio::fs::create_dir_all(parent).await
160 .map_err(|e| ForgeError::DatabaseError(
161 format!("Failed to create database directory: {}", e)
162 ))?;
163 }
164
165 let config = match backend_kind {
167 BackendKind::SQLite => GraphConfig::sqlite(),
168 BackendKind::NativeV3 => GraphConfig::native(),
169 };
170
171 let _graph = open_graph(&db_path, &config)
172 .map_err(|e| ForgeError::DatabaseError(
173 format!("Failed to open database: {}", e)
174 ))?;
175
176 Ok(UnifiedGraphStore {
177 codebase_path: codebase.to_path_buf(),
178 db_path,
179 backend_kind,
180 references: std::sync::Mutex::new(Vec::new()),
181 })
182 }
183
184 pub async fn open_with_path(
192 codebase_path: impl AsRef<Path>,
193 db_path: impl AsRef<Path>,
194 backend_kind: BackendKind,
195 ) -> Result<Self> {
196 let codebase = codebase_path.as_ref();
197 let db = db_path.as_ref();
198
199 if let Some(parent) = db.parent() {
201 tokio::fs::create_dir_all(parent).await
202 .map_err(|e| ForgeError::DatabaseError(
203 format!("Failed to create database directory: {}", e)
204 ))?;
205 }
206
207 let config = match backend_kind {
209 BackendKind::SQLite => GraphConfig::sqlite(),
210 BackendKind::NativeV3 => GraphConfig::native(),
211 };
212
213 let _graph = open_graph(db, &config)
214 .map_err(|e| ForgeError::DatabaseError(
215 format!("Failed to open database: {}", e)
216 ))?;
217
218 Ok(UnifiedGraphStore {
219 codebase_path: codebase.to_path_buf(),
220 db_path: db.to_path_buf(),
221 backend_kind,
222 references: std::sync::Mutex::new(Vec::new()),
223 })
224 }
225
226 #[cfg(test)]
228 pub async fn memory() -> Result<Self> {
229 use tempfile::tempdir;
230
231 let temp = tempdir().map_err(|e| ForgeError::DatabaseError(
232 format!("Failed to create temp directory: {}", e)
233 ))?;
234
235 Self::open(temp.path(), BackendKind::SQLite).await
236 }
237
238 #[inline]
240 pub fn backend_kind(&self) -> BackendKind {
241 self.backend_kind
242 }
243
244 #[inline]
246 pub fn db_path(&self) -> &Path {
247 &self.db_path
248 }
249
250 pub fn is_connected(&self) -> bool {
252 self.db_path.exists()
253 }
254
255 pub async fn insert_symbol(&self, _symbol: &Symbol) -> Result<SymbolId> {
265 Ok(SymbolId(1))
272 }
273
274 pub async fn insert_reference(&self, reference: &Reference) -> Result<()> {
280 if self.backend_kind == BackendKind::NativeV3 {
283 let mut refs = self.references.lock().unwrap();
284
285 let to_symbol = format!("sym_{}", reference.to.0);
288
289 refs.push(StoredReference {
290 to_symbol,
291 kind: reference.kind.clone(),
292 file_path: reference.location.file_path.clone(),
293 line_number: reference.location.line_number,
294 });
295 }
296 Ok(())
297 }
298
299 pub async fn query_symbols(&self, name: &str) -> Result<Vec<Symbol>> {
309 self.search_codebase_files(name).await
311 }
312
313 async fn search_codebase_files(&self, pattern: &str) -> Result<Vec<Symbol>> {
315 use tokio::fs;
316
317 let mut symbols = Vec::new();
318 let mut entries = fs::read_dir(&self.codebase_path).await
319 .map_err(|e| ForgeError::DatabaseError(format!("Failed to read codebase: {}", e)))?;
320
321 while let Some(entry) = entries.next_entry().await
322 .map_err(|e| ForgeError::DatabaseError(format!("Failed to read entry: {}", e)))?
323 {
324 let path = entry.path();
325 if path.extension().map(|e| e == "rs").unwrap_or(false) {
326 if let Ok(content) = fs::read_to_string(&path).await {
327 for (line_num, line) in content.lines().enumerate() {
328 if line.contains(pattern) {
329 let name = line.split_whitespace()
331 .find(|w| w.contains(pattern))
332 .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric() && c != '_'))
333 .unwrap_or(pattern)
334 .to_string();
335
336 symbols.push(Symbol {
337 id: SymbolId(symbols.len() as i64 + 1),
338 name: name.clone(),
339 fully_qualified_name: name,
340 kind: SymbolKind::Function,
341 language: Language::Rust,
342 location: Location {
343 file_path: path.clone(),
344 byte_start: 0,
345 byte_end: line.len() as u32,
346 line_number: line_num + 1,
347 },
348 parent_id: None,
349 metadata: serde_json::Value::Null,
350 });
351 break; }
353 }
354 }
355 }
356 }
357
358 Ok(symbols)
359 }
360
361 pub async fn get_symbol(&self, _id: SymbolId) -> Result<Symbol> {
371 Err(ForgeError::SymbolNotFound("Not implemented".to_string()))
372 }
373
374 pub async fn symbol_exists(&self, _id: SymbolId) -> Result<bool> {
380 Ok(false)
381 }
382
383 pub async fn query_references(&self, symbol_id: SymbolId) -> Result<Vec<Reference>> {
394 if self.backend_kind == BackendKind::NativeV3 {
397 let refs = self.references.lock().unwrap();
398 let target_symbol = format!("sym_{}", symbol_id.0);
399
400 let mut result = Vec::new();
401 for stored in refs.iter() {
402 if stored.to_symbol == target_symbol {
403 result.push(Reference {
404 from: SymbolId(0),
405 to: symbol_id,
406 kind: stored.kind.clone(),
407 location: Location {
408 file_path: stored.file_path.clone(),
409 byte_start: 0,
410 byte_end: 0,
411 line_number: stored.line_number,
412 },
413 });
414 }
415 }
416 return Ok(result);
417 }
418
419 Ok(Vec::new())
421 }
422
423 pub async fn get_all_symbols(&self) -> Result<Vec<Symbol>> {
425 Ok(Vec::new())
426 }
427
428 pub async fn symbol_count(&self) -> Result<usize> {
430 Ok(0)
431 }
432
433 pub async fn index_cross_file_references(&self) -> Result<usize> {
442 if self.backend_kind != BackendKind::NativeV3 {
443 return Ok(0); }
445
446 self.legacy_index_cross_file_references().await
449 }
450
451 async fn legacy_index_cross_file_references(&self) -> Result<usize> {
453 use tokio::fs;
454 use regex::Regex;
455
456 let mut symbols: std::collections::HashMap<String, (PathBuf, usize)> = std::collections::HashMap::new();
458 self.collect_symbols_recursive(&self.codebase_path, &mut symbols).await?;
459
460 let mut ref_count = 0;
462 let mut refs = self.references.lock().unwrap();
463 refs.clear(); let reference_pattern = Regex::new(r"\b([a-zA-Z_][a-zA-Z0-9_]*)\s*\(").unwrap();
466
467 for (symbol_name, (_file_path, _)) in &symbols {
468 for (target_file, _) in symbols.values() {
470 if let Ok(content) = fs::read_to_string(target_file).await {
471 for (line_num, line) in content.lines().enumerate() {
472 if line.contains("fn ") || line.contains("struct ") {
474 continue;
475 }
476
477 for cap in reference_pattern.captures_iter(line) {
479 if let Some(matched) = cap.get(1) {
480 if matched.as_str() == symbol_name {
481 refs.push(StoredReference {
482 to_symbol: format!("sym_{}", symbol_name),
483 kind: ReferenceKind::Call,
484 file_path: target_file.clone(),
485 line_number: line_num + 1,
486 });
487 ref_count += 1;
488 }
489 }
490 }
491 }
492 }
493 }
494 }
495
496 Ok(ref_count)
497 }
498
499 async fn collect_symbols_recursive(
500 &self,
501 dir: &Path,
502 symbols: &mut std::collections::HashMap<String, (PathBuf, usize)>,
503 ) -> Result<()> {
504 use tokio::fs;
505
506 let mut entries = fs::read_dir(dir).await
507 .map_err(|e| ForgeError::DatabaseError(format!("Failed to read dir: {}", e)))?;
508
509 while let Some(entry) = entries.next_entry().await
510 .map_err(|e| ForgeError::DatabaseError(format!("Failed to read entry: {}", e)))?
511 {
512 let path = entry.path();
513 if path.is_dir() {
514 Box::pin(self.collect_symbols_recursive(&path, symbols)).await?;
515 } else if path.is_file() && path.extension().map(|e| e == "rs").unwrap_or(false) {
516 if let Ok(content) = fs::read_to_string(&path).await {
517 for (line_num, line) in content.lines().enumerate() {
518 if let Some(fn_pos) = line.find("fn ") {
520 let after_fn = &line[fn_pos + 3..];
521 if let Some(end_pos) = after_fn.find(|c: char| c.is_whitespace() || c == '(') {
522 let name = after_fn[..end_pos].trim().to_string();
523 if !name.is_empty() {
524 symbols.insert(name, (path.clone(), line_num + 1));
525 }
526 }
527 }
528 if let Some(struct_pos) = line.find("struct ") {
530 let after_struct = &line[struct_pos + 7..];
531 if let Some(end_pos) = after_struct.find(|c: char| c.is_whitespace() || c == '{' || c == ';') {
532 let name = after_struct[..end_pos].trim().to_string();
533 if !name.is_empty() {
534 symbols.insert(name, (path.clone(), line_num + 1));
535 }
536 }
537 }
538 }
539 }
540 }
541 }
542
543 Ok(())
544 }
545
546 pub async fn query_references_for_symbol(&self, symbol_name: &str) -> Result<Vec<Reference>> {
549 if self.backend_kind != BackendKind::NativeV3 {
550 return Ok(Vec::new());
551 }
552
553 let refs = self.references.lock().unwrap();
554 let mut result = Vec::new();
555
556 for stored in refs.iter() {
557 if stored.to_symbol == format!("sym_{}", symbol_name) ||
558 stored.to_symbol.contains(symbol_name) {
559 result.push(Reference {
560 from: SymbolId(0),
561 to: SymbolId(0),
562 kind: stored.kind.clone(),
563 location: Location {
564 file_path: stored.file_path.clone(),
565 byte_start: 0,
566 byte_end: 0,
567 line_number: stored.line_number,
568 },
569 });
570 }
571 }
572
573 Ok(result)
574 }
575}
576
577#[cfg(test)]
578mod tests {
579 use super::*;
580
581 #[test]
583 fn test_backend_kind_default() {
584 assert_eq!(BackendKind::default(), BackendKind::SQLite);
585 }
586
587 #[test]
589 fn test_backend_kind_to_sqlitegraph() {
590 assert_eq!(BackendKind::SQLite.to_sqlitegraph_kind(), SqliteGraphBackendKind::SQLite);
591 assert_eq!(BackendKind::NativeV3.to_sqlitegraph_kind(), SqliteGraphBackendKind::Native);
592 }
593
594 #[test]
596 fn test_backend_kind_file_extension() {
597 assert_eq!(BackendKind::SQLite.file_extension(), "db");
598 assert_eq!(BackendKind::NativeV3.file_extension(), "v3");
599 }
600
601 #[test]
603 fn test_backend_kind_default_filename() {
604 assert_eq!(BackendKind::SQLite.default_filename(), "graph.db");
605 assert_eq!(BackendKind::NativeV3.default_filename(), "graph.v3");
606 }
607
608 #[test]
610 fn test_backend_kind_display() {
611 assert_eq!(BackendKind::SQLite.to_string(), "SQLite");
612 assert_eq!(BackendKind::NativeV3.to_string(), "NativeV3");
613 }
614
615 #[tokio::test]
617 async fn test_open_sqlite_creates_database() {
618 let temp_dir = tempfile::tempdir().unwrap();
619 let store = UnifiedGraphStore::open(temp_dir.path(), BackendKind::SQLite).await.unwrap();
620
621 assert_eq!(store.backend_kind(), BackendKind::SQLite);
622 assert!(store.db_path().ends_with("graph.db"));
623 assert!(store.is_connected());
624 }
625
626 #[tokio::test]
628 async fn test_open_native_v3_creates_database() {
629 let temp_dir = tempfile::tempdir().unwrap();
630 let store = UnifiedGraphStore::open(temp_dir.path(), BackendKind::NativeV3).await.unwrap();
631
632 assert_eq!(store.backend_kind(), BackendKind::NativeV3);
633 assert!(store.db_path().ends_with("graph.v3"));
634 assert!(store.is_connected());
635 }
636
637 #[tokio::test]
639 async fn test_open_with_custom_path() {
640 let temp_dir = tempfile::tempdir().unwrap();
641 let custom_db = temp_dir.path().join("custom").join("graph.db");
642
643 let store = UnifiedGraphStore::open_with_path(
644 temp_dir.path(),
645 &custom_db,
646 BackendKind::SQLite
647 ).await.unwrap();
648
649 assert_eq!(store.db_path(), custom_db);
650 assert!(store.is_connected());
651 }
652
653 #[tokio::test]
655 async fn test_insert_symbol_returns_id() {
656 let store = UnifiedGraphStore::memory().await.unwrap();
657
658 let symbol = Symbol {
659 id: SymbolId(0),
660 name: "test_function".to_string(),
661 fully_qualified_name: "crate::test_function".to_string(),
662 kind: SymbolKind::Function,
663 language: Language::Rust,
664 location: Location {
665 file_path: PathBuf::from("src/lib.rs"),
666 byte_start: 0,
667 byte_end: 100,
668 line_number: 10,
669 },
670 parent_id: None,
671 metadata: serde_json::json!({"doc": "Test function"}),
672 };
673
674 let id = store.insert_symbol(&symbol).await.unwrap();
675 assert!(id.0 > 0);
676 }
677
678 #[tokio::test]
680 async fn test_query_symbols_empty() {
681 let temp_dir = tempfile::tempdir().unwrap();
682 let store = UnifiedGraphStore::open(temp_dir.path(), BackendKind::SQLite).await.unwrap();
683
684 let results = store.query_symbols("nonexistent_xyz").await.unwrap();
686 assert!(results.is_empty());
687 }
688
689 #[tokio::test]
691 async fn test_insert_reference_placeholder() {
692 let store = UnifiedGraphStore::memory().await.unwrap();
693
694 let reference = Reference {
695 from: SymbolId(1),
696 to: SymbolId(2),
697 kind: ReferenceKind::Call,
698 location: Location {
699 file_path: PathBuf::from("src/lib.rs"),
700 byte_start: 25,
701 byte_end: 35,
702 line_number: 2,
703 },
704 };
705
706 store.insert_reference(&reference).await.unwrap();
708 }
709
710 #[tokio::test]
712 async fn test_symbol_exists_placeholder() {
713 let store = UnifiedGraphStore::memory().await.unwrap();
714
715 assert!(!store.symbol_exists(SymbolId(1)).await.unwrap());
717 }
718
719 #[tokio::test]
721 async fn test_get_all_symbols_empty() {
722 let store = UnifiedGraphStore::memory().await.unwrap();
723
724 let symbols = store.get_all_symbols().await.unwrap();
725 assert!(symbols.is_empty());
726 }
727
728 #[tokio::test]
730 async fn test_symbol_count_zero() {
731 let store = UnifiedGraphStore::memory().await.unwrap();
732
733 let count = store.symbol_count().await.unwrap();
734 assert_eq!(count, 0);
735 }
736
737 #[test]
739 fn test_unified_graph_store_clone() {
740 let store = UnifiedGraphStore {
741 codebase_path: PathBuf::from("/test"),
742 db_path: PathBuf::from("/test/graph.db"),
743 backend_kind: BackendKind::SQLite,
744 references: std::sync::Mutex::new(Vec::new()),
745 };
746
747 let cloned = store.clone();
748
749 assert_eq!(cloned.codebase_path, PathBuf::from("/test"));
750 assert_eq!(cloned.db_path, PathBuf::from("/test/graph.db"));
751 assert_eq!(cloned.backend_kind, BackendKind::SQLite);
752 }
753
754 #[test]
756 fn test_unified_graph_store_debug() {
757 let store = UnifiedGraphStore {
758 codebase_path: PathBuf::from("/test"),
759 db_path: PathBuf::from("/test/graph.db"),
760 backend_kind: BackendKind::SQLite,
761 references: std::sync::Mutex::new(Vec::new()),
762 };
763
764 let debug_str = format!("{:?}", store);
765 assert!(debug_str.contains("UnifiedGraphStore"));
766 assert!(debug_str.contains("codebase_path: \"/test\""));
767 assert!(debug_str.contains("db_path: \"/test/graph.db\""));
768 assert!(debug_str.contains("backend_kind: SQLite"));
769 }
770}