Skip to main content

padlock_source/frontends/
rust.rs

1// padlock-source/src/frontends/rust.rs
2//
3// Extracts struct layouts from Rust source using syn + the Visit API.
4// Sizes are approximated from type names using the target arch config.
5// Only repr(C) / repr(packed) / plain structs are handled; generics are opaque.
6
7use padlock_core::arch::ArchConfig;
8use padlock_core::ir::{AccessPattern, Field, StructLayout, TypeInfo};
9use quote::ToTokens;
10use syn::{Fields, ItemEnum, ItemStruct, Type, visit::Visit};
11
12// ── attribute guard extraction ────────────────────────────────────────────────
13
14/// Extract a lock guard name from field attributes.
15///
16/// Recognised forms:
17/// - `#[lock_protected_by = "mu"]`
18/// - `#[protected_by = "mu"]`
19/// - `#[guarded_by("mu")]` or `#[guarded_by(mu)]`
20/// - `#[pt_guarded_by("mu")]` or `#[pt_guarded_by(mu)]` (pointer variant)
21pub fn extract_guard_from_attrs(attrs: &[syn::Attribute]) -> Option<String> {
22    for attr in attrs {
23        let path = attr.path();
24        // Name-value form: #[lock_protected_by = "mu"] / #[protected_by = "mu"]
25        if (path.is_ident("lock_protected_by") || path.is_ident("protected_by"))
26            && let syn::Meta::NameValue(nv) = &attr.meta
27            && let syn::Expr::Lit(syn::ExprLit {
28                lit: syn::Lit::Str(s),
29                ..
30            }) = &nv.value
31        {
32            return Some(s.value());
33        }
34        // List form: #[guarded_by("mu")] / #[guarded_by(mu)] / #[pt_guarded_by(...)]
35        if path.is_ident("guarded_by") || path.is_ident("pt_guarded_by") {
36            // Try string literal first
37            if let Ok(s) = attr.parse_args::<syn::LitStr>() {
38                return Some(s.value());
39            }
40            // Fall back to bare identifier
41            if let Ok(id) = attr.parse_args::<syn::Ident>() {
42                return Some(id.to_string());
43            }
44        }
45    }
46    None
47}
48
49// ── type resolution ───────────────────────────────────────────────────────────
50
51fn rust_type_size_align(ty: &Type, arch: &'static ArchConfig) -> (usize, usize, TypeInfo) {
52    match ty {
53        Type::Path(tp) => {
54            let name = tp
55                .path
56                .segments
57                .last()
58                .map(|s| s.ident.to_string())
59                .unwrap_or_default();
60            let (size, align) = primitive_size_align(&name, arch);
61            (size, align, TypeInfo::Primitive { name, size, align })
62        }
63        Type::Ptr(_) | Type::Reference(_) => {
64            let s = arch.pointer_size;
65            (s, s, TypeInfo::Pointer { size: s, align: s })
66        }
67        Type::Array(arr) => {
68            let (elem_size, elem_align, elem_ty) = rust_type_size_align(&arr.elem, arch);
69            let count = array_len_from_expr(&arr.len);
70            let size = elem_size * count;
71            (
72                size,
73                elem_align,
74                TypeInfo::Array {
75                    element: Box::new(elem_ty),
76                    count,
77                    size,
78                    align: elem_align,
79                },
80            )
81        }
82        _ => {
83            let s = arch.pointer_size;
84            (
85                s,
86                s,
87                TypeInfo::Opaque {
88                    name: "(unknown)".into(),
89                    size: s,
90                    align: s,
91                },
92            )
93        }
94    }
95}
96
97fn primitive_size_align(name: &str, arch: &'static ArchConfig) -> (usize, usize) {
98    let ps = arch.pointer_size;
99    match name {
100        // ── language primitives ───────────────────────────────────────────────
101        "bool" | "u8" | "i8" => (1, 1),
102        "u16" | "i16" => (2, 2),
103        "u32" | "i32" | "f32" => (4, 4),
104        "u64" | "i64" | "f64" => (8, 8),
105        "u128" | "i128" => (16, 16),
106        "usize" | "isize" => (ps, ps),
107        "char" => (4, 4), // Rust char is a Unicode scalar (4 bytes)
108
109        // ── std atomics ───────────────────────────────────────────────────────
110        "AtomicBool" | "AtomicU8" | "AtomicI8" => (1, 1),
111        "AtomicU16" | "AtomicI16" => (2, 2),
112        "AtomicU32" | "AtomicI32" => (4, 4),
113        "AtomicU64" | "AtomicI64" => (8, 8),
114        "AtomicUsize" | "AtomicIsize" | "AtomicPtr" => (ps, ps),
115
116        // ── heap-allocated collections: ptr + len + cap (3 words) ────────────
117        // Size is independent of the element type T (generic arg already stripped).
118        "Vec" | "String" | "OsString" | "CString" | "PathBuf" => (3 * ps, ps),
119        "VecDeque" | "LinkedList" | "BinaryHeap" => (3 * ps, ps),
120        "HashMap" | "HashSet" | "BTreeMap" | "BTreeSet" => (3 * ps, ps),
121
122        // ── single-pointer smart pointers ─────────────────────────────────────
123        "Box" | "Rc" | "Arc" | "Weak" | "NonNull" | "Cell" => (ps, ps),
124
125        // ── interior-mutability / sync wrappers ───────────────────────────────
126        // Size depends on T but pointer-size is a reasonable approximation for
127        // display purposes; use binary analysis for precise results.
128        "RefCell" | "Mutex" | "RwLock" => (ps, ps),
129
130        // ── channels ─────────────────────────────────────────────────────────
131        "Sender" | "Receiver" | "SyncSender" => (ps, ps),
132
133        // ── zero-sized types ──────────────────────────────────────────────────
134        "PhantomData" | "PhantomPinned" => (0, 1),
135
136        // ── common fixed-size stdlib types ────────────────────────────────────
137        // Duration: u64 secs (8B) + u32 nanos (4B) → 12B + 4B trailing = 16B
138        "Duration" => (16, 8),
139        "Instant" | "SystemTime" => (16, 8),
140
141        // ── Pin<T> wraps T, pointer-size approximation ────────────────────────
142        "Pin" => (ps, ps),
143
144        // ── x86 SSE / AVX / AVX-512 SIMD types ───────────────────────────────
145        "__m64" => (8, 8),
146        "__m128" | "__m128d" | "__m128i" => (16, 16),
147        "__m256" | "__m256d" | "__m256i" => (32, 32),
148        "__m512" | "__m512d" | "__m512i" => (64, 64),
149
150        // ── Rust portable SIMD / packed_simd types ────────────────────────────
151        "f32x4" | "i32x4" | "u32x4" => (16, 16),
152        "f64x2" | "i64x2" | "u64x2" => (16, 16),
153        "f32x8" | "i32x8" | "u32x8" => (32, 32),
154        "f64x4" | "i64x4" | "u64x4" => (32, 32),
155        "f32x16" | "i32x16" | "u32x16" => (64, 64),
156
157        // ── unknown / third-party / generic type params (T, E, …) ────────────
158        _ => (ps, ps),
159    }
160}
161
162fn array_len_from_expr(expr: &syn::Expr) -> usize {
163    if let syn::Expr::Lit(syn::ExprLit {
164        lit: syn::Lit::Int(n),
165        ..
166    }) = expr
167    {
168        n.base10_parse::<usize>().unwrap_or(0)
169    } else {
170        0
171    }
172}
173
174// ── struct repr detection ─────────────────────────────────────────────────────
175
176fn is_packed(attrs: &[syn::Attribute]) -> bool {
177    attrs
178        .iter()
179        .any(|a| a.path().is_ident("repr") && a.to_token_stream().to_string().contains("packed"))
180}
181
182/// Extract the alignment from `#[repr(align(N))]`. Returns `None` if not present.
183fn repr_align(attrs: &[syn::Attribute]) -> Option<usize> {
184    for attr in attrs {
185        if !attr.path().is_ident("repr") {
186            continue;
187        }
188        let ts = attr.to_token_stream().to_string();
189        // Look for `align ( N )` in the token stream string.
190        // The tokeniser adds spaces: "repr (align (64))" etc.
191        if let Some(start) = ts.find("align") {
192            let after = ts[start..].trim_start_matches("align").trim_start();
193            if after.starts_with('(') {
194                let inner = after.trim_start_matches('(');
195                let num_str: String = inner.chars().take_while(|c| c.is_ascii_digit()).collect();
196                if let Ok(n) = num_str.parse::<usize>()
197                    && n > 0
198                    && n.is_power_of_two()
199                {
200                    return Some(n);
201                }
202            }
203        }
204    }
205    None
206}
207
208fn simulate_rust_layout(
209    name: String,
210    fields: &[(String, Type)],
211    packed: bool,
212    forced_align: Option<usize>,
213    arch: &'static ArchConfig,
214) -> StructLayout {
215    let mut offset = 0usize;
216    let mut struct_align = 1usize;
217    let mut out_fields: Vec<Field> = Vec::new();
218
219    for (fname, ty) in fields {
220        let (size, align, type_info) = rust_type_size_align(ty, arch);
221        let effective_align = if packed { 1 } else { align };
222
223        if effective_align > 0 {
224            offset = offset.next_multiple_of(effective_align);
225        }
226        struct_align = struct_align.max(effective_align);
227
228        out_fields.push(Field {
229            name: fname.clone(),
230            ty: type_info,
231            offset,
232            size,
233            align: effective_align,
234            source_file: None,
235            source_line: None,
236            access: AccessPattern::Unknown,
237        });
238        offset += size;
239    }
240
241    // Apply repr(align(N)): raise minimum alignment and add trailing padding.
242    if let Some(fa) = forced_align
243        && fa > struct_align
244    {
245        struct_align = fa;
246    }
247
248    if !packed && struct_align > 0 {
249        offset = offset.next_multiple_of(struct_align);
250    }
251
252    StructLayout {
253        name,
254        total_size: offset,
255        align: struct_align,
256        fields: out_fields,
257        source_file: None,
258        source_line: None,
259        arch,
260        is_packed: packed,
261        is_union: false,
262    }
263}
264
265// ── visitor ───────────────────────────────────────────────────────────────────
266
267struct StructVisitor {
268    arch: &'static ArchConfig,
269    layouts: Vec<StructLayout>,
270}
271
272impl<'ast> Visit<'ast> for StructVisitor {
273    fn visit_item_struct(&mut self, node: &'ast ItemStruct) {
274        syn::visit::visit_item_struct(self, node); // recurse into nested items
275
276        // Generic structs (e.g. `struct Foo<T>`) cannot be accurately laid out
277        // without knowing the concrete type arguments. Skip them rather than
278        // producing wrong field sizes for the type parameters.
279        if !node.generics.params.is_empty() {
280            return;
281        }
282
283        let name = node.ident.to_string();
284        let packed = is_packed(&node.attrs);
285        let forced_align = repr_align(&node.attrs);
286
287        // Collect (field_name, type, optional_guard)
288        let fields: Vec<(String, Type, Option<String>)> = match &node.fields {
289            Fields::Named(nf) => nf
290                .named
291                .iter()
292                .map(|f| {
293                    let fname = f.ident.as_ref().map(|i| i.to_string()).unwrap_or_default();
294                    let guard = extract_guard_from_attrs(&f.attrs);
295                    (fname, f.ty.clone(), guard)
296                })
297                .collect(),
298            Fields::Unnamed(uf) => uf
299                .unnamed
300                .iter()
301                .enumerate()
302                .map(|(i, f)| {
303                    let guard = extract_guard_from_attrs(&f.attrs);
304                    (format!("_{i}"), f.ty.clone(), guard)
305                })
306                .collect(),
307            Fields::Unit => vec![],
308        };
309
310        let name_ty: Vec<(String, Type)> = fields
311            .iter()
312            .map(|(n, t, _)| (n.clone(), t.clone()))
313            .collect();
314        let mut layout = simulate_rust_layout(name, &name_ty, packed, forced_align, self.arch);
315        layout.source_line = Some(node.ident.span().start().line as u32);
316
317        // Apply explicit guard annotations; these take precedence over the
318        // heuristic type-name pass in concurrency.rs (which skips non-Unknown fields).
319        for (i, (_, _, guard)) in fields.iter().enumerate() {
320            if let Some(g) = guard {
321                layout.fields[i].access = AccessPattern::Concurrent {
322                    guard: Some(g.clone()),
323                    is_atomic: false,
324                };
325            }
326        }
327
328        self.layouts.push(layout);
329    }
330
331    fn visit_item_enum(&mut self, node: &'ast ItemEnum) {
332        syn::visit::visit_item_enum(self, node);
333
334        // Skip generic enums (layout depends on unknown type arguments)
335        if !node.generics.params.is_empty() {
336            return;
337        }
338
339        let name = node.ident.to_string();
340        let n_variants = node.variants.len();
341        if n_variants == 0 {
342            return;
343        }
344
345        // Discriminant size: smallest integer that fits the variant count.
346        // Rust defaults to isize but uses the minimal repr in practice.
347        let disc_size: usize = if n_variants <= 256 {
348            1
349        } else if n_variants <= 65536 {
350            2
351        } else {
352            4
353        };
354
355        // Check if all variants are unit (C-like enum, no payload)
356        let all_unit = node
357            .variants
358            .iter()
359            .all(|v| matches!(v.fields, Fields::Unit));
360
361        if all_unit {
362            // Pure discriminant — no payload storage
363            let layout = StructLayout {
364                name,
365                total_size: disc_size,
366                align: disc_size,
367                fields: vec![Field {
368                    name: "__discriminant".to_string(),
369                    ty: TypeInfo::Primitive {
370                        name: format!("u{}", disc_size * 8),
371                        size: disc_size,
372                        align: disc_size,
373                    },
374                    offset: 0,
375                    size: disc_size,
376                    align: disc_size,
377                    source_file: None,
378                    source_line: None,
379                    access: AccessPattern::Unknown,
380                }],
381                source_file: None,
382                source_line: Some(node.ident.span().start().line as u32),
383                arch: self.arch,
384                is_packed: false,
385                is_union: false,
386            };
387            self.layouts.push(layout);
388            return;
389        }
390
391        // Data enum: find the maximum variant payload size and alignment.
392        let mut max_payload_size = 0usize;
393        let mut max_payload_align = 1usize;
394
395        for variant in &node.variants {
396            let var_fields: Vec<(String, Type)> = match &variant.fields {
397                Fields::Named(nf) => nf
398                    .named
399                    .iter()
400                    .map(|f| {
401                        let n = f.ident.as_ref().map(|i| i.to_string()).unwrap_or_default();
402                        (n, f.ty.clone())
403                    })
404                    .collect(),
405                Fields::Unnamed(uf) => uf
406                    .unnamed
407                    .iter()
408                    .enumerate()
409                    .map(|(i, f)| (format!("_{i}"), f.ty.clone()))
410                    .collect(),
411                Fields::Unit => vec![],
412            };
413
414            if !var_fields.is_empty() {
415                let var_layout =
416                    simulate_rust_layout(String::new(), &var_fields, false, None, self.arch);
417                if var_layout.total_size > max_payload_size {
418                    max_payload_size = var_layout.total_size;
419                }
420                max_payload_align = max_payload_align.max(var_layout.align);
421            }
422        }
423
424        // Conservative model: payload first at offset 0, discriminant immediately after.
425        // Rust's actual layout is compiler-controlled (niche optimisation etc.);
426        // this model gives a safe upper-bound for padding analysis.
427        let payload_align = max_payload_align.max(1);
428        let disc_offset = max_payload_size;
429        let total_before_pad = disc_offset + disc_size;
430        let total_align = payload_align.max(disc_size);
431        let total_size = total_before_pad.next_multiple_of(total_align);
432
433        let mut fields: Vec<Field> = Vec::new();
434        if max_payload_size > 0 {
435            fields.push(Field {
436                name: "__payload".to_string(),
437                ty: TypeInfo::Opaque {
438                    name: format!("largest_variant_payload ({}B)", max_payload_size),
439                    size: max_payload_size,
440                    align: payload_align,
441                },
442                offset: 0,
443                size: max_payload_size,
444                align: payload_align,
445                source_file: None,
446                source_line: None,
447                access: AccessPattern::Unknown,
448            });
449        }
450        fields.push(Field {
451            name: "__discriminant".to_string(),
452            ty: TypeInfo::Primitive {
453                name: format!("u{}", disc_size * 8),
454                size: disc_size,
455                align: disc_size,
456            },
457            offset: disc_offset,
458            size: disc_size,
459            align: disc_size,
460            source_file: None,
461            source_line: None,
462            access: AccessPattern::Unknown,
463        });
464
465        self.layouts.push(StructLayout {
466            name,
467            total_size,
468            align: total_align,
469            fields,
470            source_file: None,
471            source_line: Some(node.ident.span().start().line as u32),
472            arch: self.arch,
473            is_packed: false,
474            is_union: false,
475        });
476    }
477}
478
479// ── public API ────────────────────────────────────────────────────────────────
480
481pub fn parse_rust(source: &str, arch: &'static ArchConfig) -> anyhow::Result<Vec<StructLayout>> {
482    let file: syn::File = syn::parse_str(source)?;
483    let mut visitor = StructVisitor {
484        arch,
485        layouts: Vec::new(),
486    };
487    visitor.visit_file(&file);
488    Ok(visitor.layouts)
489}
490
491// ── tests ─────────────────────────────────────────────────────────────────────
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496    use padlock_core::arch::X86_64_SYSV;
497
498    #[test]
499    fn parse_simple_struct() {
500        let src = "struct Foo { a: u8, b: u64, c: u32 }";
501        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
502        assert_eq!(layouts.len(), 1);
503        let l = &layouts[0];
504        assert_eq!(l.name, "Foo");
505        assert_eq!(l.fields.len(), 3);
506        assert_eq!(l.fields[0].size, 1); // u8
507        assert_eq!(l.fields[1].size, 8); // u64
508        assert_eq!(l.fields[2].size, 4); // u32
509    }
510
511    #[test]
512    fn layout_includes_padding() {
513        // u8 then u64: 7 bytes padding inserted
514        let src = "struct T { a: u8, b: u64 }";
515        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
516        let l = &layouts[0];
517        assert_eq!(l.fields[0].offset, 0);
518        assert_eq!(l.fields[1].offset, 8); // u64 aligned to 8
519        assert_eq!(l.total_size, 16);
520        let gaps = padlock_core::ir::find_padding(l);
521        assert_eq!(gaps[0].bytes, 7);
522    }
523
524    #[test]
525    fn multiple_structs_parsed() {
526        let src = "struct A { x: u32 } struct B { y: u64 }";
527        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
528        assert_eq!(layouts.len(), 2);
529    }
530
531    #[test]
532    fn packed_struct_no_padding() {
533        let src = "#[repr(packed)] struct P { a: u8, b: u64 }";
534        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
535        let l = &layouts[0];
536        assert!(l.is_packed);
537        assert_eq!(l.fields[1].offset, 1); // no padding, b immediately after a
538        let gaps = padlock_core::ir::find_padding(l);
539        assert!(gaps.is_empty());
540    }
541
542    #[test]
543    fn pointer_field_uses_arch_size() {
544        let src = "struct S { p: *const u8 }";
545        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
546        assert_eq!(layouts[0].fields[0].size, 8); // 64-bit pointer
547    }
548
549    // ── attribute guard extraction ─────────────────────────────────────────────
550
551    #[test]
552    fn lock_protected_by_attr_sets_guard() {
553        let src = r#"
554struct Cache {
555    #[lock_protected_by = "mu"]
556    readers: u64,
557    mu: u64,
558}
559"#;
560        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
561        let readers = &layouts[0].fields[0];
562        assert_eq!(readers.name, "readers");
563        if let AccessPattern::Concurrent { guard, .. } = &readers.access {
564            assert_eq!(guard.as_deref(), Some("mu"));
565        } else {
566            panic!("expected Concurrent, got {:?}", readers.access);
567        }
568    }
569
570    #[test]
571    fn guarded_by_string_attr_sets_guard() {
572        let src = r#"
573struct S {
574    #[guarded_by("lock")]
575    value: u32,
576}
577"#;
578        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
579        if let AccessPattern::Concurrent { guard, .. } = &layouts[0].fields[0].access {
580            assert_eq!(guard.as_deref(), Some("lock"));
581        } else {
582            panic!("expected Concurrent");
583        }
584    }
585
586    #[test]
587    fn guarded_by_ident_attr_sets_guard() {
588        let src = r#"
589struct S {
590    #[guarded_by(mu)]
591    count: u64,
592}
593"#;
594        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
595        if let AccessPattern::Concurrent { guard, .. } = &layouts[0].fields[0].access {
596            assert_eq!(guard.as_deref(), Some("mu"));
597        } else {
598            panic!("expected Concurrent");
599        }
600    }
601
602    #[test]
603    fn protected_by_attr_sets_guard() {
604        let src = r#"
605struct S {
606    #[protected_by = "lock_a"]
607    x: u64,
608}
609"#;
610        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
611        if let AccessPattern::Concurrent { guard, .. } = &layouts[0].fields[0].access {
612            assert_eq!(guard.as_deref(), Some("lock_a"));
613        } else {
614            panic!("expected Concurrent");
615        }
616    }
617
618    #[test]
619    fn different_guards_on_same_cache_line_is_false_sharing() {
620        // readers and writers are at offsets 0 and 8 — same cache line (line 0).
621        // They have different explicit guards → confirmed false sharing.
622        let src = r#"
623struct HotPath {
624    #[lock_protected_by = "mu_a"]
625    readers: u64,
626    #[lock_protected_by = "mu_b"]
627    writers: u64,
628}
629"#;
630        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
631        assert!(padlock_core::analysis::false_sharing::has_false_sharing(
632            &layouts[0]
633        ));
634    }
635
636    #[test]
637    fn same_guard_on_same_cache_line_is_not_false_sharing() {
638        let src = r#"
639struct Safe {
640    #[lock_protected_by = "mu"]
641    a: u64,
642    #[lock_protected_by = "mu"]
643    b: u64,
644}
645"#;
646        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
647        assert!(!padlock_core::analysis::false_sharing::has_false_sharing(
648            &layouts[0]
649        ));
650    }
651
652    #[test]
653    fn unannotated_field_stays_unknown() {
654        let src = "struct S { x: u64 }";
655        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
656        assert!(matches!(
657            layouts[0].fields[0].access,
658            AccessPattern::Unknown
659        ));
660    }
661
662    // ── stdlib type sizes ─────────────────────────────────────────────────────
663
664    #[test]
665    fn vec_field_has_three_pointer_size() {
666        // Vec<T> is always ptr + len + cap regardless of T
667        let src = "struct S { items: Vec<u64> }";
668        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
669        assert_eq!(layouts[0].fields[0].size, 24); // 3 × 8 on x86-64
670    }
671
672    #[test]
673    fn string_field_has_three_pointer_size() {
674        let src = "struct S { name: String }";
675        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
676        assert_eq!(layouts[0].fields[0].size, 24);
677    }
678
679    #[test]
680    fn box_field_has_pointer_size() {
681        let src = "struct S { inner: Box<u64> }";
682        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
683        assert_eq!(layouts[0].fields[0].size, 8);
684    }
685
686    #[test]
687    fn arc_field_has_pointer_size() {
688        let src = "struct S { shared: Arc<Vec<u8>> }";
689        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
690        assert_eq!(layouts[0].fields[0].size, 8);
691    }
692
693    #[test]
694    fn phantom_data_is_zero_sized() {
695        let src = "struct S { a: u64, _marker: PhantomData<u8> }";
696        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
697        let marker = layouts[0]
698            .fields
699            .iter()
700            .find(|f| f.name == "_marker")
701            .unwrap();
702        assert_eq!(marker.size, 0);
703    }
704
705    #[test]
706    fn duration_field_is_16_bytes() {
707        let src = "struct S { timeout: Duration }";
708        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
709        assert_eq!(layouts[0].fields[0].size, 16);
710    }
711
712    #[test]
713    fn atomic_u64_has_correct_size() {
714        let src = "struct S { counter: AtomicU64 }";
715        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
716        assert_eq!(layouts[0].fields[0].size, 8);
717    }
718
719    #[test]
720    fn atomic_bool_has_correct_size() {
721        let src = "struct S { flag: AtomicBool }";
722        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
723        assert_eq!(layouts[0].fields[0].size, 1);
724    }
725
726    // ── generic struct skipping ───────────────────────────────────────────────
727
728    #[test]
729    fn generic_struct_is_skipped() {
730        // Cannot accurately lay out struct Foo<T> without knowing T.
731        let src = "struct Wrapper<T> { value: T, count: usize }";
732        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
733        assert!(
734            layouts.is_empty(),
735            "generic structs should be skipped; got {:?}",
736            layouts.iter().map(|l| &l.name).collect::<Vec<_>>()
737        );
738    }
739
740    #[test]
741    fn generic_struct_with_multiple_params_is_skipped() {
742        let src = "struct Pair<A, B> { first: A, second: B }";
743        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
744        assert!(layouts.is_empty());
745    }
746
747    #[test]
748    fn non_generic_struct_still_parsed_when_generic_sibling_exists() {
749        let src = r#"
750struct Generic<T> { value: T }
751struct Concrete { a: u32, b: u64 }
752"#;
753        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
754        assert_eq!(layouts.len(), 1);
755        assert_eq!(layouts[0].name, "Concrete");
756    }
757
758    // ── enum data variant support ─────────────────────────────────────────────
759
760    #[test]
761    fn unit_enum_is_just_discriminant() {
762        let src = "enum Color { Red, Green, Blue }";
763        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
764        assert_eq!(layouts.len(), 1);
765        let l = &layouts[0];
766        assert_eq!(l.name, "Color");
767        assert_eq!(l.total_size, 1); // 3 variants → u8 discriminant
768        assert_eq!(l.fields.len(), 1);
769        assert_eq!(l.fields[0].name, "__discriminant");
770    }
771
772    #[test]
773    fn unit_enum_with_many_variants_uses_u16_discriminant() {
774        // Build an enum with 300 variants (> 256)
775        let variants: String = (0..300)
776            .map(|i| format!("V{i}"))
777            .collect::<Vec<_>>()
778            .join(", ");
779        let src = format!("enum Big {{ {variants} }}");
780        let layouts = parse_rust(&src, &X86_64_SYSV).unwrap();
781        let l = &layouts[0];
782        assert_eq!(l.total_size, 2); // needs u16
783        assert_eq!(l.fields[0].size, 2);
784    }
785
786    #[test]
787    fn data_enum_total_size_covers_largest_variant() {
788        // Quit: no payload; Move: {x: i32, y: i32} = 8B; Write: String = 24B
789        // Max payload = 24B (String), disc = 1B → total = 32B (aligned to 8)
790        let src = r#"
791enum Message {
792    Quit,
793    Move { x: i32, y: i32 },
794    Write(String),
795}
796"#;
797        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
798        let l = &layouts[0];
799        assert_eq!(l.name, "Message");
800        // __payload (24B, align 8) + __discriminant (1B) → padded to 32B
801        assert_eq!(l.total_size, 32);
802        assert_eq!(l.fields.len(), 2);
803        let payload = l.fields.iter().find(|f| f.name == "__payload").unwrap();
804        assert_eq!(payload.size, 24); // String = 3×pointer
805    }
806
807    #[test]
808    fn generic_enum_is_skipped() {
809        let src = "enum Wrapper<T> { Some(T), None }";
810        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
811        assert!(
812            layouts.is_empty(),
813            "generic enums should be skipped; got {:?}",
814            layouts.iter().map(|l| &l.name).collect::<Vec<_>>()
815        );
816    }
817
818    #[test]
819    fn empty_enum_is_skipped() {
820        let src = "enum Never {}";
821        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
822        assert!(layouts.is_empty());
823    }
824
825    #[test]
826    fn enum_with_only_unit_variants_has_no_payload_field() {
827        let src = "enum Dir { North, South, East, West }";
828        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
829        assert!(!layouts[0].fields.iter().any(|f| f.name == "__payload"));
830    }
831
832    #[test]
833    fn data_enum_and_sibling_struct_both_parsed() {
834        let src = r#"
835enum Status { Ok, Err(u32) }
836struct Conn { port: u16, status: u32 }
837"#;
838        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
839        assert_eq!(layouts.len(), 2);
840        assert!(layouts.iter().any(|l| l.name == "Status"));
841        assert!(layouts.iter().any(|l| l.name == "Conn"));
842    }
843
844    // ── bad weather: enums ────────────────────────────────────────────────────
845
846    #[test]
847    fn enum_with_only_zero_sized_variants_has_payload_size_zero() {
848        // All unit variants → treated as unit enum, total = disc_size
849        let src = "enum E { A, B }";
850        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
851        let l = &layouts[0];
852        assert_eq!(l.total_size, 1);
853    }
854
855    #[test]
856    fn enum_mixed_unit_and_data_includes_max_payload() {
857        // Mix: unit variant + data variant; payload comes from data variant
858        let src = "enum E { Nothing, Data(u64) }";
859        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
860        let l = &layouts[0];
861        let payload = l.fields.iter().find(|f| f.name == "__payload").unwrap();
862        assert_eq!(payload.size, 8); // u64
863    }
864
865    // ── repr(align(N)) ────────────────────────────────────────────────────────
866
867    #[test]
868    fn repr_align_raises_struct_alignment() {
869        let src = "#[repr(align(64))]\nstruct CacheLine { a: u8, b: u32 }";
870        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
871        let l = &layouts[0];
872        assert_eq!(
873            l.align, 64,
874            "repr(align(64)) must set struct alignment to 64"
875        );
876        assert_eq!(l.total_size, 64, "size must be padded to 64 bytes");
877    }
878
879    #[test]
880    fn repr_align_does_not_shrink_natural_alignment() {
881        // repr(align(1)) on a struct whose natural align is 8 — must keep 8
882        let src = "#[repr(align(1))]\nstruct S { a: u64 }";
883        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
884        let l = &layouts[0];
885        assert_eq!(
886            l.align, 8,
887            "natural align must not be reduced below repr(align)"
888        );
889    }
890
891    #[test]
892    fn repr_align_adds_trailing_padding() {
893        // u8 + u32 = 5 bytes natural, padded to 8 with align(8)
894        let src = "#[repr(align(8))]\nstruct S { a: u8, b: u32 }";
895        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
896        let l = &layouts[0];
897        assert_eq!(l.total_size, 8);
898    }
899
900    #[test]
901    fn no_repr_align_has_natural_size() {
902        // Baseline: without repr(align), just natural padding
903        let src = "struct S { a: u8, b: u32 }";
904        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
905        let l = &layouts[0];
906        // a:1 + 3 pad + b:4 = 8; align=4
907        assert_eq!(l.total_size, 8);
908        assert_eq!(l.align, 4);
909    }
910
911    // ── tuple structs ─────────────────────────────────────────────────────────
912
913    #[test]
914    fn tuple_struct_fields_named_by_index() {
915        let src = "struct Pair(u64, u8);";
916        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
917        let l = &layouts[0];
918        assert_eq!(l.fields[0].name, "_0");
919        assert_eq!(l.fields[1].name, "_1");
920    }
921
922    #[test]
923    fn tuple_struct_layout_follows_alignment() {
924        // u64 then u8: no padding before u64, 7 bytes trailing
925        let src = "struct S(u64, u8);";
926        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
927        let l = &layouts[0];
928        assert_eq!(l.fields[0].offset, 0);
929        assert_eq!(l.fields[0].size, 8);
930        assert_eq!(l.fields[1].offset, 8);
931        assert_eq!(l.fields[1].size, 1);
932        assert_eq!(l.total_size, 16);
933    }
934
935    #[test]
936    fn tuple_struct_with_padding_waste_detected() {
937        // u8 then u64: 7 bytes padding
938        let src = "struct S(u8, u64);";
939        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
940        let l = &layouts[0];
941        assert_eq!(l.fields[0].offset, 0); // u8 at 0
942        assert_eq!(l.fields[1].offset, 8); // u64 aligned to 8
943        assert_eq!(l.total_size, 16);
944        let gaps = padlock_core::ir::find_padding(l);
945        assert_eq!(gaps[0].bytes, 7);
946    }
947}