1use std::collections::HashMap;
9use std::sync::Arc;
10
11use arc_swap::ArcSwap;
12use dashmap::{DashMap, DashSet};
13use dk_core::{CallEdge, Symbol, SymbolId};
14use serde::{Deserialize, Serialize};
15use uuid::Uuid;
16
17pub struct SessionGraph {
25 base_symbols: Option<Arc<ArcSwap<HashMap<SymbolId, Symbol>>>>,
27
28 pub(crate) modified_symbols: DashMap<SymbolId, Symbol>,
30
31 added_symbols: DashMap<SymbolId, Symbol>,
33
34 pub(crate) removed_symbols: DashSet<SymbolId>,
36
37 removed_symbol_names: DashMap<SymbolId, String>,
40
41 pub(crate) added_edges: DashMap<Uuid, CallEdge>,
43
44 pub(crate) removed_edges: DashSet<Uuid>,
46}
47
48#[derive(Serialize, Deserialize)]
56struct SessionGraphSnapshot {
57 modified_symbols: Vec<(SymbolId, Symbol)>,
58 added_symbols: Vec<(SymbolId, Symbol)>,
59 removed_symbols: Vec<SymbolId>,
60 #[serde(default)]
63 removed_symbol_names: Vec<(SymbolId, String)>,
64 added_edges: Vec<(Uuid, CallEdge)>,
65 removed_edges: Vec<Uuid>,
66}
67
68impl SessionGraph {
69 pub fn fork_from(base: Arc<ArcSwap<HashMap<SymbolId, Symbol>>>) -> Self {
71 Self {
72 base_symbols: Some(base),
73 modified_symbols: DashMap::new(),
74 added_symbols: DashMap::new(),
75 removed_symbols: DashSet::new(),
76 removed_symbol_names: DashMap::new(),
77 added_edges: DashMap::new(),
78 removed_edges: DashSet::new(),
79 }
80 }
81
82 pub fn empty() -> Self {
84 Self {
85 base_symbols: None,
86 modified_symbols: DashMap::new(),
87 added_symbols: DashMap::new(),
88 removed_symbols: DashSet::new(),
89 removed_symbol_names: DashMap::new(),
90 added_edges: DashMap::new(),
91 removed_edges: DashSet::new(),
92 }
93 }
94
95 pub fn get_symbol(&self, id: SymbolId) -> Option<Symbol> {
103 if self.removed_symbols.contains(&id) {
105 return None;
106 }
107
108 if let Some(sym) = self.modified_symbols.get(&id) {
110 return Some(sym.value().clone());
111 }
112
113 if let Some(sym) = self.added_symbols.get(&id) {
115 return Some(sym.value().clone());
116 }
117
118 if let Some(base) = &self.base_symbols {
120 let snapshot = base.load();
121 return snapshot.get(&id).cloned();
122 }
123
124 None
125 }
126
127 pub fn add_symbol(&self, symbol: Symbol) {
129 self.added_symbols.insert(symbol.id, symbol);
130 }
131
132 pub fn modify_symbol(&self, symbol: Symbol) {
134 let id = symbol.id;
135
136 if self.added_symbols.contains_key(&id) {
138 self.added_symbols.insert(id, symbol);
139 } else {
140 self.modified_symbols.insert(id, symbol);
141 }
142 }
143
144 pub fn remove_symbol(&self, id: SymbolId) {
146 if self.added_symbols.remove(&id).is_some() {
148 return;
149 }
150
151 if let Some((_, sym)) = self.modified_symbols.remove(&id) {
153 self.removed_symbol_names
154 .insert(id, sym.qualified_name.clone());
155 } else if let Some(base) = &self.base_symbols {
156 let snapshot = base.load();
158 if let Some(sym) = snapshot.get(&id) {
159 self.removed_symbol_names
160 .insert(id, sym.qualified_name.clone());
161 }
162 }
163
164 self.removed_symbols.insert(id);
166 }
167
168 pub fn add_edge(&self, edge: CallEdge) {
170 self.added_edges.insert(edge.id, edge);
171 }
172
173 pub fn remove_edge(&self, edge_id: Uuid) {
175 if self.added_edges.remove(&edge_id).is_some() {
177 return;
178 }
179 self.removed_edges.insert(edge_id);
180 }
181
182 pub fn get_edge(&self, edge_id: Uuid) -> Option<CallEdge> {
187 if self.removed_edges.contains(&edge_id) {
188 return None;
189 }
190 self.added_edges.get(&edge_id).map(|e| e.value().clone())
191 }
192
193 pub fn is_edge_removed(&self, edge_id: Uuid) -> bool {
196 self.removed_edges.contains(&edge_id)
197 }
198
199 pub fn changed_symbol_names(&self) -> Vec<String> {
204 let mut names = Vec::new();
205
206 for entry in self.added_symbols.iter() {
207 names.push(entry.value().qualified_name.clone());
208 }
209
210 for entry in self.modified_symbols.iter() {
211 names.push(entry.value().qualified_name.clone());
212 }
213
214 for id in self.removed_symbols.iter() {
217 let found = self
218 .base_symbols
219 .as_ref()
220 .and_then(|base| base.load().get(id.key()).map(|s| s.qualified_name.clone()));
221 if let Some(name) = found {
222 names.push(name);
223 } else if let Some(name) = self.removed_symbol_names.get(id.key()) {
224 names.push(name.value().clone());
225 }
226 }
227
228 names
229 }
230
231 pub fn update_from_parse(
240 &self,
241 _file_path: &str,
242 new_symbols: Vec<Symbol>,
243 base_symbols_for_file: &[Symbol],
244 ) {
245 let base_by_name: HashMap<&str, &Symbol> = base_symbols_for_file
247 .iter()
248 .map(|s| (s.qualified_name.as_str(), s))
249 .collect();
250
251 let new_by_name: HashMap<&str, &Symbol> = new_symbols
252 .iter()
253 .map(|s| (s.qualified_name.as_str(), s))
254 .collect();
255
256 for sym in &new_symbols {
259 if let Some(base_sym) = base_by_name.get(sym.qualified_name.as_str()) {
260 if sym.span != base_sym.span
262 || sym.signature != base_sym.signature
263 || sym.kind != base_sym.kind
264 || sym.visibility != base_sym.visibility
265 {
266 self.modify_symbol(sym.clone());
267 }
268 } else {
269 self.add_symbol(sym.clone());
270 }
271 }
272
273 for base_sym in base_symbols_for_file {
275 if !new_by_name.contains_key(base_sym.qualified_name.as_str()) {
276 self.remove_symbol(base_sym.id);
277 }
278 }
279 }
280
281 pub fn changed_symbols_for_file(&self, file_path: &str) -> Vec<String> {
284 let target = std::path::Path::new(file_path);
285 let mut names = Vec::new();
286
287 for entry in self.added_symbols.iter() {
288 if entry.value().file_path == target {
289 names.push(entry.value().name.clone());
290 }
291 }
292
293 for entry in self.modified_symbols.iter() {
294 if entry.value().file_path == target {
295 names.push(entry.value().name.clone());
296 }
297 }
298
299 names
300 }
301
302 pub fn change_count(&self) -> usize {
304 self.added_symbols.len() + self.modified_symbols.len() + self.removed_symbols.len()
305 }
306
307 pub fn to_msgpack(&self) -> anyhow::Result<Vec<u8>> {
315 let snapshot = SessionGraphSnapshot {
316 modified_symbols: self
317 .modified_symbols
318 .iter()
319 .map(|e| (*e.key(), e.value().clone()))
320 .collect(),
321 added_symbols: self
322 .added_symbols
323 .iter()
324 .map(|e| (*e.key(), e.value().clone()))
325 .collect(),
326 removed_symbols: self.removed_symbols.iter().map(|r| *r).collect(),
327 removed_symbol_names: self
328 .removed_symbol_names
329 .iter()
330 .map(|e| (*e.key(), e.value().clone()))
331 .collect(),
332 added_edges: self
333 .added_edges
334 .iter()
335 .map(|e| (*e.key(), e.value().clone()))
336 .collect(),
337 removed_edges: self.removed_edges.iter().map(|r| *r).collect(),
338 };
339
340 Ok(rmp_serde::to_vec_named(&snapshot)?)
341 }
342
343 pub fn from_msgpack(bytes: &[u8]) -> anyhow::Result<Self> {
350 let snapshot: SessionGraphSnapshot = rmp_serde::from_slice(bytes)?;
351
352 let modified_symbols = DashMap::new();
353 for (id, sym) in snapshot.modified_symbols {
354 modified_symbols.insert(id, sym);
355 }
356
357 let added_symbols = DashMap::new();
358 for (id, sym) in snapshot.added_symbols {
359 added_symbols.insert(id, sym);
360 }
361
362 let removed_symbols = DashSet::new();
363 for id in snapshot.removed_symbols {
364 removed_symbols.insert(id);
365 }
366
367 let removed_symbol_names = DashMap::new();
368 for (id, name) in snapshot.removed_symbol_names {
369 removed_symbol_names.insert(id, name);
370 }
371
372 let added_edges = DashMap::new();
373 for (id, edge) in snapshot.added_edges {
374 added_edges.insert(id, edge);
375 }
376
377 let removed_edges = DashSet::new();
378 for id in snapshot.removed_edges {
379 removed_edges.insert(id);
380 }
381
382 Ok(Self {
383 base_symbols: None,
384 modified_symbols,
385 added_symbols,
386 removed_symbols,
387 removed_symbol_names,
388 added_edges,
389 removed_edges,
390 })
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397 use dk_core::{Span, SymbolKind, Visibility};
398 use std::path::PathBuf;
399
400 fn make_symbol(name: &str) -> Symbol {
401 Symbol {
402 id: Uuid::new_v4(),
403 name: name.to_string(),
404 qualified_name: name.to_string(),
405 kind: SymbolKind::Function,
406 visibility: Visibility::Public,
407 file_path: PathBuf::from("test.rs"),
408 span: Span {
409 start_byte: 0,
410 end_byte: 10,
411 },
412 signature: None,
413 doc_comment: None,
414 parent: None,
415 last_modified_by: None,
416 last_modified_intent: None,
417 }
418 }
419
420 #[test]
421 fn empty_graph_returns_none() {
422 let g = SessionGraph::empty();
423 assert!(g.get_symbol(Uuid::new_v4()).is_none());
424 }
425
426 #[test]
427 fn add_and_get_symbol() {
428 let g = SessionGraph::empty();
429 let sym = make_symbol("foo");
430 let id = sym.id;
431 g.add_symbol(sym);
432 assert!(g.get_symbol(id).is_some());
433 assert_eq!(g.get_symbol(id).unwrap().name, "foo");
434 }
435
436 #[test]
437 fn remove_added_symbol() {
438 let g = SessionGraph::empty();
439 let sym = make_symbol("bar");
440 let id = sym.id;
441 g.add_symbol(sym);
442 g.remove_symbol(id);
443 assert!(g.get_symbol(id).is_none());
444 }
445
446 #[test]
447 fn modify_added_symbol_updates_in_place() {
448 let g = SessionGraph::empty();
449 let mut sym = make_symbol("baz");
450 let id = sym.id;
451 g.add_symbol(sym.clone());
452
453 sym.name = "baz_v2".to_string();
454 g.modify_symbol(sym);
455
456 let got = g.get_symbol(id).unwrap();
457 assert_eq!(got.name, "baz_v2");
458 }
459
460 #[test]
461 fn fork_from_base_lookup() {
462 let mut base = HashMap::new();
463 let sym = make_symbol("base_fn");
464 let id = sym.id;
465 base.insert(id, sym);
466
467 let shared = Arc::new(ArcSwap::from_pointee(base));
468 let g = SessionGraph::fork_from(shared);
469
470 assert!(g.get_symbol(id).is_some());
471 assert_eq!(g.get_symbol(id).unwrap().name, "base_fn");
472 }
473
474 #[test]
475 fn remove_base_symbol_hides_it() {
476 let mut base = HashMap::new();
477 let sym = make_symbol("base_fn");
478 let id = sym.id;
479 base.insert(id, sym);
480
481 let shared = Arc::new(ArcSwap::from_pointee(base));
482 let g = SessionGraph::fork_from(shared);
483
484 g.remove_symbol(id);
485 assert!(g.get_symbol(id).is_none());
486 }
487
488 #[test]
489 fn changed_symbol_names_collects_all() {
490 let mut base = HashMap::new();
491 let sym = make_symbol("removed_fn");
492 let removed_id = sym.id;
493 base.insert(removed_id, sym);
494
495 let shared = Arc::new(ArcSwap::from_pointee(base));
496 let g = SessionGraph::fork_from(shared);
497
498 let added = make_symbol("added_fn");
499 g.add_symbol(added);
500
501 let mut modified = make_symbol("modified_fn");
502 modified.id = Uuid::new_v4();
503 let mid = modified.id;
504 g.modified_symbols.insert(mid, modified);
506
507 g.remove_symbol(removed_id);
508
509 let names = g.changed_symbol_names();
510 assert!(names.contains(&"added_fn".to_string()));
511 assert!(names.contains(&"modified_fn".to_string()));
512 assert!(names.contains(&"removed_fn".to_string()));
513 }
514
515 #[test]
516 fn change_count() {
517 let g = SessionGraph::empty();
518 assert_eq!(g.change_count(), 0);
519
520 g.add_symbol(make_symbol("a"));
521 assert_eq!(g.change_count(), 1);
522 }
523
524 #[test]
525 fn changed_symbols_for_file_filters_by_path() {
526 let g = SessionGraph::empty();
527
528 let mut sym1 = make_symbol("create_task");
529 sym1.file_path = PathBuf::from("src/tasks.rs");
530 g.add_symbol(sym1);
531
532 let mut sym2 = make_symbol("delete_task");
533 sym2.file_path = PathBuf::from("src/tasks.rs");
534 g.add_symbol(sym2);
535
536 let mut sym3 = make_symbol("run_server");
537 sym3.file_path = PathBuf::from("src/main.rs");
538 g.add_symbol(sym3);
539
540 let task_syms = g.changed_symbols_for_file("src/tasks.rs");
541 assert_eq!(task_syms.len(), 2);
542 assert!(task_syms.contains(&"create_task".to_string()));
543 assert!(task_syms.contains(&"delete_task".to_string()));
544
545 let main_syms = g.changed_symbols_for_file("src/main.rs");
546 assert_eq!(main_syms.len(), 1);
547 assert!(main_syms.contains(&"run_server".to_string()));
548
549 let empty = g.changed_symbols_for_file("src/nonexistent.rs");
550 assert!(empty.is_empty());
551 }
552}