Skip to main content

dk_engine/
repo.rs

1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5use dashmap::DashMap;
6use dk_core::{
7    CallEdge, Error, RawCallEdge, RepoId, Result, Symbol, SymbolId,
8};
9use sqlx::postgres::PgPool;
10use tokio::sync::RwLock;
11use uuid::Uuid;
12
13use crate::changeset::ChangesetStore;
14use crate::git::GitRepository;
15use crate::pipeline::PipelineStore;
16use crate::graph::{
17    CallGraphStore, DependencyStore, SearchIndex, SymbolStore, TypeInfoStore,
18};
19use crate::parser::ParserRegistry;
20use crate::workspace::cache::{NoOpCache, WorkspaceCache};
21use crate::workspace::session_manager::WorkspaceManager;
22
23// ── Public types ──
24
25/// High-level summary of a repository's indexed codebase.
26#[derive(Debug, Clone)]
27pub struct CodebaseSummary {
28    pub languages: Vec<String>,
29    pub total_symbols: u64,
30    pub total_files: u64,
31}
32
33/// The central orchestration layer that ties together Git storage,
34/// language parsing, the semantic graph stores, and full-text search.
35///
36/// Internally concurrent: all methods take `&self`. The `SearchIndex` is
37/// wrapped in an `RwLock` (write for mutations, read for queries), and
38/// per-repo Git operations are serialised via `repo_locks`.
39pub struct Engine {
40    pub db: PgPool,
41    pub search_index: Arc<RwLock<SearchIndex>>,
42    pub parser: Arc<ParserRegistry>,
43    pub storage_path: PathBuf,
44    symbol_store: SymbolStore,
45    call_graph_store: CallGraphStore,
46    #[allow(dead_code)]
47    dep_store: DependencyStore,
48    type_info_store: TypeInfoStore,
49    changeset_store: ChangesetStore,
50    pipeline_store: PipelineStore,
51    workspace_manager: WorkspaceManager,
52    repo_locks: DashMap<RepoId, Arc<RwLock<()>>>,
53}
54
55impl Engine {
56    /// Create a new Engine instance with the default no-op workspace cache.
57    ///
58    /// Initialises all graph stores from the provided `PgPool`, creates the
59    /// `ParserRegistry` with Rust/TypeScript/Python parsers, and opens (or
60    /// creates) a Tantivy `SearchIndex` at `storage_path/search_index`.
61    ///
62    /// Delegates to [`Engine::with_cache`] with [`NoOpCache`].
63    pub fn new(storage_path: PathBuf, db: PgPool) -> Result<Self> {
64        Self::with_cache(storage_path, db, Arc::new(NoOpCache))
65    }
66
67    /// Create a new Engine with an explicit workspace cache implementation.
68    ///
69    /// This is the primary constructor. [`Engine::new`] delegates here with
70    /// [`NoOpCache`]. Pass a `ValkeyCache` (or any [`WorkspaceCache`] impl)
71    /// for multi-pod deployments.
72    pub fn with_cache(
73        storage_path: PathBuf,
74        db: PgPool,
75        cache: Arc<dyn WorkspaceCache>,
76    ) -> Result<Self> {
77        let search_index = SearchIndex::open(&storage_path.join("search_index"))?;
78        let parser = ParserRegistry::new();
79        let symbol_store = SymbolStore::new(db.clone());
80        let call_graph_store = CallGraphStore::new(db.clone());
81        let dep_store = DependencyStore::new(db.clone());
82        let type_info_store = TypeInfoStore::new(db.clone());
83        let changeset_store = ChangesetStore::new(db.clone());
84        let pipeline_store = PipelineStore::new(db.clone());
85        let workspace_manager = WorkspaceManager::with_cache(db.clone(), cache);
86
87        Ok(Self {
88            db,
89            search_index: Arc::new(RwLock::new(search_index)),
90            parser: Arc::new(parser),
91            storage_path,
92            symbol_store,
93            call_graph_store,
94            dep_store,
95            type_info_store,
96            changeset_store,
97            pipeline_store,
98            workspace_manager,
99            repo_locks: DashMap::new(),
100        })
101    }
102
103    /// Returns a reference to the symbol store for direct DB queries.
104    pub fn symbol_store(&self) -> &SymbolStore {
105        &self.symbol_store
106    }
107
108    /// Returns a reference to the changeset store.
109    pub fn changeset_store(&self) -> &ChangesetStore {
110        &self.changeset_store
111    }
112
113    /// Returns a reference to the pipeline store.
114    pub fn pipeline_store(&self) -> &PipelineStore {
115        &self.pipeline_store
116    }
117
118    /// Returns a reference to the workspace manager.
119    pub fn workspace_manager(&self) -> &WorkspaceManager {
120        &self.workspace_manager
121    }
122
123    /// Returns a reference to the call graph store for direct DB queries.
124    pub fn call_graph_store(&self) -> &CallGraphStore {
125        &self.call_graph_store
126    }
127
128    /// Returns a reference to the dependency store for direct DB queries.
129    pub fn dep_store(&self) -> &DependencyStore {
130        &self.dep_store
131    }
132
133    /// Returns a reference to the parser registry.
134    pub fn parser(&self) -> &ParserRegistry {
135        &self.parser
136    }
137
138    /// Returns a per-repo lock for serialising Git operations.
139    ///
140    /// Creates a new lock on first access for a given `repo_id`.
141    pub fn repo_lock(&self, repo_id: RepoId) -> Arc<RwLock<()>> {
142        self.repo_locks
143            .entry(repo_id)
144            .or_insert_with(|| Arc::new(RwLock::new(())))
145            .clone()
146    }
147
148    /// Remove the repo lock entry for a deleted repo.
149    pub fn remove_repo_lock(&self, repo_id: RepoId) {
150        self.repo_locks.remove(&repo_id);
151    }
152
153    // ── Repository lifecycle ──
154
155    /// Create a new repository.
156    ///
157    /// Generates a UUID, initialises a Git repository at
158    /// `storage_path/repos/<uuid>`, inserts a row into the `repositories`
159    /// table, and returns the new `RepoId`.
160    pub async fn create_repo(&self, name: &str) -> Result<RepoId> {
161        let repo_id = Uuid::new_v4();
162        let repo_path = self.storage_path.join("repos").join(repo_id.to_string());
163
164        GitRepository::init(&repo_path)?;
165
166        sqlx::query(
167            r#"
168            INSERT INTO repositories (id, name, path)
169            VALUES ($1, $2, $3)
170            "#,
171        )
172        .bind(repo_id)
173        .bind(name)
174        .bind(repo_path.to_string_lossy().as_ref())
175        .execute(&self.db)
176        .await?;
177
178        Ok(repo_id)
179    }
180
181    /// Look up a repository by name.
182    ///
183    /// Returns the `RepoId` and an opened `GitRepository` handle.
184    pub async fn get_repo(&self, name: &str) -> Result<(RepoId, GitRepository)> {
185        let row: (Uuid, String) = sqlx::query_as(
186            "SELECT id, path FROM repositories WHERE name = $1",
187        )
188        .bind(name)
189        .fetch_optional(&self.db)
190        .await?
191        .ok_or_else(|| Error::RepoNotFound(name.to_string()))?;
192
193        let (repo_id, repo_path) = row;
194        let git_repo = GitRepository::open(Path::new(&repo_path))?;
195        Ok((repo_id, git_repo))
196    }
197
198    /// Look up a repository by its UUID.
199    ///
200    /// Returns the `RepoId` and an opened `GitRepository` handle.
201    pub async fn get_repo_by_db_id(&self, repo_id: RepoId) -> Result<(RepoId, GitRepository)> {
202        let row: (String,) = sqlx::query_as(
203            "SELECT path FROM repositories WHERE id = $1",
204        )
205        .bind(repo_id)
206        .fetch_optional(&self.db)
207        .await?
208        .ok_or_else(|| Error::RepoNotFound(repo_id.to_string()))?;
209
210        let git_repo = GitRepository::open(Path::new(&row.0))?;
211        Ok((repo_id, git_repo))
212    }
213
214    // ── Indexing ──
215
216    /// Perform a full index of a repository.
217    ///
218    /// Walks the working directory (skipping `.git`), parses every file with a
219    /// supported extension, and populates the symbol table, type info store,
220    /// call graph, and full-text search index.
221    pub async fn index_repo(
222        &self,
223        repo_id: RepoId,
224        git_repo: &GitRepository,
225    ) -> Result<()> {
226        let root = git_repo.path().to_path_buf();
227        let files = collect_files(&root, &self.parser);
228
229        // Accumulate all symbols across every file so we can resolve call
230        // edges at the end.
231        let mut all_symbols: Vec<Symbol> = Vec::new();
232        let mut all_raw_edges: Vec<RawCallEdge> = Vec::new();
233
234        // Acquire the search index write lock for the duration of indexing.
235        let mut search_index = self.search_index.write().await;
236
237        for file_path in &files {
238            let relative = file_path
239                .strip_prefix(&root)
240                .unwrap_or(file_path);
241
242            let source = std::fs::read(file_path).map_err(|e| {
243                Error::Io(e)
244            })?;
245
246            let analysis = self.parser.parse_file(relative, &source)?;
247
248            // Symbols
249            for sym in &analysis.symbols {
250                self.symbol_store.upsert_symbol(repo_id, sym).await?;
251                search_index.index_symbol(repo_id, sym)?;
252            }
253
254            // Type info
255            for ti in &analysis.types {
256                self.type_info_store.upsert_type_info(ti).await?;
257            }
258
259            all_symbols.extend(analysis.symbols);
260            all_raw_edges.extend(analysis.calls);
261        }
262
263        // Commit search index once after all files.
264        search_index.commit()?;
265        drop(search_index);
266
267        // Resolve and insert call edges.
268        let edges = resolve_call_edges(&all_raw_edges, &all_symbols, repo_id);
269        for edge in &edges {
270            self.call_graph_store.insert_edge(edge).await?;
271        }
272
273        Ok(())
274    }
275
276    /// Incrementally re-index a set of changed files.
277    ///
278    /// For each path: deletes old symbols and call edges, re-parses, and
279    /// upserts the new data.
280    pub async fn update_files(
281        &self,
282        repo_id: RepoId,
283        git_repo: &GitRepository,
284        changed_files: &[PathBuf],
285    ) -> Result<()> {
286        self.update_files_by_root(repo_id, git_repo.path(), changed_files)
287            .await
288    }
289
290    /// Incrementally re-index a set of changed files, given the repository
291    /// root path directly.
292    ///
293    /// This variant avoids holding a `GitRepository` reference (which is
294    /// `!Sync`) across `.await` points, making the resulting future `Send`.
295    pub async fn update_files_by_root(
296        &self,
297        repo_id: RepoId,
298        root: &Path,
299        changed_files: &[PathBuf],
300    ) -> Result<()> {
301        let root = root.to_path_buf();
302
303        let mut all_symbols: Vec<Symbol> = Vec::new();
304        let mut all_raw_edges: Vec<RawCallEdge> = Vec::new();
305
306        // Acquire the search index write lock for the duration of re-indexing.
307        let mut search_index = self.search_index.write().await;
308
309        for file_path in changed_files {
310            let relative = file_path
311                .strip_prefix(&root)
312                .unwrap_or(file_path);
313            let rel_str = relative.to_string_lossy().to_string();
314
315            // Fetch existing symbols for this file so we can remove their
316            // search index entries.
317            let old_symbols = self
318                .symbol_store
319                .find_by_file(repo_id, &rel_str)
320                .await?;
321            for old_sym in &old_symbols {
322                search_index.remove_symbol(old_sym.id)?;
323            }
324
325            // Delete old DB rows.
326            self.call_graph_store
327                .delete_edges_for_file(repo_id, &rel_str)
328                .await?;
329            self.symbol_store
330                .delete_by_file(repo_id, &rel_str)
331                .await?;
332
333            // Re-parse.
334            let full_path = root.join(relative);
335            if !full_path.exists() {
336                // File was deleted; nothing more to do for this path.
337                continue;
338            }
339
340            if !self.parser.supports_file(relative) {
341                continue;
342            }
343
344            let source = std::fs::read(&full_path)?;
345            let analysis = self.parser.parse_file(relative, &source)?;
346
347            for sym in &analysis.symbols {
348                self.symbol_store.upsert_symbol(repo_id, sym).await?;
349                search_index.index_symbol(repo_id, sym)?;
350            }
351
352            for ti in &analysis.types {
353                self.type_info_store.upsert_type_info(ti).await?;
354            }
355
356            all_symbols.extend(analysis.symbols);
357            all_raw_edges.extend(analysis.calls);
358        }
359
360        search_index.commit()?;
361        drop(search_index);
362
363        let edges = resolve_call_edges(&all_raw_edges, &all_symbols, repo_id);
364        for edge in &edges {
365            self.call_graph_store.insert_edge(edge).await?;
366        }
367
368        Ok(())
369    }
370
371    // ── Querying ──
372
373    /// Search for symbols matching a free-text query.
374    ///
375    /// Uses Tantivy full-text search to find candidate `SymbolId`s, then
376    /// fetches the full `Symbol` objects from the database.
377    pub async fn query_symbols(
378        &self,
379        repo_id: RepoId,
380        query: &str,
381        max_results: usize,
382    ) -> Result<Vec<Symbol>> {
383        let search_index = self.search_index.read().await;
384        let ids = search_index.search(repo_id, query, max_results)?;
385        drop(search_index);
386
387        self.symbol_store.get_by_ids(&ids).await
388    }
389
390    /// Retrieve the call graph neighbourhood of a symbol.
391    ///
392    /// Returns `(callers, callees)` — the full `Symbol` objects for every
393    /// direct caller and every direct callee.
394    pub async fn get_call_graph(
395        &self,
396        _repo_id: RepoId,
397        symbol_id: SymbolId,
398    ) -> Result<(Vec<Symbol>, Vec<Symbol>)> {
399        let caller_edges = self.call_graph_store.find_callers(symbol_id).await?;
400        let callee_edges = self.call_graph_store.find_callees(symbol_id).await?;
401
402        let mut callers = Vec::with_capacity(caller_edges.len());
403        for edge in &caller_edges {
404            if let Some(sym) = self.symbol_store.get_by_id(edge.caller).await? {
405                callers.push(sym);
406            }
407        }
408
409        let mut callees = Vec::with_capacity(callee_edges.len());
410        for edge in &callee_edges {
411            if let Some(sym) = self.symbol_store.get_by_id(edge.callee).await? {
412                callees.push(sym);
413            }
414        }
415
416        Ok((callers, callees))
417    }
418
419    /// Produce a high-level summary of the indexed codebase.
420    ///
421    /// Queries the symbols table for distinct file extensions (→ languages),
422    /// distinct file paths (→ total_files), and total row count
423    /// (→ total_symbols).
424    pub async fn codebase_summary(&self, repo_id: RepoId) -> Result<CodebaseSummary> {
425        let total_symbols = self.symbol_store.count(repo_id).await? as u64;
426
427        // Count distinct files and collect unique extensions in a single query.
428        let row: (i64, Vec<String>) = sqlx::query_as(
429            r#"
430            SELECT
431                COUNT(DISTINCT file_path),
432                COALESCE(
433                    array_agg(DISTINCT substring(file_path FROM '\.([^.]+)$'))
434                        FILTER (WHERE substring(file_path FROM '\.([^.]+)$') IS NOT NULL),
435                    ARRAY[]::text[]
436                )
437            FROM symbols
438            WHERE repo_id = $1
439            "#,
440        )
441        .bind(repo_id)
442        .fetch_one(&self.db)
443        .await?;
444
445        let total_files = row.0 as u64;
446        let mut languages = row.1;
447        languages.sort();
448
449        Ok(CodebaseSummary {
450            languages,
451            total_symbols,
452            total_files,
453        })
454    }
455}
456
457// ── Helpers ──
458
459/// Recursively collect all files under `root` that are supported by the
460/// parser registry, skipping the `.git` directory.
461fn collect_files(root: &Path, parser: &ParserRegistry) -> Vec<PathBuf> {
462    let mut files = Vec::new();
463    collect_files_recursive(root, root, parser, &mut files);
464    files
465}
466
467fn collect_files_recursive(
468    root: &Path,
469    dir: &Path,
470    parser: &ParserRegistry,
471    out: &mut Vec<PathBuf>,
472) {
473    let entries = match std::fs::read_dir(dir) {
474        Ok(entries) => entries,
475        Err(_) => return,
476    };
477
478    for entry in entries.flatten() {
479        let path = entry.path();
480
481        if path.is_dir() {
482            // Skip .git and hidden directories.
483            if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
484                if name == ".git" || name.starts_with('.') {
485                    continue;
486                }
487            }
488            collect_files_recursive(root, &path, parser, out);
489        } else if path.is_file() {
490            let relative = path.strip_prefix(root).unwrap_or(&path);
491            if parser.supports_file(relative) {
492                out.push(path);
493            }
494        }
495    }
496}
497
498/// Resolve `RawCallEdge`s (which use string names) into `CallEdge`s
499/// (which use `SymbolId`s) by building a name-to-id lookup table.
500fn resolve_call_edges(
501    raw_edges: &[RawCallEdge],
502    symbols: &[Symbol],
503    repo_id: RepoId,
504) -> Vec<CallEdge> {
505    // Build name -> SymbolId lookup.
506    // Insert both `name` and `qualified_name` so either form resolves.
507    let mut name_to_id: HashMap<String, SymbolId> = HashMap::new();
508    for sym in symbols {
509        name_to_id.insert(sym.name.clone(), sym.id);
510        name_to_id.insert(sym.qualified_name.clone(), sym.id);
511    }
512
513    raw_edges
514        .iter()
515        .filter_map(|raw| {
516            let caller = name_to_id.get(&raw.caller_name)?;
517            let callee = name_to_id.get(&raw.callee_name)?;
518            Some(CallEdge {
519                id: Uuid::new_v4(),
520                repo_id,
521                caller: *caller,
522                callee: *callee,
523                kind: raw.kind.clone(),
524            })
525        })
526        .collect()
527}
528
529#[cfg(test)]
530mod tests {
531    use super::*;
532
533    #[test]
534    fn test_resolve_call_edges_basic() {
535        let sym_a_id = Uuid::new_v4();
536        let sym_b_id = Uuid::new_v4();
537        let repo_id = Uuid::new_v4();
538
539        let symbols = vec![
540            Symbol {
541                id: sym_a_id,
542                name: "foo".into(),
543                qualified_name: "crate::foo".into(),
544                kind: dk_core::SymbolKind::Function,
545                visibility: dk_core::Visibility::Public,
546                file_path: "src/lib.rs".into(),
547                span: dk_core::Span {
548                    start_byte: 0,
549                    end_byte: 100,
550                },
551                signature: None,
552                doc_comment: None,
553                parent: None,
554                last_modified_by: None,
555                last_modified_intent: None,
556            },
557            Symbol {
558                id: sym_b_id,
559                name: "bar".into(),
560                qualified_name: "crate::bar".into(),
561                kind: dk_core::SymbolKind::Function,
562                visibility: dk_core::Visibility::Public,
563                file_path: "src/lib.rs".into(),
564                span: dk_core::Span {
565                    start_byte: 100,
566                    end_byte: 200,
567                },
568                signature: None,
569                doc_comment: None,
570                parent: None,
571                last_modified_by: None,
572                last_modified_intent: None,
573            },
574        ];
575
576        let raw_edges = vec![RawCallEdge {
577            caller_name: "foo".into(),
578            callee_name: "bar".into(),
579            call_site: dk_core::Span {
580                start_byte: 50,
581                end_byte: 60,
582            },
583            kind: dk_core::CallKind::DirectCall,
584        }];
585
586        let edges = resolve_call_edges(&raw_edges, &symbols, repo_id);
587        assert_eq!(edges.len(), 1);
588        assert_eq!(edges[0].caller, sym_a_id);
589        assert_eq!(edges[0].callee, sym_b_id);
590        assert_eq!(edges[0].repo_id, repo_id);
591    }
592
593    #[test]
594    fn test_resolve_call_edges_unresolved_skipped() {
595        let sym_a_id = Uuid::new_v4();
596        let repo_id = Uuid::new_v4();
597
598        let symbols = vec![Symbol {
599            id: sym_a_id,
600            name: "foo".into(),
601            qualified_name: "crate::foo".into(),
602            kind: dk_core::SymbolKind::Function,
603            visibility: dk_core::Visibility::Public,
604            file_path: "src/lib.rs".into(),
605            span: dk_core::Span {
606                start_byte: 0,
607                end_byte: 100,
608            },
609            signature: None,
610            doc_comment: None,
611            parent: None,
612            last_modified_by: None,
613            last_modified_intent: None,
614        }];
615
616        // callee "unknown" doesn't exist in symbols
617        let raw_edges = vec![RawCallEdge {
618            caller_name: "foo".into(),
619            callee_name: "unknown".into(),
620            call_site: dk_core::Span {
621                start_byte: 50,
622                end_byte: 60,
623            },
624            kind: dk_core::CallKind::DirectCall,
625        }];
626
627        let edges = resolve_call_edges(&raw_edges, &symbols, repo_id);
628        assert!(edges.is_empty());
629    }
630
631    #[test]
632    fn test_resolve_call_edges_qualified_name() {
633        let sym_a_id = Uuid::new_v4();
634        let sym_b_id = Uuid::new_v4();
635        let repo_id = Uuid::new_v4();
636
637        let symbols = vec![
638            Symbol {
639                id: sym_a_id,
640                name: "foo".into(),
641                qualified_name: "crate::mod_a::foo".into(),
642                kind: dk_core::SymbolKind::Function,
643                visibility: dk_core::Visibility::Public,
644                file_path: "src/mod_a.rs".into(),
645                span: dk_core::Span {
646                    start_byte: 0,
647                    end_byte: 100,
648                },
649                signature: None,
650                doc_comment: None,
651                parent: None,
652                last_modified_by: None,
653                last_modified_intent: None,
654            },
655            Symbol {
656                id: sym_b_id,
657                name: "bar".into(),
658                qualified_name: "crate::mod_b::bar".into(),
659                kind: dk_core::SymbolKind::Function,
660                visibility: dk_core::Visibility::Public,
661                file_path: "src/mod_b.rs".into(),
662                span: dk_core::Span {
663                    start_byte: 0,
664                    end_byte: 100,
665                },
666                signature: None,
667                doc_comment: None,
668                parent: None,
669                last_modified_by: None,
670                last_modified_intent: None,
671            },
672        ];
673
674        // Use qualified names for resolution
675        let raw_edges = vec![RawCallEdge {
676            caller_name: "crate::mod_a::foo".into(),
677            callee_name: "crate::mod_b::bar".into(),
678            call_site: dk_core::Span {
679                start_byte: 50,
680                end_byte: 60,
681            },
682            kind: dk_core::CallKind::DirectCall,
683        }];
684
685        let edges = resolve_call_edges(&raw_edges, &symbols, repo_id);
686        assert_eq!(edges.len(), 1);
687        assert_eq!(edges[0].caller, sym_a_id);
688        assert_eq!(edges[0].callee, sym_b_id);
689    }
690
691    #[test]
692    fn test_collect_files_skips_git_dir() {
693        let dir = tempfile::tempdir().unwrap();
694        let root = dir.path();
695
696        // Create a .git directory with a file inside.
697        std::fs::create_dir_all(root.join(".git")).unwrap();
698        std::fs::write(root.join(".git/config"), b"git config").unwrap();
699
700        // Create a supported source file.
701        std::fs::write(root.join("main.rs"), b"fn main() {}").unwrap();
702
703        // Create an unsupported file.
704        std::fs::write(root.join("notes.txt"), b"hello").unwrap();
705
706        let parser = ParserRegistry::new();
707        let files = collect_files(root, &parser);
708
709        assert_eq!(files.len(), 1);
710        assert!(files[0].ends_with("main.rs"));
711    }
712
713    #[test]
714    fn test_codebase_summary_struct() {
715        let summary = CodebaseSummary {
716            languages: vec!["rs".into(), "py".into()],
717            total_symbols: 42,
718            total_files: 5,
719        };
720        assert_eq!(summary.languages.len(), 2);
721        assert_eq!(summary.total_symbols, 42);
722        assert_eq!(summary.total_files, 5);
723    }
724}