Skip to main content

tidepool_optimize/
partial.rs

1use std::collections::HashMap;
2use tidepool_eval::{Changed, Pass};
3use tidepool_repr::{Alt, AltCon, CoreExpr, CoreFrame, DataConId, Literal, PrimOpKind, VarId};
4
5/// A value that might be known during partial evaluation.
6#[derive(Debug, Clone, PartialEq, Eq)]
7enum PartialValue {
8    /// The value is statically known.
9    Known(KnownValue),
10    /// The value is only known at runtime.
11    Unknown,
12}
13
14/// A statically known value.
15#[derive(Debug, Clone, PartialEq, Eq)]
16enum KnownValue {
17    /// A literal value.
18    Lit(Literal),
19    /// A data constructor with known fields.
20    Con(DataConId, Vec<KnownValue>),
21}
22
23/// Environment mapping variables to their partial values.
24type PartialEnv = HashMap<VarId, PartialValue>;
25
26/// First-order partial evaluation pass.
27pub struct PartialEval;
28
29impl Pass for PartialEval {
30    fn run(&self, expr: &mut CoreExpr) -> Changed {
31        if expr.nodes.is_empty() {
32            return false;
33        }
34        let mut new_nodes = Vec::new();
35        let (root_idx, _) = partial_eval_at(
36            expr,
37            expr.nodes.len() - 1,
38            &PartialEnv::new(),
39            &mut new_nodes,
40        );
41        let new_expr = CoreExpr { nodes: new_nodes }.extract_subtree(root_idx);
42        if new_expr != *expr {
43            *expr = new_expr;
44            true
45        } else {
46            false
47        }
48    }
49    fn name(&self) -> &str {
50        "PartialEval"
51    }
52}
53
54/// Recursively partially evaluate an expression at a given index.
55fn partial_eval_at(
56    expr: &CoreExpr,
57    idx: usize,
58    env: &PartialEnv,
59    new_nodes: &mut Vec<CoreFrame<usize>>,
60) -> (usize, PartialValue) {
61    match &expr.nodes[idx] {
62        CoreFrame::Var(v) => match env.get(v) {
63            Some(PartialValue::Known(kv)) => {
64                let ni = emit_known(kv, new_nodes);
65                (ni, PartialValue::Known(kv.clone()))
66            }
67            _ => {
68                let ni = new_nodes.len();
69                new_nodes.push(CoreFrame::Var(*v));
70                (ni, PartialValue::Unknown)
71            }
72        },
73        CoreFrame::Lit(lit) => {
74            let ni = new_nodes.len();
75            new_nodes.push(CoreFrame::Lit(lit.clone()));
76            (ni, PartialValue::Known(KnownValue::Lit(lit.clone())))
77        }
78        CoreFrame::Con { tag, fields } => {
79            let mut fi = Vec::new();
80            let mut fv = Vec::new();
81            for &f in fields {
82                let (i, v) = partial_eval_at(expr, f, env, new_nodes);
83                fi.push(i);
84                fv.push(v);
85            }
86            let ni = new_nodes.len();
87            new_nodes.push(CoreFrame::Con {
88                tag: *tag,
89                fields: fi,
90            });
91            let mut known_fields = Vec::new();
92            for v in fv {
93                if let PartialValue::Known(k) = v {
94                    known_fields.push(k);
95                } else {
96                    return (ni, PartialValue::Unknown);
97                }
98            }
99            (ni, PartialValue::Known(KnownValue::Con(*tag, known_fields)))
100        }
101        CoreFrame::LetNonRec { binder, rhs, body } => {
102            let (rhs_i, rhs_v) = partial_eval_at(expr, *rhs, env, new_nodes);
103            let mut new_env = env.clone();
104            new_env.insert(*binder, rhs_v.clone());
105            if matches!(rhs_v, PartialValue::Known(_)) {
106                // Known RHS: evaluate body with known binder, skip the let
107                partial_eval_at(expr, *body, &new_env, new_nodes)
108            } else {
109                let (body_i, body_v) = partial_eval_at(expr, *body, &new_env, new_nodes);
110                let ni = new_nodes.len();
111                new_nodes.push(CoreFrame::LetNonRec {
112                    binder: *binder,
113                    rhs: rhs_i,
114                    body: body_i,
115                });
116                (ni, body_v)
117            }
118        }
119        CoreFrame::LetRec { bindings, body } => {
120            let mut new_env = env.clone();
121            for (b, _) in bindings {
122                new_env.insert(*b, PartialValue::Unknown);
123            }
124            let mut nb = Vec::new();
125            for (b, r) in bindings {
126                let (ri, _) = partial_eval_at(expr, *r, &new_env, new_nodes);
127                nb.push((*b, ri));
128            }
129            let (bi, bv) = partial_eval_at(expr, *body, &new_env, new_nodes);
130            let ni = new_nodes.len();
131            new_nodes.push(CoreFrame::LetRec {
132                bindings: nb,
133                body: bi,
134            });
135            (ni, bv)
136        }
137        CoreFrame::Case {
138            scrutinee,
139            binder,
140            alts,
141        } => {
142            let (si, sv) = partial_eval_at(expr, *scrutinee, env, new_nodes);
143            match &sv {
144                PartialValue::Known(KnownValue::Con(tag, field_vals)) => {
145                    let matched = alts
146                        .iter()
147                        .find(|a| matches!(&a.con, AltCon::DataAlt(t) if t == tag))
148                        .or_else(|| alts.iter().find(|a| matches!(&a.con, AltCon::Default)));
149                    if let Some(alt) = matched {
150                        let mut new_env = env.clone();
151                        new_env.insert(*binder, sv.clone());
152                        if let AltCon::DataAlt(_) = &alt.con {
153                            for (b, fv) in alt.binders.iter().zip(field_vals.iter()) {
154                                new_env.insert(*b, PartialValue::Known(fv.clone()));
155                            }
156                        }
157                        partial_eval_at(expr, alt.body, &new_env, new_nodes)
158                    } else {
159                        emit_residual_case(expr, si, binder, alts, env, new_nodes)
160                    }
161                }
162                PartialValue::Known(KnownValue::Lit(lit)) => {
163                    let matched = alts
164                        .iter()
165                        .find(|a| matches!(&a.con, AltCon::LitAlt(l) if l == lit))
166                        .or_else(|| alts.iter().find(|a| matches!(&a.con, AltCon::Default)));
167                    if let Some(alt) = matched {
168                        let mut new_env = env.clone();
169                        new_env.insert(*binder, sv.clone());
170                        partial_eval_at(expr, alt.body, &new_env, new_nodes)
171                    } else {
172                        emit_residual_case(expr, si, binder, alts, env, new_nodes)
173                    }
174                }
175                PartialValue::Unknown => emit_residual_case(expr, si, binder, alts, env, new_nodes),
176            }
177        }
178        CoreFrame::PrimOp { op, args } => {
179            let mut ai = Vec::new();
180            let mut av = Vec::new();
181            for &a in args {
182                let (i, v) = partial_eval_at(expr, a, env, new_nodes);
183                ai.push(i);
184                av.push(v);
185            }
186            if let Some(result) = try_eval_primop(*op, &av) {
187                let ni = new_nodes.len();
188                new_nodes.push(CoreFrame::Lit(result.clone()));
189                (ni, PartialValue::Known(KnownValue::Lit(result)))
190            } else {
191                let ni = new_nodes.len();
192                new_nodes.push(CoreFrame::PrimOp { op: *op, args: ai });
193                (ni, PartialValue::Unknown)
194            }
195        }
196        CoreFrame::App { fun, arg } => {
197            let (fi, _) = partial_eval_at(expr, *fun, env, new_nodes);
198            let (ai, _) = partial_eval_at(expr, *arg, env, new_nodes);
199            let ni = new_nodes.len();
200            new_nodes.push(CoreFrame::App { fun: fi, arg: ai });
201            (ni, PartialValue::Unknown)
202        }
203        CoreFrame::Lam { binder, body } => {
204            let (bi, _) = partial_eval_at(expr, *body, env, new_nodes);
205            let ni = new_nodes.len();
206            new_nodes.push(CoreFrame::Lam {
207                binder: *binder,
208                body: bi,
209            });
210            (ni, PartialValue::Unknown)
211        }
212        CoreFrame::Join {
213            label,
214            params,
215            rhs,
216            body,
217        } => {
218            let (ri, _) = partial_eval_at(expr, *rhs, env, new_nodes);
219            let (bi, bv) = partial_eval_at(expr, *body, env, new_nodes);
220            let ni = new_nodes.len();
221            new_nodes.push(CoreFrame::Join {
222                label: *label,
223                params: params.clone(),
224                rhs: ri,
225                body: bi,
226            });
227            (ni, bv)
228        }
229        CoreFrame::Jump { label, args } => {
230            let mut ai = Vec::new();
231            for &a in args {
232                let (i, _) = partial_eval_at(expr, a, env, new_nodes);
233                ai.push(i);
234            }
235            let ni = new_nodes.len();
236            new_nodes.push(CoreFrame::Jump {
237                label: *label,
238                args: ai,
239            });
240            (ni, PartialValue::Unknown)
241        }
242    }
243}
244
245/// Emit nodes for a known value into the new nodes vector.
246fn emit_known(kv: &KnownValue, new_nodes: &mut Vec<CoreFrame<usize>>) -> usize {
247    match kv {
248        KnownValue::Lit(lit) => {
249            let ni = new_nodes.len();
250            new_nodes.push(CoreFrame::Lit(lit.clone()));
251            ni
252        }
253        KnownValue::Con(tag, fields) => {
254            let fi: Vec<usize> = fields.iter().map(|k| emit_known(k, new_nodes)).collect();
255            let ni = new_nodes.len();
256            new_nodes.push(CoreFrame::Con {
257                tag: *tag,
258                fields: fi,
259            });
260            ni
261        }
262    }
263}
264
265/// Emit a residual case expression when the scrutinee is unknown.
266fn emit_residual_case(
267    expr: &CoreExpr,
268    scrut_idx: usize,
269    binder: &VarId,
270    alts: &[Alt<usize>],
271    env: &PartialEnv,
272    new_nodes: &mut Vec<CoreFrame<usize>>,
273) -> (usize, PartialValue) {
274    let mut new_env = env.clone();
275    new_env.insert(*binder, PartialValue::Unknown);
276    let mut new_alts = Vec::new();
277    for alt in alts {
278        let mut alt_env = new_env.clone();
279        for b in &alt.binders {
280            alt_env.insert(*b, PartialValue::Unknown);
281        }
282        let (bi, _) = partial_eval_at(expr, alt.body, &alt_env, new_nodes);
283        new_alts.push(Alt {
284            con: alt.con.clone(),
285            binders: alt.binders.clone(),
286            body: bi,
287        });
288    }
289    let ni = new_nodes.len();
290    new_nodes.push(CoreFrame::Case {
291        scrutinee: scrut_idx,
292        binder: *binder,
293        alts: new_alts,
294    });
295    (ni, PartialValue::Unknown)
296}
297
298/// Try to evaluate a primitive operation on partially known arguments.
299fn try_eval_primop(op: PrimOpKind, args: &[PartialValue]) -> Option<Literal> {
300    let lits: Vec<&Literal> = args
301        .iter()
302        .filter_map(|a| match a {
303            PartialValue::Known(KnownValue::Lit(l)) => Some(l),
304            _ => None,
305        })
306        .collect();
307    if lits.len() != args.len() {
308        return None;
309    }
310    match op {
311        PrimOpKind::IntAdd => {
312            if let [Literal::LitInt(a), Literal::LitInt(b)] = &lits[..] {
313                Some(Literal::LitInt(a.wrapping_add(*b)))
314            } else {
315                None
316            }
317        }
318        PrimOpKind::IntSub => {
319            if let [Literal::LitInt(a), Literal::LitInt(b)] = &lits[..] {
320                Some(Literal::LitInt(a.wrapping_sub(*b)))
321            } else {
322                None
323            }
324        }
325        PrimOpKind::IntMul => {
326            if let [Literal::LitInt(a), Literal::LitInt(b)] = &lits[..] {
327                Some(Literal::LitInt(a.wrapping_mul(*b)))
328            } else {
329                None
330            }
331        }
332        PrimOpKind::IntNegate => {
333            if let [Literal::LitInt(a)] = &lits[..] {
334                Some(Literal::LitInt(a.wrapping_neg()))
335            } else {
336                None
337            }
338        }
339        PrimOpKind::IntEq => int_cmp(&lits, |a, b| a == b),
340        PrimOpKind::IntNe => int_cmp(&lits, |a, b| a != b),
341        PrimOpKind::IntLt => int_cmp(&lits, |a, b| a < b),
342        PrimOpKind::IntLe => int_cmp(&lits, |a, b| a <= b),
343        PrimOpKind::IntGt => int_cmp(&lits, |a, b| a > b),
344        PrimOpKind::IntGe => int_cmp(&lits, |a, b| a >= b),
345        _ => None,
346    }
347}
348
349/// Helper for integer comparison primops.
350fn int_cmp(lits: &[&Literal], f: impl Fn(i64, i64) -> bool) -> Option<Literal> {
351    if let [Literal::LitInt(a), Literal::LitInt(b)] = lits {
352        Some(Literal::LitInt(if f(*a, *b) { 1 } else { 0 }))
353    } else {
354        None
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361    use tidepool_eval::env::Env;
362    use tidepool_eval::eval;
363    use tidepool_eval::heap::VecHeap;
364    use tidepool_eval::value::Value;
365    use tidepool_repr::{Alt, AltCon, CoreFrame, DataConId, Literal, PrimOpKind, VarId};
366
367    #[test]
368    fn test_partial_all_known() {
369        // let x = 1 in let y = 2 in PrimOp(IntAdd, [x, y])
370        let nodes = vec![
371            CoreFrame::Lit(Literal::LitInt(1)), // 0
372            CoreFrame::Lit(Literal::LitInt(2)), // 1
373            CoreFrame::Var(VarId(1)),           // 2: x
374            CoreFrame::Var(VarId(2)),           // 3: y
375            CoreFrame::PrimOp {
376                op: PrimOpKind::IntAdd,
377                args: vec![2, 3],
378            }, // 4
379            CoreFrame::LetNonRec {
380                binder: VarId(2),
381                rhs: 1,
382                body: 4,
383            }, // 5
384            CoreFrame::LetNonRec {
385                binder: VarId(1),
386                rhs: 0,
387                body: 5,
388            }, // 6
389        ];
390        let mut expr = CoreExpr { nodes };
391        let pass = PartialEval;
392        pass.run(&mut expr);
393
394        assert_eq!(expr.nodes.len(), 1);
395        assert_eq!(expr.nodes[0], CoreFrame::Lit(Literal::LitInt(3)));
396    }
397
398    #[test]
399    fn test_partial_all_unknown() {
400        // Var(VarId(1))
401        let nodes = vec![CoreFrame::Var(VarId(1))];
402        let mut expr = CoreExpr { nodes };
403        let pass = PartialEval;
404        let changed = pass.run(&mut expr);
405
406        assert!(!changed);
407        assert_eq!(expr.nodes.len(), 1);
408        assert_eq!(expr.nodes[0], CoreFrame::Var(VarId(1)));
409    }
410
411    #[test]
412    fn test_partial_case_known_con() {
413        // let x = Con(1, [Lit(42)]) in case x of w { DataAlt(1) [y] -> y }
414        let nodes = vec![
415            CoreFrame::Lit(Literal::LitInt(42)), // 0
416            CoreFrame::Con {
417                tag: DataConId(1),
418                fields: vec![0],
419            }, // 1
420            CoreFrame::Var(VarId(2)),            // 2: y
421            CoreFrame::Case {
422                scrutinee: 1,
423                binder: VarId(3),
424                alts: vec![Alt {
425                    con: AltCon::DataAlt(DataConId(1)),
426                    binders: vec![VarId(2)],
427                    body: 2,
428                }],
429            }, // 3
430            CoreFrame::LetNonRec {
431                binder: VarId(1),
432                rhs: 1,
433                body: 3,
434            }, // 4
435        ];
436        let mut expr = CoreExpr { nodes };
437        let pass = PartialEval;
438        pass.run(&mut expr);
439
440        assert_eq!(expr.nodes.len(), 1);
441        assert_eq!(expr.nodes[0], CoreFrame::Lit(Literal::LitInt(42)));
442    }
443
444    #[test]
445    fn test_partial_unknown_scrutinee() {
446        // case Var(x) of w { Default -> Lit(42) }
447        let nodes = vec![
448            CoreFrame::Var(VarId(1)),            // 0
449            CoreFrame::Lit(Literal::LitInt(42)), // 1
450            CoreFrame::Case {
451                scrutinee: 0,
452                binder: VarId(2),
453                alts: vec![Alt {
454                    con: AltCon::Default,
455                    binders: vec![],
456                    body: 1,
457                }],
458            }, // 2
459        ];
460        let mut expr = CoreExpr { nodes };
461        let pass = PartialEval;
462        let changed = pass.run(&mut expr);
463
464        // It might "change" by rebuilding the nodes but semantically it's residual
465        // Actually our run implementation returns true if new_expr != *expr.
466        // Let's check the structure.
467        if changed {
468            assert!(matches!(expr.nodes.last().unwrap(), CoreFrame::Case { .. }));
469        }
470    }
471
472    #[test]
473    fn test_partial_primop_fold() {
474        // PrimOp(IntAdd, [Lit(1), Lit(2)])
475        let nodes = vec![
476            CoreFrame::Lit(Literal::LitInt(1)), // 0
477            CoreFrame::Lit(Literal::LitInt(2)), // 1
478            CoreFrame::PrimOp {
479                op: PrimOpKind::IntAdd,
480                args: vec![0, 1],
481            }, // 2
482        ];
483        let mut expr = CoreExpr { nodes };
484        let pass = PartialEval;
485        pass.run(&mut expr);
486
487        assert_eq!(expr.nodes.len(), 1);
488        assert_eq!(expr.nodes[0], CoreFrame::Lit(Literal::LitInt(3)));
489    }
490
491    #[test]
492    fn test_partial_primop_unknown_arg() {
493        // PrimOp(IntAdd, [Lit(1), Var(x)])
494        let nodes = vec![
495            CoreFrame::Lit(Literal::LitInt(1)), // 0
496            CoreFrame::Var(VarId(1)),           // 1
497            CoreFrame::PrimOp {
498                op: PrimOpKind::IntAdd,
499                args: vec![0, 1],
500            }, // 2
501        ];
502        let mut expr = CoreExpr { nodes };
503        let pass = PartialEval;
504        pass.run(&mut expr);
505
506        assert!(matches!(
507            expr.nodes.last().unwrap(),
508            CoreFrame::PrimOp {
509                op: PrimOpKind::IntAdd,
510                ..
511            }
512        ));
513    }
514
515    #[test]
516    fn test_partial_preserves_eval() {
517        // let x = 10 in let y = 20 in PrimOp(IntAdd, [x, y])
518        let nodes = vec![
519            CoreFrame::Lit(Literal::LitInt(10)), // 0
520            CoreFrame::Lit(Literal::LitInt(20)), // 1
521            CoreFrame::Var(VarId(1)),            // 2
522            CoreFrame::Var(VarId(2)),            // 3
523            CoreFrame::PrimOp {
524                op: PrimOpKind::IntAdd,
525                args: vec![2, 3],
526            }, // 4
527            CoreFrame::LetNonRec {
528                binder: VarId(2),
529                rhs: 1,
530                body: 4,
531            }, // 5
532            CoreFrame::LetNonRec {
533                binder: VarId(1),
534                rhs: 0,
535                body: 5,
536            }, // 6
537        ];
538        let mut expr = CoreExpr { nodes };
539
540        let mut heap_before = VecHeap::new();
541        let val_before = eval(&expr, &Env::new(), &mut heap_before).unwrap();
542
543        let pass = PartialEval;
544        pass.run(&mut expr);
545
546        let mut heap_after = VecHeap::new();
547        let val_after = eval(&expr, &Env::new(), &mut heap_after).unwrap();
548
549        if let (Value::Lit(Literal::LitInt(n1)), Value::Lit(Literal::LitInt(n2))) =
550            (val_before, val_after)
551        {
552            assert_eq!(n1, 30);
553            assert_eq!(n2, 30);
554        } else {
555            panic!("Expected LitInt(30)");
556        }
557    }
558
559    #[test]
560    fn test_partial_nested_let() {
561        // let x = 1 in let y = PrimOp(IntAdd, [x, Lit(2)]) in PrimOp(IntAdd, [y, Lit(3)])
562        let nodes = vec![
563            CoreFrame::Lit(Literal::LitInt(1)), // 0
564            CoreFrame::Var(VarId(1)),           // 1: x
565            CoreFrame::Lit(Literal::LitInt(2)), // 2
566            CoreFrame::PrimOp {
567                op: PrimOpKind::IntAdd,
568                args: vec![1, 2],
569            }, // 3: x + 2
570            CoreFrame::Var(VarId(2)),           // 4: y
571            CoreFrame::Lit(Literal::LitInt(3)), // 5
572            CoreFrame::PrimOp {
573                op: PrimOpKind::IntAdd,
574                args: vec![4, 5],
575            }, // 6: y + 3
576            CoreFrame::LetNonRec {
577                binder: VarId(2),
578                rhs: 3,
579                body: 6,
580            }, // 7
581            CoreFrame::LetNonRec {
582                binder: VarId(1),
583                rhs: 0,
584                body: 7,
585            }, // 8
586        ];
587        let mut expr = CoreExpr { nodes };
588        let pass = PartialEval;
589        pass.run(&mut expr);
590
591        assert_eq!(expr.nodes.len(), 1);
592        assert_eq!(expr.nodes[0], CoreFrame::Lit(Literal::LitInt(6)));
593    }
594}