1use crate::Result;
2use crate::pitchfork_toml::PitchforkTomlDaemon;
3use indexmap::IndexMap;
4use miette::bail;
5use std::collections::{HashMap, HashSet, VecDeque};
6
7#[derive(Debug)]
9pub struct DependencyOrder {
10 pub levels: Vec<Vec<String>>,
13}
14
15pub fn resolve_dependencies(
24 requested: &[String],
25 all_daemons: &IndexMap<String, PitchforkTomlDaemon>,
26) -> Result<DependencyOrder> {
27 let mut to_start: HashSet<String> = HashSet::new();
29 let mut queue: VecDeque<String> = requested.iter().cloned().collect();
30
31 while let Some(id) = queue.pop_front() {
32 if to_start.contains(&id) {
33 continue;
34 }
35
36 let daemon = all_daemons
37 .get(&id)
38 .ok_or_else(|| miette::miette!("Daemon '{}' not found in configuration", id))?;
39
40 to_start.insert(id.clone());
41
42 for dep in &daemon.depends {
44 if !all_daemons.contains_key(dep) {
45 bail!(
46 "Daemon '{}' depends on '{}' which is not defined in configuration",
47 id,
48 dep
49 );
50 }
51 if !to_start.contains(dep) {
52 queue.push_back(dep.clone());
53 }
54 }
55 }
56
57 let mut in_degree: HashMap<String, usize> = HashMap::new();
59 let mut dependents: HashMap<String, Vec<String>> = HashMap::new();
60
61 for id in &to_start {
62 in_degree.entry(id.clone()).or_insert(0);
63 dependents.entry(id.clone()).or_default();
64 }
65
66 for id in &to_start {
67 let daemon = all_daemons.get(id).unwrap();
68 for dep in &daemon.depends {
69 if to_start.contains(dep) {
70 *in_degree.get_mut(id).unwrap() += 1;
71 dependents.get_mut(dep).unwrap().push(id.clone());
72 }
73 }
74 }
75
76 let mut processed: HashSet<String> = HashSet::new();
78 let mut levels: Vec<Vec<String>> = Vec::new();
79 let mut current_level: Vec<String> = in_degree
80 .iter()
81 .filter(|(_, deg)| **deg == 0)
82 .map(|(id, _)| id.clone())
83 .collect();
84
85 current_level.sort();
87
88 while !current_level.is_empty() {
89 let mut next_level = Vec::new();
90
91 for id in ¤t_level {
92 processed.insert(id.clone());
93
94 for dependent in &dependents[id] {
95 let deg = in_degree.get_mut(dependent).unwrap();
96 *deg -= 1;
97 if *deg == 0 {
98 next_level.push(dependent.clone());
99 }
100 }
101 }
102
103 levels.push(current_level);
104 next_level.sort(); current_level = next_level;
106 }
107
108 if processed.len() != to_start.len() {
110 let remaining: Vec<_> = to_start.difference(&processed).cloned().collect::<Vec<_>>();
111 bail!("Circular dependency detected involving: {:?}", remaining);
112 }
113
114 Ok(DependencyOrder { levels })
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120 use crate::pitchfork_toml::PitchforkTomlDaemon;
121 use indexmap::IndexMap;
122
123 fn make_daemon(depends: Vec<&str>) -> PitchforkTomlDaemon {
124 PitchforkTomlDaemon {
125 run: "echo test".to_string(),
126 auto: vec![],
127 cron: None,
128 retry: 0,
129 ready_delay: None,
130 ready_output: None,
131 ready_http: None,
132 ready_port: None,
133 boot_start: None,
134 depends: depends.into_iter().map(String::from).collect(),
135 path: None,
136 }
137 }
138
139 #[test]
140 fn test_no_dependencies() {
141 let mut daemons = IndexMap::new();
142 daemons.insert("api".to_string(), make_daemon(vec![]));
143
144 let result = resolve_dependencies(&["api".to_string()], &daemons).unwrap();
145
146 assert_eq!(result.levels.len(), 1);
147 assert_eq!(result.levels[0], vec!["api"]);
148 }
149
150 #[test]
151 fn test_simple_dependency() {
152 let mut daemons = IndexMap::new();
153 daemons.insert("postgres".to_string(), make_daemon(vec![]));
154 daemons.insert("api".to_string(), make_daemon(vec!["postgres"]));
155
156 let result = resolve_dependencies(&["api".to_string()], &daemons).unwrap();
157
158 assert_eq!(result.levels.len(), 2);
159 assert_eq!(result.levels[0], vec!["postgres"]);
160 assert_eq!(result.levels[1], vec!["api"]);
161 }
162
163 #[test]
164 fn test_multiple_dependencies() {
165 let mut daemons = IndexMap::new();
166 daemons.insert("postgres".to_string(), make_daemon(vec![]));
167 daemons.insert("redis".to_string(), make_daemon(vec![]));
168 daemons.insert("api".to_string(), make_daemon(vec!["postgres", "redis"]));
169
170 let result = resolve_dependencies(&["api".to_string()], &daemons).unwrap();
171
172 assert_eq!(result.levels.len(), 2);
173 assert!(result.levels[0].contains(&"postgres".to_string()));
175 assert!(result.levels[0].contains(&"redis".to_string()));
176 assert_eq!(result.levels[1], vec!["api"]);
177 }
178
179 #[test]
180 fn test_transitive_dependencies() {
181 let mut daemons = IndexMap::new();
182 daemons.insert("database".to_string(), make_daemon(vec![]));
183 daemons.insert("backend".to_string(), make_daemon(vec!["database"]));
184 daemons.insert("api".to_string(), make_daemon(vec!["backend"]));
185
186 let result = resolve_dependencies(&["api".to_string()], &daemons).unwrap();
187
188 assert_eq!(result.levels.len(), 3);
189 assert_eq!(result.levels[0], vec!["database"]);
190 assert_eq!(result.levels[1], vec!["backend"]);
191 assert_eq!(result.levels[2], vec!["api"]);
192 }
193
194 #[test]
195 fn test_diamond_dependency() {
196 let mut daemons = IndexMap::new();
197 daemons.insert("db".to_string(), make_daemon(vec![]));
198 daemons.insert("auth".to_string(), make_daemon(vec!["db"]));
199 daemons.insert("data".to_string(), make_daemon(vec!["db"]));
200 daemons.insert("api".to_string(), make_daemon(vec!["auth", "data"]));
201
202 let result = resolve_dependencies(&["api".to_string()], &daemons).unwrap();
203
204 assert_eq!(result.levels.len(), 3);
205 assert_eq!(result.levels[0], vec!["db"]);
206 assert!(result.levels[1].contains(&"auth".to_string()));
208 assert!(result.levels[1].contains(&"data".to_string()));
209 assert_eq!(result.levels[2], vec!["api"]);
210 }
211
212 #[test]
213 fn test_circular_dependency_detected() {
214 let mut daemons = IndexMap::new();
215 daemons.insert("a".to_string(), make_daemon(vec!["c"]));
216 daemons.insert("b".to_string(), make_daemon(vec!["a"]));
217 daemons.insert("c".to_string(), make_daemon(vec!["b"]));
218
219 let result = resolve_dependencies(&["a".to_string()], &daemons);
220
221 assert!(result.is_err());
222 let err = result.unwrap_err().to_string();
223 assert!(err.contains("Circular dependency"));
224 }
225
226 #[test]
227 fn test_missing_dependency_error() {
228 let mut daemons = IndexMap::new();
229 daemons.insert("api".to_string(), make_daemon(vec!["nonexistent"]));
230
231 let result = resolve_dependencies(&["api".to_string()], &daemons);
232
233 assert!(result.is_err());
234 let err = result.unwrap_err().to_string();
235 assert!(err.contains("nonexistent"));
236 assert!(err.contains("not defined"));
237 }
238
239 #[test]
240 fn test_missing_requested_daemon_error() {
241 let daemons = IndexMap::new();
242
243 let result = resolve_dependencies(&["nonexistent".to_string()], &daemons);
244
245 assert!(result.is_err());
246 let err = result.unwrap_err().to_string();
247 assert!(err.contains("nonexistent"));
248 assert!(err.contains("not found"));
249 }
250
251 #[test]
252 fn test_multiple_requested_daemons() {
253 let mut daemons = IndexMap::new();
254 daemons.insert("db".to_string(), make_daemon(vec![]));
255 daemons.insert("api".to_string(), make_daemon(vec!["db"]));
256 daemons.insert("worker".to_string(), make_daemon(vec!["db"]));
257
258 let result =
259 resolve_dependencies(&["api".to_string(), "worker".to_string()], &daemons).unwrap();
260
261 assert_eq!(result.levels.len(), 2);
262 assert_eq!(result.levels[0], vec!["db"]);
263 assert!(result.levels[1].contains(&"api".to_string()));
265 assert!(result.levels[1].contains(&"worker".to_string()));
266 }
267
268 #[test]
269 fn test_start_all_with_dependencies() {
270 let mut daemons = IndexMap::new();
271 daemons.insert("db".to_string(), make_daemon(vec![]));
272 daemons.insert("cache".to_string(), make_daemon(vec![]));
273 daemons.insert("api".to_string(), make_daemon(vec!["db", "cache"]));
274 daemons.insert("worker".to_string(), make_daemon(vec!["db"]));
275
276 let all_ids: Vec<String> = daemons.keys().cloned().collect();
277 let result = resolve_dependencies(&all_ids, &daemons).unwrap();
278
279 assert_eq!(result.levels.len(), 2);
280 assert!(result.levels[0].contains(&"db".to_string()));
282 assert!(result.levels[0].contains(&"cache".to_string()));
283 assert!(result.levels[1].contains(&"api".to_string()));
285 assert!(result.levels[1].contains(&"worker".to_string()));
286 }
287}