polykit_core/
graph.rs

1//! Dependency graph management using petgraph.
2
3use std::collections::HashSet;
4use std::fs;
5use std::path::Path;
6use std::sync::Arc;
7
8use petgraph::algo::toposort;
9use petgraph::graph::{DiGraph, NodeIndex};
10use petgraph::Direction;
11use petgraph::visit::EdgeRef;
12use rustc_hash::FxHashMap;
13use serde::{Deserialize, Serialize};
14use dashmap::DashMap;
15
16use crate::error::{Error, Result};
17use crate::package::Package;
18use crate::string_interner::intern;
19
20/// Tracks changes to packages for incremental graph updates.
21#[derive(Debug, Clone)]
22pub struct GraphChange {
23    pub added: Vec<Package>,
24    pub modified: Vec<Package>,
25    pub removed: Vec<String>,
26    pub dependency_changes: Vec<(String, Vec<String>)>,
27}
28
29#[derive(Debug, Clone)]
30pub struct GraphNode {
31    pub package: Package,
32    pub index: NodeIndex,
33}
34
35/// Serializable graph data for persistence.
36#[derive(Debug, Clone, Serialize, Deserialize)]
37struct SerializableGraph {
38    packages: Vec<Package>,
39    #[serde(serialize_with = "serialize_arc_str_vec")]
40    #[serde(deserialize_with = "deserialize_arc_str_vec")]
41    topological_order: Vec<Arc<str>>,
42    #[serde(serialize_with = "serialize_arc_str_vec_vec")]
43    #[serde(deserialize_with = "deserialize_arc_str_vec_vec")]
44    dependency_levels: Vec<Vec<Arc<str>>>,
45}
46
47fn serialize_arc_str_vec<S>(vec: &[Arc<str>], serializer: S) -> std::result::Result<S::Ok, S::Error>
48where
49    S: serde::Serializer,
50{
51    use serde::Serialize;
52    let strings: Vec<&str> = vec.iter().map(|s| s.as_ref()).collect();
53    strings.serialize(serializer)
54}
55
56fn deserialize_arc_str_vec<'de, D>(deserializer: D) -> std::result::Result<Vec<Arc<str>>, D::Error>
57where
58    D: serde::Deserializer<'de>,
59{
60    use serde::Deserialize;
61    let strings: Vec<String> = Vec::deserialize(deserializer)?;
62    Ok(strings.into_iter().map(Arc::from).collect())
63}
64
65fn serialize_arc_str_vec_vec<S>(
66    vec: &[Vec<Arc<str>>],
67    serializer: S,
68) -> std::result::Result<S::Ok, S::Error>
69where
70    S: serde::Serializer,
71{
72    use serde::Serialize;
73    let strings: Vec<Vec<&str>> = vec
74        .iter()
75        .map(|level| level.iter().map(|s| s.as_ref()).collect())
76        .collect();
77    strings.serialize(serializer)
78}
79
80fn deserialize_arc_str_vec_vec<'de, D>(
81    deserializer: D,
82) -> std::result::Result<Vec<Vec<Arc<str>>>, D::Error>
83where
84    D: serde::Deserializer<'de>,
85{
86    use serde::Deserialize;
87    let strings: Vec<Vec<String>> = Vec::deserialize(deserializer)?;
88    Ok(strings
89        .into_iter()
90        .map(|level| level.into_iter().map(Arc::from).collect())
91        .collect())
92}
93
94/// Directed acyclic graph of package dependencies.
95///
96/// Uses compact u32 indices internally for better memory efficiency and cache performance.
97/// Public API maintains String-based interface for compatibility.
98#[derive(Debug, Clone)]
99pub struct DependencyGraph {
100    // Internal compact representation using u32 indices
101    graph: DiGraph<u32, ()>,
102    id_to_name: Vec<Arc<str>>,
103    name_to_id: FxHashMap<Arc<str>, u32>,
104    
105    // Package data indexed by u32 ID
106    packages: Vec<Package>,
107    
108    // Cached computations using u32 indices internally
109    cached_topological_order: Vec<u32>,
110    dependency_levels: Vec<Vec<u32>>,
111    
112    // Public API compatibility (kept for fast lookups)
113    name_to_node: FxHashMap<String, NodeIndex>,
114    node_to_id: FxHashMap<NodeIndex, u32>,
115    
116    // Cache for transitive dependencies
117    transitive_deps_cache: DashMap<String, Arc<HashSet<String>>>,
118}
119
120impl DependencyGraph {
121    /// Creates a new dependency graph from a list of packages.
122    ///
123    /// # Errors
124    ///
125    /// Returns an error if circular dependencies are detected.
126    pub fn new(packages: Vec<Package>) -> Result<Self> {
127        let package_count = packages.len();
128        let mut graph = DiGraph::with_capacity(package_count, package_count * 2);
129        let mut id_to_name = Vec::with_capacity(package_count);
130        let mut name_to_id = FxHashMap::with_capacity_and_hasher(package_count, Default::default());
131        let mut name_to_node = FxHashMap::with_capacity_and_hasher(package_count, Default::default());
132        let mut node_to_id = FxHashMap::with_capacity_and_hasher(package_count, Default::default());
133        let mut packages_vec = Vec::with_capacity(package_count);
134
135        // First pass: intern names and assign u32 IDs
136        for (idx, package) in packages.iter().enumerate() {
137            let name_arc = intern(&package.name);
138            let id = idx as u32;
139            
140            id_to_name.push(Arc::clone(&name_arc));
141            name_to_id.insert(Arc::clone(&name_arc), id);
142            packages_vec.push(package.clone());
143        }
144
145        // Second pass: add nodes to graph with u32 IDs
146        for (idx, package) in packages.iter().enumerate() {
147            let id = idx as u32;
148            let node = graph.add_node(id);
149            name_to_node.insert(package.name.clone(), node);
150            node_to_id.insert(node, id);
151        }
152
153        // Third pass: add edges
154        for package in &packages {
155            let from_node = name_to_node.get(&package.name).unwrap();
156
157            for dep_name in &package.deps {
158                let to_node = name_to_node.get(dep_name).ok_or_else(|| {
159                    let name = dep_name.clone();
160                    Error::PackageNotFound {
161                        name: name.clone(),
162                        available: format!("Dependency '{}' not found", name),
163                    }
164                })?;
165
166                graph.add_edge(*from_node, *to_node, ());
167            }
168        }
169
170        let sorted = toposort(&graph, None).map_err(|cycle| {
171            let cycle_id = graph[cycle.node_id()];
172            let cycle_name = &id_to_name[cycle_id as usize];
173            Error::CircularDependency(format!("Cycle detected involving: {}", cycle_name))
174        })?;
175
176        let topological_order: Vec<u32> = sorted
177            .into_iter()
178            .rev()
179            .map(|idx| graph[idx])
180            .collect();
181
182        let dependency_levels =
183            Self::compute_dependency_levels_compact(&graph, &node_to_id, &topological_order)?;
184
185        Ok(Self {
186            graph,
187            id_to_name,
188            name_to_id,
189            packages: packages_vec,
190            cached_topological_order: topological_order,
191            dependency_levels,
192            name_to_node,
193            node_to_id,
194            transitive_deps_cache: DashMap::new(),
195        })
196    }
197
198    fn compute_dependency_levels_compact(
199        graph: &DiGraph<u32, ()>,
200        node_to_id: &FxHashMap<NodeIndex, u32>,
201        order: &[u32],
202    ) -> Result<Vec<Vec<u32>>> {
203        let mut levels = Vec::new();
204        let mut level_map = FxHashMap::with_capacity_and_hasher(order.len(), Default::default());
205
206        for &package_id in order {
207            // Find the node index for this ID
208            let node = node_to_id
209                .iter()
210                .find(|(_, &id)| id == package_id)
211                .map(|(node_idx, _)| *node_idx)
212                .ok_or_else(|| Error::PackageNotFound {
213                    name: format!("id-{}", package_id),
214                    available: format!("Package ID {} not found", package_id),
215                })?;
216
217            let dep_ids: Vec<u32> = graph
218                .neighbors_directed(node, Direction::Outgoing)
219                .map(|idx| graph[idx])
220                .collect();
221
222            let level = if dep_ids.is_empty() {
223                0
224            } else {
225                dep_ids.iter()
226                    .filter_map(|dep_id| level_map.get(dep_id))
227                    .max()
228                    .map(|l| l + 1)
229                    .unwrap_or(0)
230            };
231
232            level_map.insert(package_id, level);
233            while levels.len() <= level {
234                levels.push(Vec::new());
235            }
236            levels[level].push(package_id);
237        }
238
239        Ok(levels)
240    }
241
242    /// Converts a u32 ID to a package name string.
243    #[inline]
244    fn id_to_name(&self, id: u32) -> &str {
245        &self.id_to_name[id as usize]
246    }
247
248    /// Converts a package name to a u32 ID.
249    #[inline]
250    fn name_to_id(&self, name: &str) -> Option<u32> {
251        let name_arc = intern(name);
252        self.name_to_id.get(&name_arc).copied()
253    }
254
255    /// Retrieves a package by name.
256    #[inline]
257    pub fn get_package(&self, name: &str) -> Option<&Package> {
258        self.name_to_id(name)
259            .and_then(|id| self.packages.get(id as usize))
260    }
261
262    /// Returns packages in topological order (dependencies before dependents).
263    ///
264    /// This is cached during graph construction for fast access.
265    #[inline]
266    pub fn topological_order(&self) -> Vec<String> {
267        self.cached_topological_order
268            .iter()
269            .map(|&id| self.id_to_name(id).to_string())
270            .collect()
271    }
272
273    /// Returns dependency levels for parallel execution.
274    ///
275    /// Each level contains packages that can be executed in parallel.
276    #[inline]
277    pub fn dependency_levels(&self) -> Vec<Vec<String>> {
278        self.dependency_levels
279            .iter()
280            .map(|level| level.iter().map(|&id| self.id_to_name(id).to_string()).collect())
281            .collect()
282    }
283
284    /// Returns direct dependencies of a package.
285    ///
286    /// # Errors
287    ///
288    /// Returns an error if the package is not found in the graph.
289    pub fn dependencies(&self, package_name: &str) -> Result<Vec<String>> {
290        let node = self
291            .name_to_node
292            .get(package_name)
293            .ok_or_else(|| Error::PackageNotFound {
294                name: package_name.to_string(),
295                available: format!("Package '{}' not found", package_name),
296            })?;
297
298        let deps: Vec<String> = self
299            .graph
300            .neighbors_directed(*node, Direction::Outgoing)
301            .map(|idx| {
302                let dep_id = self.graph[idx];
303                self.id_to_name(dep_id).to_string()
304            })
305            .collect();
306
307        Ok(deps)
308    }
309
310    /// Returns direct dependents of a package (packages that depend on it).
311    ///
312    /// # Errors
313    ///
314    /// Returns an error if the package is not found in the graph.
315    pub fn dependents(&self, package_name: &str) -> Result<Vec<String>> {
316        let node = self
317            .name_to_node
318            .get(package_name)
319            .ok_or_else(|| Error::PackageNotFound {
320                name: package_name.to_string(),
321                available: format!("Package '{}' not found", package_name),
322            })?;
323
324        let dependents: Vec<String> = self
325            .graph
326            .neighbors_directed(*node, Direction::Incoming)
327            .map(|idx| {
328                let dep_id = self.graph[idx];
329                self.id_to_name(dep_id).to_string()
330            })
331            .collect();
332
333        Ok(dependents)
334    }
335
336    /// Returns all transitive dependents of a package.
337    ///
338    /// This includes both direct and indirect dependents (packages that depend
339    /// on packages that depend on this package, etc.).
340    ///
341    /// Results are cached for performance.
342    ///
343    /// # Errors
344    ///
345    /// Returns an error if the package is not found in the graph.
346    pub fn all_dependents(&self, package_name: &str) -> Result<HashSet<String>> {
347        if let Some(cached) = self.transitive_deps_cache.get(package_name) {
348            return Ok((**cached.value()).clone());
349        }
350
351        let mut result = HashSet::new();
352        let mut stack = vec![package_name.to_string()];
353
354        while let Some(current) = stack.pop() {
355            if result.contains(&current) {
356                continue;
357            }
358            result.insert(current.clone());
359
360            let direct_dependents = self.dependents(&current)?;
361            for dep in direct_dependents {
362                if !result.contains(&dep) {
363                    stack.push(dep);
364                }
365            }
366        }
367
368        result.remove(package_name);
369        let result_arc = Arc::new(result.clone());
370        self.transitive_deps_cache.insert(package_name.to_string(), result_arc);
371        Ok(result)
372    }
373
374    /// Returns all packages affected by changes to the given packages.
375    ///
376    /// This includes the changed packages themselves and all their transitive
377    /// dependents.
378    ///
379    /// Uses parallel BFS for better performance on large graphs.
380    ///
381    /// # Errors
382    ///
383    /// Returns an error if any of the changed packages are not found in the graph.
384    pub fn affected_packages(&self, changed_packages: &[String]) -> Result<HashSet<String>> {
385        use dashmap::DashSet;
386        use rayon::prelude::*;
387
388        let affected = DashSet::new();
389        let queue: Vec<String> = changed_packages.to_vec();
390
391        queue.par_iter().for_each(|pkg| {
392            affected.insert(pkg.clone());
393            let mut local_queue = vec![pkg.clone()];
394            let mut local_visited = HashSet::new();
395
396            while let Some(current) = local_queue.pop() {
397                if local_visited.contains(&current) {
398                    continue;
399                }
400                local_visited.insert(current.clone());
401
402                if let Ok(dependents) = self.dependents(&current) {
403                    for dep in dependents {
404                        if affected.insert(dep.clone()) {
405                            local_queue.push(dep);
406                        }
407                    }
408                }
409            }
410        });
411
412        Ok(affected.into_iter().collect())
413    }
414
415    /// Returns all packages in the graph.
416    pub fn all_packages(&self) -> Vec<&Package> {
417        // Only return packages that are still in name_to_node (not removed)
418        self.name_to_node
419            .keys()
420            .filter_map(|name| self.get_package(name))
421            .collect()
422    }
423
424    /// Serializes the graph to a file for fast loading.
425    pub fn save_to_file(&self, path: impl AsRef<Path>) -> Result<()> {
426        // Convert u32 indices back to Arc<str> for serialization
427        let topological_order: Vec<Arc<str>> = self.cached_topological_order
428            .iter()
429            .map(|&id| Arc::clone(&self.id_to_name[id as usize]))
430            .collect();
431        let dependency_levels: Vec<Vec<Arc<str>>> = self.dependency_levels
432            .iter()
433            .map(|level| level.iter().map(|&id| Arc::clone(&self.id_to_name[id as usize])).collect())
434            .collect();
435
436        let serializable = SerializableGraph {
437            packages: self.packages.clone(),
438            topological_order,
439            dependency_levels,
440        };
441
442        let serialized = bincode::serialize(&serializable).map_err(|e| Error::Adapter {
443            package: "graph".to_string(),
444            message: format!("Failed to serialize graph: {}", e),
445        })?;
446
447        let compressed = zstd::encode_all(&serialized[..], 3).map_err(|e| Error::Adapter {
448            package: "graph".to_string(),
449            message: format!("Failed to compress graph: {}", e),
450        })?;
451
452        fs::write(path, compressed).map_err(Error::Io)?;
453        Ok(())
454    }
455
456    /// Loads a graph from a previously saved file.
457    pub fn load_from_file(path: impl AsRef<Path>) -> Result<Self> {
458        let compressed = fs::read(path).map_err(Error::Io)?;
459        let serialized = zstd::decode_all(&compressed[..]).map_err(|e| Error::Adapter {
460            package: "graph".to_string(),
461            message: format!("Failed to decompress graph: {}", e),
462        })?;
463
464        let serializable: SerializableGraph =
465            bincode::deserialize(&serialized).map_err(|e| Error::Adapter {
466                package: "graph".to_string(),
467                message: format!("Failed to deserialize graph: {}", e),
468            })?;
469
470        Self::new(serializable.packages)
471    }
472
473    /// Updates the graph incrementally based on detected changes.
474    ///
475    /// This is much faster than rebuilding the entire graph when only
476    /// a few packages have changed.
477    ///
478    /// # Errors
479    ///
480    /// Returns an error if any package operations fail or circular dependencies are detected.
481    pub fn update_incremental(&mut self, changes: GraphChange) -> Result<()> {
482        // 1. Remove deleted packages
483        for package_name in &changes.removed {
484            self.remove_package(package_name)?;
485        }
486
487        // 2. Update modified packages
488        for package in &changes.modified {
489            self.update_package(package)?;
490        }
491
492        // 3. Add new packages
493        for package in &changes.added {
494            self.add_package(package)?;
495        }
496
497        // 4. Recompute only affected cached values
498        self.recompute_affected_levels(&changes)?;
499
500        Ok(())
501    }
502
503    fn remove_package(&mut self, name: &str) -> Result<()> {
504        let package_id = self.name_to_id(name).ok_or_else(|| Error::PackageNotFound {
505            name: name.to_string(),
506            available: String::new(),
507        })?;
508
509        let node = *self.name_to_node.get(name).ok_or_else(|| Error::PackageNotFound {
510            name: name.to_string(),
511            available: String::new(),
512        })?;
513
514        // Remove node from graph (this removes all edges automatically)
515        self.graph.remove_node(node);
516
517        // Remove from mappings
518        self.name_to_node.remove(name);
519        self.node_to_id.remove(&node);
520        let name_arc = intern(name);
521        self.name_to_id.remove(&name_arc);
522
523        // Remove package from packages vector (mark as removed, don't actually remove to preserve indices)
524        // For now, we'll keep the package but mark it as removed by clearing the name
525        // In a production system, you might want a more sophisticated approach
526        if (package_id as usize) < self.packages.len() {
527            // We can't easily remove from the middle of a Vec without breaking indices
528            // So we'll leave it but it won't be accessible via name lookup
529        }
530
531        // Clear affected cache entries
532        self.transitive_deps_cache.remove(name);
533
534        Ok(())
535    }
536
537    fn update_package(&mut self, package: &Package) -> Result<()> {
538        let package_id = self.name_to_id(&package.name).ok_or_else(|| Error::PackageNotFound {
539            name: package.name.clone(),
540            available: String::new(),
541        })?;
542
543        // Update package data
544        self.packages[package_id as usize] = package.clone();
545
546        // Update edges - remove old edges, add new ones
547        let node = self.name_to_node.get(&package.name).ok_or_else(|| Error::PackageNotFound {
548            name: package.name.clone(),
549            available: String::new(),
550        })?;
551
552        // Remove all outgoing edges
553        let old_edges: Vec<_> = self.graph
554            .edges_directed(*node, Direction::Outgoing)
555            .map(|e| e.id())
556            .collect();
557        for edge_id in old_edges {
558            let _ = self.graph.remove_edge(edge_id);
559        }
560
561        // Add new edges based on updated dependencies
562        for dep_name in &package.deps {
563            if let Some(dep_node) = self.name_to_node.get(dep_name) {
564                self.graph.add_edge(*node, *dep_node, ());
565            }
566        }
567
568        // Invalidate cache
569        self.transitive_deps_cache.remove(&package.name);
570
571        Ok(())
572    }
573
574    fn add_package(&mut self, package: &Package) -> Result<()> {
575        let name_arc = intern(&package.name);
576        let package_id = self.packages.len() as u32;
577
578        // Add to mappings
579        self.id_to_name.push(Arc::clone(&name_arc));
580        self.name_to_id.insert(name_arc, package_id);
581        self.packages.push(package.clone());
582
583        // Add node to graph
584        let node = self.graph.add_node(package_id);
585        self.name_to_node.insert(package.name.clone(), node);
586        self.node_to_id.insert(node, package_id);
587
588        // Add edges for dependencies
589        for dep_name in &package.deps {
590            if let Some(dep_node) = self.name_to_node.get(dep_name) {
591                self.graph.add_edge(node, *dep_node, ());
592            }
593        }
594
595        Ok(())
596    }
597
598    fn recompute_affected_levels(&mut self, changes: &GraphChange) -> Result<()> {
599        // Collect all affected package IDs
600        let mut affected_ids = HashSet::new();
601        
602        for pkg in &changes.added {
603            if let Some(id) = self.name_to_id(&pkg.name) {
604                affected_ids.insert(id);
605            }
606        }
607        for pkg in &changes.modified {
608            if let Some(id) = self.name_to_id(&pkg.name) {
609                affected_ids.insert(id);
610            }
611        }
612        for name in &changes.removed {
613            if let Some(id) = self.name_to_id(name) {
614                affected_ids.insert(id);
615            }
616        }
617
618        // For now, recompute entire topological order and levels
619        // TODO: Optimize to only recompute affected subgraph
620        let sorted = toposort(&self.graph, None).map_err(|cycle| {
621            let cycle_id = self.graph[cycle.node_id()];
622            let cycle_name = self.id_to_name(cycle_id);
623            Error::CircularDependency(format!("Cycle detected involving: {}", cycle_name))
624        })?;
625
626        self.cached_topological_order = sorted
627            .into_iter()
628            .rev()
629            .map(|idx| self.graph[idx])
630            .collect();
631
632        self.dependency_levels =
633            Self::compute_dependency_levels_compact(&self.graph, &self.node_to_id, &self.cached_topological_order)?;
634
635        Ok(())
636    }
637}