1use super::DependencyGraph;
2use super::vertex::VertexId;
3use formualizer_common::ExcelError;
4use rustc_hash::{FxHashMap, FxHashSet};
5
6pub struct Scheduler<'a> {
7 graph: &'a DependencyGraph,
8}
9
10#[derive(Debug)]
11pub struct Layer {
12 pub vertices: Vec<VertexId>,
13}
14
15#[derive(Debug)]
16pub struct Schedule {
17 pub layers: Vec<Layer>,
18 pub cycles: Vec<Vec<VertexId>>,
19}
20
21impl<'a> Scheduler<'a> {
22 pub fn new(graph: &'a DependencyGraph) -> Self {
23 Self { graph }
24 }
25
26 pub fn create_schedule(&self, vertices: &[VertexId]) -> Result<Schedule, ExcelError> {
27 #[cfg(feature = "tracing")]
28 let _span = tracing::info_span!("scheduler", vertices = vertices.len()).entered();
29 #[cfg(feature = "tracing")]
31 let _scc_span = tracing::info_span!("tarjan_scc").entered();
32 let sccs = self.tarjan_scc(vertices)?;
33 #[cfg(feature = "tracing")]
34 drop(_scc_span);
35
36 let (cycles, acyclic_sccs) = self.separate_cycles(sccs);
38
39 let layers = if self.graph.dynamic_topo_enabled() {
41 let subset: Vec<VertexId> = acyclic_sccs.into_iter().flatten().collect();
42 if subset.is_empty() {
43 Vec::new()
44 } else {
45 self.graph
46 .pk_layers_for(&subset)
47 .unwrap_or(self.build_layers(vec![subset])?)
48 }
49 } else {
50 self.build_layers(acyclic_sccs)?
51 };
52
53 Ok(Schedule { layers, cycles })
54 }
55
56 pub fn create_schedule_with_virtual(
59 &self,
60 vertices: &[VertexId],
61 vdeps: &FxHashMap<VertexId, Vec<VertexId>>,
62 ) -> Result<Schedule, ExcelError> {
63 #[cfg(feature = "tracing")]
64 let _span = tracing::info_span!(
65 "scheduler_with_virtual",
66 vertices = vertices.len(),
67 vdeps = vdeps.len()
68 )
69 .entered();
70 #[cfg(feature = "tracing")]
72 let _scc_span = tracing::info_span!("tarjan_scc_with_virtual").entered();
73 let sccs = self.tarjan_scc_with_virtual(vertices, vdeps)?;
74 #[cfg(feature = "tracing")]
75 drop(_scc_span);
76 let (cycles, acyclic_sccs) = self.separate_cycles(sccs);
78 #[cfg(feature = "tracing")]
80 let _layers_span = tracing::info_span!("build_layers_with_virtual").entered();
81 let layers = self.build_layers_with_virtual(acyclic_sccs, vdeps)?;
82 #[cfg(feature = "tracing")]
83 drop(_layers_span);
84 Ok(Schedule { layers, cycles })
85 }
86
87 pub fn tarjan_scc(&self, vertices: &[VertexId]) -> Result<Vec<Vec<VertexId>>, ExcelError> {
89 let mut index_counter = 0;
90 let mut stack = Vec::new();
91 let mut indices = FxHashMap::default();
92 let mut lowlinks = FxHashMap::default();
93 let mut on_stack = FxHashSet::default();
94 let mut sccs = Vec::new();
95 let vertex_set: FxHashSet<VertexId> = vertices.iter().copied().collect();
96
97 for &vertex in vertices {
98 if !indices.contains_key(&vertex) {
99 self.tarjan_visit(
100 vertex,
101 &mut index_counter,
102 &mut stack,
103 &mut indices,
104 &mut lowlinks,
105 &mut on_stack,
106 &mut sccs,
107 &vertex_set,
108 )?;
109 }
110 }
111
112 Ok(sccs)
113 }
114
115 fn tarjan_scc_with_virtual(
117 &self,
118 vertices: &[VertexId],
119 vdeps: &FxHashMap<VertexId, Vec<VertexId>>,
120 ) -> Result<Vec<Vec<VertexId>>, ExcelError> {
121 let mut index_counter = 0;
122 let mut stack = Vec::new();
123 let mut indices = FxHashMap::default();
124 let mut lowlinks = FxHashMap::default();
125 let mut on_stack = FxHashSet::default();
126 let mut sccs = Vec::new();
127 let vertex_set: FxHashSet<VertexId> = vertices.iter().copied().collect();
128
129 for &vertex in vertices {
130 if !indices.contains_key(&vertex) {
131 self.tarjan_visit_with_virtual(
132 vertex,
133 &mut index_counter,
134 &mut stack,
135 &mut indices,
136 &mut lowlinks,
137 &mut on_stack,
138 &mut sccs,
139 &vertex_set,
140 vdeps,
141 )?;
142 }
143 }
144
145 Ok(sccs)
146 }
147
148 fn tarjan_visit(
149 &self,
150 vertex: VertexId,
151 index_counter: &mut usize,
152 stack: &mut Vec<VertexId>,
153 indices: &mut FxHashMap<VertexId, usize>,
154 lowlinks: &mut FxHashMap<VertexId, usize>,
155 on_stack: &mut FxHashSet<VertexId>,
156 sccs: &mut Vec<Vec<VertexId>>,
157 vertex_set: &FxHashSet<VertexId>,
158 ) -> Result<(), ExcelError> {
159 indices.insert(vertex, *index_counter);
161 lowlinks.insert(vertex, *index_counter);
162 *index_counter += 1;
163 stack.push(vertex);
164 on_stack.insert(vertex);
165
166 let dependencies = self.graph.get_dependencies(vertex);
168 for &dependency in &dependencies {
169 if !vertex_set.contains(&dependency) {
171 continue;
172 }
173
174 if !indices.contains_key(&dependency) {
175 self.tarjan_visit(
177 dependency,
178 index_counter,
179 stack,
180 indices,
181 lowlinks,
182 on_stack,
183 sccs,
184 vertex_set,
185 )?;
186 let dep_lowlink = lowlinks[&dependency];
187 lowlinks.insert(vertex, lowlinks[&vertex].min(dep_lowlink));
188 } else if on_stack.contains(&dependency) {
189 let dep_index = indices[&dependency];
191 lowlinks.insert(vertex, lowlinks[&vertex].min(dep_index));
192 }
193 }
194
195 if lowlinks[&vertex] == indices[&vertex] {
197 let mut scc = Vec::new();
198 loop {
199 let w = stack.pop().unwrap();
200 on_stack.remove(&w);
201 scc.push(w);
202 if w == vertex {
203 break;
204 }
205 }
206 sccs.push(scc);
207 }
208
209 Ok(())
210 }
211
212 fn tarjan_visit_with_virtual(
213 &self,
214 vertex: VertexId,
215 index_counter: &mut usize,
216 stack: &mut Vec<VertexId>,
217 indices: &mut FxHashMap<VertexId, usize>,
218 lowlinks: &mut FxHashMap<VertexId, usize>,
219 on_stack: &mut FxHashSet<VertexId>,
220 sccs: &mut Vec<Vec<VertexId>>,
221 vertex_set: &FxHashSet<VertexId>,
222 vdeps: &FxHashMap<VertexId, Vec<VertexId>>,
223 ) -> Result<(), ExcelError> {
224 indices.insert(vertex, *index_counter);
226 lowlinks.insert(vertex, *index_counter);
227 *index_counter += 1;
228 stack.push(vertex);
229 on_stack.insert(vertex);
230
231 let mut dependencies = self.graph.get_dependencies(vertex).to_vec();
233 if let Some(extra) = vdeps.get(&vertex) {
234 dependencies.extend(extra.iter().copied());
235 }
236 for dependency in dependencies.into_iter() {
237 if !vertex_set.contains(&dependency) {
239 continue;
240 }
241
242 if !indices.contains_key(&dependency) {
243 self.tarjan_visit_with_virtual(
245 dependency,
246 index_counter,
247 stack,
248 indices,
249 lowlinks,
250 on_stack,
251 sccs,
252 vertex_set,
253 vdeps,
254 )?;
255 let dep_lowlink = lowlinks[&dependency];
256 lowlinks.insert(vertex, lowlinks[&vertex].min(dep_lowlink));
257 } else if on_stack.contains(&dependency) {
258 let dep_index = indices[&dependency];
260 lowlinks.insert(vertex, lowlinks[&vertex].min(dep_index));
261 }
262 }
263
264 if lowlinks[&vertex] == indices[&vertex] {
266 let mut scc = Vec::new();
267 loop {
268 let w = stack.pop().unwrap();
269 on_stack.remove(&w);
270 scc.push(w);
271 if w == vertex {
272 break;
273 }
274 }
275 sccs.push(scc);
276 }
277
278 Ok(())
279 }
280
281 pub(crate) fn separate_cycles(
282 &self,
283 sccs: Vec<Vec<VertexId>>,
284 ) -> (Vec<Vec<VertexId>>, Vec<Vec<VertexId>>) {
285 let mut cycles = Vec::new();
286 let mut acyclic = Vec::new();
287
288 for scc in sccs {
289 if scc.len() > 1 || (scc.len() == 1 && self.has_self_loop(scc[0])) {
290 cycles.push(scc);
291 } else {
292 acyclic.push(scc);
293 }
294 }
295
296 (cycles, acyclic)
297 }
298
299 fn has_self_loop(&self, vertex: VertexId) -> bool {
300 self.graph.has_self_loop(vertex)
301 }
302
303 pub(crate) fn build_layers(
304 &self,
305 acyclic_sccs: Vec<Vec<VertexId>>,
306 ) -> Result<Vec<Layer>, ExcelError> {
307 let vertices: Vec<VertexId> = acyclic_sccs.into_iter().flatten().collect();
308 if vertices.is_empty() {
309 return Ok(Vec::new());
310 }
311 let vertex_set: FxHashSet<VertexId> = vertices.iter().copied().collect();
312
313 let mut in_degrees: FxHashMap<VertexId, usize> = vertices.iter().map(|&v| (v, 0)).collect();
315 for &vertex_id in &vertices {
316 let dependencies = self.graph.get_dependencies(vertex_id);
317 for &dep_id in &dependencies {
318 if vertex_set.contains(&dep_id) {
319 if let Some(in_degree) = in_degrees.get_mut(&vertex_id) {
320 *in_degree += 1;
321 }
322 }
323 }
324 }
325
326 let mut queue: std::collections::VecDeque<VertexId> = in_degrees
328 .iter()
329 .filter(|&(_, &in_degree)| in_degree == 0)
330 .map(|(&v, _)| v)
331 .collect();
332
333 let mut layers = Vec::new();
334 let mut processed_count = 0;
335
336 while !queue.is_empty() {
337 let mut current_layer_vertices = Vec::new();
338 for _ in 0..queue.len() {
339 let u = queue.pop_front().unwrap();
340 current_layer_vertices.push(u);
341 processed_count += 1;
342
343 for v_dep in self.graph.get_dependents(u) {
345 if let Some(in_degree) = in_degrees.get_mut(&v_dep) {
346 *in_degree -= 1;
347 if *in_degree == 0 {
348 queue.push_back(v_dep);
349 }
350 }
351 }
352 }
353 current_layer_vertices.sort();
355 layers.push(Layer {
356 vertices: current_layer_vertices,
357 });
358 }
359
360 if processed_count != vertices.len() {
361 return Err(
362 ExcelError::new(formualizer_common::ExcelErrorKind::Circ).with_message(
363 "Unexpected cycle detected in acyclic components during layer construction"
364 .to_string(),
365 ),
366 );
367 }
368
369 Ok(layers)
370 }
371
372 pub(crate) fn build_layers_with_virtual(
373 &self,
374 acyclic_sccs: Vec<Vec<VertexId>>,
375 vdeps: &FxHashMap<VertexId, Vec<VertexId>>,
376 ) -> Result<Vec<Layer>, ExcelError> {
377 use std::collections::VecDeque;
378 let vertices: Vec<VertexId> = acyclic_sccs.into_iter().flatten().collect();
379 if vertices.is_empty() {
380 return Ok(Vec::new());
381 }
382 let vertex_set: FxHashSet<VertexId> = vertices.iter().copied().collect();
383
384 let mut combined_deps: FxHashMap<VertexId, Vec<VertexId>> = FxHashMap::default();
386 let mut combined_out: FxHashMap<VertexId, Vec<VertexId>> = FxHashMap::default();
387 for &v in &vertices {
388 let mut deps: Vec<VertexId> = self
389 .graph
390 .get_dependencies(v)
391 .iter()
392 .copied()
393 .filter(|d| vertex_set.contains(d))
394 .collect();
395 if let Some(extra) = vdeps.get(&v) {
396 deps.extend(extra.iter().copied().filter(|d| vertex_set.contains(d)));
397 }
398 deps.sort_unstable();
399 deps.dedup();
400 combined_deps.insert(v, deps);
401 }
402 for (&v, deps) in combined_deps.iter() {
404 for &d in deps {
405 combined_out.entry(d).or_default().push(v);
406 }
407 }
408 let mut in_degrees: FxHashMap<VertexId, usize> = FxHashMap::default();
410 for &v in &vertices {
411 let indeg = combined_deps.get(&v).map(|v| v.len()).unwrap_or(0);
412 in_degrees.insert(v, indeg);
413 }
414 let mut queue: VecDeque<VertexId> = in_degrees
416 .iter()
417 .filter(|&(_, °)| deg == 0)
418 .map(|(&v, _)| v)
419 .collect();
420
421 let mut layers = Vec::new();
422 let mut processed_count = 0;
423 while !queue.is_empty() {
424 let mut cur = Vec::new();
425 for _ in 0..queue.len() {
426 let u = queue.pop_front().unwrap();
427 cur.push(u);
428 processed_count += 1;
429 if let Some(dependents) = combined_out.get(&u) {
430 for &w in dependents {
431 if let Some(ind) = in_degrees.get_mut(&w) {
432 *ind = ind.saturating_sub(1);
433 if *ind == 0 {
434 queue.push_back(w);
435 }
436 }
437 }
438 }
439 }
440 cur.sort_unstable();
441 layers.push(Layer { vertices: cur });
442 }
443 if processed_count != vertices.len() {
444 return Err(
445 ExcelError::new(formualizer_common::ExcelErrorKind::Circ).with_message(
446 "Unexpected cycle detected in acyclic components during layer construction (virtual)"
447 .to_string(),
448 ),
449 );
450 }
451 Ok(layers)
452 }
453}