1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4pub enum HostcallRewritePlanKind {
5 BaselineCanonical,
6 FastOpcodeFusion,
7}
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub struct HostcallRewritePlan {
11 pub kind: HostcallRewritePlanKind,
12 pub estimated_cost: u32,
13 pub rule_id: &'static str,
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub struct HostcallRewriteDecision {
18 pub selected: HostcallRewritePlan,
19 pub expected_cost_delta: i64,
20 pub fallback_reason: Option<&'static str>,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub struct HostcallRewriteEngine {
25 enabled: bool,
26}
27
28impl HostcallRewriteEngine {
29 #[must_use]
30 pub const fn new(enabled: bool) -> Self {
31 Self { enabled }
32 }
33
34 #[must_use]
35 pub fn from_env() -> Self {
36 Self::from_opt(std::env::var("PI_HOSTCALL_EGRAPH_REWRITE").ok().as_deref())
37 }
38
39 #[must_use]
40 pub fn from_opt(value: Option<&str>) -> Self {
41 let enabled = value.is_none_or(|v| {
42 !matches!(
43 v.trim().to_ascii_lowercase().as_str(),
44 "0" | "false" | "off" | "disabled"
45 )
46 });
47 Self::new(enabled)
48 }
49
50 #[must_use]
51 pub const fn enabled(&self) -> bool {
52 self.enabled
53 }
54
55 #[must_use]
56 pub fn select_plan(
57 &self,
58 baseline: HostcallRewritePlan,
59 candidates: &[HostcallRewritePlan],
60 ) -> HostcallRewriteDecision {
61 if !self.enabled {
62 return HostcallRewriteDecision {
63 selected: baseline,
64 expected_cost_delta: 0,
65 fallback_reason: Some("rewrite_disabled"),
66 };
67 }
68
69 let mut best: Option<HostcallRewritePlan> = None;
70 let mut ambiguous = false;
71 for candidate in candidates {
72 if candidate.estimated_cost >= baseline.estimated_cost {
73 continue;
74 }
75 match best {
76 None => best = Some(*candidate),
77 Some(current) => {
78 if candidate.estimated_cost < current.estimated_cost {
79 best = Some(*candidate);
80 ambiguous = false;
81 } else if candidate.estimated_cost == current.estimated_cost
82 && (candidate.kind != current.kind || candidate.rule_id != current.rule_id)
83 {
84 ambiguous = true;
85 }
86 }
87 }
88 }
89
90 let Some(selected) = best else {
91 return HostcallRewriteDecision {
92 selected: baseline,
93 expected_cost_delta: 0,
94 fallback_reason: Some("no_better_candidate"),
95 };
96 };
97
98 if ambiguous {
99 return HostcallRewriteDecision {
100 selected: baseline,
101 expected_cost_delta: 0,
102 fallback_reason: Some("ambiguous_min_cost"),
103 };
104 }
105
106 HostcallRewriteDecision {
107 selected,
108 expected_cost_delta: i64::from(baseline.estimated_cost)
109 - i64::from(selected.estimated_cost),
110 fallback_reason: None,
111 }
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 const BASELINE: HostcallRewritePlan = HostcallRewritePlan {
120 kind: HostcallRewritePlanKind::BaselineCanonical,
121 estimated_cost: 100,
122 rule_id: "baseline",
123 };
124
125 const FAST_FUSION: HostcallRewritePlan = HostcallRewritePlan {
126 kind: HostcallRewritePlanKind::FastOpcodeFusion,
127 estimated_cost: 35,
128 rule_id: "fuse_hash_dispatch_fast_opcode",
129 };
130
131 #[test]
132 fn rewrite_engine_selects_unique_lower_cost_plan() {
133 let engine = HostcallRewriteEngine::new(true);
134 let decision = engine.select_plan(BASELINE, &[FAST_FUSION]);
135 assert_eq!(decision.selected, FAST_FUSION);
136 assert_eq!(decision.expected_cost_delta, 65);
137 assert!(decision.fallback_reason.is_none());
138 }
139
140 #[test]
141 fn rewrite_engine_rejects_when_disabled() {
142 let engine = HostcallRewriteEngine::new(false);
143 let decision = engine.select_plan(BASELINE, &[FAST_FUSION]);
144 assert_eq!(decision.selected, BASELINE);
145 assert_eq!(decision.expected_cost_delta, 0);
146 assert_eq!(decision.fallback_reason, Some("rewrite_disabled"));
147 }
148
149 #[test]
150 fn rewrite_engine_rejects_ambiguous_min_cost_candidates() {
151 let engine = HostcallRewriteEngine::new(true);
152 let alt = HostcallRewritePlan {
153 kind: HostcallRewritePlanKind::FastOpcodeFusion,
154 estimated_cost: 35,
155 rule_id: "fuse_validate_dispatch_fast_opcode",
156 };
157 let decision = engine.select_plan(BASELINE, &[FAST_FUSION, alt]);
158 assert_eq!(decision.selected, BASELINE);
159 assert_eq!(decision.fallback_reason, Some("ambiguous_min_cost"));
160 }
161
162 #[test]
165 fn rewrite_engine_rejects_no_better_candidate() {
166 let engine = HostcallRewriteEngine::new(true);
167 let worse = HostcallRewritePlan {
168 kind: HostcallRewritePlanKind::FastOpcodeFusion,
169 estimated_cost: 120,
170 rule_id: "slow_rule",
171 };
172 let equal = HostcallRewritePlan {
173 kind: HostcallRewritePlanKind::FastOpcodeFusion,
174 estimated_cost: 100,
175 rule_id: "equal_rule",
176 };
177 let decision = engine.select_plan(BASELINE, &[worse, equal]);
178 assert_eq!(decision.selected, BASELINE);
179 assert_eq!(decision.expected_cost_delta, 0);
180 assert_eq!(decision.fallback_reason, Some("no_better_candidate"));
181 }
182
183 #[test]
184 fn rewrite_engine_selects_cheapest_among_multiple_candidates() {
185 let engine = HostcallRewriteEngine::new(true);
186 let mid = HostcallRewritePlan {
187 kind: HostcallRewritePlanKind::FastOpcodeFusion,
188 estimated_cost: 50,
189 rule_id: "mid_rule",
190 };
191 let cheapest = HostcallRewritePlan {
192 kind: HostcallRewritePlanKind::FastOpcodeFusion,
193 estimated_cost: 20,
194 rule_id: "cheapest_rule",
195 };
196 let decision = engine.select_plan(BASELINE, &[mid, FAST_FUSION, cheapest]);
197 assert_eq!(decision.selected, cheapest);
198 assert_eq!(decision.expected_cost_delta, 80);
199 assert!(decision.fallback_reason.is_none());
200 }
201
202 #[test]
203 fn rewrite_engine_empty_candidates_returns_no_better() {
204 let engine = HostcallRewriteEngine::new(true);
205 let decision = engine.select_plan(BASELINE, &[]);
206 assert_eq!(decision.selected, BASELINE);
207 assert_eq!(decision.fallback_reason, Some("no_better_candidate"));
208 }
209
210 #[test]
211 fn rewrite_engine_ambiguity_resolved_by_same_kind_and_rule() {
212 let engine = HostcallRewriteEngine::new(true);
213 let dup = HostcallRewritePlan {
215 kind: HostcallRewritePlanKind::FastOpcodeFusion,
216 estimated_cost: 35,
217 rule_id: "fuse_hash_dispatch_fast_opcode",
218 };
219 let decision = engine.select_plan(BASELINE, &[FAST_FUSION, dup]);
220 assert_eq!(decision.selected, FAST_FUSION);
221 assert!(decision.fallback_reason.is_none());
222 }
223
224 #[test]
225 fn rewrite_engine_accessors() {
226 let enabled = HostcallRewriteEngine::new(true);
227 assert!(enabled.enabled());
228 let disabled = HostcallRewriteEngine::new(false);
229 assert!(!disabled.enabled());
230 }
231
232 #[test]
233 fn plan_kind_variants_distinct() {
234 assert_ne!(
235 HostcallRewritePlanKind::BaselineCanonical,
236 HostcallRewritePlanKind::FastOpcodeFusion
237 );
238 }
239
240 #[test]
241 fn from_env_returns_valid_engine() {
242 let engine = HostcallRewriteEngine::from_env();
244 let _ = engine.enabled();
245 }
246
247 #[test]
248 fn from_opt_disabled_by_known_off_values() {
249 for value in ["0", "false", "off", "disabled", "FALSE", "OFF", "Disabled"] {
250 let engine = HostcallRewriteEngine::from_opt(Some(value));
251 assert!(!engine.enabled(), "should be disabled for '{value}'");
252 }
253 }
254
255 #[test]
256 fn from_opt_enabled_for_other_values_and_none() {
257 assert!(HostcallRewriteEngine::from_opt(None).enabled());
259
260 for value in ["1", "true", "on", "yes", "anything_else"] {
262 let engine = HostcallRewriteEngine::from_opt(Some(value));
263 assert!(engine.enabled(), "should be enabled for '{value}'");
264 }
265 }
266
267 #[test]
268 fn rewrite_decision_cost_delta_correct_sign() {
269 let engine = HostcallRewriteEngine::new(true);
270 let decision = engine.select_plan(BASELINE, &[FAST_FUSION]);
271 assert!(
272 decision.expected_cost_delta > 0,
273 "positive delta means improvement"
274 );
275 assert_eq!(
276 decision.expected_cost_delta,
277 i64::from(BASELINE.estimated_cost) - i64::from(FAST_FUSION.estimated_cost)
278 );
279 }
280
281 mod proptest_rewrite {
284 use super::*;
285 use proptest::prelude::*;
286
287 fn arb_kind() -> impl Strategy<Value = HostcallRewritePlanKind> {
288 prop::sample::select(vec![
289 HostcallRewritePlanKind::BaselineCanonical,
290 HostcallRewritePlanKind::FastOpcodeFusion,
291 ])
292 }
293
294 fn arb_plan() -> impl Strategy<Value = HostcallRewritePlan> {
295 (arb_kind(), 0..1000u32).prop_map(|(kind, cost)| HostcallRewritePlan {
296 kind,
297 estimated_cost: cost,
298 rule_id: "arb_rule",
299 })
300 }
301
302 proptest! {
303 #[test]
304 fn selected_cost_never_exceeds_baseline(
305 baseline in arb_plan(),
306 candidates in prop::collection::vec(arb_plan(), 0..10),
307 ) {
308 let engine = HostcallRewriteEngine::new(true);
309 let decision = engine.select_plan(baseline, &candidates);
310 assert!(
311 decision.selected.estimated_cost <= baseline.estimated_cost,
312 "selected cost {} must not exceed baseline {}",
313 decision.selected.estimated_cost,
314 baseline.estimated_cost,
315 );
316 }
317
318 #[test]
319 fn cost_delta_is_nonnegative(
320 baseline in arb_plan(),
321 candidates in prop::collection::vec(arb_plan(), 0..10),
322 ) {
323 let engine = HostcallRewriteEngine::new(true);
324 let decision = engine.select_plan(baseline, &candidates);
325 assert!(
326 decision.expected_cost_delta >= 0,
327 "cost delta must be non-negative, got {}",
328 decision.expected_cost_delta,
329 );
330 }
331
332 #[test]
333 fn cost_delta_equals_baseline_minus_selected(
334 baseline in arb_plan(),
335 candidates in prop::collection::vec(arb_plan(), 0..10),
336 ) {
337 let engine = HostcallRewriteEngine::new(true);
338 let decision = engine.select_plan(baseline, &candidates);
339 let expected_delta = i64::from(baseline.estimated_cost)
340 - i64::from(decision.selected.estimated_cost);
341 assert_eq!(
342 decision.expected_cost_delta, expected_delta,
343 "delta must equal baseline - selected"
344 );
345 }
346
347 #[test]
348 fn disabled_engine_always_returns_baseline(
349 baseline in arb_plan(),
350 candidates in prop::collection::vec(arb_plan(), 0..10),
351 ) {
352 let engine = HostcallRewriteEngine::new(false);
353 let decision = engine.select_plan(baseline, &candidates);
354 assert_eq!(decision.selected, baseline);
355 assert_eq!(decision.expected_cost_delta, 0);
356 assert_eq!(decision.fallback_reason, Some("rewrite_disabled"));
357 }
358
359 #[test]
360 fn select_plan_is_deterministic(
361 baseline in arb_plan(),
362 candidates in prop::collection::vec(arb_plan(), 0..10),
363 enabled in any::<bool>(),
364 ) {
365 let engine = HostcallRewriteEngine::new(enabled);
366 let d1 = engine.select_plan(baseline, &candidates);
367 let d2 = engine.select_plan(baseline, &candidates);
368 assert_eq!(d1.selected, d2.selected);
369 assert_eq!(d1.expected_cost_delta, d2.expected_cost_delta);
370 assert_eq!(d1.fallback_reason, d2.fallback_reason);
371 }
372 }
373 }
374}