1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5use dashmap::DashMap;
6use dk_core::{
7 CallEdge, Error, RawCallEdge, RepoId, Result, Symbol, SymbolId,
8};
9use sqlx::postgres::PgPool;
10use tokio::sync::RwLock;
11use uuid::Uuid;
12
13use crate::changeset::ChangesetStore;
14use crate::git::GitRepository;
15use crate::pipeline::PipelineStore;
16use crate::graph::{
17 CallGraphStore, DependencyStore, SearchIndex, SymbolStore, TypeInfoStore,
18};
19use crate::parser::ParserRegistry;
20use crate::workspace::session_manager::WorkspaceManager;
21
22#[derive(Debug, Clone)]
26pub struct CodebaseSummary {
27 pub languages: Vec<String>,
28 pub total_symbols: u64,
29 pub total_files: u64,
30}
31
32pub struct Engine {
39 pub db: PgPool,
40 pub search_index: Arc<RwLock<SearchIndex>>,
41 pub parser: Arc<ParserRegistry>,
42 pub storage_path: PathBuf,
43 symbol_store: SymbolStore,
44 call_graph_store: CallGraphStore,
45 #[allow(dead_code)]
46 dep_store: DependencyStore,
47 type_info_store: TypeInfoStore,
48 changeset_store: ChangesetStore,
49 pipeline_store: PipelineStore,
50 workspace_manager: WorkspaceManager,
51 repo_locks: DashMap<RepoId, Arc<RwLock<()>>>,
52}
53
54impl Engine {
55 pub fn new(storage_path: PathBuf, db: PgPool) -> Result<Self> {
61 let search_index = SearchIndex::open(&storage_path.join("search_index"))?;
62 let parser = ParserRegistry::new();
63 let symbol_store = SymbolStore::new(db.clone());
64 let call_graph_store = CallGraphStore::new(db.clone());
65 let dep_store = DependencyStore::new(db.clone());
66 let type_info_store = TypeInfoStore::new(db.clone());
67 let changeset_store = ChangesetStore::new(db.clone());
68 let pipeline_store = PipelineStore::new(db.clone());
69 let workspace_manager = WorkspaceManager::new(db.clone());
70
71 Ok(Self {
72 db,
73 search_index: Arc::new(RwLock::new(search_index)),
74 parser: Arc::new(parser),
75 storage_path,
76 symbol_store,
77 call_graph_store,
78 dep_store,
79 type_info_store,
80 changeset_store,
81 pipeline_store,
82 workspace_manager,
83 repo_locks: DashMap::new(),
84 })
85 }
86
87 pub fn symbol_store(&self) -> &SymbolStore {
89 &self.symbol_store
90 }
91
92 pub fn changeset_store(&self) -> &ChangesetStore {
94 &self.changeset_store
95 }
96
97 pub fn pipeline_store(&self) -> &PipelineStore {
99 &self.pipeline_store
100 }
101
102 pub fn workspace_manager(&self) -> &WorkspaceManager {
104 &self.workspace_manager
105 }
106
107 pub fn call_graph_store(&self) -> &CallGraphStore {
109 &self.call_graph_store
110 }
111
112 pub fn dep_store(&self) -> &DependencyStore {
114 &self.dep_store
115 }
116
117 pub fn parser(&self) -> &ParserRegistry {
119 &self.parser
120 }
121
122 pub fn repo_lock(&self, repo_id: RepoId) -> Arc<RwLock<()>> {
126 self.repo_locks
127 .entry(repo_id)
128 .or_insert_with(|| Arc::new(RwLock::new(())))
129 .clone()
130 }
131
132 pub fn remove_repo_lock(&self, repo_id: RepoId) {
134 self.repo_locks.remove(&repo_id);
135 }
136
137 pub async fn create_repo(&self, name: &str) -> Result<RepoId> {
145 let repo_id = Uuid::new_v4();
146 let repo_path = self.storage_path.join("repos").join(repo_id.to_string());
147
148 GitRepository::init(&repo_path)?;
149
150 sqlx::query(
151 r#"
152 INSERT INTO repositories (id, name, path)
153 VALUES ($1, $2, $3)
154 "#,
155 )
156 .bind(repo_id)
157 .bind(name)
158 .bind(repo_path.to_string_lossy().as_ref())
159 .execute(&self.db)
160 .await?;
161
162 Ok(repo_id)
163 }
164
165 pub async fn get_repo(&self, name: &str) -> Result<(RepoId, GitRepository)> {
169 let row: (Uuid, String) = sqlx::query_as(
170 "SELECT id, path FROM repositories WHERE name = $1",
171 )
172 .bind(name)
173 .fetch_optional(&self.db)
174 .await?
175 .ok_or_else(|| Error::RepoNotFound(name.to_string()))?;
176
177 let (repo_id, repo_path) = row;
178 let git_repo = GitRepository::open(Path::new(&repo_path))?;
179 Ok((repo_id, git_repo))
180 }
181
182 pub async fn get_repo_by_db_id(&self, repo_id: RepoId) -> Result<(RepoId, GitRepository)> {
186 let row: (String,) = sqlx::query_as(
187 "SELECT path FROM repositories WHERE id = $1",
188 )
189 .bind(repo_id)
190 .fetch_optional(&self.db)
191 .await?
192 .ok_or_else(|| Error::RepoNotFound(repo_id.to_string()))?;
193
194 let git_repo = GitRepository::open(Path::new(&row.0))?;
195 Ok((repo_id, git_repo))
196 }
197
198 pub async fn index_repo(
206 &self,
207 repo_id: RepoId,
208 git_repo: &GitRepository,
209 ) -> Result<()> {
210 let root = git_repo.path().to_path_buf();
211 let files = collect_files(&root, &self.parser);
212
213 let mut all_symbols: Vec<Symbol> = Vec::new();
216 let mut all_raw_edges: Vec<RawCallEdge> = Vec::new();
217
218 let mut search_index = self.search_index.write().await;
220
221 for file_path in &files {
222 let relative = file_path
223 .strip_prefix(&root)
224 .unwrap_or(file_path);
225
226 let source = std::fs::read(file_path).map_err(|e| {
227 Error::Io(e)
228 })?;
229
230 let analysis = self.parser.parse_file(relative, &source)?;
231
232 for sym in &analysis.symbols {
234 self.symbol_store.upsert_symbol(repo_id, sym).await?;
235 search_index.index_symbol(repo_id, sym)?;
236 }
237
238 for ti in &analysis.types {
240 self.type_info_store.upsert_type_info(ti).await?;
241 }
242
243 all_symbols.extend(analysis.symbols);
244 all_raw_edges.extend(analysis.calls);
245 }
246
247 search_index.commit()?;
249 drop(search_index);
250
251 let edges = resolve_call_edges(&all_raw_edges, &all_symbols, repo_id);
253 for edge in &edges {
254 self.call_graph_store.insert_edge(edge).await?;
255 }
256
257 Ok(())
258 }
259
260 pub async fn update_files(
265 &self,
266 repo_id: RepoId,
267 git_repo: &GitRepository,
268 changed_files: &[PathBuf],
269 ) -> Result<()> {
270 self.update_files_by_root(repo_id, git_repo.path(), changed_files)
271 .await
272 }
273
274 pub async fn update_files_by_root(
280 &self,
281 repo_id: RepoId,
282 root: &Path,
283 changed_files: &[PathBuf],
284 ) -> Result<()> {
285 let root = root.to_path_buf();
286
287 let mut all_symbols: Vec<Symbol> = Vec::new();
288 let mut all_raw_edges: Vec<RawCallEdge> = Vec::new();
289
290 let mut search_index = self.search_index.write().await;
292
293 for file_path in changed_files {
294 let relative = file_path
295 .strip_prefix(&root)
296 .unwrap_or(file_path);
297 let rel_str = relative.to_string_lossy().to_string();
298
299 let old_symbols = self
302 .symbol_store
303 .find_by_file(repo_id, &rel_str)
304 .await?;
305 for old_sym in &old_symbols {
306 search_index.remove_symbol(old_sym.id)?;
307 }
308
309 self.call_graph_store
311 .delete_edges_for_file(repo_id, &rel_str)
312 .await?;
313 self.symbol_store
314 .delete_by_file(repo_id, &rel_str)
315 .await?;
316
317 let full_path = root.join(relative);
319 if !full_path.exists() {
320 continue;
322 }
323
324 if !self.parser.supports_file(relative) {
325 continue;
326 }
327
328 let source = std::fs::read(&full_path)?;
329 let analysis = self.parser.parse_file(relative, &source)?;
330
331 for sym in &analysis.symbols {
332 self.symbol_store.upsert_symbol(repo_id, sym).await?;
333 search_index.index_symbol(repo_id, sym)?;
334 }
335
336 for ti in &analysis.types {
337 self.type_info_store.upsert_type_info(ti).await?;
338 }
339
340 all_symbols.extend(analysis.symbols);
341 all_raw_edges.extend(analysis.calls);
342 }
343
344 search_index.commit()?;
345 drop(search_index);
346
347 let edges = resolve_call_edges(&all_raw_edges, &all_symbols, repo_id);
348 for edge in &edges {
349 self.call_graph_store.insert_edge(edge).await?;
350 }
351
352 Ok(())
353 }
354
355 pub async fn query_symbols(
362 &self,
363 repo_id: RepoId,
364 query: &str,
365 max_results: usize,
366 ) -> Result<Vec<Symbol>> {
367 let search_index = self.search_index.read().await;
368 let ids = search_index.search(repo_id, query, max_results)?;
369 drop(search_index);
370
371 self.symbol_store.get_by_ids(&ids).await
372 }
373
374 pub async fn get_call_graph(
379 &self,
380 _repo_id: RepoId,
381 symbol_id: SymbolId,
382 ) -> Result<(Vec<Symbol>, Vec<Symbol>)> {
383 let caller_edges = self.call_graph_store.find_callers(symbol_id).await?;
384 let callee_edges = self.call_graph_store.find_callees(symbol_id).await?;
385
386 let mut callers = Vec::with_capacity(caller_edges.len());
387 for edge in &caller_edges {
388 if let Some(sym) = self.symbol_store.get_by_id(edge.caller).await? {
389 callers.push(sym);
390 }
391 }
392
393 let mut callees = Vec::with_capacity(callee_edges.len());
394 for edge in &callee_edges {
395 if let Some(sym) = self.symbol_store.get_by_id(edge.callee).await? {
396 callees.push(sym);
397 }
398 }
399
400 Ok((callers, callees))
401 }
402
403 pub async fn codebase_summary(&self, repo_id: RepoId) -> Result<CodebaseSummary> {
409 let total_symbols = self.symbol_store.count(repo_id).await? as u64;
410
411 let row: (i64, Vec<String>) = sqlx::query_as(
413 r#"
414 SELECT
415 COUNT(DISTINCT file_path),
416 COALESCE(
417 array_agg(DISTINCT substring(file_path FROM '\.([^.]+)$'))
418 FILTER (WHERE substring(file_path FROM '\.([^.]+)$') IS NOT NULL),
419 ARRAY[]::text[]
420 )
421 FROM symbols
422 WHERE repo_id = $1
423 "#,
424 )
425 .bind(repo_id)
426 .fetch_one(&self.db)
427 .await?;
428
429 let total_files = row.0 as u64;
430 let mut languages = row.1;
431 languages.sort();
432
433 Ok(CodebaseSummary {
434 languages,
435 total_symbols,
436 total_files,
437 })
438 }
439}
440
441fn collect_files(root: &Path, parser: &ParserRegistry) -> Vec<PathBuf> {
446 let mut files = Vec::new();
447 collect_files_recursive(root, root, parser, &mut files);
448 files
449}
450
451fn collect_files_recursive(
452 root: &Path,
453 dir: &Path,
454 parser: &ParserRegistry,
455 out: &mut Vec<PathBuf>,
456) {
457 let entries = match std::fs::read_dir(dir) {
458 Ok(entries) => entries,
459 Err(_) => return,
460 };
461
462 for entry in entries.flatten() {
463 let path = entry.path();
464
465 if path.is_dir() {
466 if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
468 if name == ".git" || name.starts_with('.') {
469 continue;
470 }
471 }
472 collect_files_recursive(root, &path, parser, out);
473 } else if path.is_file() {
474 let relative = path.strip_prefix(root).unwrap_or(&path);
475 if parser.supports_file(relative) {
476 out.push(path);
477 }
478 }
479 }
480}
481
482fn resolve_call_edges(
485 raw_edges: &[RawCallEdge],
486 symbols: &[Symbol],
487 repo_id: RepoId,
488) -> Vec<CallEdge> {
489 let mut name_to_id: HashMap<String, SymbolId> = HashMap::new();
492 for sym in symbols {
493 name_to_id.insert(sym.name.clone(), sym.id);
494 name_to_id.insert(sym.qualified_name.clone(), sym.id);
495 }
496
497 raw_edges
498 .iter()
499 .filter_map(|raw| {
500 let caller = name_to_id.get(&raw.caller_name)?;
501 let callee = name_to_id.get(&raw.callee_name)?;
502 Some(CallEdge {
503 id: Uuid::new_v4(),
504 repo_id,
505 caller: *caller,
506 callee: *callee,
507 kind: raw.kind.clone(),
508 })
509 })
510 .collect()
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516
517 #[test]
518 fn test_resolve_call_edges_basic() {
519 let sym_a_id = Uuid::new_v4();
520 let sym_b_id = Uuid::new_v4();
521 let repo_id = Uuid::new_v4();
522
523 let symbols = vec![
524 Symbol {
525 id: sym_a_id,
526 name: "foo".into(),
527 qualified_name: "crate::foo".into(),
528 kind: dk_core::SymbolKind::Function,
529 visibility: dk_core::Visibility::Public,
530 file_path: "src/lib.rs".into(),
531 span: dk_core::Span {
532 start_byte: 0,
533 end_byte: 100,
534 },
535 signature: None,
536 doc_comment: None,
537 parent: None,
538 last_modified_by: None,
539 last_modified_intent: None,
540 },
541 Symbol {
542 id: sym_b_id,
543 name: "bar".into(),
544 qualified_name: "crate::bar".into(),
545 kind: dk_core::SymbolKind::Function,
546 visibility: dk_core::Visibility::Public,
547 file_path: "src/lib.rs".into(),
548 span: dk_core::Span {
549 start_byte: 100,
550 end_byte: 200,
551 },
552 signature: None,
553 doc_comment: None,
554 parent: None,
555 last_modified_by: None,
556 last_modified_intent: None,
557 },
558 ];
559
560 let raw_edges = vec![RawCallEdge {
561 caller_name: "foo".into(),
562 callee_name: "bar".into(),
563 call_site: dk_core::Span {
564 start_byte: 50,
565 end_byte: 60,
566 },
567 kind: dk_core::CallKind::DirectCall,
568 }];
569
570 let edges = resolve_call_edges(&raw_edges, &symbols, repo_id);
571 assert_eq!(edges.len(), 1);
572 assert_eq!(edges[0].caller, sym_a_id);
573 assert_eq!(edges[0].callee, sym_b_id);
574 assert_eq!(edges[0].repo_id, repo_id);
575 }
576
577 #[test]
578 fn test_resolve_call_edges_unresolved_skipped() {
579 let sym_a_id = Uuid::new_v4();
580 let repo_id = Uuid::new_v4();
581
582 let symbols = vec![Symbol {
583 id: sym_a_id,
584 name: "foo".into(),
585 qualified_name: "crate::foo".into(),
586 kind: dk_core::SymbolKind::Function,
587 visibility: dk_core::Visibility::Public,
588 file_path: "src/lib.rs".into(),
589 span: dk_core::Span {
590 start_byte: 0,
591 end_byte: 100,
592 },
593 signature: None,
594 doc_comment: None,
595 parent: None,
596 last_modified_by: None,
597 last_modified_intent: None,
598 }];
599
600 let raw_edges = vec![RawCallEdge {
602 caller_name: "foo".into(),
603 callee_name: "unknown".into(),
604 call_site: dk_core::Span {
605 start_byte: 50,
606 end_byte: 60,
607 },
608 kind: dk_core::CallKind::DirectCall,
609 }];
610
611 let edges = resolve_call_edges(&raw_edges, &symbols, repo_id);
612 assert!(edges.is_empty());
613 }
614
615 #[test]
616 fn test_resolve_call_edges_qualified_name() {
617 let sym_a_id = Uuid::new_v4();
618 let sym_b_id = Uuid::new_v4();
619 let repo_id = Uuid::new_v4();
620
621 let symbols = vec![
622 Symbol {
623 id: sym_a_id,
624 name: "foo".into(),
625 qualified_name: "crate::mod_a::foo".into(),
626 kind: dk_core::SymbolKind::Function,
627 visibility: dk_core::Visibility::Public,
628 file_path: "src/mod_a.rs".into(),
629 span: dk_core::Span {
630 start_byte: 0,
631 end_byte: 100,
632 },
633 signature: None,
634 doc_comment: None,
635 parent: None,
636 last_modified_by: None,
637 last_modified_intent: None,
638 },
639 Symbol {
640 id: sym_b_id,
641 name: "bar".into(),
642 qualified_name: "crate::mod_b::bar".into(),
643 kind: dk_core::SymbolKind::Function,
644 visibility: dk_core::Visibility::Public,
645 file_path: "src/mod_b.rs".into(),
646 span: dk_core::Span {
647 start_byte: 0,
648 end_byte: 100,
649 },
650 signature: None,
651 doc_comment: None,
652 parent: None,
653 last_modified_by: None,
654 last_modified_intent: None,
655 },
656 ];
657
658 let raw_edges = vec![RawCallEdge {
660 caller_name: "crate::mod_a::foo".into(),
661 callee_name: "crate::mod_b::bar".into(),
662 call_site: dk_core::Span {
663 start_byte: 50,
664 end_byte: 60,
665 },
666 kind: dk_core::CallKind::DirectCall,
667 }];
668
669 let edges = resolve_call_edges(&raw_edges, &symbols, repo_id);
670 assert_eq!(edges.len(), 1);
671 assert_eq!(edges[0].caller, sym_a_id);
672 assert_eq!(edges[0].callee, sym_b_id);
673 }
674
675 #[test]
676 fn test_collect_files_skips_git_dir() {
677 let dir = tempfile::tempdir().unwrap();
678 let root = dir.path();
679
680 std::fs::create_dir_all(root.join(".git")).unwrap();
682 std::fs::write(root.join(".git/config"), b"git config").unwrap();
683
684 std::fs::write(root.join("main.rs"), b"fn main() {}").unwrap();
686
687 std::fs::write(root.join("notes.txt"), b"hello").unwrap();
689
690 let parser = ParserRegistry::new();
691 let files = collect_files(root, &parser);
692
693 assert_eq!(files.len(), 1);
694 assert!(files[0].ends_with("main.rs"));
695 }
696
697 #[test]
698 fn test_codebase_summary_struct() {
699 let summary = CodebaseSummary {
700 languages: vec!["rs".into(), "py".into()],
701 total_symbols: 42,
702 total_files: 5,
703 };
704 assert_eq!(summary.languages.len(), 2);
705 assert_eq!(summary.total_symbols, 42);
706 assert_eq!(summary.total_files, 5);
707 }
708}