Skip to main content

tidepool_optimize/
pipeline.rs

1use crate::beta::BetaReduce;
2use crate::case_reduce::CaseReduce;
3use crate::dce::Dce;
4use crate::inline::Inline;
5use crate::partial::PartialEval;
6use tidepool_eval::pass::{Changed, Pass};
7use tidepool_repr::CoreExpr;
8
9/// Maximum number of iterations for the pipeline to avoid infinite loops.
10pub const MAX_PIPELINE_ITERATIONS: usize = 1000;
11
12/// Statistics from a pipeline run.
13#[derive(Debug, Clone, Default)]
14pub struct PipelineStats {
15    /// Total number of pipeline iterations.
16    /// Includes the final iteration where no changes were reported.
17    pub iterations: usize,
18    /// Total number of times each pass was invoked.
19    pub pass_invocations: Vec<(String, usize)>,
20}
21
22/// Run a sequence of passes to fixed point.
23/// Keeps iterating until no pass reports a change or MAX_PIPELINE_ITERATIONS is reached.
24/// Returns stats about how many iterations and per-pass invocations.
25///
26/// Returns `Err` if the number of iterations exceeds MAX_PIPELINE_ITERATIONS.
27pub fn run_pipeline(
28    passes: &[Box<dyn Pass>],
29    expr: &mut CoreExpr,
30) -> Result<PipelineStats, String> {
31    let mut stats = PipelineStats {
32        iterations: 0,
33        pass_invocations: passes.iter().map(|p| (p.name().to_string(), 0)).collect(),
34    };
35
36    if passes.is_empty() {
37        return Ok(stats);
38    }
39
40    loop {
41        stats.iterations += 1;
42        if stats.iterations > MAX_PIPELINE_ITERATIONS {
43            return Err(format!(
44                "Optimization pipeline exceeded maximum iterations ({}). Potential infinite loop in passes: {:?}",
45                MAX_PIPELINE_ITERATIONS,
46                passes.iter().map(|p| p.name()).collect::<Vec<_>>()
47            ));
48        }
49
50        let mut changed: Changed = false;
51        for (i, pass) in passes.iter().enumerate() {
52            if pass.run(expr) {
53                changed = true;
54            }
55            stats.pass_invocations[i].1 += 1;
56        }
57
58        if !changed {
59            break;
60        }
61    }
62
63    Ok(stats)
64}
65
66/// Returns the default optimization pass sequence.
67/// Order: BetaReduce → Inline → CaseReduce → Dce → PartialEval.
68pub fn default_passes() -> Vec<Box<dyn Pass>> {
69    vec![
70        Box::new(BetaReduce),
71        Box::new(Inline),
72        Box::new(CaseReduce),
73        Box::new(Dce),
74        Box::new(PartialEval),
75    ]
76}
77
78/// Run the default optimization pipeline to fixed point.
79pub fn optimize(expr: &mut CoreExpr) -> Result<PipelineStats, String> {
80    run_pipeline(&default_passes(), expr)
81}
82
83/// Run a single pass to fixed point (convenience).
84/// Returns the number of times the pass reported a change.
85pub fn run_pass_to_fixpoint(pass: &dyn Pass, expr: &mut CoreExpr) -> Result<usize, String> {
86    let mut changes = 0;
87    loop {
88        if !pass.run(expr) {
89            break;
90        }
91        changes += 1;
92        if changes >= MAX_PIPELINE_ITERATIONS {
93            return Err(format!(
94                "Pass '{}' exceeded maximum iterations ({}) in run_pass_to_fixpoint.",
95                pass.name(),
96                MAX_PIPELINE_ITERATIONS
97            ));
98        }
99    }
100    Ok(changes)
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use std::cell::Cell;
107    use tidepool_repr::{CoreFrame, RecursiveTree, VarId};
108
109    struct TestPass {
110        name: String,
111        changes_remaining: Cell<usize>,
112    }
113
114    impl Pass for TestPass {
115        fn run(&self, _expr: &mut CoreExpr) -> Changed {
116            let rem = self.changes_remaining.get();
117            if rem > 0 {
118                self.changes_remaining.set(rem - 1);
119                true
120            } else {
121                false
122            }
123        }
124
125        fn name(&self) -> &str {
126            &self.name
127        }
128    }
129
130    fn dummy_expr() -> CoreExpr {
131        RecursiveTree {
132            nodes: vec![CoreFrame::Var(VarId(0))],
133        }
134    }
135
136    #[test]
137    fn test_empty_pipeline() {
138        let mut expr = dummy_expr();
139        let stats = run_pipeline(&[], &mut expr).unwrap();
140        assert_eq!(stats.iterations, 0);
141        assert!(stats.pass_invocations.is_empty());
142    }
143
144    #[test]
145    fn test_single_noop_pass() {
146        let mut expr = dummy_expr();
147        let pass = Box::new(TestPass {
148            name: "NoOp".to_string(),
149            changes_remaining: Cell::new(0),
150        });
151        let stats = run_pipeline(&[pass], &mut expr).unwrap();
152        assert_eq!(stats.iterations, 1);
153        assert_eq!(stats.pass_invocations[0], ("NoOp".to_string(), 1));
154    }
155
156    #[test]
157    fn test_single_changing_pass() {
158        let mut expr = dummy_expr();
159        let pass = Box::new(TestPass {
160            name: "Changing".to_string(),
161            changes_remaining: Cell::new(1),
162        });
163        let stats = run_pipeline(&[pass], &mut expr).unwrap();
164        assert_eq!(stats.iterations, 2);
165        assert_eq!(stats.pass_invocations[0], ("Changing".to_string(), 2));
166    }
167
168    #[test]
169    fn test_fixed_point_terminates() {
170        let mut expr = dummy_expr();
171        let n = 5;
172        let pass = Box::new(TestPass {
173            name: "N-Times".to_string(),
174            changes_remaining: Cell::new(n),
175        });
176        let stats = run_pipeline(&[pass], &mut expr).unwrap();
177        assert_eq!(stats.iterations, n + 1);
178        assert_eq!(stats.pass_invocations[0], ("N-Times".to_string(), n + 1));
179    }
180
181    #[test]
182    fn test_pipeline_stats() {
183        let mut expr = dummy_expr();
184        let pass1 = Box::new(TestPass {
185            name: "P1".to_string(),
186            changes_remaining: Cell::new(2),
187        });
188        let pass2 = Box::new(TestPass {
189            name: "P2".to_string(),
190            changes_remaining: Cell::new(1),
191        });
192        let stats = run_pipeline(&[pass1, pass2], &mut expr).unwrap();
193        // Iteration 1: P1 changes (2->1), P2 changes (1->0). Changed = true.
194        // Iteration 2: P1 changes (1->0), P2 no change. Changed = true.
195        // Iteration 3: P1 no change, P2 no change. Changed = false. Break.
196        assert_eq!(stats.iterations, 3);
197        assert_eq!(stats.pass_invocations[0], ("P1".to_string(), 3));
198        assert_eq!(stats.pass_invocations[1], ("P2".to_string(), 3));
199    }
200
201    #[test]
202    fn test_run_pass_to_fixpoint() {
203        let mut expr = dummy_expr();
204        let n = 3;
205        let pass = TestPass {
206            name: "N-Times".to_string(),
207            changes_remaining: Cell::new(n),
208        };
209        let changes = run_pass_to_fixpoint(&pass, &mut expr).unwrap();
210        assert_eq!(changes, n);
211    }
212
213    #[test]
214    fn test_infinite_loop_returns_err() {
215        struct InfinitePass;
216        impl Pass for InfinitePass {
217            fn run(&self, _expr: &mut CoreExpr) -> Changed {
218                true
219            }
220            fn name(&self) -> &str {
221                "Infinite"
222            }
223        }
224        let mut expr = dummy_expr();
225        let result = run_pipeline(&[Box::new(InfinitePass)], &mut expr);
226        assert!(result.is_err());
227        assert!(result.unwrap_err().contains("exceeded maximum iterations"));
228    }
229}