Skip to main content

normalize_surface_syntax/ir/
structure_eq.rs

1//! Structural equality for IR types.
2//!
3//! `structure_eq` compares IR trees ignoring "surface hints" - fields that
4//! capture language-specific details but don't affect program semantics.
5//!
6//! # Hint Fields (normalized during comparison)
7//!
8//! - `Stmt::Let { mutable }` - Lua doesn't distinguish const/let
9//! - `Expr::Member { computed }` - normalized to false when property is string literal
10//!
11//! # Core Fields (must match exactly)
12//!
13//! - All names, values, operators
14//! - Control flow structure
15//! - Expression trees
16
17use super::{Expr, Function, Method, Pat, PatField, Program, Stmt, TemplatePart};
18
19/// Trait for structural equality comparison.
20///
21/// Unlike `PartialEq`, this ignores surface hint fields that may differ
22/// between languages but don't affect program semantics.
23pub trait StructureEq {
24    /// Compare two values for structural equality.
25    fn structure_eq(&self, other: &Self) -> bool;
26}
27
28impl StructureEq for Program {
29    fn structure_eq(&self, other: &Self) -> bool {
30        self.body.len() == other.body.len()
31            && self
32                .body
33                .iter()
34                .zip(&other.body)
35                .all(|(a, b)| a.structure_eq(b))
36    }
37}
38
39impl StructureEq for Stmt {
40    fn structure_eq(&self, other: &Self) -> bool {
41        match (self, other) {
42            (Stmt::Expr(a), Stmt::Expr(b)) => a.structure_eq(b),
43
44            // Ignore `mutable`, `type_annotation`, and `span` - they are surface hints
45            (
46                Stmt::Let {
47                    name: n1,
48                    init: i1,
49                    mutable: _,
50                    type_annotation: _,
51                    span: _,
52                },
53                Stmt::Let {
54                    name: n2,
55                    init: i2,
56                    mutable: _,
57                    type_annotation: _,
58                    span: _,
59                },
60            ) => n1 == n2 && option_structure_eq(i1.as_ref(), i2.as_ref()),
61
62            // Destructure: ignore span; mutable is a surface hint
63            (
64                Stmt::Destructure {
65                    pat: p1,
66                    value: v1,
67                    mutable: _,
68                    span: _,
69                },
70                Stmt::Destructure {
71                    pat: p2,
72                    value: v2,
73                    mutable: _,
74                    span: _,
75                },
76            ) => p1.structure_eq(p2) && v1.structure_eq(v2),
77
78            (Stmt::Block(a), Stmt::Block(b)) => vec_structure_eq(a, b),
79
80            (
81                Stmt::If {
82                    test: t1,
83                    consequent: c1,
84                    alternate: a1,
85                    span: _,
86                },
87                Stmt::If {
88                    test: t2,
89                    consequent: c2,
90                    alternate: a2,
91                    span: _,
92                },
93            ) => {
94                t1.structure_eq(t2)
95                    && c1.structure_eq(c2.as_ref())
96                    && match (a1, a2) {
97                        (None, None) => true,
98                        (Some(a), Some(b)) => a.structure_eq(b.as_ref()),
99                        _ => false,
100                    }
101            }
102
103            (
104                Stmt::While {
105                    test: t1,
106                    body: b1,
107                    span: _,
108                },
109                Stmt::While {
110                    test: t2,
111                    body: b2,
112                    span: _,
113                },
114            ) => t1.structure_eq(t2) && b1.structure_eq(b2.as_ref()),
115
116            (
117                Stmt::For {
118                    init: i1,
119                    test: t1,
120                    update: u1,
121                    body: b1,
122                    span: _,
123                },
124                Stmt::For {
125                    init: i2,
126                    test: t2,
127                    update: u2,
128                    body: b2,
129                    span: _,
130                },
131            ) => {
132                (match (i1, i2) {
133                    (None, None) => true,
134                    (Some(a), Some(b)) => a.structure_eq(b.as_ref()),
135                    _ => false,
136                }) && option_structure_eq(t1.as_ref(), t2.as_ref())
137                    && option_structure_eq(u1.as_ref(), u2.as_ref())
138                    && b1.structure_eq(b2.as_ref())
139            }
140
141            (
142                Stmt::ForIn {
143                    variable: v1,
144                    iterable: i1,
145                    body: b1,
146                    span: _,
147                },
148                Stmt::ForIn {
149                    variable: v2,
150                    iterable: i2,
151                    body: b2,
152                    span: _,
153                },
154            ) => v1 == v2 && i1.structure_eq(i2) && b1.structure_eq(b2.as_ref()),
155
156            (Stmt::Return(a), Stmt::Return(b)) => option_structure_eq(a.as_ref(), b.as_ref()),
157
158            (Stmt::Break, Stmt::Break) => true,
159            (Stmt::Continue, Stmt::Continue) => true,
160
161            (
162                Stmt::TryCatch {
163                    body: b1,
164                    catch_param: cp1,
165                    catch_body: cb1,
166                    finally_body: fb1,
167                    span: _,
168                },
169                Stmt::TryCatch {
170                    body: b2,
171                    catch_param: cp2,
172                    catch_body: cb2,
173                    finally_body: fb2,
174                    span: _,
175                },
176            ) => {
177                b1.structure_eq(b2.as_ref())
178                    && cp1 == cp2
179                    && match (cb1, cb2) {
180                        (None, None) => true,
181                        (Some(a), Some(b)) => a.structure_eq(b.as_ref()),
182                        _ => false,
183                    }
184                    && match (fb1, fb2) {
185                        (None, None) => true,
186                        (Some(a), Some(b)) => a.structure_eq(b.as_ref()),
187                        _ => false,
188                    }
189            }
190
191            (Stmt::Function(a), Stmt::Function(b)) => a.structure_eq(b),
192
193            // Comments: compare text and block flag; ignore span
194            (
195                Stmt::Comment {
196                    text: t1,
197                    block: b1,
198                    span: _,
199                },
200                Stmt::Comment {
201                    text: t2,
202                    block: b2,
203                    span: _,
204                },
205            ) => t1 == t2 && b1 == b2,
206
207            // Import: compare source and names (ignore span)
208            (
209                Stmt::Import {
210                    source: s1,
211                    names: n1,
212                    span: _,
213                },
214                Stmt::Import {
215                    source: s2,
216                    names: n2,
217                    span: _,
218                },
219            ) => s1 == s2 && n1 == n2,
220
221            // Export: compare names and source (ignore span)
222            (
223                Stmt::Export {
224                    names: n1,
225                    source: s1,
226                    span: _,
227                },
228                Stmt::Export {
229                    names: n2,
230                    source: s2,
231                    span: _,
232                },
233            ) => n1 == n2 && s1 == s2,
234
235            // Class: compare name, extends, and methods (ignore span)
236            (
237                Stmt::Class {
238                    name: n1,
239                    extends: e1,
240                    methods: m1,
241                    span: _,
242                },
243                Stmt::Class {
244                    name: n2,
245                    extends: e2,
246                    methods: m2,
247                    span: _,
248                },
249            ) => {
250                n1 == n2
251                    && e1 == e2
252                    && m1.len() == m2.len()
253                    && m1.iter().zip(m2).all(|(a, b)| a.structure_eq(b))
254            }
255
256            _ => false,
257        }
258    }
259}
260
261impl StructureEq for Pat {
262    fn structure_eq(&self, other: &Self) -> bool {
263        match (self, other) {
264            (Pat::Ident(a), Pat::Ident(b)) => a == b,
265            (Pat::Object(a), Pat::Object(b)) => {
266                a.len() == b.len() && a.iter().zip(b).all(|(x, y)| x.structure_eq(y))
267            }
268            (Pat::Array(a_elems, a_rest), Pat::Array(b_elems, b_rest)) => {
269                a_rest == b_rest
270                    && a_elems.len() == b_elems.len()
271                    && a_elems.iter().zip(b_elems).all(|(x, y)| match (x, y) {
272                        (None, None) => true,
273                        (Some(p), Some(q)) => p.structure_eq(q),
274                        _ => false,
275                    })
276            }
277            (Pat::Rest(a), Pat::Rest(b)) => a.structure_eq(b),
278            _ => false,
279        }
280    }
281}
282
283impl StructureEq for PatField {
284    fn structure_eq(&self, other: &Self) -> bool {
285        self.key == other.key
286            && self.pat.structure_eq(&other.pat)
287            && option_structure_eq(self.default.as_ref(), other.default.as_ref())
288    }
289}
290
291impl StructureEq for Expr {
292    fn structure_eq(&self, other: &Self) -> bool {
293        match (self, other) {
294            (Expr::Literal(a), Expr::Literal(b)) => a == b,
295            (Expr::Ident(a), Expr::Ident(b)) => a == b,
296
297            (
298                Expr::Binary {
299                    left: l1,
300                    op: o1,
301                    right: r1,
302                    span: _,
303                },
304                Expr::Binary {
305                    left: l2,
306                    op: o2,
307                    right: r2,
308                    span: _,
309                },
310            ) => o1 == o2 && l1.structure_eq(l2) && r1.structure_eq(r2),
311
312            (
313                Expr::Unary {
314                    op: o1,
315                    expr: e1,
316                    span: _,
317                },
318                Expr::Unary {
319                    op: o2,
320                    expr: e2,
321                    span: _,
322                },
323            ) => o1 == o2 && e1.structure_eq(e2),
324
325            (
326                Expr::Call {
327                    callee: c1,
328                    args: a1,
329                    span: _,
330                },
331                Expr::Call {
332                    callee: c2,
333                    args: a2,
334                    span: _,
335                },
336            ) => c1.structure_eq(c2) && vec_structure_eq(a1, a2),
337
338            // Normalize `computed` when property is a string literal; ignore `span`
339            (
340                Expr::Member {
341                    object: o1,
342                    property: p1,
343                    computed: _,
344                    span: _,
345                },
346                Expr::Member {
347                    object: o2,
348                    property: p2,
349                    computed: _,
350                    span: _,
351                },
352            ) => o1.structure_eq(o2) && p1.structure_eq(p2),
353
354            (Expr::Array(a), Expr::Array(b)) => vec_structure_eq(a, b),
355
356            (Expr::Object(a), Expr::Object(b)) => {
357                a.len() == b.len()
358                    && a.iter()
359                        .zip(b)
360                        .all(|((k1, v1), (k2, v2))| k1 == k2 && v1.structure_eq(v2))
361            }
362
363            (Expr::Function(a), Expr::Function(b)) => a.structure_eq(b),
364
365            (
366                Expr::Conditional {
367                    test: t1,
368                    consequent: c1,
369                    alternate: a1,
370                    span: _,
371                },
372                Expr::Conditional {
373                    test: t2,
374                    consequent: c2,
375                    alternate: a2,
376                    span: _,
377                },
378            ) => t1.structure_eq(t2) && c1.structure_eq(c2) && a1.structure_eq(a2),
379
380            (
381                Expr::Assign {
382                    target: t1,
383                    value: v1,
384                    span: _,
385                },
386                Expr::Assign {
387                    target: t2,
388                    value: v2,
389                    span: _,
390                },
391            ) => t1.structure_eq(t2) && v1.structure_eq(v2),
392
393            (Expr::TemplateLiteral(a), Expr::TemplateLiteral(b)) => {
394                a.len() == b.len() && a.iter().zip(b).all(|(x, y)| x.structure_eq(y))
395            }
396
397            _ => false,
398        }
399    }
400}
401
402impl StructureEq for TemplatePart {
403    fn structure_eq(&self, other: &Self) -> bool {
404        match (self, other) {
405            (TemplatePart::Text(a), TemplatePart::Text(b)) => a == b,
406            (TemplatePart::Expr(a), TemplatePart::Expr(b)) => a.structure_eq(b),
407            _ => false,
408        }
409    }
410}
411
412impl StructureEq for Function {
413    fn structure_eq(&self, other: &Self) -> bool {
414        self.name == other.name
415            && self.params.len() == other.params.len()
416            && self
417                .params
418                .iter()
419                .zip(&other.params)
420                .all(|(a, b)| a.name == b.name)
421            && vec_structure_eq(&self.body, &other.body)
422    }
423}
424
425impl StructureEq for Method {
426    fn structure_eq(&self, other: &Self) -> bool {
427        self.name == other.name
428            && self.is_static == other.is_static
429            && self.params.len() == other.params.len()
430            && self
431                .params
432                .iter()
433                .zip(&other.params)
434                .all(|(a, b)| a.name == b.name)
435            && vec_structure_eq(&self.body, &other.body)
436    }
437}
438
439// Helper functions
440
441fn vec_structure_eq<T: StructureEq>(a: &[T], b: &[T]) -> bool {
442    a.len() == b.len() && a.iter().zip(b).all(|(x, y)| x.structure_eq(y))
443}
444
445fn option_structure_eq<T: StructureEq>(a: Option<&T>, b: Option<&T>) -> bool {
446    match (a, b) {
447        (None, None) => true,
448        (Some(x), Some(y)) => x.structure_eq(y),
449        _ => false,
450    }
451}
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456
457    #[test]
458    fn test_mutable_is_ignored() {
459        let const_decl = Stmt::Let {
460            name: "x".into(),
461            init: Some(Expr::number(42)),
462            mutable: false,
463            type_annotation: None,
464            span: None,
465        };
466        let let_decl = Stmt::Let {
467            name: "x".into(),
468            init: Some(Expr::number(42)),
469            mutable: true,
470            type_annotation: None,
471            span: None,
472        };
473
474        assert!(const_decl.structure_eq(&let_decl));
475        assert_ne!(const_decl, let_decl); // Regular equality still differs
476    }
477
478    #[test]
479    fn test_computed_is_ignored() {
480        let dot_access = Expr::Member {
481            object: Box::new(Expr::ident("obj")),
482            property: Box::new(Expr::string("foo")),
483            computed: false,
484            span: None,
485        };
486        let bracket_access = Expr::Member {
487            object: Box::new(Expr::ident("obj")),
488            property: Box::new(Expr::string("foo")),
489            computed: true,
490            span: None,
491        };
492
493        assert!(dot_access.structure_eq(&bracket_access));
494        assert_ne!(dot_access, bracket_access); // Regular equality still differs
495    }
496
497    #[test]
498    fn test_different_names_not_equal() {
499        let x = Stmt::Let {
500            name: "x".into(),
501            init: Some(Expr::number(1)),
502            mutable: false,
503            type_annotation: None,
504            span: None,
505        };
506        let y = Stmt::Let {
507            name: "y".into(),
508            init: Some(Expr::number(1)),
509            mutable: false,
510            type_annotation: None,
511            span: None,
512        };
513
514        assert!(!x.structure_eq(&y));
515    }
516
517    #[test]
518    fn test_program_equality() {
519        let p1 = Program {
520            body: vec![Stmt::const_decl("x", Expr::number(1))],
521        };
522        let p2 = Program {
523            body: vec![Stmt::let_decl("x", Some(Expr::number(1)))],
524        };
525
526        assert!(p1.structure_eq(&p2));
527    }
528}