dk_engine/workspace/
session_graph.rs1use std::collections::HashMap;
9use std::sync::Arc;
10
11use arc_swap::ArcSwap;
12use dashmap::{DashMap, DashSet};
13use dk_core::{CallEdge, Symbol, SymbolId};
14use uuid::Uuid;
15
16pub struct SessionGraph {
24 base_symbols: Option<Arc<ArcSwap<HashMap<SymbolId, Symbol>>>>,
26
27 modified_symbols: DashMap<SymbolId, Symbol>,
29
30 added_symbols: DashMap<SymbolId, Symbol>,
32
33 removed_symbols: DashSet<SymbolId>,
35
36 added_edges: DashMap<Uuid, CallEdge>,
38
39 removed_edges: DashSet<Uuid>,
41}
42
43impl SessionGraph {
44 pub fn fork_from(base: Arc<ArcSwap<HashMap<SymbolId, Symbol>>>) -> Self {
46 Self {
47 base_symbols: Some(base),
48 modified_symbols: DashMap::new(),
49 added_symbols: DashMap::new(),
50 removed_symbols: DashSet::new(),
51 added_edges: DashMap::new(),
52 removed_edges: DashSet::new(),
53 }
54 }
55
56 pub fn empty() -> Self {
58 Self {
59 base_symbols: None,
60 modified_symbols: DashMap::new(),
61 added_symbols: DashMap::new(),
62 removed_symbols: DashSet::new(),
63 added_edges: DashMap::new(),
64 removed_edges: DashSet::new(),
65 }
66 }
67
68 pub fn get_symbol(&self, id: SymbolId) -> Option<Symbol> {
76 if self.removed_symbols.contains(&id) {
78 return None;
79 }
80
81 if let Some(sym) = self.modified_symbols.get(&id) {
83 return Some(sym.value().clone());
84 }
85
86 if let Some(sym) = self.added_symbols.get(&id) {
88 return Some(sym.value().clone());
89 }
90
91 if let Some(base) = &self.base_symbols {
93 let snapshot = base.load();
94 return snapshot.get(&id).cloned();
95 }
96
97 None
98 }
99
100 pub fn add_symbol(&self, symbol: Symbol) {
102 self.added_symbols.insert(symbol.id, symbol);
103 }
104
105 pub fn modify_symbol(&self, symbol: Symbol) {
107 let id = symbol.id;
108
109 if self.added_symbols.contains_key(&id) {
111 self.added_symbols.insert(id, symbol);
112 } else {
113 self.modified_symbols.insert(id, symbol);
114 }
115 }
116
117 pub fn remove_symbol(&self, id: SymbolId) {
119 if self.added_symbols.remove(&id).is_some() {
121 return;
122 }
123
124 self.modified_symbols.remove(&id);
126
127 self.removed_symbols.insert(id);
129 }
130
131 pub fn add_edge(&self, edge: CallEdge) {
133 self.added_edges.insert(edge.id, edge);
134 }
135
136 pub fn remove_edge(&self, edge_id: Uuid) {
138 if self.added_edges.remove(&edge_id).is_some() {
140 return;
141 }
142 self.removed_edges.insert(edge_id);
143 }
144
145 pub fn changed_symbol_names(&self) -> Vec<String> {
150 let mut names = Vec::new();
151
152 for entry in self.added_symbols.iter() {
153 names.push(entry.value().qualified_name.clone());
154 }
155
156 for entry in self.modified_symbols.iter() {
157 names.push(entry.value().qualified_name.clone());
158 }
159
160 if let Some(base) = &self.base_symbols {
162 let snapshot = base.load();
163 for id in self.removed_symbols.iter() {
164 if let Some(sym) = snapshot.get(id.key()) {
165 names.push(sym.qualified_name.clone());
166 }
167 }
168 }
169
170 names
171 }
172
173 pub fn update_from_parse(
182 &self,
183 _file_path: &str,
184 new_symbols: Vec<Symbol>,
185 base_symbols_for_file: &[Symbol],
186 ) {
187 let base_by_name: HashMap<&str, &Symbol> = base_symbols_for_file
189 .iter()
190 .map(|s| (s.qualified_name.as_str(), s))
191 .collect();
192
193 let new_by_name: HashMap<&str, &Symbol> = new_symbols
194 .iter()
195 .map(|s| (s.qualified_name.as_str(), s))
196 .collect();
197
198 for sym in &new_symbols {
201 if let Some(base_sym) = base_by_name.get(sym.qualified_name.as_str()) {
202 if sym.span != base_sym.span
204 || sym.signature != base_sym.signature
205 || sym.kind != base_sym.kind
206 || sym.visibility != base_sym.visibility
207 {
208 self.modify_symbol(sym.clone());
209 }
210 } else {
211 self.add_symbol(sym.clone());
212 }
213 }
214
215 for base_sym in base_symbols_for_file {
217 if !new_by_name.contains_key(base_sym.qualified_name.as_str()) {
218 self.remove_symbol(base_sym.id);
219 }
220 }
221 }
222
223 pub fn changed_symbols_for_file(&self, file_path: &str) -> Vec<String> {
226 let target = std::path::Path::new(file_path);
227 let mut names = Vec::new();
228
229 for entry in self.added_symbols.iter() {
230 if entry.value().file_path == target {
231 names.push(entry.value().name.clone());
232 }
233 }
234
235 for entry in self.modified_symbols.iter() {
236 if entry.value().file_path == target {
237 names.push(entry.value().name.clone());
238 }
239 }
240
241 names
242 }
243
244 pub fn change_count(&self) -> usize {
246 self.added_symbols.len() + self.modified_symbols.len() + self.removed_symbols.len()
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253 use dk_core::{Span, SymbolKind, Visibility};
254 use std::path::PathBuf;
255
256 fn make_symbol(name: &str) -> Symbol {
257 Symbol {
258 id: Uuid::new_v4(),
259 name: name.to_string(),
260 qualified_name: name.to_string(),
261 kind: SymbolKind::Function,
262 visibility: Visibility::Public,
263 file_path: PathBuf::from("test.rs"),
264 span: Span {
265 start_byte: 0,
266 end_byte: 10,
267 },
268 signature: None,
269 doc_comment: None,
270 parent: None,
271 last_modified_by: None,
272 last_modified_intent: None,
273 }
274 }
275
276 #[test]
277 fn empty_graph_returns_none() {
278 let g = SessionGraph::empty();
279 assert!(g.get_symbol(Uuid::new_v4()).is_none());
280 }
281
282 #[test]
283 fn add_and_get_symbol() {
284 let g = SessionGraph::empty();
285 let sym = make_symbol("foo");
286 let id = sym.id;
287 g.add_symbol(sym);
288 assert!(g.get_symbol(id).is_some());
289 assert_eq!(g.get_symbol(id).unwrap().name, "foo");
290 }
291
292 #[test]
293 fn remove_added_symbol() {
294 let g = SessionGraph::empty();
295 let sym = make_symbol("bar");
296 let id = sym.id;
297 g.add_symbol(sym);
298 g.remove_symbol(id);
299 assert!(g.get_symbol(id).is_none());
300 }
301
302 #[test]
303 fn modify_added_symbol_updates_in_place() {
304 let g = SessionGraph::empty();
305 let mut sym = make_symbol("baz");
306 let id = sym.id;
307 g.add_symbol(sym.clone());
308
309 sym.name = "baz_v2".to_string();
310 g.modify_symbol(sym);
311
312 let got = g.get_symbol(id).unwrap();
313 assert_eq!(got.name, "baz_v2");
314 }
315
316 #[test]
317 fn fork_from_base_lookup() {
318 let mut base = HashMap::new();
319 let sym = make_symbol("base_fn");
320 let id = sym.id;
321 base.insert(id, sym);
322
323 let shared = Arc::new(ArcSwap::from_pointee(base));
324 let g = SessionGraph::fork_from(shared);
325
326 assert!(g.get_symbol(id).is_some());
327 assert_eq!(g.get_symbol(id).unwrap().name, "base_fn");
328 }
329
330 #[test]
331 fn remove_base_symbol_hides_it() {
332 let mut base = HashMap::new();
333 let sym = make_symbol("base_fn");
334 let id = sym.id;
335 base.insert(id, sym);
336
337 let shared = Arc::new(ArcSwap::from_pointee(base));
338 let g = SessionGraph::fork_from(shared);
339
340 g.remove_symbol(id);
341 assert!(g.get_symbol(id).is_none());
342 }
343
344 #[test]
345 fn changed_symbol_names_collects_all() {
346 let mut base = HashMap::new();
347 let sym = make_symbol("removed_fn");
348 let removed_id = sym.id;
349 base.insert(removed_id, sym);
350
351 let shared = Arc::new(ArcSwap::from_pointee(base));
352 let g = SessionGraph::fork_from(shared);
353
354 let added = make_symbol("added_fn");
355 g.add_symbol(added);
356
357 let mut modified = make_symbol("modified_fn");
358 modified.id = Uuid::new_v4();
359 let mid = modified.id;
360 g.modified_symbols.insert(mid, modified);
362
363 g.remove_symbol(removed_id);
364
365 let names = g.changed_symbol_names();
366 assert!(names.contains(&"added_fn".to_string()));
367 assert!(names.contains(&"modified_fn".to_string()));
368 assert!(names.contains(&"removed_fn".to_string()));
369 }
370
371 #[test]
372 fn change_count() {
373 let g = SessionGraph::empty();
374 assert_eq!(g.change_count(), 0);
375
376 g.add_symbol(make_symbol("a"));
377 assert_eq!(g.change_count(), 1);
378 }
379
380 #[test]
381 fn changed_symbols_for_file_filters_by_path() {
382 let g = SessionGraph::empty();
383
384 let mut sym1 = make_symbol("create_task");
385 sym1.file_path = PathBuf::from("src/tasks.rs");
386 g.add_symbol(sym1);
387
388 let mut sym2 = make_symbol("delete_task");
389 sym2.file_path = PathBuf::from("src/tasks.rs");
390 g.add_symbol(sym2);
391
392 let mut sym3 = make_symbol("run_server");
393 sym3.file_path = PathBuf::from("src/main.rs");
394 g.add_symbol(sym3);
395
396 let task_syms = g.changed_symbols_for_file("src/tasks.rs");
397 assert_eq!(task_syms.len(), 2);
398 assert!(task_syms.contains(&"create_task".to_string()));
399 assert!(task_syms.contains(&"delete_task".to_string()));
400
401 let main_syms = g.changed_symbols_for_file("src/main.rs");
402 assert_eq!(main_syms.len(), 1);
403 assert!(main_syms.contains(&"run_server".to_string()));
404
405 let empty = g.changed_symbols_for_file("src/nonexistent.rs");
406 assert!(empty.is_empty());
407 }
408}