Skip to main content

pi/
hostcall_rewrite.rs

1//! Constrained hostcall rewrite planner for hot-path marshalling.
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq)]
4pub enum HostcallRewritePlanKind {
5    BaselineCanonical,
6    FastOpcodeFusion,
7}
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub struct HostcallRewritePlan {
11    pub kind: HostcallRewritePlanKind,
12    pub estimated_cost: u32,
13    pub rule_id: &'static str,
14}
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub struct HostcallRewriteDecision {
18    pub selected: HostcallRewritePlan,
19    pub expected_cost_delta: i64,
20    pub fallback_reason: Option<&'static str>,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub struct HostcallRewriteEngine {
25    enabled: bool,
26}
27
28impl HostcallRewriteEngine {
29    #[must_use]
30    pub const fn new(enabled: bool) -> Self {
31        Self { enabled }
32    }
33
34    #[must_use]
35    pub fn from_env() -> Self {
36        Self::from_opt(std::env::var("PI_HOSTCALL_EGRAPH_REWRITE").ok().as_deref())
37    }
38
39    #[must_use]
40    pub fn from_opt(value: Option<&str>) -> Self {
41        let enabled = value.is_none_or(|v| {
42            !matches!(
43                v.trim().to_ascii_lowercase().as_str(),
44                "0" | "false" | "off" | "disabled"
45            )
46        });
47        Self::new(enabled)
48    }
49
50    #[must_use]
51    pub const fn enabled(&self) -> bool {
52        self.enabled
53    }
54
55    #[must_use]
56    pub fn select_plan(
57        &self,
58        baseline: HostcallRewritePlan,
59        candidates: &[HostcallRewritePlan],
60    ) -> HostcallRewriteDecision {
61        if !self.enabled {
62            return HostcallRewriteDecision {
63                selected: baseline,
64                expected_cost_delta: 0,
65                fallback_reason: Some("rewrite_disabled"),
66            };
67        }
68
69        let mut best: Option<HostcallRewritePlan> = None;
70        let mut ambiguous = false;
71        for candidate in candidates {
72            if candidate.estimated_cost >= baseline.estimated_cost {
73                continue;
74            }
75            match best {
76                None => best = Some(*candidate),
77                Some(current) => {
78                    if candidate.estimated_cost < current.estimated_cost {
79                        best = Some(*candidate);
80                        ambiguous = false;
81                    } else if candidate.estimated_cost == current.estimated_cost
82                        && (candidate.kind != current.kind || candidate.rule_id != current.rule_id)
83                    {
84                        ambiguous = true;
85                    }
86                }
87            }
88        }
89
90        let Some(selected) = best else {
91            return HostcallRewriteDecision {
92                selected: baseline,
93                expected_cost_delta: 0,
94                fallback_reason: Some("no_better_candidate"),
95            };
96        };
97
98        if ambiguous {
99            return HostcallRewriteDecision {
100                selected: baseline,
101                expected_cost_delta: 0,
102                fallback_reason: Some("ambiguous_min_cost"),
103            };
104        }
105
106        HostcallRewriteDecision {
107            selected,
108            expected_cost_delta: i64::from(baseline.estimated_cost)
109                - i64::from(selected.estimated_cost),
110            fallback_reason: None,
111        }
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    const BASELINE: HostcallRewritePlan = HostcallRewritePlan {
120        kind: HostcallRewritePlanKind::BaselineCanonical,
121        estimated_cost: 100,
122        rule_id: "baseline",
123    };
124
125    const FAST_FUSION: HostcallRewritePlan = HostcallRewritePlan {
126        kind: HostcallRewritePlanKind::FastOpcodeFusion,
127        estimated_cost: 35,
128        rule_id: "fuse_hash_dispatch_fast_opcode",
129    };
130
131    #[test]
132    fn rewrite_engine_selects_unique_lower_cost_plan() {
133        let engine = HostcallRewriteEngine::new(true);
134        let decision = engine.select_plan(BASELINE, &[FAST_FUSION]);
135        assert_eq!(decision.selected, FAST_FUSION);
136        assert_eq!(decision.expected_cost_delta, 65);
137        assert!(decision.fallback_reason.is_none());
138    }
139
140    #[test]
141    fn rewrite_engine_rejects_when_disabled() {
142        let engine = HostcallRewriteEngine::new(false);
143        let decision = engine.select_plan(BASELINE, &[FAST_FUSION]);
144        assert_eq!(decision.selected, BASELINE);
145        assert_eq!(decision.expected_cost_delta, 0);
146        assert_eq!(decision.fallback_reason, Some("rewrite_disabled"));
147    }
148
149    #[test]
150    fn rewrite_engine_rejects_ambiguous_min_cost_candidates() {
151        let engine = HostcallRewriteEngine::new(true);
152        let alt = HostcallRewritePlan {
153            kind: HostcallRewritePlanKind::FastOpcodeFusion,
154            estimated_cost: 35,
155            rule_id: "fuse_validate_dispatch_fast_opcode",
156        };
157        let decision = engine.select_plan(BASELINE, &[FAST_FUSION, alt]);
158        assert_eq!(decision.selected, BASELINE);
159        assert_eq!(decision.fallback_reason, Some("ambiguous_min_cost"));
160    }
161
162    // ── Additional coverage ──
163
164    #[test]
165    fn rewrite_engine_rejects_no_better_candidate() {
166        let engine = HostcallRewriteEngine::new(true);
167        let worse = HostcallRewritePlan {
168            kind: HostcallRewritePlanKind::FastOpcodeFusion,
169            estimated_cost: 120,
170            rule_id: "slow_rule",
171        };
172        let equal = HostcallRewritePlan {
173            kind: HostcallRewritePlanKind::FastOpcodeFusion,
174            estimated_cost: 100,
175            rule_id: "equal_rule",
176        };
177        let decision = engine.select_plan(BASELINE, &[worse, equal]);
178        assert_eq!(decision.selected, BASELINE);
179        assert_eq!(decision.expected_cost_delta, 0);
180        assert_eq!(decision.fallback_reason, Some("no_better_candidate"));
181    }
182
183    #[test]
184    fn rewrite_engine_selects_cheapest_among_multiple_candidates() {
185        let engine = HostcallRewriteEngine::new(true);
186        let mid = HostcallRewritePlan {
187            kind: HostcallRewritePlanKind::FastOpcodeFusion,
188            estimated_cost: 50,
189            rule_id: "mid_rule",
190        };
191        let cheapest = HostcallRewritePlan {
192            kind: HostcallRewritePlanKind::FastOpcodeFusion,
193            estimated_cost: 20,
194            rule_id: "cheapest_rule",
195        };
196        let decision = engine.select_plan(BASELINE, &[mid, FAST_FUSION, cheapest]);
197        assert_eq!(decision.selected, cheapest);
198        assert_eq!(decision.expected_cost_delta, 80);
199        assert!(decision.fallback_reason.is_none());
200    }
201
202    #[test]
203    fn rewrite_engine_empty_candidates_returns_no_better() {
204        let engine = HostcallRewriteEngine::new(true);
205        let decision = engine.select_plan(BASELINE, &[]);
206        assert_eq!(decision.selected, BASELINE);
207        assert_eq!(decision.fallback_reason, Some("no_better_candidate"));
208    }
209
210    #[test]
211    fn rewrite_engine_ambiguity_resolved_by_same_kind_and_rule() {
212        let engine = HostcallRewriteEngine::new(true);
213        // Same kind AND same rule_id = NOT ambiguous (they're the same plan)
214        let dup = HostcallRewritePlan {
215            kind: HostcallRewritePlanKind::FastOpcodeFusion,
216            estimated_cost: 35,
217            rule_id: "fuse_hash_dispatch_fast_opcode",
218        };
219        let decision = engine.select_plan(BASELINE, &[FAST_FUSION, dup]);
220        assert_eq!(decision.selected, FAST_FUSION);
221        assert!(decision.fallback_reason.is_none());
222    }
223
224    #[test]
225    fn rewrite_engine_accessors() {
226        let enabled = HostcallRewriteEngine::new(true);
227        assert!(enabled.enabled());
228        let disabled = HostcallRewriteEngine::new(false);
229        assert!(!disabled.enabled());
230    }
231
232    #[test]
233    fn plan_kind_variants_distinct() {
234        assert_ne!(
235            HostcallRewritePlanKind::BaselineCanonical,
236            HostcallRewritePlanKind::FastOpcodeFusion
237        );
238    }
239
240    #[test]
241    fn from_env_returns_valid_engine() {
242        // Smoke test: from_env() should not panic regardless of env state
243        let engine = HostcallRewriteEngine::from_env();
244        let _ = engine.enabled();
245    }
246
247    #[test]
248    fn from_opt_disabled_by_known_off_values() {
249        for value in ["0", "false", "off", "disabled", "FALSE", "OFF", "Disabled"] {
250            let engine = HostcallRewriteEngine::from_opt(Some(value));
251            assert!(!engine.enabled(), "should be disabled for '{value}'");
252        }
253    }
254
255    #[test]
256    fn from_opt_enabled_for_other_values_and_none() {
257        // None (env var unset) → enabled
258        assert!(HostcallRewriteEngine::from_opt(None).enabled());
259
260        // Any non-disabled value → enabled
261        for value in ["1", "true", "on", "yes", "anything_else"] {
262            let engine = HostcallRewriteEngine::from_opt(Some(value));
263            assert!(engine.enabled(), "should be enabled for '{value}'");
264        }
265    }
266
267    #[test]
268    fn rewrite_decision_cost_delta_correct_sign() {
269        let engine = HostcallRewriteEngine::new(true);
270        let decision = engine.select_plan(BASELINE, &[FAST_FUSION]);
271        assert!(
272            decision.expected_cost_delta > 0,
273            "positive delta means improvement"
274        );
275        assert_eq!(
276            decision.expected_cost_delta,
277            i64::from(BASELINE.estimated_cost) - i64::from(FAST_FUSION.estimated_cost)
278        );
279    }
280
281    // ── Property tests ──
282
283    mod proptest_rewrite {
284        use super::*;
285        use proptest::prelude::*;
286
287        fn arb_kind() -> impl Strategy<Value = HostcallRewritePlanKind> {
288            prop::sample::select(vec![
289                HostcallRewritePlanKind::BaselineCanonical,
290                HostcallRewritePlanKind::FastOpcodeFusion,
291            ])
292        }
293
294        fn arb_plan() -> impl Strategy<Value = HostcallRewritePlan> {
295            (arb_kind(), 0..1000u32).prop_map(|(kind, cost)| HostcallRewritePlan {
296                kind,
297                estimated_cost: cost,
298                rule_id: "arb_rule",
299            })
300        }
301
302        proptest! {
303            #[test]
304            fn selected_cost_never_exceeds_baseline(
305                baseline in arb_plan(),
306                candidates in prop::collection::vec(arb_plan(), 0..10),
307            ) {
308                let engine = HostcallRewriteEngine::new(true);
309                let decision = engine.select_plan(baseline, &candidates);
310                assert!(
311                    decision.selected.estimated_cost <= baseline.estimated_cost,
312                    "selected cost {} must not exceed baseline {}",
313                    decision.selected.estimated_cost,
314                    baseline.estimated_cost,
315                );
316            }
317
318            #[test]
319            fn cost_delta_is_nonnegative(
320                baseline in arb_plan(),
321                candidates in prop::collection::vec(arb_plan(), 0..10),
322            ) {
323                let engine = HostcallRewriteEngine::new(true);
324                let decision = engine.select_plan(baseline, &candidates);
325                assert!(
326                    decision.expected_cost_delta >= 0,
327                    "cost delta must be non-negative, got {}",
328                    decision.expected_cost_delta,
329                );
330            }
331
332            #[test]
333            fn cost_delta_equals_baseline_minus_selected(
334                baseline in arb_plan(),
335                candidates in prop::collection::vec(arb_plan(), 0..10),
336            ) {
337                let engine = HostcallRewriteEngine::new(true);
338                let decision = engine.select_plan(baseline, &candidates);
339                let expected_delta = i64::from(baseline.estimated_cost)
340                    - i64::from(decision.selected.estimated_cost);
341                assert_eq!(
342                    decision.expected_cost_delta, expected_delta,
343                    "delta must equal baseline - selected"
344                );
345            }
346
347            #[test]
348            fn disabled_engine_always_returns_baseline(
349                baseline in arb_plan(),
350                candidates in prop::collection::vec(arb_plan(), 0..10),
351            ) {
352                let engine = HostcallRewriteEngine::new(false);
353                let decision = engine.select_plan(baseline, &candidates);
354                assert_eq!(decision.selected, baseline);
355                assert_eq!(decision.expected_cost_delta, 0);
356                assert_eq!(decision.fallback_reason, Some("rewrite_disabled"));
357            }
358
359            #[test]
360            fn select_plan_is_deterministic(
361                baseline in arb_plan(),
362                candidates in prop::collection::vec(arb_plan(), 0..10),
363                enabled in any::<bool>(),
364            ) {
365                let engine = HostcallRewriteEngine::new(enabled);
366                let d1 = engine.select_plan(baseline, &candidates);
367                let d2 = engine.select_plan(baseline, &candidates);
368                assert_eq!(d1.selected, d2.selected);
369                assert_eq!(d1.expected_cost_delta, d2.expected_cost_delta);
370                assert_eq!(d1.fallback_reason, d2.fallback_reason);
371            }
372        }
373    }
374}