Skip to main content

pitchfork_cli/
deps.rs

1use crate::Result;
2use crate::daemon_id::DaemonId;
3use crate::error::{DependencyError, find_similar_daemon};
4use crate::pitchfork_toml::PitchforkTomlDaemon;
5use indexmap::IndexMap;
6use std::collections::{HashMap, HashSet, VecDeque};
7
8/// Result of dependency resolution
9#[derive(Debug)]
10pub struct DependencyOrder {
11    /// Groups of daemons that can be started in parallel.
12    /// Each level depends only on daemons in previous levels.
13    pub levels: Vec<Vec<DaemonId>>,
14}
15
16/// Resolve dependency order using Kahn's algorithm (topological sort).
17///
18/// Returns daemons grouped into levels where:
19/// - Level 0: daemons with no dependencies (or deps already satisfied)
20/// - Level 1: daemons that only depend on level 0
21/// - Level N: daemons that only depend on levels 0..(N-1)
22///
23/// Daemons within the same level can be started in parallel.
24pub fn resolve_dependencies(
25    requested: &[DaemonId],
26    all_daemons: &IndexMap<DaemonId, PitchforkTomlDaemon>,
27) -> Result<DependencyOrder> {
28    // 1. Build the full set of daemons to start (requested + transitive deps)
29    let mut to_start: HashSet<DaemonId> = HashSet::new();
30    let mut queue: VecDeque<DaemonId> = requested.iter().cloned().collect();
31
32    while let Some(id) = queue.pop_front() {
33        if to_start.contains(&id) {
34            continue;
35        }
36
37        let daemon = all_daemons.get(&id).ok_or_else(|| {
38            let suggestion = find_similar_daemon(
39                &id.qualified(),
40                all_daemons
41                    .keys()
42                    .map(|k| k.qualified())
43                    .collect::<Vec<_>>()
44                    .iter()
45                    .map(|s| s.as_str()),
46            );
47            DependencyError::DaemonNotFound {
48                name: id.qualified(),
49                suggestion,
50            }
51        })?;
52
53        to_start.insert(id.clone());
54
55        // Add dependencies to queue
56        for dep in &daemon.depends {
57            if !all_daemons.contains_key(dep) {
58                return Err(DependencyError::MissingDependency {
59                    daemon: id.qualified(),
60                    dependency: dep.qualified(),
61                }
62                .into());
63            }
64            if !to_start.contains(dep) {
65                queue.push_back(dep.clone());
66            }
67        }
68    }
69
70    // 2. Build adjacency list and in-degree map
71    let mut in_degree: HashMap<DaemonId, usize> = HashMap::new();
72    let mut dependents: HashMap<DaemonId, Vec<DaemonId>> = HashMap::new();
73
74    for id in &to_start {
75        in_degree.entry(id.clone()).or_insert(0);
76        dependents.entry(id.clone()).or_default();
77    }
78
79    for id in &to_start {
80        let daemon = all_daemons.get(id).ok_or_else(|| {
81            miette::miette!("Internal error: daemon '{}' missing from configuration", id)
82        })?;
83        for dep in &daemon.depends {
84            if to_start.contains(dep) {
85                *in_degree.get_mut(id).ok_or_else(|| {
86                    miette::miette!("Internal error: in_degree missing for daemon '{}'", id)
87                })? += 1;
88                dependents
89                    .get_mut(dep)
90                    .ok_or_else(|| {
91                        miette::miette!("Internal error: dependents missing for daemon '{}'", dep)
92                    })?
93                    .push(id.clone());
94            }
95        }
96    }
97
98    // 3. Kahn's algorithm with level tracking
99    let mut processed: HashSet<DaemonId> = HashSet::new();
100    let mut levels: Vec<Vec<DaemonId>> = Vec::new();
101    let mut current_level: Vec<DaemonId> = in_degree
102        .iter()
103        .filter(|(_, deg)| **deg == 0)
104        .map(|(id, _)| id.clone())
105        .collect();
106
107    // Sort for deterministic order
108    current_level.sort();
109
110    while !current_level.is_empty() {
111        let mut next_level = Vec::new();
112
113        for id in &current_level {
114            processed.insert(id.clone());
115
116            let deps = dependents.get(id).ok_or_else(|| {
117                miette::miette!("Internal error: dependents missing for daemon '{}'", id)
118            })?;
119            for dependent in deps {
120                let deg = in_degree.get_mut(dependent).ok_or_else(|| {
121                    miette::miette!(
122                        "Internal error: in_degree missing for daemon '{}'",
123                        dependent
124                    )
125                })?;
126                *deg -= 1;
127                if *deg == 0 {
128                    next_level.push(dependent.clone());
129                }
130            }
131        }
132
133        levels.push(current_level);
134        next_level.sort(); // Sort for deterministic order
135        current_level = next_level;
136    }
137
138    // 4. Check for cycles
139    if processed.len() != to_start.len() {
140        let mut involved: Vec<_> = to_start
141            .difference(&processed)
142            .map(|id| id.qualified())
143            .collect();
144        involved.sort(); // Deterministic output
145        return Err(DependencyError::CircularDependency { involved }.into());
146    }
147
148    Ok(DependencyOrder { levels })
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use crate::daemon_id::DaemonId;
155    use crate::pitchfork_toml::PitchforkTomlDaemon;
156    use indexmap::IndexMap;
157
158    fn make_daemon(depends: Vec<&str>) -> PitchforkTomlDaemon {
159        PitchforkTomlDaemon {
160            run: "echo test".to_string(),
161            port_bump_attempts: 10,
162            depends: depends
163                .into_iter()
164                .map(|s| DaemonId::new("global", s))
165                .collect(),
166            ..PitchforkTomlDaemon::default()
167        }
168    }
169
170    fn id(name: &str) -> DaemonId {
171        DaemonId::new("global", name)
172    }
173
174    #[test]
175    fn test_no_dependencies() {
176        let mut daemons = IndexMap::new();
177        daemons.insert(id("api"), make_daemon(vec![]));
178
179        let result = resolve_dependencies(&[id("api")], &daemons).unwrap();
180
181        assert_eq!(result.levels.len(), 1);
182        assert_eq!(result.levels[0], vec![id("api")]);
183    }
184
185    #[test]
186    fn test_simple_dependency() {
187        let mut daemons = IndexMap::new();
188        daemons.insert(id("postgres"), make_daemon(vec![]));
189        daemons.insert(id("api"), make_daemon(vec!["postgres"]));
190
191        let result = resolve_dependencies(&[id("api")], &daemons).unwrap();
192
193        assert_eq!(result.levels.len(), 2);
194        assert_eq!(result.levels[0], vec![id("postgres")]);
195        assert_eq!(result.levels[1], vec![id("api")]);
196    }
197
198    #[test]
199    fn test_multiple_dependencies() {
200        let mut daemons = IndexMap::new();
201        daemons.insert(id("postgres"), make_daemon(vec![]));
202        daemons.insert(id("redis"), make_daemon(vec![]));
203        daemons.insert(id("api"), make_daemon(vec!["postgres", "redis"]));
204
205        let result = resolve_dependencies(&[id("api")], &daemons).unwrap();
206
207        assert_eq!(result.levels.len(), 2);
208        // postgres and redis can start in parallel
209        assert!(result.levels[0].contains(&id("postgres")));
210        assert!(result.levels[0].contains(&id("redis")));
211        assert_eq!(result.levels[1], vec![id("api")]);
212    }
213
214    #[test]
215    fn test_transitive_dependencies() {
216        let mut daemons = IndexMap::new();
217        daemons.insert(id("database"), make_daemon(vec![]));
218        daemons.insert(id("backend"), make_daemon(vec!["database"]));
219        daemons.insert(id("api"), make_daemon(vec!["backend"]));
220
221        let result = resolve_dependencies(&[id("api")], &daemons).unwrap();
222
223        assert_eq!(result.levels.len(), 3);
224        assert_eq!(result.levels[0], vec![id("database")]);
225        assert_eq!(result.levels[1], vec![id("backend")]);
226        assert_eq!(result.levels[2], vec![id("api")]);
227    }
228
229    #[test]
230    fn test_diamond_dependency() {
231        let mut daemons = IndexMap::new();
232        daemons.insert(id("db"), make_daemon(vec![]));
233        daemons.insert(id("auth"), make_daemon(vec!["db"]));
234        daemons.insert(id("data"), make_daemon(vec!["db"]));
235        daemons.insert(id("api"), make_daemon(vec!["auth", "data"]));
236
237        let result = resolve_dependencies(&[id("api")], &daemons).unwrap();
238
239        assert_eq!(result.levels.len(), 3);
240        assert_eq!(result.levels[0], vec![id("db")]);
241        // auth and data can start in parallel
242        assert!(result.levels[1].contains(&id("auth")));
243        assert!(result.levels[1].contains(&id("data")));
244        assert_eq!(result.levels[2], vec![id("api")]);
245    }
246
247    #[test]
248    fn test_circular_dependency_detected() {
249        let mut daemons = IndexMap::new();
250        daemons.insert(id("a"), make_daemon(vec!["c"]));
251        daemons.insert(id("b"), make_daemon(vec!["a"]));
252        daemons.insert(id("c"), make_daemon(vec!["b"]));
253
254        let result = resolve_dependencies(&[id("a")], &daemons);
255
256        assert!(result.is_err());
257        let err = result.unwrap_err().to_string();
258        assert!(err.contains("circular dependency"));
259    }
260
261    #[test]
262    fn test_missing_dependency_error() {
263        let mut daemons = IndexMap::new();
264        daemons.insert(
265            id("api"),
266            PitchforkTomlDaemon {
267                run: "echo test".to_string(),
268                port_bump_attempts: 10,
269                depends: vec![DaemonId::new("global", "nonexistent")],
270                ..PitchforkTomlDaemon::default()
271            },
272        );
273
274        let result = resolve_dependencies(&[id("api")], &daemons);
275
276        assert!(result.is_err());
277        let err = result.unwrap_err().to_string();
278        assert!(err.contains("nonexistent"));
279        assert!(err.contains("not defined"));
280    }
281
282    #[test]
283    fn test_missing_requested_daemon_error() {
284        let daemons = IndexMap::new();
285
286        let result = resolve_dependencies(&[id("nonexistent")], &daemons);
287
288        assert!(result.is_err());
289        let err = result.unwrap_err().to_string();
290        assert!(err.contains("nonexistent"));
291        assert!(err.contains("not found"));
292    }
293
294    #[test]
295    fn test_multiple_requested_daemons() {
296        let mut daemons = IndexMap::new();
297        daemons.insert(id("db"), make_daemon(vec![]));
298        daemons.insert(id("api"), make_daemon(vec!["db"]));
299        daemons.insert(id("worker"), make_daemon(vec!["db"]));
300
301        let result = resolve_dependencies(&[id("api"), id("worker")], &daemons).unwrap();
302
303        assert_eq!(result.levels.len(), 2);
304        assert_eq!(result.levels[0], vec![id("db")]);
305        // api and worker can start in parallel
306        assert!(result.levels[1].contains(&id("api")));
307        assert!(result.levels[1].contains(&id("worker")));
308    }
309
310    #[test]
311    fn test_start_all_with_dependencies() {
312        let mut daemons = IndexMap::new();
313        daemons.insert(id("db"), make_daemon(vec![]));
314        daemons.insert(id("cache"), make_daemon(vec![]));
315        daemons.insert(id("api"), make_daemon(vec!["db", "cache"]));
316        daemons.insert(id("worker"), make_daemon(vec!["db"]));
317
318        let all_ids: Vec<DaemonId> = daemons.keys().cloned().collect();
319        let result = resolve_dependencies(&all_ids, &daemons).unwrap();
320
321        assert_eq!(result.levels.len(), 2);
322        // db and cache have no deps
323        assert!(result.levels[0].contains(&id("db")));
324        assert!(result.levels[0].contains(&id("cache")));
325        // api and worker depend on level 0
326        assert!(result.levels[1].contains(&id("api")));
327        assert!(result.levels[1].contains(&id("worker")));
328    }
329}