1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5use dashmap::DashMap;
6use dk_core::{
7 CallEdge, Error, RawCallEdge, RepoId, Result, Symbol, SymbolId,
8};
9use sqlx::postgres::PgPool;
10use tokio::sync::RwLock;
11use uuid::Uuid;
12
13use crate::changeset::ChangesetStore;
14use crate::git::GitRepository;
15use crate::pipeline::PipelineStore;
16use crate::graph::{
17 CallGraphStore, DependencyStore, SearchIndex, SymbolStore, TypeInfoStore,
18};
19use crate::parser::ParserRegistry;
20use crate::workspace::cache::{NoOpCache, WorkspaceCache};
21use crate::workspace::session_manager::WorkspaceManager;
22
23#[derive(Debug, Clone)]
27pub struct CodebaseSummary {
28 pub languages: Vec<String>,
29 pub total_symbols: u64,
30 pub total_files: u64,
31}
32
33pub struct Engine {
40 pub db: PgPool,
41 pub search_index: Arc<RwLock<SearchIndex>>,
42 pub parser: Arc<ParserRegistry>,
43 pub storage_path: PathBuf,
44 symbol_store: SymbolStore,
45 call_graph_store: CallGraphStore,
46 #[allow(dead_code)]
47 dep_store: DependencyStore,
48 type_info_store: TypeInfoStore,
49 changeset_store: ChangesetStore,
50 pipeline_store: PipelineStore,
51 workspace_manager: WorkspaceManager,
52 repo_locks: DashMap<RepoId, Arc<RwLock<()>>>,
53}
54
55impl Engine {
56 pub fn new(storage_path: PathBuf, db: PgPool) -> Result<Self> {
64 Self::with_cache(storage_path, db, Arc::new(NoOpCache))
65 }
66
67 pub fn with_cache(
73 storage_path: PathBuf,
74 db: PgPool,
75 cache: Arc<dyn WorkspaceCache>,
76 ) -> Result<Self> {
77 let search_index = SearchIndex::open(&storage_path.join("search_index"))?;
78 let parser = ParserRegistry::new();
79 let symbol_store = SymbolStore::new(db.clone());
80 let call_graph_store = CallGraphStore::new(db.clone());
81 let dep_store = DependencyStore::new(db.clone());
82 let type_info_store = TypeInfoStore::new(db.clone());
83 let changeset_store = ChangesetStore::new(db.clone());
84 let pipeline_store = PipelineStore::new(db.clone());
85 let workspace_manager = WorkspaceManager::with_cache(db.clone(), cache);
86
87 Ok(Self {
88 db,
89 search_index: Arc::new(RwLock::new(search_index)),
90 parser: Arc::new(parser),
91 storage_path,
92 symbol_store,
93 call_graph_store,
94 dep_store,
95 type_info_store,
96 changeset_store,
97 pipeline_store,
98 workspace_manager,
99 repo_locks: DashMap::new(),
100 })
101 }
102
103 pub fn symbol_store(&self) -> &SymbolStore {
105 &self.symbol_store
106 }
107
108 pub fn changeset_store(&self) -> &ChangesetStore {
110 &self.changeset_store
111 }
112
113 pub fn pipeline_store(&self) -> &PipelineStore {
115 &self.pipeline_store
116 }
117
118 pub fn workspace_manager(&self) -> &WorkspaceManager {
120 &self.workspace_manager
121 }
122
123 pub fn call_graph_store(&self) -> &CallGraphStore {
125 &self.call_graph_store
126 }
127
128 pub fn dep_store(&self) -> &DependencyStore {
130 &self.dep_store
131 }
132
133 pub fn parser(&self) -> &ParserRegistry {
135 &self.parser
136 }
137
138 pub fn repo_lock(&self, repo_id: RepoId) -> Arc<RwLock<()>> {
142 self.repo_locks
143 .entry(repo_id)
144 .or_insert_with(|| Arc::new(RwLock::new(())))
145 .clone()
146 }
147
148 pub fn remove_repo_lock(&self, repo_id: RepoId) {
150 self.repo_locks.remove(&repo_id);
151 }
152
153 pub async fn create_repo(&self, name: &str) -> Result<RepoId> {
161 let repo_id = Uuid::new_v4();
162 let repo_path = self.storage_path.join("repos").join(repo_id.to_string());
163
164 GitRepository::init(&repo_path)?;
165
166 sqlx::query(
167 r#"
168 INSERT INTO repositories (id, name, path)
169 VALUES ($1, $2, $3)
170 "#,
171 )
172 .bind(repo_id)
173 .bind(name)
174 .bind(repo_path.to_string_lossy().as_ref())
175 .execute(&self.db)
176 .await?;
177
178 Ok(repo_id)
179 }
180
181 pub async fn get_repo(&self, name: &str) -> Result<(RepoId, GitRepository)> {
186 let row: Option<(Uuid, String)> = sqlx::query_as(
188 "SELECT id, path FROM repositories WHERE name = $1",
189 )
190 .bind(name)
191 .fetch_optional(&self.db)
192 .await?;
193
194 let row = match row {
200 Some(r) => r,
201 None => {
202 let mut rows: Vec<(Uuid, String)> = sqlx::query_as(
203 "SELECT id, path FROM repositories \
204 WHERE split_part(name, '/', 2) = $1 \
205 OR (name = split_part($1, '/', 2) AND $1 LIKE '%/%')",
206 )
207 .bind(name)
208 .fetch_all(&self.db)
209 .await?;
210 match rows.len() {
211 0 => return Err(Error::RepoNotFound(name.to_string())),
212 1 => rows.remove(0),
213 _ => return Err(Error::AmbiguousRepoName(name.to_string())),
214 }
215 }
216 };
217
218 let (repo_id, repo_path) = row;
219 let git_repo = GitRepository::open(Path::new(&repo_path))?;
220 Ok((repo_id, git_repo))
221 }
222
223 pub async fn get_repo_by_db_id(&self, repo_id: RepoId) -> Result<(RepoId, GitRepository)> {
227 let row: (String,) = sqlx::query_as(
228 "SELECT path FROM repositories WHERE id = $1",
229 )
230 .bind(repo_id)
231 .fetch_optional(&self.db)
232 .await?
233 .ok_or_else(|| Error::RepoNotFound(repo_id.to_string()))?;
234
235 let git_repo = GitRepository::open(Path::new(&row.0))?;
236 Ok((repo_id, git_repo))
237 }
238
239 pub async fn index_repo(
247 &self,
248 repo_id: RepoId,
249 git_repo: &GitRepository,
250 ) -> Result<()> {
251 let root = git_repo.path().to_path_buf();
252 let files = collect_files(&root, &self.parser);
253
254 let mut all_symbols: Vec<Symbol> = Vec::new();
257 let mut all_raw_edges: Vec<RawCallEdge> = Vec::new();
258
259 let mut search_index = self.search_index.write().await;
261
262 for file_path in &files {
263 let relative = file_path
264 .strip_prefix(&root)
265 .unwrap_or(file_path);
266
267 let source = std::fs::read(file_path).map_err(|e| {
268 Error::Io(e)
269 })?;
270
271 let analysis = self.parser.parse_file(relative, &source)?;
272
273 for sym in &analysis.symbols {
275 self.symbol_store.upsert_symbol(repo_id, sym).await?;
276 search_index.index_symbol(repo_id, sym)?;
277 }
278
279 for ti in &analysis.types {
281 self.type_info_store.upsert_type_info(ti).await?;
282 }
283
284 all_symbols.extend(analysis.symbols);
285 all_raw_edges.extend(analysis.calls);
286 }
287
288 search_index.commit()?;
290 drop(search_index);
291
292 let edges = resolve_call_edges(&all_raw_edges, &all_symbols, repo_id);
294 for edge in &edges {
295 self.call_graph_store.insert_edge(edge).await?;
296 }
297
298 Ok(())
299 }
300
301 pub async fn update_files(
306 &self,
307 repo_id: RepoId,
308 git_repo: &GitRepository,
309 changed_files: &[PathBuf],
310 ) -> Result<()> {
311 self.update_files_by_root(repo_id, git_repo.path(), changed_files)
312 .await
313 }
314
315 pub async fn update_files_by_root(
321 &self,
322 repo_id: RepoId,
323 root: &Path,
324 changed_files: &[PathBuf],
325 ) -> Result<()> {
326 let root = root.to_path_buf();
327
328 let mut all_symbols: Vec<Symbol> = Vec::new();
329 let mut all_raw_edges: Vec<RawCallEdge> = Vec::new();
330
331 let mut search_index = self.search_index.write().await;
333
334 for file_path in changed_files {
335 let relative = file_path
336 .strip_prefix(&root)
337 .unwrap_or(file_path);
338 let rel_str = relative.to_string_lossy().to_string();
339
340 let old_symbols = self
343 .symbol_store
344 .find_by_file(repo_id, &rel_str)
345 .await?;
346 for old_sym in &old_symbols {
347 search_index.remove_symbol(old_sym.id)?;
348 }
349
350 self.call_graph_store
352 .delete_edges_for_file(repo_id, &rel_str)
353 .await?;
354 self.symbol_store
355 .delete_by_file(repo_id, &rel_str)
356 .await?;
357
358 let full_path = root.join(relative);
360 if !full_path.exists() {
361 continue;
363 }
364
365 if !self.parser.supports_file(relative) {
366 continue;
367 }
368
369 let source = std::fs::read(&full_path)?;
370 let analysis = self.parser.parse_file(relative, &source)?;
371
372 for sym in &analysis.symbols {
373 self.symbol_store.upsert_symbol(repo_id, sym).await?;
374 search_index.index_symbol(repo_id, sym)?;
375 }
376
377 for ti in &analysis.types {
378 self.type_info_store.upsert_type_info(ti).await?;
379 }
380
381 all_symbols.extend(analysis.symbols);
382 all_raw_edges.extend(analysis.calls);
383 }
384
385 search_index.commit()?;
386 drop(search_index);
387
388 let edges = resolve_call_edges(&all_raw_edges, &all_symbols, repo_id);
389 for edge in &edges {
390 self.call_graph_store.insert_edge(edge).await?;
391 }
392
393 Ok(())
394 }
395
396 pub async fn query_symbols(
403 &self,
404 repo_id: RepoId,
405 query: &str,
406 max_results: usize,
407 ) -> Result<Vec<Symbol>> {
408 let search_index = self.search_index.read().await;
409 let ids = search_index.search(repo_id, query, max_results)?;
410 drop(search_index);
411
412 self.symbol_store.get_by_ids(&ids).await
413 }
414
415 pub async fn get_call_graph(
420 &self,
421 _repo_id: RepoId,
422 symbol_id: SymbolId,
423 ) -> Result<(Vec<Symbol>, Vec<Symbol>)> {
424 let caller_edges = self.call_graph_store.find_callers(symbol_id).await?;
425 let callee_edges = self.call_graph_store.find_callees(symbol_id).await?;
426
427 let mut callers = Vec::with_capacity(caller_edges.len());
428 for edge in &caller_edges {
429 if let Some(sym) = self.symbol_store.get_by_id(edge.caller).await? {
430 callers.push(sym);
431 }
432 }
433
434 let mut callees = Vec::with_capacity(callee_edges.len());
435 for edge in &callee_edges {
436 if let Some(sym) = self.symbol_store.get_by_id(edge.callee).await? {
437 callees.push(sym);
438 }
439 }
440
441 Ok((callers, callees))
442 }
443
444 pub async fn codebase_summary(&self, repo_id: RepoId) -> Result<CodebaseSummary> {
450 let total_symbols = self.symbol_store.count(repo_id).await? as u64;
451
452 let row: (i64, Vec<String>) = sqlx::query_as(
454 r#"
455 SELECT
456 COUNT(DISTINCT file_path),
457 COALESCE(
458 array_agg(DISTINCT substring(file_path FROM '\.([^.]+)$'))
459 FILTER (WHERE substring(file_path FROM '\.([^.]+)$') IS NOT NULL),
460 ARRAY[]::text[]
461 )
462 FROM symbols
463 WHERE repo_id = $1
464 "#,
465 )
466 .bind(repo_id)
467 .fetch_one(&self.db)
468 .await?;
469
470 let total_files = row.0 as u64;
471 let mut languages = row.1;
472 languages.sort();
473
474 Ok(CodebaseSummary {
475 languages,
476 total_symbols,
477 total_files,
478 })
479 }
480}
481
482fn collect_files(root: &Path, parser: &ParserRegistry) -> Vec<PathBuf> {
487 let mut files = Vec::new();
488 collect_files_recursive(root, root, parser, &mut files);
489 files
490}
491
492fn collect_files_recursive(
493 root: &Path,
494 dir: &Path,
495 parser: &ParserRegistry,
496 out: &mut Vec<PathBuf>,
497) {
498 let entries = match std::fs::read_dir(dir) {
499 Ok(entries) => entries,
500 Err(_) => return,
501 };
502
503 for entry in entries.flatten() {
504 let path = entry.path();
505
506 if path.is_dir() {
507 if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
509 if name == ".git" || name.starts_with('.') {
510 continue;
511 }
512 }
513 collect_files_recursive(root, &path, parser, out);
514 } else if path.is_file() {
515 let relative = path.strip_prefix(root).unwrap_or(&path);
516 if parser.supports_file(relative) {
517 out.push(path);
518 }
519 }
520 }
521}
522
523fn resolve_call_edges(
526 raw_edges: &[RawCallEdge],
527 symbols: &[Symbol],
528 repo_id: RepoId,
529) -> Vec<CallEdge> {
530 let mut name_to_id: HashMap<String, SymbolId> = HashMap::new();
533 for sym in symbols {
534 name_to_id.insert(sym.name.clone(), sym.id);
535 name_to_id.insert(sym.qualified_name.clone(), sym.id);
536 }
537
538 raw_edges
539 .iter()
540 .filter_map(|raw| {
541 let caller = name_to_id.get(&raw.caller_name)?;
542 let callee = name_to_id.get(&raw.callee_name)?;
543 Some(CallEdge {
544 id: Uuid::new_v4(),
545 repo_id,
546 caller: *caller,
547 callee: *callee,
548 kind: raw.kind.clone(),
549 })
550 })
551 .collect()
552}
553
554#[cfg(test)]
555mod tests {
556 use super::*;
557
558 #[test]
559 fn test_resolve_call_edges_basic() {
560 let sym_a_id = Uuid::new_v4();
561 let sym_b_id = Uuid::new_v4();
562 let repo_id = Uuid::new_v4();
563
564 let symbols = vec![
565 Symbol {
566 id: sym_a_id,
567 name: "foo".into(),
568 qualified_name: "crate::foo".into(),
569 kind: dk_core::SymbolKind::Function,
570 visibility: dk_core::Visibility::Public,
571 file_path: "src/lib.rs".into(),
572 span: dk_core::Span {
573 start_byte: 0,
574 end_byte: 100,
575 },
576 signature: None,
577 doc_comment: None,
578 parent: None,
579 last_modified_by: None,
580 last_modified_intent: None,
581 },
582 Symbol {
583 id: sym_b_id,
584 name: "bar".into(),
585 qualified_name: "crate::bar".into(),
586 kind: dk_core::SymbolKind::Function,
587 visibility: dk_core::Visibility::Public,
588 file_path: "src/lib.rs".into(),
589 span: dk_core::Span {
590 start_byte: 100,
591 end_byte: 200,
592 },
593 signature: None,
594 doc_comment: None,
595 parent: None,
596 last_modified_by: None,
597 last_modified_intent: None,
598 },
599 ];
600
601 let raw_edges = vec![RawCallEdge {
602 caller_name: "foo".into(),
603 callee_name: "bar".into(),
604 call_site: dk_core::Span {
605 start_byte: 50,
606 end_byte: 60,
607 },
608 kind: dk_core::CallKind::DirectCall,
609 }];
610
611 let edges = resolve_call_edges(&raw_edges, &symbols, repo_id);
612 assert_eq!(edges.len(), 1);
613 assert_eq!(edges[0].caller, sym_a_id);
614 assert_eq!(edges[0].callee, sym_b_id);
615 assert_eq!(edges[0].repo_id, repo_id);
616 }
617
618 #[test]
619 fn test_resolve_call_edges_unresolved_skipped() {
620 let sym_a_id = Uuid::new_v4();
621 let repo_id = Uuid::new_v4();
622
623 let symbols = vec![Symbol {
624 id: sym_a_id,
625 name: "foo".into(),
626 qualified_name: "crate::foo".into(),
627 kind: dk_core::SymbolKind::Function,
628 visibility: dk_core::Visibility::Public,
629 file_path: "src/lib.rs".into(),
630 span: dk_core::Span {
631 start_byte: 0,
632 end_byte: 100,
633 },
634 signature: None,
635 doc_comment: None,
636 parent: None,
637 last_modified_by: None,
638 last_modified_intent: None,
639 }];
640
641 let raw_edges = vec![RawCallEdge {
643 caller_name: "foo".into(),
644 callee_name: "unknown".into(),
645 call_site: dk_core::Span {
646 start_byte: 50,
647 end_byte: 60,
648 },
649 kind: dk_core::CallKind::DirectCall,
650 }];
651
652 let edges = resolve_call_edges(&raw_edges, &symbols, repo_id);
653 assert!(edges.is_empty());
654 }
655
656 #[test]
657 fn test_resolve_call_edges_qualified_name() {
658 let sym_a_id = Uuid::new_v4();
659 let sym_b_id = Uuid::new_v4();
660 let repo_id = Uuid::new_v4();
661
662 let symbols = vec![
663 Symbol {
664 id: sym_a_id,
665 name: "foo".into(),
666 qualified_name: "crate::mod_a::foo".into(),
667 kind: dk_core::SymbolKind::Function,
668 visibility: dk_core::Visibility::Public,
669 file_path: "src/mod_a.rs".into(),
670 span: dk_core::Span {
671 start_byte: 0,
672 end_byte: 100,
673 },
674 signature: None,
675 doc_comment: None,
676 parent: None,
677 last_modified_by: None,
678 last_modified_intent: None,
679 },
680 Symbol {
681 id: sym_b_id,
682 name: "bar".into(),
683 qualified_name: "crate::mod_b::bar".into(),
684 kind: dk_core::SymbolKind::Function,
685 visibility: dk_core::Visibility::Public,
686 file_path: "src/mod_b.rs".into(),
687 span: dk_core::Span {
688 start_byte: 0,
689 end_byte: 100,
690 },
691 signature: None,
692 doc_comment: None,
693 parent: None,
694 last_modified_by: None,
695 last_modified_intent: None,
696 },
697 ];
698
699 let raw_edges = vec![RawCallEdge {
701 caller_name: "crate::mod_a::foo".into(),
702 callee_name: "crate::mod_b::bar".into(),
703 call_site: dk_core::Span {
704 start_byte: 50,
705 end_byte: 60,
706 },
707 kind: dk_core::CallKind::DirectCall,
708 }];
709
710 let edges = resolve_call_edges(&raw_edges, &symbols, repo_id);
711 assert_eq!(edges.len(), 1);
712 assert_eq!(edges[0].caller, sym_a_id);
713 assert_eq!(edges[0].callee, sym_b_id);
714 }
715
716 #[test]
717 fn test_collect_files_skips_git_dir() {
718 let dir = tempfile::tempdir().unwrap();
719 let root = dir.path();
720
721 std::fs::create_dir_all(root.join(".git")).unwrap();
723 std::fs::write(root.join(".git/config"), b"git config").unwrap();
724
725 std::fs::write(root.join("main.rs"), b"fn main() {}").unwrap();
727
728 std::fs::write(root.join("notes.txt"), b"hello").unwrap();
730
731 let parser = ParserRegistry::new();
732 let files = collect_files(root, &parser);
733
734 assert_eq!(files.len(), 1);
735 assert!(files[0].ends_with("main.rs"));
736 }
737
738 #[test]
739 fn test_codebase_summary_struct() {
740 let summary = CodebaseSummary {
741 languages: vec!["rs".into(), "py".into()],
742 total_symbols: 42,
743 total_files: 5,
744 };
745 assert_eq!(summary.languages.len(), 2);
746 assert_eq!(summary.total_symbols, 42);
747 assert_eq!(summary.total_files, 5);
748 }
749}