tidepool_optimize/
pipeline.rs1use crate::beta::BetaReduce;
2use crate::case_reduce::CaseReduce;
3use crate::dce::Dce;
4use crate::inline::Inline;
5use crate::partial::PartialEval;
6use tidepool_eval::pass::{Changed, Pass};
7use tidepool_repr::CoreExpr;
8
9pub const MAX_PIPELINE_ITERATIONS: usize = 1000;
11
12#[derive(Debug, Clone, Default)]
14pub struct PipelineStats {
15 pub iterations: usize,
18 pub pass_invocations: Vec<(String, usize)>,
20}
21
22pub fn run_pipeline(
28 passes: &[Box<dyn Pass>],
29 expr: &mut CoreExpr,
30) -> Result<PipelineStats, String> {
31 let mut stats = PipelineStats {
32 iterations: 0,
33 pass_invocations: passes.iter().map(|p| (p.name().to_string(), 0)).collect(),
34 };
35
36 if passes.is_empty() {
37 return Ok(stats);
38 }
39
40 loop {
41 stats.iterations += 1;
42 if stats.iterations > MAX_PIPELINE_ITERATIONS {
43 return Err(format!(
44 "Optimization pipeline exceeded maximum iterations ({}). Potential infinite loop in passes: {:?}",
45 MAX_PIPELINE_ITERATIONS,
46 passes.iter().map(|p| p.name()).collect::<Vec<_>>()
47 ));
48 }
49
50 let mut changed: Changed = false;
51 for (i, pass) in passes.iter().enumerate() {
52 if pass.run(expr) {
53 changed = true;
54 }
55 stats.pass_invocations[i].1 += 1;
56 }
57
58 if !changed {
59 break;
60 }
61 }
62
63 Ok(stats)
64}
65
66pub fn default_passes() -> Vec<Box<dyn Pass>> {
69 vec![
70 Box::new(BetaReduce),
71 Box::new(Inline),
72 Box::new(CaseReduce),
73 Box::new(Dce),
74 Box::new(PartialEval),
75 ]
76}
77
78pub fn optimize(expr: &mut CoreExpr) -> Result<PipelineStats, String> {
80 run_pipeline(&default_passes(), expr)
81}
82
83pub fn run_pass_to_fixpoint(pass: &dyn Pass, expr: &mut CoreExpr) -> Result<usize, String> {
86 let mut changes = 0;
87 loop {
88 if !pass.run(expr) {
89 break;
90 }
91 changes += 1;
92 if changes >= MAX_PIPELINE_ITERATIONS {
93 return Err(format!(
94 "Pass '{}' exceeded maximum iterations ({}) in run_pass_to_fixpoint.",
95 pass.name(),
96 MAX_PIPELINE_ITERATIONS
97 ));
98 }
99 }
100 Ok(changes)
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use std::cell::Cell;
107 use tidepool_repr::{CoreFrame, RecursiveTree, VarId};
108
109 struct TestPass {
110 name: String,
111 changes_remaining: Cell<usize>,
112 }
113
114 impl Pass for TestPass {
115 fn run(&self, _expr: &mut CoreExpr) -> Changed {
116 let rem = self.changes_remaining.get();
117 if rem > 0 {
118 self.changes_remaining.set(rem - 1);
119 true
120 } else {
121 false
122 }
123 }
124
125 fn name(&self) -> &str {
126 &self.name
127 }
128 }
129
130 fn dummy_expr() -> CoreExpr {
131 RecursiveTree {
132 nodes: vec![CoreFrame::Var(VarId(0))],
133 }
134 }
135
136 #[test]
137 fn test_empty_pipeline() {
138 let mut expr = dummy_expr();
139 let stats = run_pipeline(&[], &mut expr).unwrap();
140 assert_eq!(stats.iterations, 0);
141 assert!(stats.pass_invocations.is_empty());
142 }
143
144 #[test]
145 fn test_single_noop_pass() {
146 let mut expr = dummy_expr();
147 let pass = Box::new(TestPass {
148 name: "NoOp".to_string(),
149 changes_remaining: Cell::new(0),
150 });
151 let stats = run_pipeline(&[pass], &mut expr).unwrap();
152 assert_eq!(stats.iterations, 1);
153 assert_eq!(stats.pass_invocations[0], ("NoOp".to_string(), 1));
154 }
155
156 #[test]
157 fn test_single_changing_pass() {
158 let mut expr = dummy_expr();
159 let pass = Box::new(TestPass {
160 name: "Changing".to_string(),
161 changes_remaining: Cell::new(1),
162 });
163 let stats = run_pipeline(&[pass], &mut expr).unwrap();
164 assert_eq!(stats.iterations, 2);
165 assert_eq!(stats.pass_invocations[0], ("Changing".to_string(), 2));
166 }
167
168 #[test]
169 fn test_fixed_point_terminates() {
170 let mut expr = dummy_expr();
171 let n = 5;
172 let pass = Box::new(TestPass {
173 name: "N-Times".to_string(),
174 changes_remaining: Cell::new(n),
175 });
176 let stats = run_pipeline(&[pass], &mut expr).unwrap();
177 assert_eq!(stats.iterations, n + 1);
178 assert_eq!(stats.pass_invocations[0], ("N-Times".to_string(), n + 1));
179 }
180
181 #[test]
182 fn test_pipeline_stats() {
183 let mut expr = dummy_expr();
184 let pass1 = Box::new(TestPass {
185 name: "P1".to_string(),
186 changes_remaining: Cell::new(2),
187 });
188 let pass2 = Box::new(TestPass {
189 name: "P2".to_string(),
190 changes_remaining: Cell::new(1),
191 });
192 let stats = run_pipeline(&[pass1, pass2], &mut expr).unwrap();
193 assert_eq!(stats.iterations, 3);
197 assert_eq!(stats.pass_invocations[0], ("P1".to_string(), 3));
198 assert_eq!(stats.pass_invocations[1], ("P2".to_string(), 3));
199 }
200
201 #[test]
202 fn test_run_pass_to_fixpoint() {
203 let mut expr = dummy_expr();
204 let n = 3;
205 let pass = TestPass {
206 name: "N-Times".to_string(),
207 changes_remaining: Cell::new(n),
208 };
209 let changes = run_pass_to_fixpoint(&pass, &mut expr).unwrap();
210 assert_eq!(changes, n);
211 }
212
213 #[test]
214 fn test_infinite_loop_returns_err() {
215 struct InfinitePass;
216 impl Pass for InfinitePass {
217 fn run(&self, _expr: &mut CoreExpr) -> Changed {
218 true
219 }
220 fn name(&self) -> &str {
221 "Infinite"
222 }
223 }
224 let mut expr = dummy_expr();
225 let result = run_pipeline(&[Box::new(InfinitePass)], &mut expr);
226 assert!(result.is_err());
227 assert!(result.unwrap_err().contains("exceeded maximum iterations"));
228 }
229}