Skip to main content

dk_engine/
repo.rs

1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5use dashmap::DashMap;
6use dk_core::{
7    CallEdge, Error, RawCallEdge, RepoId, Result, Symbol, SymbolId,
8};
9use sqlx::postgres::PgPool;
10use tokio::sync::RwLock;
11use uuid::Uuid;
12
13use crate::changeset::ChangesetStore;
14use crate::git::GitRepository;
15use crate::pipeline::PipelineStore;
16use crate::graph::{
17    CallGraphStore, DependencyStore, SearchIndex, SymbolStore, TypeInfoStore,
18};
19use crate::parser::ParserRegistry;
20use crate::workspace::session_manager::WorkspaceManager;
21
22// ── Public types ──
23
24/// High-level summary of a repository's indexed codebase.
25#[derive(Debug, Clone)]
26pub struct CodebaseSummary {
27    pub languages: Vec<String>,
28    pub total_symbols: u64,
29    pub total_files: u64,
30}
31
32/// The central orchestration layer that ties together Git storage,
33/// language parsing, the semantic graph stores, and full-text search.
34///
35/// Internally concurrent: all methods take `&self`. The `SearchIndex` is
36/// wrapped in an `RwLock` (write for mutations, read for queries), and
37/// per-repo Git operations are serialised via `repo_locks`.
38pub struct Engine {
39    pub db: PgPool,
40    pub search_index: Arc<RwLock<SearchIndex>>,
41    pub parser: Arc<ParserRegistry>,
42    pub storage_path: PathBuf,
43    symbol_store: SymbolStore,
44    call_graph_store: CallGraphStore,
45    #[allow(dead_code)]
46    dep_store: DependencyStore,
47    type_info_store: TypeInfoStore,
48    changeset_store: ChangesetStore,
49    pipeline_store: PipelineStore,
50    workspace_manager: WorkspaceManager,
51    repo_locks: DashMap<RepoId, Arc<RwLock<()>>>,
52}
53
54impl Engine {
55    /// Create a new Engine instance.
56    ///
57    /// Initialises all graph stores from the provided `PgPool`, creates the
58    /// `ParserRegistry` with Rust/TypeScript/Python parsers, and opens (or
59    /// creates) a Tantivy `SearchIndex` at `storage_path/search_index`.
60    pub fn new(storage_path: PathBuf, db: PgPool) -> Result<Self> {
61        let search_index = SearchIndex::open(&storage_path.join("search_index"))?;
62        let parser = ParserRegistry::new();
63        let symbol_store = SymbolStore::new(db.clone());
64        let call_graph_store = CallGraphStore::new(db.clone());
65        let dep_store = DependencyStore::new(db.clone());
66        let type_info_store = TypeInfoStore::new(db.clone());
67        let changeset_store = ChangesetStore::new(db.clone());
68        let pipeline_store = PipelineStore::new(db.clone());
69        let workspace_manager = WorkspaceManager::new(db.clone());
70
71        Ok(Self {
72            db,
73            search_index: Arc::new(RwLock::new(search_index)),
74            parser: Arc::new(parser),
75            storage_path,
76            symbol_store,
77            call_graph_store,
78            dep_store,
79            type_info_store,
80            changeset_store,
81            pipeline_store,
82            workspace_manager,
83            repo_locks: DashMap::new(),
84        })
85    }
86
87    /// Returns a reference to the symbol store for direct DB queries.
88    pub fn symbol_store(&self) -> &SymbolStore {
89        &self.symbol_store
90    }
91
92    /// Returns a reference to the changeset store.
93    pub fn changeset_store(&self) -> &ChangesetStore {
94        &self.changeset_store
95    }
96
97    /// Returns a reference to the pipeline store.
98    pub fn pipeline_store(&self) -> &PipelineStore {
99        &self.pipeline_store
100    }
101
102    /// Returns a reference to the workspace manager.
103    pub fn workspace_manager(&self) -> &WorkspaceManager {
104        &self.workspace_manager
105    }
106
107    /// Returns a reference to the call graph store for direct DB queries.
108    pub fn call_graph_store(&self) -> &CallGraphStore {
109        &self.call_graph_store
110    }
111
112    /// Returns a reference to the dependency store for direct DB queries.
113    pub fn dep_store(&self) -> &DependencyStore {
114        &self.dep_store
115    }
116
117    /// Returns a reference to the parser registry.
118    pub fn parser(&self) -> &ParserRegistry {
119        &self.parser
120    }
121
122    /// Returns a per-repo lock for serialising Git operations.
123    ///
124    /// Creates a new lock on first access for a given `repo_id`.
125    pub fn repo_lock(&self, repo_id: RepoId) -> Arc<RwLock<()>> {
126        self.repo_locks
127            .entry(repo_id)
128            .or_insert_with(|| Arc::new(RwLock::new(())))
129            .clone()
130    }
131
132    /// Remove the repo lock entry for a deleted repo.
133    pub fn remove_repo_lock(&self, repo_id: RepoId) {
134        self.repo_locks.remove(&repo_id);
135    }
136
137    // ── Repository lifecycle ──
138
139    /// Create a new repository.
140    ///
141    /// Generates a UUID, initialises a Git repository at
142    /// `storage_path/repos/<uuid>`, inserts a row into the `repositories`
143    /// table, and returns the new `RepoId`.
144    pub async fn create_repo(&self, name: &str) -> Result<RepoId> {
145        let repo_id = Uuid::new_v4();
146        let repo_path = self.storage_path.join("repos").join(repo_id.to_string());
147
148        GitRepository::init(&repo_path)?;
149
150        sqlx::query(
151            r#"
152            INSERT INTO repositories (id, name, path)
153            VALUES ($1, $2, $3)
154            "#,
155        )
156        .bind(repo_id)
157        .bind(name)
158        .bind(repo_path.to_string_lossy().as_ref())
159        .execute(&self.db)
160        .await?;
161
162        Ok(repo_id)
163    }
164
165    /// Look up a repository by name.
166    ///
167    /// Returns the `RepoId` and an opened `GitRepository` handle.
168    pub async fn get_repo(&self, name: &str) -> Result<(RepoId, GitRepository)> {
169        let row: (Uuid, String) = sqlx::query_as(
170            "SELECT id, path FROM repositories WHERE name = $1",
171        )
172        .bind(name)
173        .fetch_optional(&self.db)
174        .await?
175        .ok_or_else(|| Error::RepoNotFound(name.to_string()))?;
176
177        let (repo_id, repo_path) = row;
178        let git_repo = GitRepository::open(Path::new(&repo_path))?;
179        Ok((repo_id, git_repo))
180    }
181
182    /// Look up a repository by its UUID.
183    ///
184    /// Returns the `RepoId` and an opened `GitRepository` handle.
185    pub async fn get_repo_by_db_id(&self, repo_id: RepoId) -> Result<(RepoId, GitRepository)> {
186        let row: (String,) = sqlx::query_as(
187            "SELECT path FROM repositories WHERE id = $1",
188        )
189        .bind(repo_id)
190        .fetch_optional(&self.db)
191        .await?
192        .ok_or_else(|| Error::RepoNotFound(repo_id.to_string()))?;
193
194        let git_repo = GitRepository::open(Path::new(&row.0))?;
195        Ok((repo_id, git_repo))
196    }
197
198    // ── Indexing ──
199
200    /// Perform a full index of a repository.
201    ///
202    /// Walks the working directory (skipping `.git`), parses every file with a
203    /// supported extension, and populates the symbol table, type info store,
204    /// call graph, and full-text search index.
205    pub async fn index_repo(
206        &self,
207        repo_id: RepoId,
208        git_repo: &GitRepository,
209    ) -> Result<()> {
210        let root = git_repo.path().to_path_buf();
211        let files = collect_files(&root, &self.parser);
212
213        // Accumulate all symbols across every file so we can resolve call
214        // edges at the end.
215        let mut all_symbols: Vec<Symbol> = Vec::new();
216        let mut all_raw_edges: Vec<RawCallEdge> = Vec::new();
217
218        // Acquire the search index write lock for the duration of indexing.
219        let mut search_index = self.search_index.write().await;
220
221        for file_path in &files {
222            let relative = file_path
223                .strip_prefix(&root)
224                .unwrap_or(file_path);
225
226            let source = std::fs::read(file_path).map_err(|e| {
227                Error::Io(e)
228            })?;
229
230            let analysis = self.parser.parse_file(relative, &source)?;
231
232            // Symbols
233            for sym in &analysis.symbols {
234                self.symbol_store.upsert_symbol(repo_id, sym).await?;
235                search_index.index_symbol(repo_id, sym)?;
236            }
237
238            // Type info
239            for ti in &analysis.types {
240                self.type_info_store.upsert_type_info(ti).await?;
241            }
242
243            all_symbols.extend(analysis.symbols);
244            all_raw_edges.extend(analysis.calls);
245        }
246
247        // Commit search index once after all files.
248        search_index.commit()?;
249        drop(search_index);
250
251        // Resolve and insert call edges.
252        let edges = resolve_call_edges(&all_raw_edges, &all_symbols, repo_id);
253        for edge in &edges {
254            self.call_graph_store.insert_edge(edge).await?;
255        }
256
257        Ok(())
258    }
259
260    /// Incrementally re-index a set of changed files.
261    ///
262    /// For each path: deletes old symbols and call edges, re-parses, and
263    /// upserts the new data.
264    pub async fn update_files(
265        &self,
266        repo_id: RepoId,
267        git_repo: &GitRepository,
268        changed_files: &[PathBuf],
269    ) -> Result<()> {
270        self.update_files_by_root(repo_id, git_repo.path(), changed_files)
271            .await
272    }
273
274    /// Incrementally re-index a set of changed files, given the repository
275    /// root path directly.
276    ///
277    /// This variant avoids holding a `GitRepository` reference (which is
278    /// `!Sync`) across `.await` points, making the resulting future `Send`.
279    pub async fn update_files_by_root(
280        &self,
281        repo_id: RepoId,
282        root: &Path,
283        changed_files: &[PathBuf],
284    ) -> Result<()> {
285        let root = root.to_path_buf();
286
287        let mut all_symbols: Vec<Symbol> = Vec::new();
288        let mut all_raw_edges: Vec<RawCallEdge> = Vec::new();
289
290        // Acquire the search index write lock for the duration of re-indexing.
291        let mut search_index = self.search_index.write().await;
292
293        for file_path in changed_files {
294            let relative = file_path
295                .strip_prefix(&root)
296                .unwrap_or(file_path);
297            let rel_str = relative.to_string_lossy().to_string();
298
299            // Fetch existing symbols for this file so we can remove their
300            // search index entries.
301            let old_symbols = self
302                .symbol_store
303                .find_by_file(repo_id, &rel_str)
304                .await?;
305            for old_sym in &old_symbols {
306                search_index.remove_symbol(old_sym.id)?;
307            }
308
309            // Delete old DB rows.
310            self.call_graph_store
311                .delete_edges_for_file(repo_id, &rel_str)
312                .await?;
313            self.symbol_store
314                .delete_by_file(repo_id, &rel_str)
315                .await?;
316
317            // Re-parse.
318            let full_path = root.join(relative);
319            if !full_path.exists() {
320                // File was deleted; nothing more to do for this path.
321                continue;
322            }
323
324            if !self.parser.supports_file(relative) {
325                continue;
326            }
327
328            let source = std::fs::read(&full_path)?;
329            let analysis = self.parser.parse_file(relative, &source)?;
330
331            for sym in &analysis.symbols {
332                self.symbol_store.upsert_symbol(repo_id, sym).await?;
333                search_index.index_symbol(repo_id, sym)?;
334            }
335
336            for ti in &analysis.types {
337                self.type_info_store.upsert_type_info(ti).await?;
338            }
339
340            all_symbols.extend(analysis.symbols);
341            all_raw_edges.extend(analysis.calls);
342        }
343
344        search_index.commit()?;
345        drop(search_index);
346
347        let edges = resolve_call_edges(&all_raw_edges, &all_symbols, repo_id);
348        for edge in &edges {
349            self.call_graph_store.insert_edge(edge).await?;
350        }
351
352        Ok(())
353    }
354
355    // ── Querying ──
356
357    /// Search for symbols matching a free-text query.
358    ///
359    /// Uses Tantivy full-text search to find candidate `SymbolId`s, then
360    /// fetches the full `Symbol` objects from the database.
361    pub async fn query_symbols(
362        &self,
363        repo_id: RepoId,
364        query: &str,
365        max_results: usize,
366    ) -> Result<Vec<Symbol>> {
367        let search_index = self.search_index.read().await;
368        let ids = search_index.search(repo_id, query, max_results)?;
369        drop(search_index);
370
371        self.symbol_store.get_by_ids(&ids).await
372    }
373
374    /// Retrieve the call graph neighbourhood of a symbol.
375    ///
376    /// Returns `(callers, callees)` — the full `Symbol` objects for every
377    /// direct caller and every direct callee.
378    pub async fn get_call_graph(
379        &self,
380        _repo_id: RepoId,
381        symbol_id: SymbolId,
382    ) -> Result<(Vec<Symbol>, Vec<Symbol>)> {
383        let caller_edges = self.call_graph_store.find_callers(symbol_id).await?;
384        let callee_edges = self.call_graph_store.find_callees(symbol_id).await?;
385
386        let mut callers = Vec::with_capacity(caller_edges.len());
387        for edge in &caller_edges {
388            if let Some(sym) = self.symbol_store.get_by_id(edge.caller).await? {
389                callers.push(sym);
390            }
391        }
392
393        let mut callees = Vec::with_capacity(callee_edges.len());
394        for edge in &callee_edges {
395            if let Some(sym) = self.symbol_store.get_by_id(edge.callee).await? {
396                callees.push(sym);
397            }
398        }
399
400        Ok((callers, callees))
401    }
402
403    /// Produce a high-level summary of the indexed codebase.
404    ///
405    /// Queries the symbols table for distinct file extensions (→ languages),
406    /// distinct file paths (→ total_files), and total row count
407    /// (→ total_symbols).
408    pub async fn codebase_summary(&self, repo_id: RepoId) -> Result<CodebaseSummary> {
409        let total_symbols = self.symbol_store.count(repo_id).await? as u64;
410
411        // Count distinct files and collect unique extensions in a single query.
412        let row: (i64, Vec<String>) = sqlx::query_as(
413            r#"
414            SELECT
415                COUNT(DISTINCT file_path),
416                COALESCE(
417                    array_agg(DISTINCT substring(file_path FROM '\.([^.]+)$'))
418                        FILTER (WHERE substring(file_path FROM '\.([^.]+)$') IS NOT NULL),
419                    ARRAY[]::text[]
420                )
421            FROM symbols
422            WHERE repo_id = $1
423            "#,
424        )
425        .bind(repo_id)
426        .fetch_one(&self.db)
427        .await?;
428
429        let total_files = row.0 as u64;
430        let mut languages = row.1;
431        languages.sort();
432
433        Ok(CodebaseSummary {
434            languages,
435            total_symbols,
436            total_files,
437        })
438    }
439}
440
441// ── Helpers ──
442
443/// Recursively collect all files under `root` that are supported by the
444/// parser registry, skipping the `.git` directory.
445fn collect_files(root: &Path, parser: &ParserRegistry) -> Vec<PathBuf> {
446    let mut files = Vec::new();
447    collect_files_recursive(root, root, parser, &mut files);
448    files
449}
450
451fn collect_files_recursive(
452    root: &Path,
453    dir: &Path,
454    parser: &ParserRegistry,
455    out: &mut Vec<PathBuf>,
456) {
457    let entries = match std::fs::read_dir(dir) {
458        Ok(entries) => entries,
459        Err(_) => return,
460    };
461
462    for entry in entries.flatten() {
463        let path = entry.path();
464
465        if path.is_dir() {
466            // Skip .git and hidden directories.
467            if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
468                if name == ".git" || name.starts_with('.') {
469                    continue;
470                }
471            }
472            collect_files_recursive(root, &path, parser, out);
473        } else if path.is_file() {
474            let relative = path.strip_prefix(root).unwrap_or(&path);
475            if parser.supports_file(relative) {
476                out.push(path);
477            }
478        }
479    }
480}
481
482/// Resolve `RawCallEdge`s (which use string names) into `CallEdge`s
483/// (which use `SymbolId`s) by building a name-to-id lookup table.
484fn resolve_call_edges(
485    raw_edges: &[RawCallEdge],
486    symbols: &[Symbol],
487    repo_id: RepoId,
488) -> Vec<CallEdge> {
489    // Build name -> SymbolId lookup.
490    // Insert both `name` and `qualified_name` so either form resolves.
491    let mut name_to_id: HashMap<String, SymbolId> = HashMap::new();
492    for sym in symbols {
493        name_to_id.insert(sym.name.clone(), sym.id);
494        name_to_id.insert(sym.qualified_name.clone(), sym.id);
495    }
496
497    raw_edges
498        .iter()
499        .filter_map(|raw| {
500            let caller = name_to_id.get(&raw.caller_name)?;
501            let callee = name_to_id.get(&raw.callee_name)?;
502            Some(CallEdge {
503                id: Uuid::new_v4(),
504                repo_id,
505                caller: *caller,
506                callee: *callee,
507                kind: raw.kind.clone(),
508            })
509        })
510        .collect()
511}
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516
517    #[test]
518    fn test_resolve_call_edges_basic() {
519        let sym_a_id = Uuid::new_v4();
520        let sym_b_id = Uuid::new_v4();
521        let repo_id = Uuid::new_v4();
522
523        let symbols = vec![
524            Symbol {
525                id: sym_a_id,
526                name: "foo".into(),
527                qualified_name: "crate::foo".into(),
528                kind: dk_core::SymbolKind::Function,
529                visibility: dk_core::Visibility::Public,
530                file_path: "src/lib.rs".into(),
531                span: dk_core::Span {
532                    start_byte: 0,
533                    end_byte: 100,
534                },
535                signature: None,
536                doc_comment: None,
537                parent: None,
538                last_modified_by: None,
539                last_modified_intent: None,
540            },
541            Symbol {
542                id: sym_b_id,
543                name: "bar".into(),
544                qualified_name: "crate::bar".into(),
545                kind: dk_core::SymbolKind::Function,
546                visibility: dk_core::Visibility::Public,
547                file_path: "src/lib.rs".into(),
548                span: dk_core::Span {
549                    start_byte: 100,
550                    end_byte: 200,
551                },
552                signature: None,
553                doc_comment: None,
554                parent: None,
555                last_modified_by: None,
556                last_modified_intent: None,
557            },
558        ];
559
560        let raw_edges = vec![RawCallEdge {
561            caller_name: "foo".into(),
562            callee_name: "bar".into(),
563            call_site: dk_core::Span {
564                start_byte: 50,
565                end_byte: 60,
566            },
567            kind: dk_core::CallKind::DirectCall,
568        }];
569
570        let edges = resolve_call_edges(&raw_edges, &symbols, repo_id);
571        assert_eq!(edges.len(), 1);
572        assert_eq!(edges[0].caller, sym_a_id);
573        assert_eq!(edges[0].callee, sym_b_id);
574        assert_eq!(edges[0].repo_id, repo_id);
575    }
576
577    #[test]
578    fn test_resolve_call_edges_unresolved_skipped() {
579        let sym_a_id = Uuid::new_v4();
580        let repo_id = Uuid::new_v4();
581
582        let symbols = vec![Symbol {
583            id: sym_a_id,
584            name: "foo".into(),
585            qualified_name: "crate::foo".into(),
586            kind: dk_core::SymbolKind::Function,
587            visibility: dk_core::Visibility::Public,
588            file_path: "src/lib.rs".into(),
589            span: dk_core::Span {
590                start_byte: 0,
591                end_byte: 100,
592            },
593            signature: None,
594            doc_comment: None,
595            parent: None,
596            last_modified_by: None,
597            last_modified_intent: None,
598        }];
599
600        // callee "unknown" doesn't exist in symbols
601        let raw_edges = vec![RawCallEdge {
602            caller_name: "foo".into(),
603            callee_name: "unknown".into(),
604            call_site: dk_core::Span {
605                start_byte: 50,
606                end_byte: 60,
607            },
608            kind: dk_core::CallKind::DirectCall,
609        }];
610
611        let edges = resolve_call_edges(&raw_edges, &symbols, repo_id);
612        assert!(edges.is_empty());
613    }
614
615    #[test]
616    fn test_resolve_call_edges_qualified_name() {
617        let sym_a_id = Uuid::new_v4();
618        let sym_b_id = Uuid::new_v4();
619        let repo_id = Uuid::new_v4();
620
621        let symbols = vec![
622            Symbol {
623                id: sym_a_id,
624                name: "foo".into(),
625                qualified_name: "crate::mod_a::foo".into(),
626                kind: dk_core::SymbolKind::Function,
627                visibility: dk_core::Visibility::Public,
628                file_path: "src/mod_a.rs".into(),
629                span: dk_core::Span {
630                    start_byte: 0,
631                    end_byte: 100,
632                },
633                signature: None,
634                doc_comment: None,
635                parent: None,
636                last_modified_by: None,
637                last_modified_intent: None,
638            },
639            Symbol {
640                id: sym_b_id,
641                name: "bar".into(),
642                qualified_name: "crate::mod_b::bar".into(),
643                kind: dk_core::SymbolKind::Function,
644                visibility: dk_core::Visibility::Public,
645                file_path: "src/mod_b.rs".into(),
646                span: dk_core::Span {
647                    start_byte: 0,
648                    end_byte: 100,
649                },
650                signature: None,
651                doc_comment: None,
652                parent: None,
653                last_modified_by: None,
654                last_modified_intent: None,
655            },
656        ];
657
658        // Use qualified names for resolution
659        let raw_edges = vec![RawCallEdge {
660            caller_name: "crate::mod_a::foo".into(),
661            callee_name: "crate::mod_b::bar".into(),
662            call_site: dk_core::Span {
663                start_byte: 50,
664                end_byte: 60,
665            },
666            kind: dk_core::CallKind::DirectCall,
667        }];
668
669        let edges = resolve_call_edges(&raw_edges, &symbols, repo_id);
670        assert_eq!(edges.len(), 1);
671        assert_eq!(edges[0].caller, sym_a_id);
672        assert_eq!(edges[0].callee, sym_b_id);
673    }
674
675    #[test]
676    fn test_collect_files_skips_git_dir() {
677        let dir = tempfile::tempdir().unwrap();
678        let root = dir.path();
679
680        // Create a .git directory with a file inside.
681        std::fs::create_dir_all(root.join(".git")).unwrap();
682        std::fs::write(root.join(".git/config"), b"git config").unwrap();
683
684        // Create a supported source file.
685        std::fs::write(root.join("main.rs"), b"fn main() {}").unwrap();
686
687        // Create an unsupported file.
688        std::fs::write(root.join("notes.txt"), b"hello").unwrap();
689
690        let parser = ParserRegistry::new();
691        let files = collect_files(root, &parser);
692
693        assert_eq!(files.len(), 1);
694        assert!(files[0].ends_with("main.rs"));
695    }
696
697    #[test]
698    fn test_codebase_summary_struct() {
699        let summary = CodebaseSummary {
700            languages: vec!["rs".into(), "py".into()],
701            total_symbols: 42,
702            total_files: 5,
703        };
704        assert_eq!(summary.languages.len(), 2);
705        assert_eq!(summary.total_symbols, 42);
706        assert_eq!(summary.total_files, 5);
707    }
708}