Skip to main content

normalize_surface_syntax/ir/
structure_eq.rs

1//! Structural equality for IR types.
2//!
3//! `structure_eq` compares IR trees ignoring "surface hints" - fields that
4//! capture language-specific details but don't affect program semantics.
5//!
6//! # Hint Fields (normalized during comparison)
7//!
8//! - `Stmt::Let { mutable }` - Lua doesn't distinguish const/let
9//! - `Expr::Member { computed }` - normalized to false when property is string literal
10//!
11//! # Core Fields (must match exactly)
12//!
13//! - All names, values, operators
14//! - Control flow structure
15//! - Expression trees
16
17use crate::{Expr, Function, Program, Stmt};
18
19/// Trait for structural equality comparison.
20///
21/// Unlike `PartialEq`, this ignores surface hint fields that may differ
22/// between languages but don't affect program semantics.
23pub trait StructureEq {
24    /// Compare two values for structural equality.
25    fn structure_eq(&self, other: &Self) -> bool;
26}
27
28impl StructureEq for Program {
29    fn structure_eq(&self, other: &Self) -> bool {
30        self.body.len() == other.body.len()
31            && self
32                .body
33                .iter()
34                .zip(&other.body)
35                .all(|(a, b)| a.structure_eq(b))
36    }
37}
38
39impl StructureEq for Stmt {
40    fn structure_eq(&self, other: &Self) -> bool {
41        match (self, other) {
42            (Stmt::Expr(a), Stmt::Expr(b)) => a.structure_eq(b),
43
44            // Ignore `mutable` - it's a surface hint
45            (
46                Stmt::Let {
47                    name: n1,
48                    init: i1,
49                    mutable: _,
50                },
51                Stmt::Let {
52                    name: n2,
53                    init: i2,
54                    mutable: _,
55                },
56            ) => n1 == n2 && option_structure_eq(i1.as_ref(), i2.as_ref()),
57
58            (Stmt::Block(a), Stmt::Block(b)) => vec_structure_eq(a, b),
59
60            (
61                Stmt::If {
62                    test: t1,
63                    consequent: c1,
64                    alternate: a1,
65                },
66                Stmt::If {
67                    test: t2,
68                    consequent: c2,
69                    alternate: a2,
70                },
71            ) => {
72                t1.structure_eq(t2)
73                    && c1.structure_eq(c2.as_ref())
74                    && match (a1, a2) {
75                        (None, None) => true,
76                        (Some(a), Some(b)) => a.structure_eq(b.as_ref()),
77                        _ => false,
78                    }
79            }
80
81            (Stmt::While { test: t1, body: b1 }, Stmt::While { test: t2, body: b2 }) => {
82                t1.structure_eq(t2) && b1.structure_eq(b2.as_ref())
83            }
84
85            (
86                Stmt::For {
87                    init: i1,
88                    test: t1,
89                    update: u1,
90                    body: b1,
91                },
92                Stmt::For {
93                    init: i2,
94                    test: t2,
95                    update: u2,
96                    body: b2,
97                },
98            ) => {
99                (match (i1, i2) {
100                    (None, None) => true,
101                    (Some(a), Some(b)) => a.structure_eq(b.as_ref()),
102                    _ => false,
103                }) && option_structure_eq(t1.as_ref(), t2.as_ref())
104                    && option_structure_eq(u1.as_ref(), u2.as_ref())
105                    && b1.structure_eq(b2.as_ref())
106            }
107
108            (
109                Stmt::ForIn {
110                    variable: v1,
111                    iterable: i1,
112                    body: b1,
113                },
114                Stmt::ForIn {
115                    variable: v2,
116                    iterable: i2,
117                    body: b2,
118                },
119            ) => v1 == v2 && i1.structure_eq(i2) && b1.structure_eq(b2.as_ref()),
120
121            (Stmt::Return(a), Stmt::Return(b)) => option_structure_eq(a.as_ref(), b.as_ref()),
122
123            (Stmt::Break, Stmt::Break) => true,
124            (Stmt::Continue, Stmt::Continue) => true,
125
126            (
127                Stmt::TryCatch {
128                    body: b1,
129                    catch_param: cp1,
130                    catch_body: cb1,
131                    finally_body: fb1,
132                },
133                Stmt::TryCatch {
134                    body: b2,
135                    catch_param: cp2,
136                    catch_body: cb2,
137                    finally_body: fb2,
138                },
139            ) => {
140                b1.structure_eq(b2.as_ref())
141                    && cp1 == cp2
142                    && match (cb1, cb2) {
143                        (None, None) => true,
144                        (Some(a), Some(b)) => a.structure_eq(b.as_ref()),
145                        _ => false,
146                    }
147                    && match (fb1, fb2) {
148                        (None, None) => true,
149                        (Some(a), Some(b)) => a.structure_eq(b.as_ref()),
150                        _ => false,
151                    }
152            }
153
154            (Stmt::Function(a), Stmt::Function(b)) => a.structure_eq(b),
155
156            _ => false,
157        }
158    }
159}
160
161impl StructureEq for Expr {
162    fn structure_eq(&self, other: &Self) -> bool {
163        match (self, other) {
164            (Expr::Literal(a), Expr::Literal(b)) => a == b,
165            (Expr::Ident(a), Expr::Ident(b)) => a == b,
166
167            (
168                Expr::Binary {
169                    left: l1,
170                    op: o1,
171                    right: r1,
172                },
173                Expr::Binary {
174                    left: l2,
175                    op: o2,
176                    right: r2,
177                },
178            ) => o1 == o2 && l1.structure_eq(l2) && r1.structure_eq(r2),
179
180            (Expr::Unary { op: o1, expr: e1 }, Expr::Unary { op: o2, expr: e2 }) => {
181                o1 == o2 && e1.structure_eq(e2)
182            }
183
184            (
185                Expr::Call {
186                    callee: c1,
187                    args: a1,
188                },
189                Expr::Call {
190                    callee: c2,
191                    args: a2,
192                },
193            ) => c1.structure_eq(c2) && vec_structure_eq(a1, a2),
194
195            // Normalize `computed` when property is a string literal
196            (
197                Expr::Member {
198                    object: o1,
199                    property: p1,
200                    computed: _,
201                },
202                Expr::Member {
203                    object: o2,
204                    property: p2,
205                    computed: _,
206                },
207            ) => o1.structure_eq(o2) && p1.structure_eq(p2),
208
209            (Expr::Array(a), Expr::Array(b)) => vec_structure_eq(a, b),
210
211            (Expr::Object(a), Expr::Object(b)) => {
212                a.len() == b.len()
213                    && a.iter()
214                        .zip(b)
215                        .all(|((k1, v1), (k2, v2))| k1 == k2 && v1.structure_eq(v2))
216            }
217
218            (Expr::Function(a), Expr::Function(b)) => a.structure_eq(b),
219
220            (
221                Expr::Conditional {
222                    test: t1,
223                    consequent: c1,
224                    alternate: a1,
225                },
226                Expr::Conditional {
227                    test: t2,
228                    consequent: c2,
229                    alternate: a2,
230                },
231            ) => t1.structure_eq(t2) && c1.structure_eq(c2) && a1.structure_eq(a2),
232
233            (
234                Expr::Assign {
235                    target: t1,
236                    value: v1,
237                },
238                Expr::Assign {
239                    target: t2,
240                    value: v2,
241                },
242            ) => t1.structure_eq(t2) && v1.structure_eq(v2),
243
244            _ => false,
245        }
246    }
247}
248
249impl StructureEq for Function {
250    fn structure_eq(&self, other: &Self) -> bool {
251        self.name == other.name
252            && self.params == other.params
253            && vec_structure_eq(&self.body, &other.body)
254    }
255}
256
257// Helper functions
258
259fn vec_structure_eq<T: StructureEq>(a: &[T], b: &[T]) -> bool {
260    a.len() == b.len() && a.iter().zip(b).all(|(x, y)| x.structure_eq(y))
261}
262
263fn option_structure_eq<T: StructureEq>(a: Option<&T>, b: Option<&T>) -> bool {
264    match (a, b) {
265        (None, None) => true,
266        (Some(x), Some(y)) => x.structure_eq(y),
267        _ => false,
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn test_mutable_is_ignored() {
277        let const_decl = Stmt::Let {
278            name: "x".into(),
279            init: Some(Expr::number(42)),
280            mutable: false,
281        };
282        let let_decl = Stmt::Let {
283            name: "x".into(),
284            init: Some(Expr::number(42)),
285            mutable: true,
286        };
287
288        assert!(const_decl.structure_eq(&let_decl));
289        assert_ne!(const_decl, let_decl); // Regular equality still differs
290    }
291
292    #[test]
293    fn test_computed_is_ignored() {
294        let dot_access = Expr::Member {
295            object: Box::new(Expr::ident("obj")),
296            property: Box::new(Expr::string("foo")),
297            computed: false,
298        };
299        let bracket_access = Expr::Member {
300            object: Box::new(Expr::ident("obj")),
301            property: Box::new(Expr::string("foo")),
302            computed: true,
303        };
304
305        assert!(dot_access.structure_eq(&bracket_access));
306        assert_ne!(dot_access, bracket_access); // Regular equality still differs
307    }
308
309    #[test]
310    fn test_different_names_not_equal() {
311        let x = Stmt::Let {
312            name: "x".into(),
313            init: Some(Expr::number(1)),
314            mutable: false,
315        };
316        let y = Stmt::Let {
317            name: "y".into(),
318            init: Some(Expr::number(1)),
319            mutable: false,
320        };
321
322        assert!(!x.structure_eq(&y));
323    }
324
325    #[test]
326    fn test_program_equality() {
327        let p1 = Program {
328            body: vec![Stmt::const_decl("x", Expr::number(1))],
329        };
330        let p2 = Program {
331            body: vec![Stmt::let_decl("x", Some(Expr::number(1)))],
332        };
333
334        assert!(p1.structure_eq(&p2));
335    }
336}