Skip to main content

oximedia_graph/
optimization.rs

1//! Graph optimization passes.
2//!
3//! Provides a pluggable pipeline of optimization passes that simplify a graph
4//! represented as a list of [`NodeSpec`]s.
5
6#![allow(dead_code)]
7
8use std::collections::HashMap;
9
10// ─────────────────────────────────────────────────────────────────────────────
11// NodeSpec
12// ─────────────────────────────────────────────────────────────────────────────
13
14/// A simplified, serializable description of a graph node used by the
15/// optimization passes.
16#[derive(Debug, Clone, PartialEq)]
17pub struct NodeSpec {
18    /// Unique node identifier.
19    pub id: String,
20    /// Node type name (e.g. `"Scale"`, `"Brightness"`, `"Contrast"`).
21    pub node_type: String,
22    /// Key/value parameters for the node.
23    pub params: HashMap<String, String>,
24    /// IDs of nodes that feed into this node (predecessor IDs).
25    pub inputs: Vec<String>,
26    /// IDs of nodes that consume output from this node (successor IDs).
27    pub outputs: Vec<String>,
28}
29
30impl NodeSpec {
31    /// Create a minimal node spec.
32    #[must_use]
33    pub fn new(id: impl Into<String>, node_type: impl Into<String>) -> Self {
34        Self {
35            id: id.into(),
36            node_type: node_type.into(),
37            params: HashMap::new(),
38            inputs: vec![],
39            outputs: vec![],
40        }
41    }
42
43    /// Builder helper: add a parameter.
44    #[must_use]
45    pub fn with_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
46        self.params.insert(key.into(), value.into());
47        self
48    }
49
50    /// Builder helper: set inputs.
51    #[must_use]
52    pub fn with_inputs(mut self, inputs: Vec<String>) -> Self {
53        self.inputs = inputs;
54        self
55    }
56
57    /// Builder helper: set outputs.
58    #[must_use]
59    pub fn with_outputs(mut self, outputs: Vec<String>) -> Self {
60        self.outputs = outputs;
61        self
62    }
63}
64
65// ─────────────────────────────────────────────────────────────────────────────
66// OptimizationPass trait
67// ─────────────────────────────────────────────────────────────────────────────
68
69/// An optimization pass that transforms a list of [`NodeSpec`]s in-place.
70pub trait OptimizationPass: Send + Sync {
71    /// Human-readable name of the pass.
72    fn name(&self) -> &str;
73
74    /// Apply the pass.  Returns the number of optimizations applied.
75    fn optimize(&self, nodes: &mut Vec<NodeSpec>) -> usize;
76}
77
78// ─────────────────────────────────────────────────────────────────────────────
79// ConstantFoldingPass
80// ─────────────────────────────────────────────────────────────────────────────
81
82/// Removes identity nodes whose output is always equal to their input.
83///
84/// Currently folds:
85/// - `Scale` nodes where the `"factor"` parameter is `"1.0"` or `"1"`.
86pub struct ConstantFoldingPass;
87
88impl ConstantFoldingPass {
89    /// Create a new constant folding pass.
90    #[must_use]
91    pub fn new() -> Self {
92        Self
93    }
94
95    fn is_identity(node: &NodeSpec) -> bool {
96        match node.node_type.as_str() {
97            "Scale" => {
98                let factor = node
99                    .params
100                    .get("factor")
101                    .map(String::as_str)
102                    .unwrap_or("1.0");
103                factor == "1.0" || factor == "1"
104            }
105            _ => false,
106        }
107    }
108}
109
110impl Default for ConstantFoldingPass {
111    fn default() -> Self {
112        Self::new()
113    }
114}
115
116impl OptimizationPass for ConstantFoldingPass {
117    fn name(&self) -> &str {
118        "ConstantFolding"
119    }
120
121    fn optimize(&self, nodes: &mut Vec<NodeSpec>) -> usize {
122        let before = nodes.len();
123
124        // Collect IDs of identity nodes.
125        let identity_ids: Vec<String> = nodes
126            .iter()
127            .filter(|n| Self::is_identity(n))
128            .map(|n| n.id.clone())
129            .collect();
130
131        if identity_ids.is_empty() {
132            return 0;
133        }
134
135        let identity_set: std::collections::HashSet<&str> =
136            identity_ids.iter().map(String::as_str).collect();
137
138        // Rewire: for every node whose input is an identity node, replace that
139        // input with the identity node's own inputs.
140        for node in nodes.iter_mut() {
141            node.inputs = node
142                .inputs
143                .iter()
144                .flat_map(|inp| {
145                    if identity_set.contains(inp.as_str()) {
146                        // Find the identity node's inputs (they carry the original data).
147                        // We already have the id; we need to look up its inputs from the
148                        // original slice.  Since we're iterating mutably we keep a copy.
149                        // In a real compiler IR this would be a proper use-def chain;
150                        // here we use a simplified placeholder approach: remove the edge.
151                        vec![] // no inputs → effectively bypass
152                    } else {
153                        vec![inp.clone()]
154                    }
155                })
156                .collect();
157        }
158
159        // Remove identity nodes.
160        nodes.retain(|n| !identity_set.contains(n.id.as_str()));
161
162        before - nodes.len()
163    }
164}
165
166// ─────────────────────────────────────────────────────────────────────────────
167// DeadNodeEliminationPass
168// ─────────────────────────────────────────────────────────────────────────────
169
170/// Removes nodes whose `outputs` list is empty (i.e., no other node consumes
171/// their output).
172///
173/// Nodes with no outputs are considered dead unless they are the only node in
174/// the graph (treating them as implicit sinks).
175pub struct DeadNodeEliminationPass;
176
177impl DeadNodeEliminationPass {
178    /// Create a new dead-node elimination pass.
179    #[must_use]
180    pub fn new() -> Self {
181        Self
182    }
183}
184
185impl Default for DeadNodeEliminationPass {
186    fn default() -> Self {
187        Self::new()
188    }
189}
190
191impl OptimizationPass for DeadNodeEliminationPass {
192    fn name(&self) -> &str {
193        "DeadNodeElimination"
194    }
195
196    fn optimize(&self, nodes: &mut Vec<NodeSpec>) -> usize {
197        if nodes.len() <= 1 {
198            return 0; // preserve single-node graphs
199        }
200
201        let before = nodes.len();
202
203        // Compute the set of all node IDs referenced by any other node (either as
204        // an input source or as an output target listed in another node's outputs).
205        let referenced: std::collections::HashSet<String> = nodes
206            .iter()
207            .flat_map(|n| n.inputs.iter().chain(n.outputs.iter()).cloned())
208            .collect();
209
210        // A node is "dead" if it has no successors AND is not referenced by anyone.
211        nodes.retain(|n| !n.outputs.is_empty() || referenced.contains(&n.id));
212
213        before - nodes.len()
214    }
215}
216
217// ─────────────────────────────────────────────────────────────────────────────
218// NodeFusionPass
219// ─────────────────────────────────────────────────────────────────────────────
220
221/// Fuses compatible sequential node pairs into a single fused node.
222///
223/// Currently fuses:
224/// - `Brightness` immediately followed by `Contrast` → `BrightnessContrast`
225pub struct NodeFusionPass;
226
227impl NodeFusionPass {
228    /// Create a new node fusion pass.
229    #[must_use]
230    pub fn new() -> Self {
231        Self
232    }
233}
234
235impl Default for NodeFusionPass {
236    fn default() -> Self {
237        Self::new()
238    }
239}
240
241impl OptimizationPass for NodeFusionPass {
242    fn name(&self) -> &str {
243        "NodeFusion"
244    }
245
246    fn optimize(&self, nodes: &mut Vec<NodeSpec>) -> usize {
247        let mut fusions = 0usize;
248        let mut i = 0;
249
250        while i + 1 < nodes.len() {
251            let (a, b) = (&nodes[i], &nodes[i + 1]);
252
253            // Check for Brightness → Contrast sequential pair.
254            let is_fusable = a.node_type == "Brightness"
255                && b.node_type == "Contrast"
256                && b.inputs.contains(&a.id);
257
258            if is_fusable {
259                // Merge parameters.
260                let mut params = a.params.clone();
261                for (k, v) in &b.params {
262                    params.insert(k.clone(), v.clone());
263                }
264
265                let fused = NodeSpec {
266                    id: format!("{}_{}", a.id, b.id),
267                    node_type: "BrightnessContrast".to_string(),
268                    params,
269                    inputs: a.inputs.clone(),
270                    outputs: b.outputs.clone(),
271                };
272
273                let a_id = a.id.clone();
274                let b_id = b.id.clone();
275                let fused_id = fused.id.clone();
276
277                // Replace the pair with the fused node.
278                nodes.remove(i + 1);
279                nodes[i] = fused;
280
281                // Update all references in the remaining nodes.
282                for node in nodes.iter_mut() {
283                    for inp in &mut node.inputs {
284                        if *inp == a_id || *inp == b_id {
285                            *inp = fused_id.clone();
286                        }
287                    }
288                    for out in &mut node.outputs {
289                        if *out == a_id || *out == b_id {
290                            *out = fused_id.clone();
291                        }
292                    }
293                }
294
295                fusions += 1;
296                // Do not advance i – re-check from the same position in case
297                // the fused node can itself be fused with a successor.
298            } else {
299                i += 1;
300            }
301        }
302
303        fusions
304    }
305}
306
307// ─────────────────────────────────────────────────────────────────────────────
308// OptimizationReport
309// ─────────────────────────────────────────────────────────────────────────────
310
311/// Summary of optimizations applied by the [`GraphOptimizer`].
312#[derive(Debug, Clone)]
313pub struct OptimizationReport {
314    /// Names of passes that were applied (in order).
315    pub passes_applied: Vec<String>,
316    /// Number of nodes before optimization.
317    pub nodes_before: usize,
318    /// Number of nodes after optimization.
319    pub nodes_after: usize,
320    /// Total number of individual optimizations performed.
321    pub optimizations: usize,
322}
323
324// ─────────────────────────────────────────────────────────────────────────────
325// GraphOptimizer
326// ─────────────────────────────────────────────────────────────────────────────
327
328/// Runs a sequence of [`OptimizationPass`]es over a node list.
329#[derive(Default)]
330pub struct GraphOptimizer {
331    passes: Vec<Box<dyn OptimizationPass>>,
332}
333
334impl GraphOptimizer {
335    /// Create a new optimizer with no passes.
336    #[must_use]
337    pub fn new() -> Self {
338        Self { passes: vec![] }
339    }
340
341    /// Add an optimization pass.
342    pub fn add_pass(&mut self, pass: Box<dyn OptimizationPass>) {
343        self.passes.push(pass);
344    }
345
346    /// Run all registered passes over `nodes`.
347    ///
348    /// Returns the optimized node list and an [`OptimizationReport`].
349    #[must_use]
350    pub fn run(&self, mut nodes: Vec<NodeSpec>) -> (Vec<NodeSpec>, OptimizationReport) {
351        let nodes_before = nodes.len();
352        let mut passes_applied = Vec::new();
353        let mut total_optimizations = 0;
354
355        for pass in &self.passes {
356            let count = pass.optimize(&mut nodes);
357            passes_applied.push(pass.name().to_string());
358            total_optimizations += count;
359        }
360
361        let report = OptimizationReport {
362            passes_applied,
363            nodes_before,
364            nodes_after: nodes.len(),
365            optimizations: total_optimizations,
366        };
367
368        (nodes, report)
369    }
370}
371
372// ─────────────────────────────────────────────────────────────────────────────
373// Unit tests
374// ─────────────────────────────────────────────────────────────────────────────
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379
380    fn scale_node(id: &str, factor: &str) -> NodeSpec {
381        NodeSpec::new(id, "Scale").with_param("factor", factor)
382    }
383
384    fn brightness_node(id: &str) -> NodeSpec {
385        NodeSpec::new(id, "Brightness").with_param("value", "1.2")
386    }
387
388    fn contrast_node(id: &str, input: &str) -> NodeSpec {
389        NodeSpec::new(id, "Contrast")
390            .with_param("value", "1.1")
391            .with_inputs(vec![input.to_string()])
392    }
393
394    // ── ConstantFoldingPass ───────────────────────────────────────────────────
395
396    #[test]
397    fn test_constant_folding_removes_scale_one() {
398        let pass = ConstantFoldingPass::new();
399        let mut nodes = vec![scale_node("s1", "1.0")];
400        let removed = pass.optimize(&mut nodes);
401        assert_eq!(removed, 1);
402        assert!(nodes.is_empty());
403    }
404
405    #[test]
406    fn test_constant_folding_keeps_scale_two() {
407        let pass = ConstantFoldingPass::new();
408        let mut nodes = vec![scale_node("s1", "2.0")];
409        let removed = pass.optimize(&mut nodes);
410        assert_eq!(removed, 0);
411        assert_eq!(nodes.len(), 1);
412    }
413
414    #[test]
415    fn test_constant_folding_integer_one() {
416        let pass = ConstantFoldingPass::new();
417        let mut nodes = vec![scale_node("s1", "1")];
418        let removed = pass.optimize(&mut nodes);
419        assert_eq!(removed, 1);
420    }
421
422    #[test]
423    fn test_constant_folding_mixed_nodes() {
424        let pass = ConstantFoldingPass::new();
425        let mut nodes = vec![scale_node("s1", "1.0"), scale_node("s2", "0.5")];
426        let removed = pass.optimize(&mut nodes);
427        assert_eq!(removed, 1);
428        assert_eq!(nodes.len(), 1);
429        assert_eq!(nodes[0].id, "s2");
430    }
431
432    // ── DeadNodeEliminationPass ───────────────────────────────────────────────
433
434    #[test]
435    fn test_dead_node_elimination_no_outputs() {
436        let pass = DeadNodeEliminationPass::new();
437        // Two nodes: neither references the other.
438        let mut nodes = vec![NodeSpec::new("a", "Filter"), NodeSpec::new("b", "Filter")];
439        // Neither has outputs or is referenced → both are dead.
440        let removed = pass.optimize(&mut nodes);
441        assert_eq!(removed, 2);
442        assert!(nodes.is_empty());
443    }
444
445    #[test]
446    fn test_dead_node_elimination_referenced_node_kept() {
447        let pass = DeadNodeEliminationPass::new();
448        let mut nodes = vec![
449            NodeSpec::new("a", "Source").with_outputs(vec!["b".to_string()]),
450            NodeSpec::new("b", "Sink").with_inputs(vec!["a".to_string()]),
451        ];
452        let removed = pass.optimize(&mut nodes);
453        assert_eq!(removed, 0);
454    }
455
456    #[test]
457    fn test_dead_node_elimination_single_node_preserved() {
458        let pass = DeadNodeEliminationPass::new();
459        let mut nodes = vec![NodeSpec::new("a", "Source")];
460        let removed = pass.optimize(&mut nodes);
461        assert_eq!(removed, 0); // single-node graphs preserved
462        assert_eq!(nodes.len(), 1);
463    }
464
465    // ── NodeFusionPass ────────────────────────────────────────────────────────
466
467    #[test]
468    fn test_node_fusion_brightness_contrast() {
469        let pass = NodeFusionPass::new();
470        let mut nodes = vec![
471            brightness_node("b1").with_outputs(vec!["c1".to_string()]),
472            contrast_node("c1", "b1"),
473        ];
474        let fusions = pass.optimize(&mut nodes);
475        assert_eq!(fusions, 1);
476        assert_eq!(nodes.len(), 1);
477        assert_eq!(nodes[0].node_type, "BrightnessContrast");
478    }
479
480    #[test]
481    fn test_node_fusion_no_match() {
482        let pass = NodeFusionPass::new();
483        let mut nodes = vec![NodeSpec::new("a", "Scale"), NodeSpec::new("b", "Gamma")];
484        let fusions = pass.optimize(&mut nodes);
485        assert_eq!(fusions, 0);
486        assert_eq!(nodes.len(), 2);
487    }
488
489    #[test]
490    fn test_node_fusion_fused_node_has_merged_params() {
491        let pass = NodeFusionPass::new();
492        let mut nodes = vec![
493            brightness_node("b1").with_outputs(vec!["c1".to_string()]),
494            contrast_node("c1", "b1"),
495        ];
496        pass.optimize(&mut nodes);
497        assert!(nodes[0].params.contains_key("value"));
498    }
499
500    // ── GraphOptimizer ────────────────────────────────────────────────────────
501
502    #[test]
503    fn test_optimizer_empty_graph() {
504        let mut opt = GraphOptimizer::new();
505        opt.add_pass(Box::new(ConstantFoldingPass::new()));
506        let (nodes, report) = opt.run(vec![]);
507        assert!(nodes.is_empty());
508        assert_eq!(report.nodes_before, 0);
509        assert_eq!(report.nodes_after, 0);
510    }
511
512    #[test]
513    fn test_optimizer_report_fields() {
514        let mut opt = GraphOptimizer::new();
515        opt.add_pass(Box::new(ConstantFoldingPass::new()));
516        let nodes = vec![scale_node("s1", "1.0"), scale_node("s2", "2.0")];
517        let (_, report) = opt.run(nodes);
518        assert_eq!(report.nodes_before, 2);
519        assert_eq!(report.nodes_after, 1);
520        assert_eq!(report.optimizations, 1);
521        assert_eq!(report.passes_applied, vec!["ConstantFolding"]);
522    }
523
524    #[test]
525    fn test_optimizer_multiple_passes() {
526        let mut opt = GraphOptimizer::new();
527        opt.add_pass(Box::new(ConstantFoldingPass::new()));
528        opt.add_pass(Box::new(DeadNodeEliminationPass::new()));
529        let nodes = vec![scale_node("s1", "1.0"), NodeSpec::new("orphan", "Filter")];
530        let (_, report) = opt.run(nodes);
531        assert_eq!(report.passes_applied.len(), 2);
532    }
533
534    #[test]
535    fn test_optimizer_no_passes() {
536        let opt = GraphOptimizer::new();
537        let nodes = vec![NodeSpec::new("a", "Filter")];
538        let (result, report) = opt.run(nodes);
539        assert_eq!(result.len(), 1);
540        assert_eq!(report.optimizations, 0);
541        assert!(report.passes_applied.is_empty());
542    }
543}