1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5use dashmap::DashMap;
6use dk_core::{
7 CallEdge, Error, RawCallEdge, RepoId, Result, Symbol, SymbolId,
8};
9use sqlx::postgres::PgPool;
10use tokio::sync::RwLock;
11use uuid::Uuid;
12
13use crate::changeset::ChangesetStore;
14use crate::git::GitRepository;
15use crate::pipeline::PipelineStore;
16use crate::graph::{
17 CallGraphStore, DependencyStore, SearchIndex, SymbolStore, TypeInfoStore,
18};
19use crate::parser::ParserRegistry;
20use crate::workspace::cache::{NoOpCache, WorkspaceCache};
21use crate::workspace::session_manager::WorkspaceManager;
22
23#[derive(Debug, Clone)]
27pub struct CodebaseSummary {
28 pub languages: Vec<String>,
29 pub total_symbols: u64,
30 pub total_files: u64,
31}
32
33pub struct Engine {
40 pub db: PgPool,
41 pub search_index: Arc<RwLock<SearchIndex>>,
42 pub parser: Arc<ParserRegistry>,
43 pub storage_path: PathBuf,
44 symbol_store: SymbolStore,
45 call_graph_store: CallGraphStore,
46 #[allow(dead_code)]
47 dep_store: DependencyStore,
48 type_info_store: TypeInfoStore,
49 changeset_store: ChangesetStore,
50 pipeline_store: PipelineStore,
51 workspace_manager: WorkspaceManager,
52 repo_locks: DashMap<RepoId, Arc<RwLock<()>>>,
53}
54
55impl Engine {
56 pub fn new(storage_path: PathBuf, db: PgPool) -> Result<Self> {
64 Self::with_cache(storage_path, db, Arc::new(NoOpCache))
65 }
66
67 pub fn with_cache(
73 storage_path: PathBuf,
74 db: PgPool,
75 cache: Arc<dyn WorkspaceCache>,
76 ) -> Result<Self> {
77 let search_index = SearchIndex::open(&storage_path.join("search_index"))?;
78 let parser = ParserRegistry::new();
79 let symbol_store = SymbolStore::new(db.clone());
80 let call_graph_store = CallGraphStore::new(db.clone());
81 let dep_store = DependencyStore::new(db.clone());
82 let type_info_store = TypeInfoStore::new(db.clone());
83 let changeset_store = ChangesetStore::new(db.clone());
84 let pipeline_store = PipelineStore::new(db.clone());
85 let workspace_manager = WorkspaceManager::with_cache(db.clone(), cache);
86
87 Ok(Self {
88 db,
89 search_index: Arc::new(RwLock::new(search_index)),
90 parser: Arc::new(parser),
91 storage_path,
92 symbol_store,
93 call_graph_store,
94 dep_store,
95 type_info_store,
96 changeset_store,
97 pipeline_store,
98 workspace_manager,
99 repo_locks: DashMap::new(),
100 })
101 }
102
103 pub fn symbol_store(&self) -> &SymbolStore {
105 &self.symbol_store
106 }
107
108 pub fn changeset_store(&self) -> &ChangesetStore {
110 &self.changeset_store
111 }
112
113 pub fn pipeline_store(&self) -> &PipelineStore {
115 &self.pipeline_store
116 }
117
118 pub fn workspace_manager(&self) -> &WorkspaceManager {
120 &self.workspace_manager
121 }
122
123 pub fn call_graph_store(&self) -> &CallGraphStore {
125 &self.call_graph_store
126 }
127
128 pub fn dep_store(&self) -> &DependencyStore {
130 &self.dep_store
131 }
132
133 pub fn parser(&self) -> &ParserRegistry {
135 &self.parser
136 }
137
138 pub fn repo_lock(&self, repo_id: RepoId) -> Arc<RwLock<()>> {
142 self.repo_locks
143 .entry(repo_id)
144 .or_insert_with(|| Arc::new(RwLock::new(())))
145 .clone()
146 }
147
148 pub fn remove_repo_lock(&self, repo_id: RepoId) {
150 self.repo_locks.remove(&repo_id);
151 }
152
153 pub async fn create_repo(&self, name: &str) -> Result<RepoId> {
161 let repo_id = Uuid::new_v4();
162 let repo_path = self.storage_path.join("repos").join(repo_id.to_string());
163
164 GitRepository::init(&repo_path)?;
165
166 sqlx::query(
167 r#"
168 INSERT INTO repositories (id, name, path)
169 VALUES ($1, $2, $3)
170 "#,
171 )
172 .bind(repo_id)
173 .bind(name)
174 .bind(repo_path.to_string_lossy().as_ref())
175 .execute(&self.db)
176 .await?;
177
178 Ok(repo_id)
179 }
180
181 pub async fn get_repo(&self, name: &str) -> Result<(RepoId, GitRepository)> {
185 let row: (Uuid, String) = sqlx::query_as(
186 "SELECT id, path FROM repositories WHERE name = $1",
187 )
188 .bind(name)
189 .fetch_optional(&self.db)
190 .await?
191 .ok_or_else(|| Error::RepoNotFound(name.to_string()))?;
192
193 let (repo_id, repo_path) = row;
194 let git_repo = GitRepository::open(Path::new(&repo_path))?;
195 Ok((repo_id, git_repo))
196 }
197
198 pub async fn get_repo_by_db_id(&self, repo_id: RepoId) -> Result<(RepoId, GitRepository)> {
202 let row: (String,) = sqlx::query_as(
203 "SELECT path FROM repositories WHERE id = $1",
204 )
205 .bind(repo_id)
206 .fetch_optional(&self.db)
207 .await?
208 .ok_or_else(|| Error::RepoNotFound(repo_id.to_string()))?;
209
210 let git_repo = GitRepository::open(Path::new(&row.0))?;
211 Ok((repo_id, git_repo))
212 }
213
214 pub async fn index_repo(
222 &self,
223 repo_id: RepoId,
224 git_repo: &GitRepository,
225 ) -> Result<()> {
226 let root = git_repo.path().to_path_buf();
227 let files = collect_files(&root, &self.parser);
228
229 let mut all_symbols: Vec<Symbol> = Vec::new();
232 let mut all_raw_edges: Vec<RawCallEdge> = Vec::new();
233
234 let mut search_index = self.search_index.write().await;
236
237 for file_path in &files {
238 let relative = file_path
239 .strip_prefix(&root)
240 .unwrap_or(file_path);
241
242 let source = std::fs::read(file_path).map_err(|e| {
243 Error::Io(e)
244 })?;
245
246 let analysis = self.parser.parse_file(relative, &source)?;
247
248 for sym in &analysis.symbols {
250 self.symbol_store.upsert_symbol(repo_id, sym).await?;
251 search_index.index_symbol(repo_id, sym)?;
252 }
253
254 for ti in &analysis.types {
256 self.type_info_store.upsert_type_info(ti).await?;
257 }
258
259 all_symbols.extend(analysis.symbols);
260 all_raw_edges.extend(analysis.calls);
261 }
262
263 search_index.commit()?;
265 drop(search_index);
266
267 let edges = resolve_call_edges(&all_raw_edges, &all_symbols, repo_id);
269 for edge in &edges {
270 self.call_graph_store.insert_edge(edge).await?;
271 }
272
273 Ok(())
274 }
275
276 pub async fn update_files(
281 &self,
282 repo_id: RepoId,
283 git_repo: &GitRepository,
284 changed_files: &[PathBuf],
285 ) -> Result<()> {
286 self.update_files_by_root(repo_id, git_repo.path(), changed_files)
287 .await
288 }
289
290 pub async fn update_files_by_root(
296 &self,
297 repo_id: RepoId,
298 root: &Path,
299 changed_files: &[PathBuf],
300 ) -> Result<()> {
301 let root = root.to_path_buf();
302
303 let mut all_symbols: Vec<Symbol> = Vec::new();
304 let mut all_raw_edges: Vec<RawCallEdge> = Vec::new();
305
306 let mut search_index = self.search_index.write().await;
308
309 for file_path in changed_files {
310 let relative = file_path
311 .strip_prefix(&root)
312 .unwrap_or(file_path);
313 let rel_str = relative.to_string_lossy().to_string();
314
315 let old_symbols = self
318 .symbol_store
319 .find_by_file(repo_id, &rel_str)
320 .await?;
321 for old_sym in &old_symbols {
322 search_index.remove_symbol(old_sym.id)?;
323 }
324
325 self.call_graph_store
327 .delete_edges_for_file(repo_id, &rel_str)
328 .await?;
329 self.symbol_store
330 .delete_by_file(repo_id, &rel_str)
331 .await?;
332
333 let full_path = root.join(relative);
335 if !full_path.exists() {
336 continue;
338 }
339
340 if !self.parser.supports_file(relative) {
341 continue;
342 }
343
344 let source = std::fs::read(&full_path)?;
345 let analysis = self.parser.parse_file(relative, &source)?;
346
347 for sym in &analysis.symbols {
348 self.symbol_store.upsert_symbol(repo_id, sym).await?;
349 search_index.index_symbol(repo_id, sym)?;
350 }
351
352 for ti in &analysis.types {
353 self.type_info_store.upsert_type_info(ti).await?;
354 }
355
356 all_symbols.extend(analysis.symbols);
357 all_raw_edges.extend(analysis.calls);
358 }
359
360 search_index.commit()?;
361 drop(search_index);
362
363 let edges = resolve_call_edges(&all_raw_edges, &all_symbols, repo_id);
364 for edge in &edges {
365 self.call_graph_store.insert_edge(edge).await?;
366 }
367
368 Ok(())
369 }
370
371 pub async fn query_symbols(
378 &self,
379 repo_id: RepoId,
380 query: &str,
381 max_results: usize,
382 ) -> Result<Vec<Symbol>> {
383 let search_index = self.search_index.read().await;
384 let ids = search_index.search(repo_id, query, max_results)?;
385 drop(search_index);
386
387 self.symbol_store.get_by_ids(&ids).await
388 }
389
390 pub async fn get_call_graph(
395 &self,
396 _repo_id: RepoId,
397 symbol_id: SymbolId,
398 ) -> Result<(Vec<Symbol>, Vec<Symbol>)> {
399 let caller_edges = self.call_graph_store.find_callers(symbol_id).await?;
400 let callee_edges = self.call_graph_store.find_callees(symbol_id).await?;
401
402 let mut callers = Vec::with_capacity(caller_edges.len());
403 for edge in &caller_edges {
404 if let Some(sym) = self.symbol_store.get_by_id(edge.caller).await? {
405 callers.push(sym);
406 }
407 }
408
409 let mut callees = Vec::with_capacity(callee_edges.len());
410 for edge in &callee_edges {
411 if let Some(sym) = self.symbol_store.get_by_id(edge.callee).await? {
412 callees.push(sym);
413 }
414 }
415
416 Ok((callers, callees))
417 }
418
419 pub async fn codebase_summary(&self, repo_id: RepoId) -> Result<CodebaseSummary> {
425 let total_symbols = self.symbol_store.count(repo_id).await? as u64;
426
427 let row: (i64, Vec<String>) = sqlx::query_as(
429 r#"
430 SELECT
431 COUNT(DISTINCT file_path),
432 COALESCE(
433 array_agg(DISTINCT substring(file_path FROM '\.([^.]+)$'))
434 FILTER (WHERE substring(file_path FROM '\.([^.]+)$') IS NOT NULL),
435 ARRAY[]::text[]
436 )
437 FROM symbols
438 WHERE repo_id = $1
439 "#,
440 )
441 .bind(repo_id)
442 .fetch_one(&self.db)
443 .await?;
444
445 let total_files = row.0 as u64;
446 let mut languages = row.1;
447 languages.sort();
448
449 Ok(CodebaseSummary {
450 languages,
451 total_symbols,
452 total_files,
453 })
454 }
455}
456
457fn collect_files(root: &Path, parser: &ParserRegistry) -> Vec<PathBuf> {
462 let mut files = Vec::new();
463 collect_files_recursive(root, root, parser, &mut files);
464 files
465}
466
467fn collect_files_recursive(
468 root: &Path,
469 dir: &Path,
470 parser: &ParserRegistry,
471 out: &mut Vec<PathBuf>,
472) {
473 let entries = match std::fs::read_dir(dir) {
474 Ok(entries) => entries,
475 Err(_) => return,
476 };
477
478 for entry in entries.flatten() {
479 let path = entry.path();
480
481 if path.is_dir() {
482 if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
484 if name == ".git" || name.starts_with('.') {
485 continue;
486 }
487 }
488 collect_files_recursive(root, &path, parser, out);
489 } else if path.is_file() {
490 let relative = path.strip_prefix(root).unwrap_or(&path);
491 if parser.supports_file(relative) {
492 out.push(path);
493 }
494 }
495 }
496}
497
498fn resolve_call_edges(
501 raw_edges: &[RawCallEdge],
502 symbols: &[Symbol],
503 repo_id: RepoId,
504) -> Vec<CallEdge> {
505 let mut name_to_id: HashMap<String, SymbolId> = HashMap::new();
508 for sym in symbols {
509 name_to_id.insert(sym.name.clone(), sym.id);
510 name_to_id.insert(sym.qualified_name.clone(), sym.id);
511 }
512
513 raw_edges
514 .iter()
515 .filter_map(|raw| {
516 let caller = name_to_id.get(&raw.caller_name)?;
517 let callee = name_to_id.get(&raw.callee_name)?;
518 Some(CallEdge {
519 id: Uuid::new_v4(),
520 repo_id,
521 caller: *caller,
522 callee: *callee,
523 kind: raw.kind.clone(),
524 })
525 })
526 .collect()
527}
528
529#[cfg(test)]
530mod tests {
531 use super::*;
532
533 #[test]
534 fn test_resolve_call_edges_basic() {
535 let sym_a_id = Uuid::new_v4();
536 let sym_b_id = Uuid::new_v4();
537 let repo_id = Uuid::new_v4();
538
539 let symbols = vec![
540 Symbol {
541 id: sym_a_id,
542 name: "foo".into(),
543 qualified_name: "crate::foo".into(),
544 kind: dk_core::SymbolKind::Function,
545 visibility: dk_core::Visibility::Public,
546 file_path: "src/lib.rs".into(),
547 span: dk_core::Span {
548 start_byte: 0,
549 end_byte: 100,
550 },
551 signature: None,
552 doc_comment: None,
553 parent: None,
554 last_modified_by: None,
555 last_modified_intent: None,
556 },
557 Symbol {
558 id: sym_b_id,
559 name: "bar".into(),
560 qualified_name: "crate::bar".into(),
561 kind: dk_core::SymbolKind::Function,
562 visibility: dk_core::Visibility::Public,
563 file_path: "src/lib.rs".into(),
564 span: dk_core::Span {
565 start_byte: 100,
566 end_byte: 200,
567 },
568 signature: None,
569 doc_comment: None,
570 parent: None,
571 last_modified_by: None,
572 last_modified_intent: None,
573 },
574 ];
575
576 let raw_edges = vec![RawCallEdge {
577 caller_name: "foo".into(),
578 callee_name: "bar".into(),
579 call_site: dk_core::Span {
580 start_byte: 50,
581 end_byte: 60,
582 },
583 kind: dk_core::CallKind::DirectCall,
584 }];
585
586 let edges = resolve_call_edges(&raw_edges, &symbols, repo_id);
587 assert_eq!(edges.len(), 1);
588 assert_eq!(edges[0].caller, sym_a_id);
589 assert_eq!(edges[0].callee, sym_b_id);
590 assert_eq!(edges[0].repo_id, repo_id);
591 }
592
593 #[test]
594 fn test_resolve_call_edges_unresolved_skipped() {
595 let sym_a_id = Uuid::new_v4();
596 let repo_id = Uuid::new_v4();
597
598 let symbols = vec![Symbol {
599 id: sym_a_id,
600 name: "foo".into(),
601 qualified_name: "crate::foo".into(),
602 kind: dk_core::SymbolKind::Function,
603 visibility: dk_core::Visibility::Public,
604 file_path: "src/lib.rs".into(),
605 span: dk_core::Span {
606 start_byte: 0,
607 end_byte: 100,
608 },
609 signature: None,
610 doc_comment: None,
611 parent: None,
612 last_modified_by: None,
613 last_modified_intent: None,
614 }];
615
616 let raw_edges = vec![RawCallEdge {
618 caller_name: "foo".into(),
619 callee_name: "unknown".into(),
620 call_site: dk_core::Span {
621 start_byte: 50,
622 end_byte: 60,
623 },
624 kind: dk_core::CallKind::DirectCall,
625 }];
626
627 let edges = resolve_call_edges(&raw_edges, &symbols, repo_id);
628 assert!(edges.is_empty());
629 }
630
631 #[test]
632 fn test_resolve_call_edges_qualified_name() {
633 let sym_a_id = Uuid::new_v4();
634 let sym_b_id = Uuid::new_v4();
635 let repo_id = Uuid::new_v4();
636
637 let symbols = vec![
638 Symbol {
639 id: sym_a_id,
640 name: "foo".into(),
641 qualified_name: "crate::mod_a::foo".into(),
642 kind: dk_core::SymbolKind::Function,
643 visibility: dk_core::Visibility::Public,
644 file_path: "src/mod_a.rs".into(),
645 span: dk_core::Span {
646 start_byte: 0,
647 end_byte: 100,
648 },
649 signature: None,
650 doc_comment: None,
651 parent: None,
652 last_modified_by: None,
653 last_modified_intent: None,
654 },
655 Symbol {
656 id: sym_b_id,
657 name: "bar".into(),
658 qualified_name: "crate::mod_b::bar".into(),
659 kind: dk_core::SymbolKind::Function,
660 visibility: dk_core::Visibility::Public,
661 file_path: "src/mod_b.rs".into(),
662 span: dk_core::Span {
663 start_byte: 0,
664 end_byte: 100,
665 },
666 signature: None,
667 doc_comment: None,
668 parent: None,
669 last_modified_by: None,
670 last_modified_intent: None,
671 },
672 ];
673
674 let raw_edges = vec![RawCallEdge {
676 caller_name: "crate::mod_a::foo".into(),
677 callee_name: "crate::mod_b::bar".into(),
678 call_site: dk_core::Span {
679 start_byte: 50,
680 end_byte: 60,
681 },
682 kind: dk_core::CallKind::DirectCall,
683 }];
684
685 let edges = resolve_call_edges(&raw_edges, &symbols, repo_id);
686 assert_eq!(edges.len(), 1);
687 assert_eq!(edges[0].caller, sym_a_id);
688 assert_eq!(edges[0].callee, sym_b_id);
689 }
690
691 #[test]
692 fn test_collect_files_skips_git_dir() {
693 let dir = tempfile::tempdir().unwrap();
694 let root = dir.path();
695
696 std::fs::create_dir_all(root.join(".git")).unwrap();
698 std::fs::write(root.join(".git/config"), b"git config").unwrap();
699
700 std::fs::write(root.join("main.rs"), b"fn main() {}").unwrap();
702
703 std::fs::write(root.join("notes.txt"), b"hello").unwrap();
705
706 let parser = ParserRegistry::new();
707 let files = collect_files(root, &parser);
708
709 assert_eq!(files.len(), 1);
710 assert!(files[0].ends_with("main.rs"));
711 }
712
713 #[test]
714 fn test_codebase_summary_struct() {
715 let summary = CodebaseSummary {
716 languages: vec!["rs".into(), "py".into()],
717 total_symbols: 42,
718 total_files: 5,
719 };
720 assert_eq!(summary.languages.len(), 2);
721 assert_eq!(summary.total_symbols, 42);
722 assert_eq!(summary.total_files, 5);
723 }
724}