formualizer_eval/engine/
scheduler.rs

1use super::DependencyGraph;
2use super::vertex::VertexId;
3use formualizer_common::ExcelError;
4use rustc_hash::{FxHashMap, FxHashSet};
5
6pub struct Scheduler<'a> {
7    graph: &'a DependencyGraph,
8}
9
10#[derive(Debug)]
11pub struct Layer {
12    pub vertices: Vec<VertexId>,
13}
14
15#[derive(Debug)]
16pub struct Schedule {
17    pub layers: Vec<Layer>,
18    pub cycles: Vec<Vec<VertexId>>,
19}
20
21impl<'a> Scheduler<'a> {
22    pub fn new(graph: &'a DependencyGraph) -> Self {
23        Self { graph }
24    }
25
26    pub fn create_schedule(&self, vertices: &[VertexId]) -> Result<Schedule, ExcelError> {
27        #[cfg(feature = "tracing")]
28        let _span = tracing::info_span!("scheduler", vertices = vertices.len()).entered();
29        // 1. Find strongly connected components using Tarjan's algorithm
30        #[cfg(feature = "tracing")]
31        let _scc_span = tracing::info_span!("tarjan_scc").entered();
32        let sccs = self.tarjan_scc(vertices)?;
33        #[cfg(feature = "tracing")]
34        drop(_scc_span);
35
36        // 2. Separate cyclic from acyclic components
37        let (cycles, acyclic_sccs) = self.separate_cycles(sccs);
38
39        // 3. Topologically sort acyclic components into layers
40        let layers = if self.graph.dynamic_topo_enabled() {
41            let subset: Vec<VertexId> = acyclic_sccs.into_iter().flatten().collect();
42            if subset.is_empty() {
43                Vec::new()
44            } else {
45                self.graph
46                    .pk_layers_for(&subset)
47                    .unwrap_or(self.build_layers(vec![subset])?)
48            }
49        } else {
50            self.build_layers(acyclic_sccs)?
51        };
52
53        Ok(Schedule { layers, cycles })
54    }
55
56    /// Create a schedule considering additional ephemeral (virtual) dependencies just for this pass.
57    /// `vdeps` maps a vertex to extra dependency vertices that should be considered as incoming edges.
58    pub fn create_schedule_with_virtual(
59        &self,
60        vertices: &[VertexId],
61        vdeps: &FxHashMap<VertexId, Vec<VertexId>>,
62    ) -> Result<Schedule, ExcelError> {
63        #[cfg(feature = "tracing")]
64        let _span = tracing::info_span!(
65            "scheduler_with_virtual",
66            vertices = vertices.len(),
67            vdeps = vdeps.len()
68        )
69        .entered();
70        // 1. SCC detection with virtual deps
71        #[cfg(feature = "tracing")]
72        let _scc_span = tracing::info_span!("tarjan_scc_with_virtual").entered();
73        let sccs = self.tarjan_scc_with_virtual(vertices, vdeps)?;
74        #[cfg(feature = "tracing")]
75        drop(_scc_span);
76        // 2. Separate cycles and acyclic components
77        let (cycles, acyclic_sccs) = self.separate_cycles(sccs);
78        // 3. Build layers over combined adjacency (graph + vdeps)
79        #[cfg(feature = "tracing")]
80        let _layers_span = tracing::info_span!("build_layers_with_virtual").entered();
81        let layers = self.build_layers_with_virtual(acyclic_sccs, vdeps)?;
82        #[cfg(feature = "tracing")]
83        drop(_layers_span);
84        Ok(Schedule { layers, cycles })
85    }
86
87    /// Tarjan's strongly connected components algorithm
88    pub fn tarjan_scc(&self, vertices: &[VertexId]) -> Result<Vec<Vec<VertexId>>, ExcelError> {
89        let mut index_counter = 0;
90        let mut stack = Vec::new();
91        let mut indices = FxHashMap::default();
92        let mut lowlinks = FxHashMap::default();
93        let mut on_stack = FxHashSet::default();
94        let mut sccs = Vec::new();
95        let vertex_set: FxHashSet<VertexId> = vertices.iter().copied().collect();
96
97        for &vertex in vertices {
98            if !indices.contains_key(&vertex) {
99                self.tarjan_visit(
100                    vertex,
101                    &mut index_counter,
102                    &mut stack,
103                    &mut indices,
104                    &mut lowlinks,
105                    &mut on_stack,
106                    &mut sccs,
107                    &vertex_set,
108                )?;
109            }
110        }
111
112        Ok(sccs)
113    }
114
115    /// Tarjan with virtual deps
116    fn tarjan_scc_with_virtual(
117        &self,
118        vertices: &[VertexId],
119        vdeps: &FxHashMap<VertexId, Vec<VertexId>>,
120    ) -> Result<Vec<Vec<VertexId>>, ExcelError> {
121        let mut index_counter = 0;
122        let mut stack = Vec::new();
123        let mut indices = FxHashMap::default();
124        let mut lowlinks = FxHashMap::default();
125        let mut on_stack = FxHashSet::default();
126        let mut sccs = Vec::new();
127        let vertex_set: FxHashSet<VertexId> = vertices.iter().copied().collect();
128
129        for &vertex in vertices {
130            if !indices.contains_key(&vertex) {
131                self.tarjan_visit_with_virtual(
132                    vertex,
133                    &mut index_counter,
134                    &mut stack,
135                    &mut indices,
136                    &mut lowlinks,
137                    &mut on_stack,
138                    &mut sccs,
139                    &vertex_set,
140                    vdeps,
141                )?;
142            }
143        }
144
145        Ok(sccs)
146    }
147
148    fn tarjan_visit(
149        &self,
150        vertex: VertexId,
151        index_counter: &mut usize,
152        stack: &mut Vec<VertexId>,
153        indices: &mut FxHashMap<VertexId, usize>,
154        lowlinks: &mut FxHashMap<VertexId, usize>,
155        on_stack: &mut FxHashSet<VertexId>,
156        sccs: &mut Vec<Vec<VertexId>>,
157        vertex_set: &FxHashSet<VertexId>,
158    ) -> Result<(), ExcelError> {
159        // Set the depth index for vertex to the smallest unused index
160        indices.insert(vertex, *index_counter);
161        lowlinks.insert(vertex, *index_counter);
162        *index_counter += 1;
163        stack.push(vertex);
164        on_stack.insert(vertex);
165
166        // Consider successors of vertex (dependencies)
167        let dependencies = self.graph.get_dependencies(vertex);
168        for &dependency in &dependencies {
169            // Only consider dependencies that are part of the current scheduling task
170            if !vertex_set.contains(&dependency) {
171                continue;
172            }
173
174            if !indices.contains_key(&dependency) {
175                // Successor dependency has not yet been visited; recurse on it
176                self.tarjan_visit(
177                    dependency,
178                    index_counter,
179                    stack,
180                    indices,
181                    lowlinks,
182                    on_stack,
183                    sccs,
184                    vertex_set,
185                )?;
186                let dep_lowlink = lowlinks[&dependency];
187                lowlinks.insert(vertex, lowlinks[&vertex].min(dep_lowlink));
188            } else if on_stack.contains(&dependency) {
189                // Successor dependency is in stack and hence in the current SCC
190                let dep_index = indices[&dependency];
191                lowlinks.insert(vertex, lowlinks[&vertex].min(dep_index));
192            }
193        }
194
195        // If vertex is a root node, pop the stack and print an SCC
196        if lowlinks[&vertex] == indices[&vertex] {
197            let mut scc = Vec::new();
198            loop {
199                let w = stack.pop().unwrap();
200                on_stack.remove(&w);
201                scc.push(w);
202                if w == vertex {
203                    break;
204                }
205            }
206            sccs.push(scc);
207        }
208
209        Ok(())
210    }
211
212    fn tarjan_visit_with_virtual(
213        &self,
214        vertex: VertexId,
215        index_counter: &mut usize,
216        stack: &mut Vec<VertexId>,
217        indices: &mut FxHashMap<VertexId, usize>,
218        lowlinks: &mut FxHashMap<VertexId, usize>,
219        on_stack: &mut FxHashSet<VertexId>,
220        sccs: &mut Vec<Vec<VertexId>>,
221        vertex_set: &FxHashSet<VertexId>,
222        vdeps: &FxHashMap<VertexId, Vec<VertexId>>,
223    ) -> Result<(), ExcelError> {
224        // Set the depth index for vertex to the smallest unused index
225        indices.insert(vertex, *index_counter);
226        lowlinks.insert(vertex, *index_counter);
227        *index_counter += 1;
228        stack.push(vertex);
229        on_stack.insert(vertex);
230
231        // Consider successors of vertex (dependencies) including virtual deps
232        let mut dependencies = self.graph.get_dependencies(vertex).to_vec();
233        if let Some(extra) = vdeps.get(&vertex) {
234            dependencies.extend(extra.iter().copied());
235        }
236        for dependency in dependencies.into_iter() {
237            // Only consider dependencies that are part of the current scheduling task
238            if !vertex_set.contains(&dependency) {
239                continue;
240            }
241
242            if !indices.contains_key(&dependency) {
243                // Successor dependency has not yet been visited; recurse on it
244                self.tarjan_visit_with_virtual(
245                    dependency,
246                    index_counter,
247                    stack,
248                    indices,
249                    lowlinks,
250                    on_stack,
251                    sccs,
252                    vertex_set,
253                    vdeps,
254                )?;
255                let dep_lowlink = lowlinks[&dependency];
256                lowlinks.insert(vertex, lowlinks[&vertex].min(dep_lowlink));
257            } else if on_stack.contains(&dependency) {
258                // Successor dependency is in stack and hence in the current SCC
259                let dep_index = indices[&dependency];
260                lowlinks.insert(vertex, lowlinks[&vertex].min(dep_index));
261            }
262        }
263
264        // If vertex is a root node, pop the stack and produce an SCC
265        if lowlinks[&vertex] == indices[&vertex] {
266            let mut scc = Vec::new();
267            loop {
268                let w = stack.pop().unwrap();
269                on_stack.remove(&w);
270                scc.push(w);
271                if w == vertex {
272                    break;
273                }
274            }
275            sccs.push(scc);
276        }
277
278        Ok(())
279    }
280
281    pub(crate) fn separate_cycles(
282        &self,
283        sccs: Vec<Vec<VertexId>>,
284    ) -> (Vec<Vec<VertexId>>, Vec<Vec<VertexId>>) {
285        let mut cycles = Vec::new();
286        let mut acyclic = Vec::new();
287
288        for scc in sccs {
289            if scc.len() > 1 || (scc.len() == 1 && self.has_self_loop(scc[0])) {
290                cycles.push(scc);
291            } else {
292                acyclic.push(scc);
293            }
294        }
295
296        (cycles, acyclic)
297    }
298
299    fn has_self_loop(&self, vertex: VertexId) -> bool {
300        self.graph.has_self_loop(vertex)
301    }
302
303    pub(crate) fn build_layers(
304        &self,
305        acyclic_sccs: Vec<Vec<VertexId>>,
306    ) -> Result<Vec<Layer>, ExcelError> {
307        let vertices: Vec<VertexId> = acyclic_sccs.into_iter().flatten().collect();
308        if vertices.is_empty() {
309            return Ok(Vec::new());
310        }
311        let vertex_set: FxHashSet<VertexId> = vertices.iter().copied().collect();
312
313        // Calculate in-degrees for all vertices in the acyclic subgraph
314        let mut in_degrees: FxHashMap<VertexId, usize> = vertices.iter().map(|&v| (v, 0)).collect();
315        for &vertex_id in &vertices {
316            let dependencies = self.graph.get_dependencies(vertex_id);
317            for &dep_id in &dependencies {
318                if vertex_set.contains(&dep_id) {
319                    if let Some(in_degree) = in_degrees.get_mut(&vertex_id) {
320                        *in_degree += 1;
321                    }
322                }
323            }
324        }
325
326        // Initialize the queue with all nodes having an in-degree of 0
327        let mut queue: std::collections::VecDeque<VertexId> = in_degrees
328            .iter()
329            .filter(|&(_, &in_degree)| in_degree == 0)
330            .map(|(&v, _)| v)
331            .collect();
332
333        let mut layers = Vec::new();
334        let mut processed_count = 0;
335
336        while !queue.is_empty() {
337            let mut current_layer_vertices = Vec::new();
338            for _ in 0..queue.len() {
339                let u = queue.pop_front().unwrap();
340                current_layer_vertices.push(u);
341                processed_count += 1;
342
343                // For each dependent of u, reduce its in-degree
344                for v_dep in self.graph.get_dependents(u) {
345                    if let Some(in_degree) = in_degrees.get_mut(&v_dep) {
346                        *in_degree -= 1;
347                        if *in_degree == 0 {
348                            queue.push_back(v_dep);
349                        }
350                    }
351                }
352            }
353            // Sort for deterministic output in tests
354            current_layer_vertices.sort();
355            layers.push(Layer {
356                vertices: current_layer_vertices,
357            });
358        }
359
360        if processed_count != vertices.len() {
361            return Err(
362                ExcelError::new(formualizer_common::ExcelErrorKind::Circ).with_message(
363                    "Unexpected cycle detected in acyclic components during layer construction"
364                        .to_string(),
365                ),
366            );
367        }
368
369        Ok(layers)
370    }
371
372    pub(crate) fn build_layers_with_virtual(
373        &self,
374        acyclic_sccs: Vec<Vec<VertexId>>,
375        vdeps: &FxHashMap<VertexId, Vec<VertexId>>,
376    ) -> Result<Vec<Layer>, ExcelError> {
377        use std::collections::VecDeque;
378        let vertices: Vec<VertexId> = acyclic_sccs.into_iter().flatten().collect();
379        if vertices.is_empty() {
380            return Ok(Vec::new());
381        }
382        let vertex_set: FxHashSet<VertexId> = vertices.iter().copied().collect();
383
384        // Build combined adjacency (dependencies and dependents) within the subset
385        let mut combined_deps: FxHashMap<VertexId, Vec<VertexId>> = FxHashMap::default();
386        let mut combined_out: FxHashMap<VertexId, Vec<VertexId>> = FxHashMap::default();
387        for &v in &vertices {
388            let mut deps: Vec<VertexId> = self
389                .graph
390                .get_dependencies(v)
391                .iter()
392                .copied()
393                .filter(|d| vertex_set.contains(d))
394                .collect();
395            if let Some(extra) = vdeps.get(&v) {
396                deps.extend(extra.iter().copied().filter(|d| vertex_set.contains(d)));
397            }
398            deps.sort_unstable();
399            deps.dedup();
400            combined_deps.insert(v, deps);
401        }
402        // invert
403        for (&v, deps) in combined_deps.iter() {
404            for &d in deps {
405                combined_out.entry(d).or_default().push(v);
406            }
407        }
408        // in-degrees
409        let mut in_degrees: FxHashMap<VertexId, usize> = FxHashMap::default();
410        for &v in &vertices {
411            let indeg = combined_deps.get(&v).map(|v| v.len()).unwrap_or(0);
412            in_degrees.insert(v, indeg);
413        }
414        // queue of 0 in-degree
415        let mut queue: VecDeque<VertexId> = in_degrees
416            .iter()
417            .filter(|&(_, &deg)| deg == 0)
418            .map(|(&v, _)| v)
419            .collect();
420
421        let mut layers = Vec::new();
422        let mut processed_count = 0;
423        while !queue.is_empty() {
424            let mut cur = Vec::new();
425            for _ in 0..queue.len() {
426                let u = queue.pop_front().unwrap();
427                cur.push(u);
428                processed_count += 1;
429                if let Some(dependents) = combined_out.get(&u) {
430                    for &w in dependents {
431                        if let Some(ind) = in_degrees.get_mut(&w) {
432                            *ind = ind.saturating_sub(1);
433                            if *ind == 0 {
434                                queue.push_back(w);
435                            }
436                        }
437                    }
438                }
439            }
440            cur.sort_unstable();
441            layers.push(Layer { vertices: cur });
442        }
443        if processed_count != vertices.len() {
444            return Err(
445                ExcelError::new(formualizer_common::ExcelErrorKind::Circ).with_message(
446                    "Unexpected cycle detected in acyclic components during layer construction (virtual)"
447                        .to_string(),
448                ),
449            );
450        }
451        Ok(layers)
452    }
453}