Skip to main content

polykit_core/
graph.rs

1//! Dependency graph management using petgraph.
2
3use std::collections::HashSet;
4use std::fs;
5use std::path::Path;
6use std::sync::Arc;
7
8use petgraph::algo::toposort;
9use petgraph::graph::{DiGraph, NodeIndex};
10use petgraph::Direction;
11use rustc_hash::FxHashMap;
12use serde::{Deserialize, Serialize};
13
14use crate::error::{Error, Result};
15use crate::package::Package;
16
17#[derive(Debug, Clone)]
18pub struct GraphNode {
19    pub package: Package,
20    pub index: NodeIndex,
21}
22
23/// Serializable graph data for persistence.
24#[derive(Debug, Clone, Serialize, Deserialize)]
25struct SerializableGraph {
26    packages: Vec<Package>,
27    #[serde(serialize_with = "serialize_arc_str_vec")]
28    #[serde(deserialize_with = "deserialize_arc_str_vec")]
29    topological_order: Vec<Arc<str>>,
30    #[serde(serialize_with = "serialize_arc_str_vec_vec")]
31    #[serde(deserialize_with = "deserialize_arc_str_vec_vec")]
32    dependency_levels: Vec<Vec<Arc<str>>>,
33}
34
35fn serialize_arc_str_vec<S>(vec: &[Arc<str>], serializer: S) -> std::result::Result<S::Ok, S::Error>
36where
37    S: serde::Serializer,
38{
39    use serde::Serialize;
40    let strings: Vec<&str> = vec.iter().map(|s| s.as_ref()).collect();
41    strings.serialize(serializer)
42}
43
44fn deserialize_arc_str_vec<'de, D>(deserializer: D) -> std::result::Result<Vec<Arc<str>>, D::Error>
45where
46    D: serde::Deserializer<'de>,
47{
48    use serde::Deserialize;
49    let strings: Vec<String> = Vec::deserialize(deserializer)?;
50    Ok(strings.into_iter().map(Arc::from).collect())
51}
52
53fn serialize_arc_str_vec_vec<S>(
54    vec: &[Vec<Arc<str>>],
55    serializer: S,
56) -> std::result::Result<S::Ok, S::Error>
57where
58    S: serde::Serializer,
59{
60    use serde::Serialize;
61    let strings: Vec<Vec<&str>> = vec
62        .iter()
63        .map(|level| level.iter().map(|s| s.as_ref()).collect())
64        .collect();
65    strings.serialize(serializer)
66}
67
68fn deserialize_arc_str_vec_vec<'de, D>(
69    deserializer: D,
70) -> std::result::Result<Vec<Vec<Arc<str>>>, D::Error>
71where
72    D: serde::Deserializer<'de>,
73{
74    use serde::Deserialize;
75    let strings: Vec<Vec<String>> = Vec::deserialize(deserializer)?;
76    Ok(strings
77        .into_iter()
78        .map(|level| level.into_iter().map(Arc::from).collect())
79        .collect())
80}
81
82/// Directed acyclic graph of package dependencies.
83#[derive(Debug)]
84pub struct DependencyGraph {
85    graph: DiGraph<Arc<str>, ()>,
86    #[allow(dead_code)]
87    node_map: FxHashMap<Arc<str>, NodeIndex>,
88    name_to_node: FxHashMap<String, NodeIndex>,
89    packages: FxHashMap<NodeIndex, Package>,
90    cached_topological_order: Vec<Arc<str>>,
91    dependency_levels: Vec<Vec<Arc<str>>>,
92}
93
94impl DependencyGraph {
95    /// Creates a new dependency graph from a list of packages.
96    ///
97    /// # Errors
98    ///
99    /// Returns an error if circular dependencies are detected.
100    pub fn new(packages: Vec<Package>) -> Result<Self> {
101        let package_count = packages.len();
102        let mut graph = DiGraph::with_capacity(package_count, package_count * 2);
103        let mut node_map = FxHashMap::with_capacity_and_hasher(package_count, Default::default());
104        let mut name_to_node =
105            FxHashMap::with_capacity_and_hasher(package_count, Default::default());
106        let mut packages_map =
107            FxHashMap::with_capacity_and_hasher(package_count, Default::default());
108
109        let mut name_cache: FxHashMap<String, Arc<str>> =
110            FxHashMap::with_capacity_and_hasher(package_count, Default::default());
111        for package in &packages {
112            let name_arc = Arc::from(package.name.as_str());
113            name_cache.insert(package.name.clone(), Arc::clone(&name_arc));
114        }
115
116        for package in &packages {
117            let name_arc = name_cache.get(&package.name).unwrap();
118            let node = graph.add_node(Arc::clone(name_arc));
119            node_map.insert(Arc::clone(name_arc), node);
120            name_to_node.insert(package.name.clone(), node);
121            packages_map.insert(node, package.clone());
122        }
123
124        for package in &packages {
125            let name_arc = name_cache.get(&package.name).unwrap();
126            let from_node = node_map
127                .get(name_arc)
128                .ok_or_else(|| Error::PackageNotFound {
129                    name: package.name.clone(),
130                    available: format!(
131                        "Package '{}' not found during graph construction",
132                        package.name
133                    ),
134                })?;
135
136            for dep_name in &package.deps {
137                let dep_arc = name_cache
138                    .get(dep_name)
139                    .ok_or_else(|| Error::PackageNotFound {
140                        name: dep_name.clone(),
141                        available: format!("Dependency '{}' not found", dep_name),
142                    })?;
143                let to_node = node_map.get(dep_arc).ok_or_else(|| {
144                    let name = dep_name.clone();
145                    Error::PackageNotFound {
146                        name: name.clone(),
147                        available: format!("Dependency '{}' not found", name),
148                    }
149                })?;
150
151                graph.add_edge(*from_node, *to_node, ());
152            }
153        }
154
155        let sorted = toposort(&graph, None).map_err(|cycle| {
156            let cycle_node = graph[cycle.node_id()].as_ref();
157            Error::CircularDependency(format!("Cycle detected involving: {}", cycle_node))
158        })?;
159
160        let topological_order: Vec<Arc<str>> = sorted
161            .into_iter()
162            .rev()
163            .map(|idx| Arc::clone(&graph[idx]))
164            .collect();
165
166        let dependency_levels =
167            Self::compute_dependency_levels(&graph, &node_map, &topological_order)?;
168
169        Ok(Self {
170            graph,
171            node_map,
172            name_to_node,
173            packages: packages_map,
174            cached_topological_order: topological_order,
175            dependency_levels,
176        })
177    }
178
179    fn compute_dependency_levels(
180        graph: &DiGraph<Arc<str>, ()>,
181        node_map: &FxHashMap<Arc<str>, NodeIndex>,
182        order: &[Arc<str>],
183    ) -> Result<Vec<Vec<Arc<str>>>> {
184        let mut levels = Vec::new();
185        let mut level_map = FxHashMap::with_capacity_and_hasher(order.len(), Default::default());
186
187        for package_name in order {
188            let node = node_map
189                .get(package_name)
190                .ok_or_else(|| Error::PackageNotFound {
191                    name: package_name.to_string(),
192                    available: format!("Package '{}' not found in node_map", package_name),
193                })?;
194
195            let deps: Vec<Arc<str>> = graph
196                .neighbors_directed(*node, Direction::Outgoing)
197                .map(|idx| Arc::clone(&graph[idx]))
198                .collect();
199
200            let level = if deps.is_empty() {
201                0
202            } else {
203                deps.iter()
204                    .filter_map(|dep| level_map.get(dep))
205                    .max()
206                    .map(|l| l + 1)
207                    .unwrap_or(0)
208            };
209
210            level_map.insert(Arc::clone(package_name), level);
211            while levels.len() <= level {
212                levels.push(Vec::new());
213            }
214            levels[level].push(Arc::clone(package_name));
215        }
216
217        Ok(levels)
218    }
219
220    /// Retrieves a package by name.
221    #[inline]
222    pub fn get_package(&self, name: &str) -> Option<&Package> {
223        self.name_to_node
224            .get(name)
225            .and_then(|idx| self.packages.get(idx))
226    }
227
228    /// Returns packages in topological order (dependencies before dependents).
229    ///
230    /// This is cached during graph construction for fast access.
231    #[inline]
232    pub fn topological_order(&self) -> Vec<String> {
233        self.cached_topological_order
234            .iter()
235            .map(|s| s.to_string())
236            .collect()
237    }
238
239    /// Returns dependency levels for parallel execution.
240    ///
241    /// Each level contains packages that can be executed in parallel.
242    #[inline]
243    pub fn dependency_levels(&self) -> Vec<Vec<String>> {
244        self.dependency_levels
245            .iter()
246            .map(|level| level.iter().map(|s| s.to_string()).collect())
247            .collect()
248    }
249
250    /// Returns direct dependencies of a package.
251    ///
252    /// # Errors
253    ///
254    /// Returns an error if the package is not found in the graph.
255    pub fn dependencies(&self, package_name: &str) -> Result<Vec<String>> {
256        let node = self
257            .name_to_node
258            .get(package_name)
259            .ok_or_else(|| Error::PackageNotFound {
260                name: package_name.to_string(),
261                available: format!("Package '{}' not found", package_name),
262            })?;
263
264        let deps: Vec<String> = self
265            .graph
266            .neighbors_directed(*node, Direction::Outgoing)
267            .map(|idx| self.graph[idx].to_string())
268            .collect();
269
270        Ok(deps)
271    }
272
273    /// Returns direct dependents of a package (packages that depend on it).
274    ///
275    /// # Errors
276    ///
277    /// Returns an error if the package is not found in the graph.
278    pub fn dependents(&self, package_name: &str) -> Result<Vec<String>> {
279        let node = self
280            .name_to_node
281            .get(package_name)
282            .ok_or_else(|| Error::PackageNotFound {
283                name: package_name.to_string(),
284                available: format!("Package '{}' not found", package_name),
285            })?;
286
287        let dependents: Vec<String> = self
288            .graph
289            .neighbors_directed(*node, Direction::Incoming)
290            .map(|idx| self.graph[idx].to_string())
291            .collect();
292
293        Ok(dependents)
294    }
295
296    /// Returns all transitive dependents of a package.
297    ///
298    /// This includes both direct and indirect dependents (packages that depend
299    /// on packages that depend on this package, etc.).
300    ///
301    /// # Errors
302    ///
303    /// Returns an error if the package is not found in the graph.
304    pub fn all_dependents(&self, package_name: &str) -> Result<HashSet<String>> {
305        let mut result = HashSet::new();
306        let mut stack = vec![package_name.to_string()];
307
308        while let Some(current) = stack.pop() {
309            if result.contains(&current) {
310                continue;
311            }
312            result.insert(current.clone());
313
314            let direct_dependents = self.dependents(&current)?;
315            for dep in direct_dependents {
316                if !result.contains(&dep) {
317                    stack.push(dep);
318                }
319            }
320        }
321
322        result.remove(package_name);
323        Ok(result)
324    }
325
326    /// Returns all packages affected by changes to the given packages.
327    ///
328    /// This includes the changed packages themselves and all their transitive
329    /// dependents.
330    ///
331    /// # Errors
332    ///
333    /// Returns an error if any of the changed packages are not found in the graph.
334    pub fn affected_packages(&self, changed_packages: &[String]) -> Result<HashSet<String>> {
335        let mut affected = HashSet::new();
336
337        for package_name in changed_packages {
338            affected.insert(package_name.clone());
339            let dependents = self.all_dependents(package_name)?;
340            affected.extend(dependents);
341        }
342
343        Ok(affected)
344    }
345
346    /// Returns all packages in the graph.
347    pub fn all_packages(&self) -> Vec<&Package> {
348        self.packages.values().collect()
349    }
350
351    /// Serializes the graph to a file for fast loading.
352    pub fn save_to_file(&self, path: impl AsRef<Path>) -> Result<()> {
353        let packages: Vec<Package> = self.packages.values().cloned().collect();
354
355        let serializable = SerializableGraph {
356            packages,
357            topological_order: self.cached_topological_order.clone(),
358            dependency_levels: self.dependency_levels.clone(),
359        };
360
361        let serialized = bincode::serialize(&serializable).map_err(|e| Error::Adapter {
362            package: "graph".to_string(),
363            message: format!("Failed to serialize graph: {}", e),
364        })?;
365
366        let compressed = zstd::encode_all(&serialized[..], 3).map_err(|e| Error::Adapter {
367            package: "graph".to_string(),
368            message: format!("Failed to compress graph: {}", e),
369        })?;
370
371        fs::write(path, compressed).map_err(Error::Io)?;
372        Ok(())
373    }
374
375    /// Loads a graph from a previously saved file.
376    pub fn load_from_file(path: impl AsRef<Path>) -> Result<Self> {
377        let compressed = fs::read(path).map_err(Error::Io)?;
378        let serialized = zstd::decode_all(&compressed[..]).map_err(|e| Error::Adapter {
379            package: "graph".to_string(),
380            message: format!("Failed to decompress graph: {}", e),
381        })?;
382
383        let serializable: SerializableGraph =
384            bincode::deserialize(&serialized).map_err(|e| Error::Adapter {
385                package: "graph".to_string(),
386                message: format!("Failed to deserialize graph: {}", e),
387            })?;
388
389        Self::new(serializable.packages)
390    }
391}