1use std::path::{Path, PathBuf};
33use std::sync::Arc;
34
35use crate::error::{ForgeError, Result};
36use crate::types::{Symbol, SymbolId, Reference, SymbolKind, ReferenceKind, Language, Location};
37
38pub use sqlitegraph::backend::{NodeSpec, EdgeSpec};
40pub use sqlitegraph::graph::{GraphEntity, SqliteGraph};
41pub use sqlitegraph::config::{BackendKind as SqliteGraphBackendKind, GraphConfig, open_graph};
42
43#[derive(Clone, Copy, Debug, PartialEq, Eq)]
47pub enum BackendKind {
48 SQLite,
50 NativeV3,
52}
53
54impl Default for BackendKind {
55 fn default() -> Self {
56 Self::SQLite
57 }
58}
59
60impl std::fmt::Display for BackendKind {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 match self {
63 Self::SQLite => write!(f, "SQLite"),
64 Self::NativeV3 => write!(f, "NativeV3"),
65 }
66 }
67}
68
69impl BackendKind {
70 fn to_sqlitegraph_kind(&self) -> SqliteGraphBackendKind {
72 match self {
73 Self::SQLite => SqliteGraphBackendKind::SQLite,
74 Self::NativeV3 => SqliteGraphBackendKind::Native,
75 }
76 }
77
78 pub fn file_extension(&self) -> &str {
80 match self {
81 Self::SQLite => "db",
82 Self::NativeV3 => "v3",
83 }
84 }
85
86 pub fn default_filename(&self) -> &str {
88 match self {
89 Self::SQLite => "graph.db",
90 Self::NativeV3 => "graph.v3",
91 }
92 }
93}
94
95pub struct UnifiedGraphStore {
101 pub codebase_path: PathBuf,
103 pub db_path: PathBuf,
105 pub backend_kind: BackendKind,
107 references: std::sync::Mutex<Vec<StoredReference>>,
109}
110
111#[derive(Clone, Debug)]
113struct StoredReference {
114 from_symbol: String,
115 to_symbol: String,
116 kind: ReferenceKind,
117 file_path: PathBuf,
118 line_number: usize,
119}
120
121impl Clone for UnifiedGraphStore {
122 fn clone(&self) -> Self {
123 Self {
124 codebase_path: self.codebase_path.clone(),
125 db_path: self.db_path.clone(),
126 backend_kind: self.backend_kind,
127 references: std::sync::Mutex::new(self.references.lock().unwrap().clone()),
128 }
129 }
130}
131
132impl std::fmt::Debug for UnifiedGraphStore {
133 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134 f.debug_struct("UnifiedGraphStore")
135 .field("codebase_path", &self.codebase_path)
136 .field("db_path", &self.db_path)
137 .field("backend_kind", &self.backend_kind)
138 .field("connected", &self.is_connected())
139 .finish()
140 }
141}
142
143impl UnifiedGraphStore {
144 pub async fn open(codebase_path: impl AsRef<Path>, backend_kind: BackendKind) -> Result<Self> {
155 let codebase = codebase_path.as_ref();
156 let db_path = codebase.join(".forge").join(backend_kind.default_filename());
157
158 if let Some(parent) = db_path.parent() {
160 tokio::fs::create_dir_all(parent).await
161 .map_err(|e| ForgeError::DatabaseError(
162 format!("Failed to create database directory: {}", e)
163 ))?;
164 }
165
166 let config = match backend_kind {
168 BackendKind::SQLite => GraphConfig::sqlite(),
169 BackendKind::NativeV3 => GraphConfig::native(),
170 };
171
172 let _graph = open_graph(&db_path, &config)
173 .map_err(|e| ForgeError::DatabaseError(
174 format!("Failed to open database: {}", e)
175 ))?;
176
177 Ok(UnifiedGraphStore {
178 codebase_path: codebase.to_path_buf(),
179 db_path,
180 backend_kind,
181 references: std::sync::Mutex::new(Vec::new()),
182 })
183 }
184
185 pub async fn open_with_path(
193 codebase_path: impl AsRef<Path>,
194 db_path: impl AsRef<Path>,
195 backend_kind: BackendKind,
196 ) -> Result<Self> {
197 let codebase = codebase_path.as_ref();
198 let db = db_path.as_ref();
199
200 if let Some(parent) = db.parent() {
202 tokio::fs::create_dir_all(parent).await
203 .map_err(|e| ForgeError::DatabaseError(
204 format!("Failed to create database directory: {}", e)
205 ))?;
206 }
207
208 let config = match backend_kind {
210 BackendKind::SQLite => GraphConfig::sqlite(),
211 BackendKind::NativeV3 => GraphConfig::native(),
212 };
213
214 let _graph = open_graph(db, &config)
215 .map_err(|e| ForgeError::DatabaseError(
216 format!("Failed to open database: {}", e)
217 ))?;
218
219 Ok(UnifiedGraphStore {
220 codebase_path: codebase.to_path_buf(),
221 db_path: db.to_path_buf(),
222 backend_kind,
223 references: std::sync::Mutex::new(Vec::new()),
224 })
225 }
226
227 #[cfg(test)]
229 pub async fn memory() -> Result<Self> {
230 use tempfile::tempdir;
231
232 let temp = tempdir().map_err(|e| ForgeError::DatabaseError(
233 format!("Failed to create temp directory: {}", e)
234 ))?;
235
236 Self::open(temp.path(), BackendKind::SQLite).await
237 }
238
239 #[inline]
241 pub fn backend_kind(&self) -> BackendKind {
242 self.backend_kind
243 }
244
245 #[inline]
247 pub fn db_path(&self) -> &Path {
248 &self.db_path
249 }
250
251 pub fn is_connected(&self) -> bool {
253 self.db_path.exists()
254 }
255
256 pub async fn insert_symbol(&self, _symbol: &Symbol) -> Result<SymbolId> {
266 Ok(SymbolId(1))
273 }
274
275 pub async fn insert_reference(&self, reference: &Reference) -> Result<()> {
281 if self.backend_kind == BackendKind::NativeV3 {
284 let mut refs = self.references.lock().unwrap();
285
286 let from_symbol = format!("sym_{}", reference.from.0);
289 let to_symbol = format!("sym_{}", reference.to.0);
290
291 refs.push(StoredReference {
292 from_symbol,
293 to_symbol,
294 kind: reference.kind.clone(),
295 file_path: reference.location.file_path.clone(),
296 line_number: reference.location.line_number,
297 });
298 }
299 Ok(())
300 }
301
302 pub async fn query_symbols(&self, name: &str) -> Result<Vec<Symbol>> {
312 self.search_codebase_files(name).await
314 }
315
316 async fn search_codebase_files(&self, pattern: &str) -> Result<Vec<Symbol>> {
318 use tokio::fs;
319
320 let mut symbols = Vec::new();
321 let mut entries = fs::read_dir(&self.codebase_path).await
322 .map_err(|e| ForgeError::DatabaseError(format!("Failed to read codebase: {}", e)))?;
323
324 while let Some(entry) = entries.next_entry().await
325 .map_err(|e| ForgeError::DatabaseError(format!("Failed to read entry: {}", e)))?
326 {
327 let path = entry.path();
328 if path.extension().map(|e| e == "rs").unwrap_or(false) {
329 if let Ok(content) = fs::read_to_string(&path).await {
330 for (line_num, line) in content.lines().enumerate() {
331 if line.contains(pattern) {
332 let name = line.split_whitespace()
334 .find(|w| w.contains(pattern))
335 .map(|s| s.trim_matches(|c: char| !c.is_alphanumeric() && c != '_'))
336 .unwrap_or(pattern)
337 .to_string();
338
339 symbols.push(Symbol {
340 id: SymbolId(symbols.len() as i64 + 1),
341 name: name.clone(),
342 fully_qualified_name: name,
343 kind: SymbolKind::Function,
344 language: Language::Rust,
345 location: Location {
346 file_path: path.clone(),
347 byte_start: 0,
348 byte_end: line.len() as u32,
349 line_number: line_num + 1,
350 },
351 parent_id: None,
352 metadata: serde_json::Value::Null,
353 });
354 break; }
356 }
357 }
358 }
359 }
360
361 Ok(symbols)
362 }
363
364 pub async fn get_symbol(&self, _id: SymbolId) -> Result<Symbol> {
374 Err(ForgeError::SymbolNotFound("Not implemented".to_string()))
375 }
376
377 pub async fn symbol_exists(&self, _id: SymbolId) -> Result<bool> {
383 Ok(false)
384 }
385
386 pub async fn query_references(&self, symbol_id: SymbolId) -> Result<Vec<Reference>> {
397 if self.backend_kind == BackendKind::NativeV3 {
400 let refs = self.references.lock().unwrap();
401 let target_symbol = format!("sym_{}", symbol_id.0);
402
403 let mut result = Vec::new();
404 for stored in refs.iter() {
405 if stored.to_symbol == target_symbol {
406 result.push(Reference {
407 from: SymbolId(0),
408 to: symbol_id,
409 kind: stored.kind.clone(),
410 location: Location {
411 file_path: stored.file_path.clone(),
412 byte_start: 0,
413 byte_end: 0,
414 line_number: stored.line_number,
415 },
416 });
417 }
418 }
419 return Ok(result);
420 }
421
422 Ok(Vec::new())
424 }
425
426 pub async fn get_all_symbols(&self) -> Result<Vec<Symbol>> {
428 Ok(Vec::new())
429 }
430
431 pub async fn symbol_count(&self) -> Result<usize> {
433 Ok(0)
434 }
435
436 pub async fn index_cross_file_references(&self) -> Result<usize> {
445 if self.backend_kind != BackendKind::NativeV3 {
446 return Ok(0); }
448
449 self.legacy_index_cross_file_references().await
452 }
453
454 async fn legacy_index_cross_file_references(&self) -> Result<usize> {
456 use tokio::fs;
457 use regex::Regex;
458
459 let mut symbols: std::collections::HashMap<String, (PathBuf, usize)> = std::collections::HashMap::new();
461 self.collect_symbols_recursive(&self.codebase_path, &mut symbols).await?;
462
463 let mut ref_count = 0;
465 let mut refs = self.references.lock().unwrap();
466 refs.clear(); let reference_pattern = Regex::new(r"\b([a-zA-Z_][a-zA-Z0-9_]*)\s*\(").unwrap();
469
470 for (symbol_name, (file_path, _)) in &symbols {
471 for (target_file, _) in symbols.values() {
473 if let Ok(content) = fs::read_to_string(target_file).await {
474 for (line_num, line) in content.lines().enumerate() {
475 if line.contains("fn ") || line.contains("struct ") {
477 continue;
478 }
479
480 for cap in reference_pattern.captures_iter(line) {
482 if let Some(matched) = cap.get(1) {
483 if matched.as_str() == symbol_name {
484 refs.push(StoredReference {
485 from_symbol: format!("file_{}", target_file.display()),
486 to_symbol: format!("sym_{}", symbol_name),
487 kind: ReferenceKind::Call,
488 file_path: target_file.clone(),
489 line_number: line_num + 1,
490 });
491 ref_count += 1;
492 }
493 }
494 }
495 }
496 }
497 }
498 }
499
500 Ok(ref_count)
501 }
502
503 async fn collect_symbols_recursive(
504 &self,
505 dir: &Path,
506 symbols: &mut std::collections::HashMap<String, (PathBuf, usize)>,
507 ) -> Result<()> {
508 use tokio::fs;
509
510 let mut entries = fs::read_dir(dir).await
511 .map_err(|e| ForgeError::DatabaseError(format!("Failed to read dir: {}", e)))?;
512
513 while let Some(entry) = entries.next_entry().await
514 .map_err(|e| ForgeError::DatabaseError(format!("Failed to read entry: {}", e)))?
515 {
516 let path = entry.path();
517 if path.is_dir() {
518 Box::pin(self.collect_symbols_recursive(&path, symbols)).await?;
519 } else if path.is_file() && path.extension().map(|e| e == "rs").unwrap_or(false) {
520 if let Ok(content) = fs::read_to_string(&path).await {
521 for (line_num, line) in content.lines().enumerate() {
522 if let Some(fn_pos) = line.find("fn ") {
524 let after_fn = &line[fn_pos + 3..];
525 if let Some(end_pos) = after_fn.find(|c: char| c.is_whitespace() || c == '(') {
526 let name = after_fn[..end_pos].trim().to_string();
527 if !name.is_empty() {
528 symbols.insert(name, (path.clone(), line_num + 1));
529 }
530 }
531 }
532 if let Some(struct_pos) = line.find("struct ") {
534 let after_struct = &line[struct_pos + 7..];
535 if let Some(end_pos) = after_struct.find(|c: char| c.is_whitespace() || c == '{' || c == ';') {
536 let name = after_struct[..end_pos].trim().to_string();
537 if !name.is_empty() {
538 symbols.insert(name, (path.clone(), line_num + 1));
539 }
540 }
541 }
542 }
543 }
544 }
545 }
546
547 Ok(())
548 }
549
550 pub async fn query_references_for_symbol(&self, symbol_name: &str) -> Result<Vec<Reference>> {
553 if self.backend_kind != BackendKind::NativeV3 {
554 return Ok(Vec::new());
555 }
556
557 let refs = self.references.lock().unwrap();
558 let mut result = Vec::new();
559
560 for stored in refs.iter() {
561 if stored.to_symbol == format!("sym_{}", symbol_name) ||
562 stored.to_symbol.contains(symbol_name) {
563 result.push(Reference {
564 from: SymbolId(0),
565 to: SymbolId(0),
566 kind: stored.kind.clone(),
567 location: Location {
568 file_path: stored.file_path.clone(),
569 byte_start: 0,
570 byte_end: 0,
571 line_number: stored.line_number,
572 },
573 });
574 }
575 }
576
577 Ok(result)
578 }
579}
580
581#[cfg(test)]
582mod tests {
583 use super::*;
584
585 #[test]
587 fn test_backend_kind_default() {
588 assert_eq!(BackendKind::default(), BackendKind::SQLite);
589 }
590
591 #[test]
593 fn test_backend_kind_to_sqlitegraph() {
594 assert_eq!(BackendKind::SQLite.to_sqlitegraph_kind(), SqliteGraphBackendKind::SQLite);
595 assert_eq!(BackendKind::NativeV3.to_sqlitegraph_kind(), SqliteGraphBackendKind::Native);
596 }
597
598 #[test]
600 fn test_backend_kind_file_extension() {
601 assert_eq!(BackendKind::SQLite.file_extension(), "db");
602 assert_eq!(BackendKind::NativeV3.file_extension(), "v3");
603 }
604
605 #[test]
607 fn test_backend_kind_default_filename() {
608 assert_eq!(BackendKind::SQLite.default_filename(), "graph.db");
609 assert_eq!(BackendKind::NativeV3.default_filename(), "graph.v3");
610 }
611
612 #[test]
614 fn test_backend_kind_display() {
615 assert_eq!(BackendKind::SQLite.to_string(), "SQLite");
616 assert_eq!(BackendKind::NativeV3.to_string(), "NativeV3");
617 }
618
619 #[tokio::test]
621 async fn test_open_sqlite_creates_database() {
622 let temp_dir = tempfile::tempdir().unwrap();
623 let store = UnifiedGraphStore::open(temp_dir.path(), BackendKind::SQLite).await.unwrap();
624
625 assert_eq!(store.backend_kind(), BackendKind::SQLite);
626 assert!(store.db_path().ends_with("graph.db"));
627 assert!(store.is_connected());
628 }
629
630 #[tokio::test]
632 async fn test_open_native_v3_creates_database() {
633 let temp_dir = tempfile::tempdir().unwrap();
634 let store = UnifiedGraphStore::open(temp_dir.path(), BackendKind::NativeV3).await.unwrap();
635
636 assert_eq!(store.backend_kind(), BackendKind::NativeV3);
637 assert!(store.db_path().ends_with("graph.v3"));
638 assert!(store.is_connected());
639 }
640
641 #[tokio::test]
643 async fn test_open_with_custom_path() {
644 let temp_dir = tempfile::tempdir().unwrap();
645 let custom_db = temp_dir.path().join("custom").join("graph.db");
646
647 let store = UnifiedGraphStore::open_with_path(
648 temp_dir.path(),
649 &custom_db,
650 BackendKind::SQLite
651 ).await.unwrap();
652
653 assert_eq!(store.db_path(), custom_db);
654 assert!(store.is_connected());
655 }
656
657 #[tokio::test]
659 async fn test_insert_symbol_returns_id() {
660 let store = UnifiedGraphStore::memory().await.unwrap();
661
662 let symbol = Symbol {
663 id: SymbolId(0),
664 name: "test_function".to_string(),
665 fully_qualified_name: "crate::test_function".to_string(),
666 kind: SymbolKind::Function,
667 language: Language::Rust,
668 location: Location {
669 file_path: PathBuf::from("src/lib.rs"),
670 byte_start: 0,
671 byte_end: 100,
672 line_number: 10,
673 },
674 parent_id: None,
675 metadata: serde_json::json!({"doc": "Test function"}),
676 };
677
678 let id = store.insert_symbol(&symbol).await.unwrap();
679 assert!(id.0 > 0);
680 }
681
682 #[tokio::test]
684 async fn test_query_symbols_empty() {
685 let temp_dir = tempfile::tempdir().unwrap();
686 let store = UnifiedGraphStore::open(temp_dir.path(), BackendKind::SQLite).await.unwrap();
687
688 let results = store.query_symbols("nonexistent_xyz").await.unwrap();
690 assert!(results.is_empty());
691 }
692
693 #[tokio::test]
695 async fn test_insert_reference_placeholder() {
696 let store = UnifiedGraphStore::memory().await.unwrap();
697
698 let reference = Reference {
699 from: SymbolId(1),
700 to: SymbolId(2),
701 kind: ReferenceKind::Call,
702 location: Location {
703 file_path: PathBuf::from("src/lib.rs"),
704 byte_start: 25,
705 byte_end: 35,
706 line_number: 2,
707 },
708 };
709
710 store.insert_reference(&reference).await.unwrap();
712 }
713
714 #[tokio::test]
716 async fn test_symbol_exists_placeholder() {
717 let store = UnifiedGraphStore::memory().await.unwrap();
718
719 assert!(!store.symbol_exists(SymbolId(1)).await.unwrap());
721 }
722
723 #[tokio::test]
725 async fn test_get_all_symbols_empty() {
726 let store = UnifiedGraphStore::memory().await.unwrap();
727
728 let symbols = store.get_all_symbols().await.unwrap();
729 assert!(symbols.is_empty());
730 }
731
732 #[tokio::test]
734 async fn test_symbol_count_zero() {
735 let store = UnifiedGraphStore::memory().await.unwrap();
736
737 let count = store.symbol_count().await.unwrap();
738 assert_eq!(count, 0);
739 }
740
741 #[test]
743 fn test_unified_graph_store_clone() {
744 let store = UnifiedGraphStore {
745 codebase_path: PathBuf::from("/test"),
746 db_path: PathBuf::from("/test/graph.db"),
747 backend_kind: BackendKind::SQLite,
748 references: std::sync::Mutex::new(Vec::new()),
749 };
750
751 let cloned = store.clone();
752
753 assert_eq!(cloned.codebase_path, PathBuf::from("/test"));
754 assert_eq!(cloned.db_path, PathBuf::from("/test/graph.db"));
755 assert_eq!(cloned.backend_kind, BackendKind::SQLite);
756 }
757
758 #[test]
760 fn test_unified_graph_store_debug() {
761 let store = UnifiedGraphStore {
762 codebase_path: PathBuf::from("/test"),
763 db_path: PathBuf::from("/test/graph.db"),
764 backend_kind: BackendKind::SQLite,
765 references: std::sync::Mutex::new(Vec::new()),
766 };
767
768 let debug_str = format!("{:?}", store);
769 assert!(debug_str.contains("UnifiedGraphStore"));
770 assert!(debug_str.contains("codebase_path: \"/test\""));
771 assert!(debug_str.contains("db_path: \"/test/graph.db\""));
772 assert!(debug_str.contains("backend_kind: SQLite"));
773 }
774}