Skip to main content

formualizer_eval/engine/
scheduler.rs

1use super::DependencyGraph;
2use super::vertex::VertexId;
3use formualizer_common::ExcelError;
4use rustc_hash::{FxHashMap, FxHashSet};
5
6pub struct Scheduler<'a> {
7    graph: &'a DependencyGraph,
8}
9
10#[derive(Debug, Clone)]
11pub struct Layer {
12    pub vertices: Vec<VertexId>,
13}
14
15#[derive(Debug, Clone)]
16pub struct Schedule {
17    pub layers: Vec<Layer>,
18    pub cycles: Vec<Vec<VertexId>>,
19}
20
21impl<'a> Scheduler<'a> {
22    pub fn new(graph: &'a DependencyGraph) -> Self {
23        Self { graph }
24    }
25
26    pub fn create_schedule(&self, vertices: &[VertexId]) -> Result<Schedule, ExcelError> {
27        #[cfg(feature = "tracing")]
28        let _span = tracing::info_span!("scheduler", vertices = vertices.len()).entered();
29        // 1. Find strongly connected components using Tarjan's algorithm
30        #[cfg(feature = "tracing")]
31        let _scc_span = tracing::info_span!("tarjan_scc").entered();
32        let sccs = self.tarjan_scc(vertices)?;
33        #[cfg(feature = "tracing")]
34        drop(_scc_span);
35
36        // 2. Separate cyclic from acyclic components
37        let (cycles, acyclic_sccs) = self.separate_cycles(sccs);
38
39        // 3. Topologically sort acyclic components into layers
40        let layers = if self.graph.dynamic_topo_enabled() {
41            let subset: Vec<VertexId> = acyclic_sccs.into_iter().flatten().collect();
42            if subset.is_empty() {
43                Vec::new()
44            } else {
45                self.graph
46                    .pk_layers_for(&subset)
47                    .unwrap_or(self.build_layers(vec![subset])?)
48            }
49        } else {
50            self.build_layers(acyclic_sccs)?
51        };
52
53        Ok(Schedule { layers, cycles })
54    }
55
56    /// Create a schedule considering additional ephemeral (virtual) dependencies just for this pass.
57    /// `vdeps` maps a vertex to extra dependency vertices that should be considered as incoming edges.
58    pub fn create_schedule_with_virtual(
59        &self,
60        vertices: &[VertexId],
61        vdeps: &FxHashMap<VertexId, Vec<VertexId>>,
62    ) -> Result<Schedule, ExcelError> {
63        #[cfg(feature = "tracing")]
64        let _span = tracing::info_span!(
65            "scheduler_with_virtual",
66            vertices = vertices.len(),
67            vdeps = vdeps.len()
68        )
69        .entered();
70        // 1. SCC detection with virtual deps
71        #[cfg(feature = "tracing")]
72        let _scc_span = tracing::info_span!("tarjan_scc_with_virtual").entered();
73        let sccs = self.tarjan_scc_with_virtual(vertices, vdeps)?;
74        #[cfg(feature = "tracing")]
75        drop(_scc_span);
76        // 2. Separate cycles and acyclic components
77        let (cycles, acyclic_sccs) = self.separate_cycles(sccs);
78        // 3. Build layers over combined adjacency (graph + vdeps)
79        #[cfg(feature = "tracing")]
80        let _layers_span = tracing::info_span!("build_layers_with_virtual").entered();
81        let layers = self.build_layers_with_virtual(acyclic_sccs, vdeps)?;
82        #[cfg(feature = "tracing")]
83        drop(_layers_span);
84        Ok(Schedule { layers, cycles })
85    }
86
87    /// Tarjan's strongly connected components algorithm
88    pub fn tarjan_scc(&self, vertices: &[VertexId]) -> Result<Vec<Vec<VertexId>>, ExcelError> {
89        let mut index_counter = 0;
90        let mut stack = Vec::new();
91        let mut indices = FxHashMap::default();
92        let mut lowlinks = FxHashMap::default();
93        let mut on_stack = FxHashSet::default();
94        let mut sccs = Vec::new();
95        let vertex_set: FxHashSet<VertexId> = vertices.iter().copied().collect();
96
97        for &vertex in vertices {
98            if !indices.contains_key(&vertex) {
99                self.tarjan_visit(
100                    vertex,
101                    &mut index_counter,
102                    &mut stack,
103                    &mut indices,
104                    &mut lowlinks,
105                    &mut on_stack,
106                    &mut sccs,
107                    &vertex_set,
108                )?;
109            }
110        }
111
112        Ok(sccs)
113    }
114
115    /// Tarjan with virtual deps
116    fn tarjan_scc_with_virtual(
117        &self,
118        vertices: &[VertexId],
119        vdeps: &FxHashMap<VertexId, Vec<VertexId>>,
120    ) -> Result<Vec<Vec<VertexId>>, ExcelError> {
121        let mut index_counter = 0;
122        let mut stack = Vec::new();
123        let mut indices = FxHashMap::default();
124        let mut lowlinks = FxHashMap::default();
125        let mut on_stack = FxHashSet::default();
126        let mut sccs = Vec::new();
127        let vertex_set: FxHashSet<VertexId> = vertices.iter().copied().collect();
128
129        for &vertex in vertices {
130            if !indices.contains_key(&vertex) {
131                self.tarjan_visit_with_virtual(
132                    vertex,
133                    &mut index_counter,
134                    &mut stack,
135                    &mut indices,
136                    &mut lowlinks,
137                    &mut on_stack,
138                    &mut sccs,
139                    &vertex_set,
140                    vdeps,
141                )?;
142            }
143        }
144
145        Ok(sccs)
146    }
147
148    fn tarjan_visit(
149        &self,
150        vertex: VertexId,
151        index_counter: &mut usize,
152        stack: &mut Vec<VertexId>,
153        indices: &mut FxHashMap<VertexId, usize>,
154        lowlinks: &mut FxHashMap<VertexId, usize>,
155        on_stack: &mut FxHashSet<VertexId>,
156        sccs: &mut Vec<Vec<VertexId>>,
157        vertex_set: &FxHashSet<VertexId>,
158    ) -> Result<(), ExcelError> {
159        // Set the depth index for vertex to the smallest unused index
160        indices.insert(vertex, *index_counter);
161        lowlinks.insert(vertex, *index_counter);
162        *index_counter += 1;
163        stack.push(vertex);
164        on_stack.insert(vertex);
165
166        // Consider successors of vertex (dependencies)
167        if let Some(dependencies) = self.graph.dependencies_slice(vertex) {
168            for &dependency in dependencies {
169                // Only consider dependencies that are part of the current scheduling task
170                if !vertex_set.contains(&dependency) {
171                    continue;
172                }
173
174                if !indices.contains_key(&dependency) {
175                    // Successor dependency has not yet been visited; recurse on it
176                    self.tarjan_visit(
177                        dependency,
178                        index_counter,
179                        stack,
180                        indices,
181                        lowlinks,
182                        on_stack,
183                        sccs,
184                        vertex_set,
185                    )?;
186                    let dep_lowlink = lowlinks[&dependency];
187                    lowlinks.insert(vertex, lowlinks[&vertex].min(dep_lowlink));
188                } else if on_stack.contains(&dependency) {
189                    // Successor dependency is in stack and hence in the current SCC
190                    let dep_index = indices[&dependency];
191                    lowlinks.insert(vertex, lowlinks[&vertex].min(dep_index));
192                }
193            }
194        } else {
195            let dependencies = self.graph.get_dependencies(vertex);
196            for dependency in dependencies {
197                // Only consider dependencies that are part of the current scheduling task
198                if !vertex_set.contains(&dependency) {
199                    continue;
200                }
201
202                if !indices.contains_key(&dependency) {
203                    // Successor dependency has not yet been visited; recurse on it
204                    self.tarjan_visit(
205                        dependency,
206                        index_counter,
207                        stack,
208                        indices,
209                        lowlinks,
210                        on_stack,
211                        sccs,
212                        vertex_set,
213                    )?;
214                    let dep_lowlink = lowlinks[&dependency];
215                    lowlinks.insert(vertex, lowlinks[&vertex].min(dep_lowlink));
216                } else if on_stack.contains(&dependency) {
217                    // Successor dependency is in stack and hence in the current SCC
218                    let dep_index = indices[&dependency];
219                    lowlinks.insert(vertex, lowlinks[&vertex].min(dep_index));
220                }
221            }
222        }
223
224        // If vertex is a root node, pop the stack and print an SCC
225        if lowlinks[&vertex] == indices[&vertex] {
226            let mut scc = Vec::new();
227            loop {
228                let w = stack.pop().unwrap();
229                on_stack.remove(&w);
230                scc.push(w);
231                if w == vertex {
232                    break;
233                }
234            }
235            sccs.push(scc);
236        }
237
238        Ok(())
239    }
240
241    fn tarjan_visit_with_virtual(
242        &self,
243        vertex: VertexId,
244        index_counter: &mut usize,
245        stack: &mut Vec<VertexId>,
246        indices: &mut FxHashMap<VertexId, usize>,
247        lowlinks: &mut FxHashMap<VertexId, usize>,
248        on_stack: &mut FxHashSet<VertexId>,
249        sccs: &mut Vec<Vec<VertexId>>,
250        vertex_set: &FxHashSet<VertexId>,
251        vdeps: &FxHashMap<VertexId, Vec<VertexId>>,
252    ) -> Result<(), ExcelError> {
253        // Set the depth index for vertex to the smallest unused index
254        indices.insert(vertex, *index_counter);
255        lowlinks.insert(vertex, *index_counter);
256        *index_counter += 1;
257        stack.push(vertex);
258        on_stack.insert(vertex);
259
260        // Consider successors of vertex (dependencies) including virtual deps
261        if let Some(extra) = vdeps.get(&vertex) {
262            let mut dependencies: Vec<VertexId> =
263                if let Some(base) = self.graph.dependencies_slice(vertex) {
264                    base.to_vec()
265                } else {
266                    self.graph.get_dependencies(vertex)
267                };
268            dependencies.extend(extra.iter().copied());
269
270            for dependency in dependencies {
271                // Only consider dependencies that are part of the current scheduling task
272                if !vertex_set.contains(&dependency) {
273                    continue;
274                }
275
276                if !indices.contains_key(&dependency) {
277                    // Successor dependency has not yet been visited; recurse on it
278                    self.tarjan_visit_with_virtual(
279                        dependency,
280                        index_counter,
281                        stack,
282                        indices,
283                        lowlinks,
284                        on_stack,
285                        sccs,
286                        vertex_set,
287                        vdeps,
288                    )?;
289                    let dep_lowlink = lowlinks[&dependency];
290                    lowlinks.insert(vertex, lowlinks[&vertex].min(dep_lowlink));
291                } else if on_stack.contains(&dependency) {
292                    // Successor dependency is in stack and hence in the current SCC
293                    let dep_index = indices[&dependency];
294                    lowlinks.insert(vertex, lowlinks[&vertex].min(dep_index));
295                }
296            }
297        } else if let Some(dependencies) = self.graph.dependencies_slice(vertex) {
298            for &dependency in dependencies {
299                // Only consider dependencies that are part of the current scheduling task
300                if !vertex_set.contains(&dependency) {
301                    continue;
302                }
303
304                if !indices.contains_key(&dependency) {
305                    // Successor dependency has not yet been visited; recurse on it
306                    self.tarjan_visit_with_virtual(
307                        dependency,
308                        index_counter,
309                        stack,
310                        indices,
311                        lowlinks,
312                        on_stack,
313                        sccs,
314                        vertex_set,
315                        vdeps,
316                    )?;
317                    let dep_lowlink = lowlinks[&dependency];
318                    lowlinks.insert(vertex, lowlinks[&vertex].min(dep_lowlink));
319                } else if on_stack.contains(&dependency) {
320                    // Successor dependency is in stack and hence in the current SCC
321                    let dep_index = indices[&dependency];
322                    lowlinks.insert(vertex, lowlinks[&vertex].min(dep_index));
323                }
324            }
325        } else {
326            let dependencies = self.graph.get_dependencies(vertex);
327            for dependency in dependencies {
328                // Only consider dependencies that are part of the current scheduling task
329                if !vertex_set.contains(&dependency) {
330                    continue;
331                }
332
333                if !indices.contains_key(&dependency) {
334                    // Successor dependency has not yet been visited; recurse on it
335                    self.tarjan_visit_with_virtual(
336                        dependency,
337                        index_counter,
338                        stack,
339                        indices,
340                        lowlinks,
341                        on_stack,
342                        sccs,
343                        vertex_set,
344                        vdeps,
345                    )?;
346                    let dep_lowlink = lowlinks[&dependency];
347                    lowlinks.insert(vertex, lowlinks[&vertex].min(dep_lowlink));
348                } else if on_stack.contains(&dependency) {
349                    // Successor dependency is in stack and hence in the current SCC
350                    let dep_index = indices[&dependency];
351                    lowlinks.insert(vertex, lowlinks[&vertex].min(dep_index));
352                }
353            }
354        }
355
356        // If vertex is a root node, pop the stack and produce an SCC
357        if lowlinks[&vertex] == indices[&vertex] {
358            let mut scc = Vec::new();
359            loop {
360                let w = stack.pop().unwrap();
361                on_stack.remove(&w);
362                scc.push(w);
363                if w == vertex {
364                    break;
365                }
366            }
367            sccs.push(scc);
368        }
369
370        Ok(())
371    }
372
373    pub(crate) fn separate_cycles(
374        &self,
375        sccs: Vec<Vec<VertexId>>,
376    ) -> (Vec<Vec<VertexId>>, Vec<Vec<VertexId>>) {
377        let mut cycles = Vec::new();
378        let mut acyclic = Vec::new();
379
380        for scc in sccs {
381            if scc.len() > 1 || (scc.len() == 1 && self.has_self_loop(scc[0])) {
382                cycles.push(scc);
383            } else {
384                acyclic.push(scc);
385            }
386        }
387
388        (cycles, acyclic)
389    }
390
391    fn has_self_loop(&self, vertex: VertexId) -> bool {
392        self.graph.has_self_loop(vertex)
393    }
394
395    pub(crate) fn build_layers(
396        &self,
397        acyclic_sccs: Vec<Vec<VertexId>>,
398    ) -> Result<Vec<Layer>, ExcelError> {
399        let vertices: Vec<VertexId> = acyclic_sccs.into_iter().flatten().collect();
400        if vertices.is_empty() {
401            return Ok(Vec::new());
402        }
403        let vertex_set: FxHashSet<VertexId> = vertices.iter().copied().collect();
404
405        // Calculate in-degrees for all vertices in the acyclic subgraph
406        let mut in_degrees: FxHashMap<VertexId, usize> = vertices.iter().map(|&v| (v, 0)).collect();
407        for &vertex_id in &vertices {
408            if let Some(dependencies) = self.graph.dependencies_slice(vertex_id) {
409                for &dep_id in dependencies {
410                    if vertex_set.contains(&dep_id)
411                        && let Some(in_degree) = in_degrees.get_mut(&vertex_id)
412                    {
413                        *in_degree += 1;
414                    }
415                }
416            } else {
417                let dependencies = self.graph.get_dependencies(vertex_id);
418                for dep_id in dependencies {
419                    if vertex_set.contains(&dep_id)
420                        && let Some(in_degree) = in_degrees.get_mut(&vertex_id)
421                    {
422                        *in_degree += 1;
423                    }
424                }
425            }
426        }
427
428        // Initialize the queue with all nodes having an in-degree of 0
429        let mut queue: std::collections::VecDeque<VertexId> = in_degrees
430            .iter()
431            .filter(|&(_, &in_degree)| in_degree == 0)
432            .map(|(&v, _)| v)
433            .collect();
434
435        let mut layers = Vec::new();
436        let mut processed_count = 0;
437
438        while !queue.is_empty() {
439            let mut current_layer_vertices = Vec::new();
440            for _ in 0..queue.len() {
441                let u = queue.pop_front().unwrap();
442                current_layer_vertices.push(u);
443                processed_count += 1;
444
445                // For each dependent of u, reduce its in-degree
446                if let Some(dependents) = self.graph.dependents_slice(u) {
447                    for &v_dep in dependents {
448                        if let Some(in_degree) = in_degrees.get_mut(&v_dep) {
449                            *in_degree -= 1;
450                            if *in_degree == 0 {
451                                queue.push_back(v_dep);
452                            }
453                        }
454                    }
455                } else {
456                    for v_dep in self.graph.get_dependents(u) {
457                        if let Some(in_degree) = in_degrees.get_mut(&v_dep) {
458                            *in_degree -= 1;
459                            if *in_degree == 0 {
460                                queue.push_back(v_dep);
461                            }
462                        }
463                    }
464                }
465            }
466            // Sort for deterministic output in tests
467            current_layer_vertices.sort();
468            layers.push(Layer {
469                vertices: current_layer_vertices,
470            });
471        }
472
473        if processed_count != vertices.len() {
474            return Err(
475                ExcelError::new(formualizer_common::ExcelErrorKind::Circ).with_message(
476                    "Unexpected cycle detected in acyclic components during layer construction"
477                        .to_string(),
478                ),
479            );
480        }
481
482        Ok(layers)
483    }
484
485    pub(crate) fn build_layers_with_virtual(
486        &self,
487        acyclic_sccs: Vec<Vec<VertexId>>,
488        vdeps: &FxHashMap<VertexId, Vec<VertexId>>,
489    ) -> Result<Vec<Layer>, ExcelError> {
490        use std::collections::VecDeque;
491        let vertices: Vec<VertexId> = acyclic_sccs.into_iter().flatten().collect();
492        if vertices.is_empty() {
493            return Ok(Vec::new());
494        }
495        let vertex_set: FxHashSet<VertexId> = vertices.iter().copied().collect();
496
497        // Build combined adjacency (dependencies and dependents) within the subset
498        let mut combined_deps: FxHashMap<VertexId, Vec<VertexId>> = FxHashMap::default();
499        let mut combined_out: FxHashMap<VertexId, Vec<VertexId>> = FxHashMap::default();
500        for &v in &vertices {
501            let mut deps: Vec<VertexId> = Vec::new();
502            if let Some(base) = self.graph.dependencies_slice(v) {
503                deps.extend(base.iter().copied().filter(|d| vertex_set.contains(d)));
504            } else {
505                deps.extend(
506                    self.graph
507                        .get_dependencies(v)
508                        .into_iter()
509                        .filter(|d| vertex_set.contains(d)),
510                );
511            }
512            if let Some(extra) = vdeps.get(&v) {
513                deps.extend(extra.iter().copied().filter(|d| vertex_set.contains(d)));
514            }
515            deps.sort_unstable();
516            deps.dedup();
517            combined_deps.insert(v, deps);
518        }
519        // invert
520        for (&v, deps) in combined_deps.iter() {
521            for &d in deps {
522                combined_out.entry(d).or_default().push(v);
523            }
524        }
525        // in-degrees
526        let mut in_degrees: FxHashMap<VertexId, usize> = FxHashMap::default();
527        for &v in &vertices {
528            let indeg = combined_deps.get(&v).map(|v| v.len()).unwrap_or(0);
529            in_degrees.insert(v, indeg);
530        }
531        // queue of 0 in-degree
532        let mut queue: VecDeque<VertexId> = in_degrees
533            .iter()
534            .filter(|&(_, &deg)| deg == 0)
535            .map(|(&v, _)| v)
536            .collect();
537
538        let mut layers = Vec::new();
539        let mut processed_count = 0;
540        while !queue.is_empty() {
541            let mut cur = Vec::new();
542            for _ in 0..queue.len() {
543                let u = queue.pop_front().unwrap();
544                cur.push(u);
545                processed_count += 1;
546                if let Some(dependents) = combined_out.get(&u) {
547                    for &w in dependents {
548                        if let Some(ind) = in_degrees.get_mut(&w) {
549                            *ind = ind.saturating_sub(1);
550                            if *ind == 0 {
551                                queue.push_back(w);
552                            }
553                        }
554                    }
555                }
556            }
557            cur.sort_unstable();
558            layers.push(Layer { vertices: cur });
559        }
560        if processed_count != vertices.len() {
561            return Err(
562                ExcelError::new(formualizer_common::ExcelErrorKind::Circ).with_message(
563                    "Unexpected cycle detected in acyclic components during layer construction (virtual)"
564                        .to_string(),
565                ),
566            );
567        }
568        Ok(layers)
569    }
570}