Skip to main content

dk_engine/workspace/
session_graph.rs

1//! SessionGraph — delta-based semantic graph layered on a shared base.
2//!
3//! The shared base symbol table is stored in an `ArcSwap` so it can be
4//! atomically replaced when the repository is re-indexed. Each session
5//! maintains its own deltas (added, modified, removed symbols and edges)
6//! in lock-free `DashMap`/`DashSet` collections.
7
8use std::collections::HashMap;
9use std::sync::Arc;
10
11use arc_swap::ArcSwap;
12use dashmap::{DashMap, DashSet};
13use dk_core::{CallEdge, Symbol, SymbolId};
14use serde::{Deserialize, Serialize};
15use uuid::Uuid;
16
17// ── SessionGraph ─────────────────────────────────────────────────────
18
19/// A delta-based semantic graph for a single session workspace.
20///
21/// Lookups resolve in order: removed -> modified -> added -> base.
22/// This gives each session a consistent, isolated view of the symbol
23/// graph without copying the entire base.
24pub struct SessionGraph {
25    /// Shared, read-only base symbol table (repo-wide).
26    base_symbols: Option<Arc<ArcSwap<HashMap<SymbolId, Symbol>>>>,
27
28    /// Symbols that existed in the base and were modified in this session.
29    pub(crate) modified_symbols: DashMap<SymbolId, Symbol>,
30
31    /// Symbols that are newly created in this session.
32    added_symbols: DashMap<SymbolId, Symbol>,
33
34    /// Symbols that existed in the base and were removed in this session.
35    pub(crate) removed_symbols: DashSet<SymbolId>,
36
37    /// Cached names of removed symbols (populated during serialization or
38    /// deserialization so `changed_symbol_names()` works without the base).
39    removed_symbol_names: DashMap<SymbolId, String>,
40
41    /// Call edges added in this session.
42    pub(crate) added_edges: DashMap<Uuid, CallEdge>,
43
44    /// Call edge IDs removed from the base in this session.
45    pub(crate) removed_edges: DashSet<Uuid>,
46}
47
48// ── Snapshot (serde bridge) ───────────────────────────────────────────
49
50/// Serializable snapshot of the session delta.
51///
52/// `DashMap`/`DashSet` are not directly serializable, so we flatten
53/// them into `Vec`s. The shared `base_symbols` is intentionally excluded
54/// — it is repo-wide state that is not owned by the session.
55#[derive(Serialize, Deserialize)]
56struct SessionGraphSnapshot {
57    modified_symbols: Vec<(SymbolId, Symbol)>,
58    added_symbols: Vec<(SymbolId, Symbol)>,
59    removed_symbols: Vec<SymbolId>,
60    /// Names of removed symbols, so `changed_symbol_names()` works on
61    /// deserialized graphs without requiring the base symbol table.
62    #[serde(default)]
63    removed_symbol_names: Vec<(SymbolId, String)>,
64    added_edges: Vec<(Uuid, CallEdge)>,
65    removed_edges: Vec<Uuid>,
66}
67
68impl SessionGraph {
69    /// Fork from a shared base symbol table.
70    pub fn fork_from(base: Arc<ArcSwap<HashMap<SymbolId, Symbol>>>) -> Self {
71        Self {
72            base_symbols: Some(base),
73            modified_symbols: DashMap::new(),
74            added_symbols: DashMap::new(),
75            removed_symbols: DashSet::new(),
76            removed_symbol_names: DashMap::new(),
77            added_edges: DashMap::new(),
78            removed_edges: DashSet::new(),
79        }
80    }
81
82    /// Create an empty session graph (no shared base).
83    pub fn empty() -> Self {
84        Self {
85            base_symbols: None,
86            modified_symbols: DashMap::new(),
87            added_symbols: DashMap::new(),
88            removed_symbols: DashSet::new(),
89            removed_symbol_names: DashMap::new(),
90            added_edges: DashMap::new(),
91            removed_edges: DashSet::new(),
92        }
93    }
94
95    /// Look up a symbol by ID, respecting the session delta.
96    ///
97    /// Resolution order:
98    /// 1. If removed in this session, return `None`.
99    /// 2. If modified in this session, return the modified version.
100    /// 3. If added in this session, return it.
101    /// 4. Fall through to the shared base.
102    pub fn get_symbol(&self, id: SymbolId) -> Option<Symbol> {
103        // Removed in this session?
104        if self.removed_symbols.contains(&id) {
105            return None;
106        }
107
108        // Modified in this session?
109        if let Some(sym) = self.modified_symbols.get(&id) {
110            return Some(sym.value().clone());
111        }
112
113        // Added in this session?
114        if let Some(sym) = self.added_symbols.get(&id) {
115            return Some(sym.value().clone());
116        }
117
118        // Base lookup
119        if let Some(base) = &self.base_symbols {
120            let snapshot = base.load();
121            return snapshot.get(&id).cloned();
122        }
123
124        None
125    }
126
127    /// Add a new symbol to this session.
128    pub fn add_symbol(&self, symbol: Symbol) {
129        self.added_symbols.insert(symbol.id, symbol);
130    }
131
132    /// Modify an existing symbol (base or previously added).
133    pub fn modify_symbol(&self, symbol: Symbol) {
134        let id = symbol.id;
135
136        // If it was added in this session, update the added entry.
137        if self.added_symbols.contains_key(&id) {
138            self.added_symbols.insert(id, symbol);
139        } else {
140            self.modified_symbols.insert(id, symbol);
141        }
142    }
143
144    /// Remove a symbol from the session view.
145    pub fn remove_symbol(&self, id: SymbolId) {
146        // If it was added in this session, just drop it.
147        if self.added_symbols.remove(&id).is_some() {
148            return;
149        }
150
151        // If it was modified, capture the name before dropping.
152        if let Some((_, sym)) = self.modified_symbols.remove(&id) {
153            self.removed_symbol_names
154                .insert(id, sym.qualified_name.clone());
155        } else if let Some(base) = &self.base_symbols {
156            // Look up name from base for the cache.
157            let snapshot = base.load();
158            if let Some(sym) = snapshot.get(&id) {
159                self.removed_symbol_names
160                    .insert(id, sym.qualified_name.clone());
161            }
162        }
163
164        // Mark as removed from base.
165        self.removed_symbols.insert(id);
166    }
167
168    /// Add a call edge.
169    pub fn add_edge(&self, edge: CallEdge) {
170        self.added_edges.insert(edge.id, edge);
171    }
172
173    /// Remove a call edge.
174    pub fn remove_edge(&self, edge_id: Uuid) {
175        // If it was added in this session, just drop it.
176        if self.added_edges.remove(&edge_id).is_some() {
177            return;
178        }
179        self.removed_edges.insert(edge_id);
180    }
181
182    /// Look up an added edge by ID.
183    ///
184    /// Returns `None` if the edge was not added in this session or has been
185    /// removed.
186    pub fn get_edge(&self, edge_id: Uuid) -> Option<CallEdge> {
187        if self.removed_edges.contains(&edge_id) {
188            return None;
189        }
190        self.added_edges.get(&edge_id).map(|e| e.value().clone())
191    }
192
193    /// Returns `true` if the given edge ID is marked as removed in this
194    /// session.
195    pub fn is_edge_removed(&self, edge_id: Uuid) -> bool {
196        self.removed_edges.contains(&edge_id)
197    }
198
199    /// Return the names of all symbols changed in this session
200    /// (added, modified, or removed).
201    ///
202    /// Used by the conflict detector to find overlapping changes.
203    pub fn changed_symbol_names(&self) -> Vec<String> {
204        let mut names = Vec::new();
205
206        for entry in self.added_symbols.iter() {
207            names.push(entry.value().qualified_name.clone());
208        }
209
210        for entry in self.modified_symbols.iter() {
211            names.push(entry.value().qualified_name.clone());
212        }
213
214        // For removed symbols, try the base first, then the cached names
215        // (which are populated during remove_symbol and deserialization).
216        for id in self.removed_symbols.iter() {
217            let found = self
218                .base_symbols
219                .as_ref()
220                .and_then(|base| base.load().get(id.key()).map(|s| s.qualified_name.clone()));
221            if let Some(name) = found {
222                names.push(name);
223            } else if let Some(name) = self.removed_symbol_names.get(id.key()) {
224                names.push(name.value().clone());
225            }
226        }
227
228        names
229    }
230
231    /// Update the session graph from a parse result for a single file.
232    ///
233    /// Compares the new symbols against the base symbols for that file,
234    /// and classifies each as added, modified, or removed within the
235    /// session delta.
236    ///
237    /// `base_symbols_for_file` should contain all symbols from the base
238    /// that belong to the given file path.
239    pub fn update_from_parse(
240        &self,
241        _file_path: &str,
242        new_symbols: Vec<Symbol>,
243        base_symbols_for_file: &[Symbol],
244    ) {
245        // Build a lookup of base symbols by qualified name for this file.
246        let base_by_name: HashMap<&str, &Symbol> = base_symbols_for_file
247            .iter()
248            .map(|s| (s.qualified_name.as_str(), s))
249            .collect();
250
251        let new_by_name: HashMap<&str, &Symbol> = new_symbols
252            .iter()
253            .map(|s| (s.qualified_name.as_str(), s))
254            .collect();
255
256        // Symbols in new but not in base -> added
257        // Symbols in both but changed -> modified
258        for sym in &new_symbols {
259            if let Some(base_sym) = base_by_name.get(sym.qualified_name.as_str()) {
260                // Compare span, signature, etc. to detect modification.
261                if sym.span != base_sym.span
262                    || sym.signature != base_sym.signature
263                    || sym.kind != base_sym.kind
264                    || sym.visibility != base_sym.visibility
265                {
266                    self.modify_symbol(sym.clone());
267                }
268            } else {
269                self.add_symbol(sym.clone());
270            }
271        }
272
273        // Symbols in base but not in new -> removed
274        for base_sym in base_symbols_for_file {
275            if !new_by_name.contains_key(base_sym.qualified_name.as_str()) {
276                self.remove_symbol(base_sym.id);
277            }
278        }
279    }
280
281    /// Return the names of symbols changed in this session that belong
282    /// to the given file path. Useful for cross-session file awareness.
283    pub fn changed_symbols_for_file(&self, file_path: &str) -> Vec<String> {
284        let target = std::path::Path::new(file_path);
285        let mut names = Vec::new();
286
287        for entry in self.added_symbols.iter() {
288            if entry.value().file_path == target {
289                names.push(entry.value().name.clone());
290            }
291        }
292
293        for entry in self.modified_symbols.iter() {
294            if entry.value().file_path == target {
295                names.push(entry.value().name.clone());
296            }
297        }
298
299        names
300    }
301
302    /// Number of symbols changed (added + modified + removed).
303    pub fn change_count(&self) -> usize {
304        self.added_symbols.len() + self.modified_symbols.len() + self.removed_symbols.len()
305    }
306
307    // ── Serialization ─────────────────────────────────────────────────
308
309    /// Serialize the session delta (modified/added/removed symbols and edges)
310    /// to MessagePack bytes.
311    ///
312    /// The shared `base_symbols` table is NOT included — it is repo-wide
313    /// state managed independently of individual sessions.
314    pub fn to_msgpack(&self) -> anyhow::Result<Vec<u8>> {
315        let snapshot = SessionGraphSnapshot {
316            modified_symbols: self
317                .modified_symbols
318                .iter()
319                .map(|e| (*e.key(), e.value().clone()))
320                .collect(),
321            added_symbols: self
322                .added_symbols
323                .iter()
324                .map(|e| (*e.key(), e.value().clone()))
325                .collect(),
326            removed_symbols: self.removed_symbols.iter().map(|r| *r).collect(),
327            removed_symbol_names: self
328                .removed_symbol_names
329                .iter()
330                .map(|e| (*e.key(), e.value().clone()))
331                .collect(),
332            added_edges: self
333                .added_edges
334                .iter()
335                .map(|e| (*e.key(), e.value().clone()))
336                .collect(),
337            removed_edges: self.removed_edges.iter().map(|r| *r).collect(),
338        };
339
340        Ok(rmp_serde::to_vec_named(&snapshot)?)
341    }
342
343    /// Deserialize a session delta from MessagePack bytes produced by
344    /// [`Self::to_msgpack`].
345    ///
346    /// The returned graph has no shared base (`base_symbols` is `None`).
347    /// Callers that need base-symbol lookups must call
348    /// [`Self::fork_from`] and replay the delta on top.
349    pub fn from_msgpack(bytes: &[u8]) -> anyhow::Result<Self> {
350        let snapshot: SessionGraphSnapshot = rmp_serde::from_slice(bytes)?;
351
352        let modified_symbols = DashMap::new();
353        for (id, sym) in snapshot.modified_symbols {
354            modified_symbols.insert(id, sym);
355        }
356
357        let added_symbols = DashMap::new();
358        for (id, sym) in snapshot.added_symbols {
359            added_symbols.insert(id, sym);
360        }
361
362        let removed_symbols = DashSet::new();
363        for id in snapshot.removed_symbols {
364            removed_symbols.insert(id);
365        }
366
367        let removed_symbol_names = DashMap::new();
368        for (id, name) in snapshot.removed_symbol_names {
369            removed_symbol_names.insert(id, name);
370        }
371
372        let added_edges = DashMap::new();
373        for (id, edge) in snapshot.added_edges {
374            added_edges.insert(id, edge);
375        }
376
377        let removed_edges = DashSet::new();
378        for id in snapshot.removed_edges {
379            removed_edges.insert(id);
380        }
381
382        Ok(Self {
383            base_symbols: None,
384            modified_symbols,
385            added_symbols,
386            removed_symbols,
387            removed_symbol_names,
388            added_edges,
389            removed_edges,
390        })
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397    use dk_core::{Span, SymbolKind, Visibility};
398    use std::path::PathBuf;
399
400    fn make_symbol(name: &str) -> Symbol {
401        Symbol {
402            id: Uuid::new_v4(),
403            name: name.to_string(),
404            qualified_name: name.to_string(),
405            kind: SymbolKind::Function,
406            visibility: Visibility::Public,
407            file_path: PathBuf::from("test.rs"),
408            span: Span {
409                start_byte: 0,
410                end_byte: 10,
411            },
412            signature: None,
413            doc_comment: None,
414            parent: None,
415            last_modified_by: None,
416            last_modified_intent: None,
417        }
418    }
419
420    #[test]
421    fn empty_graph_returns_none() {
422        let g = SessionGraph::empty();
423        assert!(g.get_symbol(Uuid::new_v4()).is_none());
424    }
425
426    #[test]
427    fn add_and_get_symbol() {
428        let g = SessionGraph::empty();
429        let sym = make_symbol("foo");
430        let id = sym.id;
431        g.add_symbol(sym);
432        assert!(g.get_symbol(id).is_some());
433        assert_eq!(g.get_symbol(id).unwrap().name, "foo");
434    }
435
436    #[test]
437    fn remove_added_symbol() {
438        let g = SessionGraph::empty();
439        let sym = make_symbol("bar");
440        let id = sym.id;
441        g.add_symbol(sym);
442        g.remove_symbol(id);
443        assert!(g.get_symbol(id).is_none());
444    }
445
446    #[test]
447    fn modify_added_symbol_updates_in_place() {
448        let g = SessionGraph::empty();
449        let mut sym = make_symbol("baz");
450        let id = sym.id;
451        g.add_symbol(sym.clone());
452
453        sym.name = "baz_v2".to_string();
454        g.modify_symbol(sym);
455
456        let got = g.get_symbol(id).unwrap();
457        assert_eq!(got.name, "baz_v2");
458    }
459
460    #[test]
461    fn fork_from_base_lookup() {
462        let mut base = HashMap::new();
463        let sym = make_symbol("base_fn");
464        let id = sym.id;
465        base.insert(id, sym);
466
467        let shared = Arc::new(ArcSwap::from_pointee(base));
468        let g = SessionGraph::fork_from(shared);
469
470        assert!(g.get_symbol(id).is_some());
471        assert_eq!(g.get_symbol(id).unwrap().name, "base_fn");
472    }
473
474    #[test]
475    fn remove_base_symbol_hides_it() {
476        let mut base = HashMap::new();
477        let sym = make_symbol("base_fn");
478        let id = sym.id;
479        base.insert(id, sym);
480
481        let shared = Arc::new(ArcSwap::from_pointee(base));
482        let g = SessionGraph::fork_from(shared);
483
484        g.remove_symbol(id);
485        assert!(g.get_symbol(id).is_none());
486    }
487
488    #[test]
489    fn changed_symbol_names_collects_all() {
490        let mut base = HashMap::new();
491        let sym = make_symbol("removed_fn");
492        let removed_id = sym.id;
493        base.insert(removed_id, sym);
494
495        let shared = Arc::new(ArcSwap::from_pointee(base));
496        let g = SessionGraph::fork_from(shared);
497
498        let added = make_symbol("added_fn");
499        g.add_symbol(added);
500
501        let mut modified = make_symbol("modified_fn");
502        modified.id = Uuid::new_v4();
503        let mid = modified.id;
504        // Pretend it's in base by inserting to modified_symbols directly
505        g.modified_symbols.insert(mid, modified);
506
507        g.remove_symbol(removed_id);
508
509        let names = g.changed_symbol_names();
510        assert!(names.contains(&"added_fn".to_string()));
511        assert!(names.contains(&"modified_fn".to_string()));
512        assert!(names.contains(&"removed_fn".to_string()));
513    }
514
515    #[test]
516    fn change_count() {
517        let g = SessionGraph::empty();
518        assert_eq!(g.change_count(), 0);
519
520        g.add_symbol(make_symbol("a"));
521        assert_eq!(g.change_count(), 1);
522    }
523
524    #[test]
525    fn changed_symbols_for_file_filters_by_path() {
526        let g = SessionGraph::empty();
527
528        let mut sym1 = make_symbol("create_task");
529        sym1.file_path = PathBuf::from("src/tasks.rs");
530        g.add_symbol(sym1);
531
532        let mut sym2 = make_symbol("delete_task");
533        sym2.file_path = PathBuf::from("src/tasks.rs");
534        g.add_symbol(sym2);
535
536        let mut sym3 = make_symbol("run_server");
537        sym3.file_path = PathBuf::from("src/main.rs");
538        g.add_symbol(sym3);
539
540        let task_syms = g.changed_symbols_for_file("src/tasks.rs");
541        assert_eq!(task_syms.len(), 2);
542        assert!(task_syms.contains(&"create_task".to_string()));
543        assert!(task_syms.contains(&"delete_task".to_string()));
544
545        let main_syms = g.changed_symbols_for_file("src/main.rs");
546        assert_eq!(main_syms.len(), 1);
547        assert!(main_syms.contains(&"run_server".to_string()));
548
549        let empty = g.changed_symbols_for_file("src/nonexistent.rs");
550        assert!(empty.is_empty());
551    }
552}