Skip to main content

padlock_source/frontends/
rust.rs

1// padlock-source/src/frontends/rust.rs
2//
3// Extracts struct layouts from Rust source using syn + the Visit API.
4// Sizes are approximated from type names using the target arch config.
5// Only repr(C) / repr(packed) / plain structs are handled; generics are opaque.
6
7use padlock_core::arch::ArchConfig;
8use padlock_core::ir::{AccessPattern, Field, StructLayout, TypeInfo};
9use quote::ToTokens;
10use syn::{Fields, ItemEnum, ItemStruct, Type, visit::Visit};
11
12// ── attribute guard extraction ────────────────────────────────────────────────
13
14/// Extract a lock guard name from field attributes.
15///
16/// Recognised forms:
17/// - `#[lock_protected_by = "mu"]`
18/// - `#[protected_by = "mu"]`
19/// - `#[guarded_by("mu")]` or `#[guarded_by(mu)]`
20/// - `#[pt_guarded_by("mu")]` or `#[pt_guarded_by(mu)]` (pointer variant)
21pub fn extract_guard_from_attrs(attrs: &[syn::Attribute]) -> Option<String> {
22    for attr in attrs {
23        let path = attr.path();
24        // Name-value form: #[lock_protected_by = "mu"] / #[protected_by = "mu"]
25        if (path.is_ident("lock_protected_by") || path.is_ident("protected_by"))
26            && let syn::Meta::NameValue(nv) = &attr.meta
27            && let syn::Expr::Lit(syn::ExprLit {
28                lit: syn::Lit::Str(s),
29                ..
30            }) = &nv.value
31        {
32            return Some(s.value());
33        }
34        // List form: #[guarded_by("mu")] / #[guarded_by(mu)] / #[pt_guarded_by(...)]
35        if path.is_ident("guarded_by") || path.is_ident("pt_guarded_by") {
36            // Try string literal first
37            if let Ok(s) = attr.parse_args::<syn::LitStr>() {
38                return Some(s.value());
39            }
40            // Fall back to bare identifier
41            if let Ok(id) = attr.parse_args::<syn::Ident>() {
42                return Some(id.to_string());
43            }
44        }
45    }
46    None
47}
48
49// ── type resolution ───────────────────────────────────────────────────────────
50
51fn rust_type_size_align(ty: &Type, arch: &'static ArchConfig) -> (usize, usize, TypeInfo) {
52    match ty {
53        Type::Path(tp) => {
54            let name = tp
55                .path
56                .segments
57                .last()
58                .map(|s| s.ident.to_string())
59                .unwrap_or_default();
60            let (size, align) = primitive_size_align(&name, arch);
61            (size, align, TypeInfo::Primitive { name, size, align })
62        }
63        Type::Ptr(_) | Type::Reference(_) => {
64            let s = arch.pointer_size;
65            (s, s, TypeInfo::Pointer { size: s, align: s })
66        }
67        Type::Array(arr) => {
68            let (elem_size, elem_align, elem_ty) = rust_type_size_align(&arr.elem, arch);
69            let count = array_len_from_expr(&arr.len);
70            let size = elem_size * count;
71            (
72                size,
73                elem_align,
74                TypeInfo::Array {
75                    element: Box::new(elem_ty),
76                    count,
77                    size,
78                    align: elem_align,
79                },
80            )
81        }
82        _ => {
83            let s = arch.pointer_size;
84            (
85                s,
86                s,
87                TypeInfo::Opaque {
88                    name: "(unknown)".into(),
89                    size: s,
90                    align: s,
91                },
92            )
93        }
94    }
95}
96
97fn primitive_size_align(name: &str, arch: &'static ArchConfig) -> (usize, usize) {
98    let ps = arch.pointer_size;
99    match name {
100        // ── language primitives ───────────────────────────────────────────────
101        "bool" | "u8" | "i8" => (1, 1),
102        "u16" | "i16" => (2, 2),
103        "u32" | "i32" | "f32" => (4, 4),
104        "u64" | "i64" | "f64" => (8, 8),
105        "u128" | "i128" => (16, 16),
106        "usize" | "isize" => (ps, ps),
107        "char" => (4, 4), // Rust char is a Unicode scalar (4 bytes)
108
109        // ── std atomics ───────────────────────────────────────────────────────
110        "AtomicBool" | "AtomicU8" | "AtomicI8" => (1, 1),
111        "AtomicU16" | "AtomicI16" => (2, 2),
112        "AtomicU32" | "AtomicI32" => (4, 4),
113        "AtomicU64" | "AtomicI64" => (8, 8),
114        "AtomicUsize" | "AtomicIsize" | "AtomicPtr" => (ps, ps),
115
116        // ── heap-allocated collections: ptr + len + cap (3 words) ────────────
117        // Size is independent of the element type T (generic arg already stripped).
118        "Vec" | "String" | "OsString" | "CString" | "PathBuf" => (3 * ps, ps),
119        "VecDeque" | "LinkedList" | "BinaryHeap" => (3 * ps, ps),
120        "HashMap" | "HashSet" | "BTreeMap" | "BTreeSet" => (3 * ps, ps),
121
122        // ── single-pointer smart pointers ─────────────────────────────────────
123        "Box" | "Rc" | "Arc" | "Weak" | "NonNull" | "Cell" => (ps, ps),
124
125        // ── interior-mutability / sync wrappers ───────────────────────────────
126        // Size depends on T but pointer-size is a reasonable approximation for
127        // display purposes; use binary analysis for precise results.
128        "RefCell" | "Mutex" | "RwLock" => (ps, ps),
129
130        // ── channels ─────────────────────────────────────────────────────────
131        "Sender" | "Receiver" | "SyncSender" => (ps, ps),
132
133        // ── zero-sized types ──────────────────────────────────────────────────
134        "PhantomData" | "PhantomPinned" => (0, 1),
135
136        // ── common fixed-size stdlib types ────────────────────────────────────
137        // Duration: u64 secs (8B) + u32 nanos (4B) → 12B + 4B trailing = 16B
138        "Duration" => (16, 8),
139        "Instant" | "SystemTime" => (16, 8),
140
141        // ── Pin<T> wraps T, pointer-size approximation ────────────────────────
142        "Pin" => (ps, ps),
143
144        // ── x86 SSE / AVX / AVX-512 SIMD types ───────────────────────────────
145        "__m64" => (8, 8),
146        "__m128" | "__m128d" | "__m128i" => (16, 16),
147        "__m256" | "__m256d" | "__m256i" => (32, 32),
148        "__m512" | "__m512d" | "__m512i" => (64, 64),
149
150        // ── Rust portable SIMD / packed_simd types ────────────────────────────
151        "f32x4" | "i32x4" | "u32x4" => (16, 16),
152        "f64x2" | "i64x2" | "u64x2" => (16, 16),
153        "f32x8" | "i32x8" | "u32x8" => (32, 32),
154        "f64x4" | "i64x4" | "u64x4" => (32, 32),
155        "f32x16" | "i32x16" | "u32x16" => (64, 64),
156
157        // ── unknown / third-party / generic type params (T, E, …) ────────────
158        _ => (ps, ps),
159    }
160}
161
162fn array_len_from_expr(expr: &syn::Expr) -> usize {
163    if let syn::Expr::Lit(syn::ExprLit {
164        lit: syn::Lit::Int(n),
165        ..
166    }) = expr
167    {
168        n.base10_parse::<usize>().unwrap_or(0)
169    } else {
170        0
171    }
172}
173
174// ── struct repr detection ─────────────────────────────────────────────────────
175
176fn is_packed(attrs: &[syn::Attribute]) -> bool {
177    attrs
178        .iter()
179        .any(|a| a.path().is_ident("repr") && a.to_token_stream().to_string().contains("packed"))
180}
181
182fn simulate_rust_layout(
183    name: String,
184    fields: &[(String, Type)],
185    packed: bool,
186    arch: &'static ArchConfig,
187) -> StructLayout {
188    let mut offset = 0usize;
189    let mut struct_align = 1usize;
190    let mut out_fields: Vec<Field> = Vec::new();
191
192    for (fname, ty) in fields {
193        let (size, align, type_info) = rust_type_size_align(ty, arch);
194        let effective_align = if packed { 1 } else { align };
195
196        if effective_align > 0 {
197            offset = offset.next_multiple_of(effective_align);
198        }
199        struct_align = struct_align.max(effective_align);
200
201        out_fields.push(Field {
202            name: fname.clone(),
203            ty: type_info,
204            offset,
205            size,
206            align: effective_align,
207            source_file: None,
208            source_line: None,
209            access: AccessPattern::Unknown,
210        });
211        offset += size;
212    }
213
214    if !packed && struct_align > 0 {
215        offset = offset.next_multiple_of(struct_align);
216    }
217
218    StructLayout {
219        name,
220        total_size: offset,
221        align: struct_align,
222        fields: out_fields,
223        source_file: None,
224        source_line: None,
225        arch,
226        is_packed: packed,
227        is_union: false,
228    }
229}
230
231// ── visitor ───────────────────────────────────────────────────────────────────
232
233struct StructVisitor {
234    arch: &'static ArchConfig,
235    layouts: Vec<StructLayout>,
236}
237
238impl<'ast> Visit<'ast> for StructVisitor {
239    fn visit_item_struct(&mut self, node: &'ast ItemStruct) {
240        syn::visit::visit_item_struct(self, node); // recurse into nested items
241
242        // Generic structs (e.g. `struct Foo<T>`) cannot be accurately laid out
243        // without knowing the concrete type arguments. Skip them rather than
244        // producing wrong field sizes for the type parameters.
245        if !node.generics.params.is_empty() {
246            return;
247        }
248
249        let name = node.ident.to_string();
250        let packed = is_packed(&node.attrs);
251
252        // Collect (field_name, type, optional_guard)
253        let fields: Vec<(String, Type, Option<String>)> = match &node.fields {
254            Fields::Named(nf) => nf
255                .named
256                .iter()
257                .map(|f| {
258                    let fname = f.ident.as_ref().map(|i| i.to_string()).unwrap_or_default();
259                    let guard = extract_guard_from_attrs(&f.attrs);
260                    (fname, f.ty.clone(), guard)
261                })
262                .collect(),
263            Fields::Unnamed(uf) => uf
264                .unnamed
265                .iter()
266                .enumerate()
267                .map(|(i, f)| {
268                    let guard = extract_guard_from_attrs(&f.attrs);
269                    (format!("_{i}"), f.ty.clone(), guard)
270                })
271                .collect(),
272            Fields::Unit => vec![],
273        };
274
275        let name_ty: Vec<(String, Type)> = fields
276            .iter()
277            .map(|(n, t, _)| (n.clone(), t.clone()))
278            .collect();
279        let mut layout = simulate_rust_layout(name, &name_ty, packed, self.arch);
280        layout.source_line = Some(node.ident.span().start().line as u32);
281
282        // Apply explicit guard annotations; these take precedence over the
283        // heuristic type-name pass in concurrency.rs (which skips non-Unknown fields).
284        for (i, (_, _, guard)) in fields.iter().enumerate() {
285            if let Some(g) = guard {
286                layout.fields[i].access = AccessPattern::Concurrent {
287                    guard: Some(g.clone()),
288                    is_atomic: false,
289                };
290            }
291        }
292
293        self.layouts.push(layout);
294    }
295
296    fn visit_item_enum(&mut self, node: &'ast ItemEnum) {
297        syn::visit::visit_item_enum(self, node);
298
299        // Skip generic enums (layout depends on unknown type arguments)
300        if !node.generics.params.is_empty() {
301            return;
302        }
303
304        let name = node.ident.to_string();
305        let n_variants = node.variants.len();
306        if n_variants == 0 {
307            return;
308        }
309
310        // Discriminant size: smallest integer that fits the variant count.
311        // Rust defaults to isize but uses the minimal repr in practice.
312        let disc_size: usize = if n_variants <= 256 {
313            1
314        } else if n_variants <= 65536 {
315            2
316        } else {
317            4
318        };
319
320        // Check if all variants are unit (C-like enum, no payload)
321        let all_unit = node
322            .variants
323            .iter()
324            .all(|v| matches!(v.fields, Fields::Unit));
325
326        if all_unit {
327            // Pure discriminant — no payload storage
328            let layout = StructLayout {
329                name,
330                total_size: disc_size,
331                align: disc_size,
332                fields: vec![Field {
333                    name: "__discriminant".to_string(),
334                    ty: TypeInfo::Primitive {
335                        name: format!("u{}", disc_size * 8),
336                        size: disc_size,
337                        align: disc_size,
338                    },
339                    offset: 0,
340                    size: disc_size,
341                    align: disc_size,
342                    source_file: None,
343                    source_line: None,
344                    access: AccessPattern::Unknown,
345                }],
346                source_file: None,
347                source_line: Some(node.ident.span().start().line as u32),
348                arch: self.arch,
349                is_packed: false,
350                is_union: false,
351            };
352            self.layouts.push(layout);
353            return;
354        }
355
356        // Data enum: find the maximum variant payload size and alignment.
357        let mut max_payload_size = 0usize;
358        let mut max_payload_align = 1usize;
359
360        for variant in &node.variants {
361            let var_fields: Vec<(String, Type)> = match &variant.fields {
362                Fields::Named(nf) => nf
363                    .named
364                    .iter()
365                    .map(|f| {
366                        let n = f.ident.as_ref().map(|i| i.to_string()).unwrap_or_default();
367                        (n, f.ty.clone())
368                    })
369                    .collect(),
370                Fields::Unnamed(uf) => uf
371                    .unnamed
372                    .iter()
373                    .enumerate()
374                    .map(|(i, f)| (format!("_{i}"), f.ty.clone()))
375                    .collect(),
376                Fields::Unit => vec![],
377            };
378
379            if !var_fields.is_empty() {
380                let var_layout = simulate_rust_layout(String::new(), &var_fields, false, self.arch);
381                if var_layout.total_size > max_payload_size {
382                    max_payload_size = var_layout.total_size;
383                }
384                max_payload_align = max_payload_align.max(var_layout.align);
385            }
386        }
387
388        // Conservative model: payload first at offset 0, discriminant immediately after.
389        // Rust's actual layout is compiler-controlled (niche optimisation etc.);
390        // this model gives a safe upper-bound for padding analysis.
391        let payload_align = max_payload_align.max(1);
392        let disc_offset = max_payload_size;
393        let total_before_pad = disc_offset + disc_size;
394        let total_align = payload_align.max(disc_size);
395        let total_size = total_before_pad.next_multiple_of(total_align);
396
397        let mut fields: Vec<Field> = Vec::new();
398        if max_payload_size > 0 {
399            fields.push(Field {
400                name: "__payload".to_string(),
401                ty: TypeInfo::Opaque {
402                    name: format!("largest_variant_payload ({}B)", max_payload_size),
403                    size: max_payload_size,
404                    align: payload_align,
405                },
406                offset: 0,
407                size: max_payload_size,
408                align: payload_align,
409                source_file: None,
410                source_line: None,
411                access: AccessPattern::Unknown,
412            });
413        }
414        fields.push(Field {
415            name: "__discriminant".to_string(),
416            ty: TypeInfo::Primitive {
417                name: format!("u{}", disc_size * 8),
418                size: disc_size,
419                align: disc_size,
420            },
421            offset: disc_offset,
422            size: disc_size,
423            align: disc_size,
424            source_file: None,
425            source_line: None,
426            access: AccessPattern::Unknown,
427        });
428
429        self.layouts.push(StructLayout {
430            name,
431            total_size,
432            align: total_align,
433            fields,
434            source_file: None,
435            source_line: Some(node.ident.span().start().line as u32),
436            arch: self.arch,
437            is_packed: false,
438            is_union: false,
439        });
440    }
441}
442
443// ── public API ────────────────────────────────────────────────────────────────
444
445pub fn parse_rust(source: &str, arch: &'static ArchConfig) -> anyhow::Result<Vec<StructLayout>> {
446    let file: syn::File = syn::parse_str(source)?;
447    let mut visitor = StructVisitor {
448        arch,
449        layouts: Vec::new(),
450    };
451    visitor.visit_file(&file);
452    Ok(visitor.layouts)
453}
454
455// ── tests ─────────────────────────────────────────────────────────────────────
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460    use padlock_core::arch::X86_64_SYSV;
461
462    #[test]
463    fn parse_simple_struct() {
464        let src = "struct Foo { a: u8, b: u64, c: u32 }";
465        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
466        assert_eq!(layouts.len(), 1);
467        let l = &layouts[0];
468        assert_eq!(l.name, "Foo");
469        assert_eq!(l.fields.len(), 3);
470        assert_eq!(l.fields[0].size, 1); // u8
471        assert_eq!(l.fields[1].size, 8); // u64
472        assert_eq!(l.fields[2].size, 4); // u32
473    }
474
475    #[test]
476    fn layout_includes_padding() {
477        // u8 then u64: 7 bytes padding inserted
478        let src = "struct T { a: u8, b: u64 }";
479        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
480        let l = &layouts[0];
481        assert_eq!(l.fields[0].offset, 0);
482        assert_eq!(l.fields[1].offset, 8); // u64 aligned to 8
483        assert_eq!(l.total_size, 16);
484        let gaps = padlock_core::ir::find_padding(l);
485        assert_eq!(gaps[0].bytes, 7);
486    }
487
488    #[test]
489    fn multiple_structs_parsed() {
490        let src = "struct A { x: u32 } struct B { y: u64 }";
491        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
492        assert_eq!(layouts.len(), 2);
493    }
494
495    #[test]
496    fn packed_struct_no_padding() {
497        let src = "#[repr(packed)] struct P { a: u8, b: u64 }";
498        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
499        let l = &layouts[0];
500        assert!(l.is_packed);
501        assert_eq!(l.fields[1].offset, 1); // no padding, b immediately after a
502        let gaps = padlock_core::ir::find_padding(l);
503        assert!(gaps.is_empty());
504    }
505
506    #[test]
507    fn pointer_field_uses_arch_size() {
508        let src = "struct S { p: *const u8 }";
509        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
510        assert_eq!(layouts[0].fields[0].size, 8); // 64-bit pointer
511    }
512
513    // ── attribute guard extraction ─────────────────────────────────────────────
514
515    #[test]
516    fn lock_protected_by_attr_sets_guard() {
517        let src = r#"
518struct Cache {
519    #[lock_protected_by = "mu"]
520    readers: u64,
521    mu: u64,
522}
523"#;
524        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
525        let readers = &layouts[0].fields[0];
526        assert_eq!(readers.name, "readers");
527        if let AccessPattern::Concurrent { guard, .. } = &readers.access {
528            assert_eq!(guard.as_deref(), Some("mu"));
529        } else {
530            panic!("expected Concurrent, got {:?}", readers.access);
531        }
532    }
533
534    #[test]
535    fn guarded_by_string_attr_sets_guard() {
536        let src = r#"
537struct S {
538    #[guarded_by("lock")]
539    value: u32,
540}
541"#;
542        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
543        if let AccessPattern::Concurrent { guard, .. } = &layouts[0].fields[0].access {
544            assert_eq!(guard.as_deref(), Some("lock"));
545        } else {
546            panic!("expected Concurrent");
547        }
548    }
549
550    #[test]
551    fn guarded_by_ident_attr_sets_guard() {
552        let src = r#"
553struct S {
554    #[guarded_by(mu)]
555    count: u64,
556}
557"#;
558        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
559        if let AccessPattern::Concurrent { guard, .. } = &layouts[0].fields[0].access {
560            assert_eq!(guard.as_deref(), Some("mu"));
561        } else {
562            panic!("expected Concurrent");
563        }
564    }
565
566    #[test]
567    fn protected_by_attr_sets_guard() {
568        let src = r#"
569struct S {
570    #[protected_by = "lock_a"]
571    x: u64,
572}
573"#;
574        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
575        if let AccessPattern::Concurrent { guard, .. } = &layouts[0].fields[0].access {
576            assert_eq!(guard.as_deref(), Some("lock_a"));
577        } else {
578            panic!("expected Concurrent");
579        }
580    }
581
582    #[test]
583    fn different_guards_on_same_cache_line_is_false_sharing() {
584        // readers and writers are at offsets 0 and 8 — same cache line (line 0).
585        // They have different explicit guards → confirmed false sharing.
586        let src = r#"
587struct HotPath {
588    #[lock_protected_by = "mu_a"]
589    readers: u64,
590    #[lock_protected_by = "mu_b"]
591    writers: u64,
592}
593"#;
594        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
595        assert!(padlock_core::analysis::false_sharing::has_false_sharing(
596            &layouts[0]
597        ));
598    }
599
600    #[test]
601    fn same_guard_on_same_cache_line_is_not_false_sharing() {
602        let src = r#"
603struct Safe {
604    #[lock_protected_by = "mu"]
605    a: u64,
606    #[lock_protected_by = "mu"]
607    b: u64,
608}
609"#;
610        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
611        assert!(!padlock_core::analysis::false_sharing::has_false_sharing(
612            &layouts[0]
613        ));
614    }
615
616    #[test]
617    fn unannotated_field_stays_unknown() {
618        let src = "struct S { x: u64 }";
619        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
620        assert!(matches!(
621            layouts[0].fields[0].access,
622            AccessPattern::Unknown
623        ));
624    }
625
626    // ── stdlib type sizes ─────────────────────────────────────────────────────
627
628    #[test]
629    fn vec_field_has_three_pointer_size() {
630        // Vec<T> is always ptr + len + cap regardless of T
631        let src = "struct S { items: Vec<u64> }";
632        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
633        assert_eq!(layouts[0].fields[0].size, 24); // 3 × 8 on x86-64
634    }
635
636    #[test]
637    fn string_field_has_three_pointer_size() {
638        let src = "struct S { name: String }";
639        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
640        assert_eq!(layouts[0].fields[0].size, 24);
641    }
642
643    #[test]
644    fn box_field_has_pointer_size() {
645        let src = "struct S { inner: Box<u64> }";
646        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
647        assert_eq!(layouts[0].fields[0].size, 8);
648    }
649
650    #[test]
651    fn arc_field_has_pointer_size() {
652        let src = "struct S { shared: Arc<Vec<u8>> }";
653        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
654        assert_eq!(layouts[0].fields[0].size, 8);
655    }
656
657    #[test]
658    fn phantom_data_is_zero_sized() {
659        let src = "struct S { a: u64, _marker: PhantomData<u8> }";
660        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
661        let marker = layouts[0]
662            .fields
663            .iter()
664            .find(|f| f.name == "_marker")
665            .unwrap();
666        assert_eq!(marker.size, 0);
667    }
668
669    #[test]
670    fn duration_field_is_16_bytes() {
671        let src = "struct S { timeout: Duration }";
672        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
673        assert_eq!(layouts[0].fields[0].size, 16);
674    }
675
676    #[test]
677    fn atomic_u64_has_correct_size() {
678        let src = "struct S { counter: AtomicU64 }";
679        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
680        assert_eq!(layouts[0].fields[0].size, 8);
681    }
682
683    #[test]
684    fn atomic_bool_has_correct_size() {
685        let src = "struct S { flag: AtomicBool }";
686        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
687        assert_eq!(layouts[0].fields[0].size, 1);
688    }
689
690    // ── generic struct skipping ───────────────────────────────────────────────
691
692    #[test]
693    fn generic_struct_is_skipped() {
694        // Cannot accurately lay out struct Foo<T> without knowing T.
695        let src = "struct Wrapper<T> { value: T, count: usize }";
696        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
697        assert!(
698            layouts.is_empty(),
699            "generic structs should be skipped; got {:?}",
700            layouts.iter().map(|l| &l.name).collect::<Vec<_>>()
701        );
702    }
703
704    #[test]
705    fn generic_struct_with_multiple_params_is_skipped() {
706        let src = "struct Pair<A, B> { first: A, second: B }";
707        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
708        assert!(layouts.is_empty());
709    }
710
711    #[test]
712    fn non_generic_struct_still_parsed_when_generic_sibling_exists() {
713        let src = r#"
714struct Generic<T> { value: T }
715struct Concrete { a: u32, b: u64 }
716"#;
717        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
718        assert_eq!(layouts.len(), 1);
719        assert_eq!(layouts[0].name, "Concrete");
720    }
721
722    // ── enum data variant support ─────────────────────────────────────────────
723
724    #[test]
725    fn unit_enum_is_just_discriminant() {
726        let src = "enum Color { Red, Green, Blue }";
727        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
728        assert_eq!(layouts.len(), 1);
729        let l = &layouts[0];
730        assert_eq!(l.name, "Color");
731        assert_eq!(l.total_size, 1); // 3 variants → u8 discriminant
732        assert_eq!(l.fields.len(), 1);
733        assert_eq!(l.fields[0].name, "__discriminant");
734    }
735
736    #[test]
737    fn unit_enum_with_many_variants_uses_u16_discriminant() {
738        // Build an enum with 300 variants (> 256)
739        let variants: String = (0..300)
740            .map(|i| format!("V{i}"))
741            .collect::<Vec<_>>()
742            .join(", ");
743        let src = format!("enum Big {{ {variants} }}");
744        let layouts = parse_rust(&src, &X86_64_SYSV).unwrap();
745        let l = &layouts[0];
746        assert_eq!(l.total_size, 2); // needs u16
747        assert_eq!(l.fields[0].size, 2);
748    }
749
750    #[test]
751    fn data_enum_total_size_covers_largest_variant() {
752        // Quit: no payload; Move: {x: i32, y: i32} = 8B; Write: String = 24B
753        // Max payload = 24B (String), disc = 1B → total = 32B (aligned to 8)
754        let src = r#"
755enum Message {
756    Quit,
757    Move { x: i32, y: i32 },
758    Write(String),
759}
760"#;
761        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
762        let l = &layouts[0];
763        assert_eq!(l.name, "Message");
764        // __payload (24B, align 8) + __discriminant (1B) → padded to 32B
765        assert_eq!(l.total_size, 32);
766        assert_eq!(l.fields.len(), 2);
767        let payload = l.fields.iter().find(|f| f.name == "__payload").unwrap();
768        assert_eq!(payload.size, 24); // String = 3×pointer
769    }
770
771    #[test]
772    fn generic_enum_is_skipped() {
773        let src = "enum Wrapper<T> { Some(T), None }";
774        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
775        assert!(
776            layouts.is_empty(),
777            "generic enums should be skipped; got {:?}",
778            layouts.iter().map(|l| &l.name).collect::<Vec<_>>()
779        );
780    }
781
782    #[test]
783    fn empty_enum_is_skipped() {
784        let src = "enum Never {}";
785        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
786        assert!(layouts.is_empty());
787    }
788
789    #[test]
790    fn enum_with_only_unit_variants_has_no_payload_field() {
791        let src = "enum Dir { North, South, East, West }";
792        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
793        assert!(!layouts[0].fields.iter().any(|f| f.name == "__payload"));
794    }
795
796    #[test]
797    fn data_enum_and_sibling_struct_both_parsed() {
798        let src = r#"
799enum Status { Ok, Err(u32) }
800struct Conn { port: u16, status: u32 }
801"#;
802        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
803        assert_eq!(layouts.len(), 2);
804        assert!(layouts.iter().any(|l| l.name == "Status"));
805        assert!(layouts.iter().any(|l| l.name == "Conn"));
806    }
807
808    // ── bad weather: enums ────────────────────────────────────────────────────
809
810    #[test]
811    fn enum_with_only_zero_sized_variants_has_payload_size_zero() {
812        // All unit variants → treated as unit enum, total = disc_size
813        let src = "enum E { A, B }";
814        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
815        let l = &layouts[0];
816        assert_eq!(l.total_size, 1);
817    }
818
819    #[test]
820    fn enum_mixed_unit_and_data_includes_max_payload() {
821        // Mix: unit variant + data variant; payload comes from data variant
822        let src = "enum E { Nothing, Data(u64) }";
823        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
824        let l = &layouts[0];
825        let payload = l.fields.iter().find(|f| f.name == "__payload").unwrap();
826        assert_eq!(payload.size, 8); // u64
827    }
828}