Skip to main content

padlock_source/frontends/
rust.rs

1// padlock-source/src/frontends/rust.rs
2//
3// Extracts struct layouts from Rust source using syn + the Visit API.
4// Sizes are approximated from type names using the target arch config.
5// Only repr(C) / repr(packed) / plain structs are handled; generics are opaque.
6
7use padlock_core::arch::ArchConfig;
8use padlock_core::ir::{AccessPattern, Field, StructLayout, TypeInfo};
9use quote::ToTokens;
10use syn::{Fields, ItemStruct, Type, visit::Visit};
11
12// ── attribute guard extraction ────────────────────────────────────────────────
13
14/// Extract a lock guard name from field attributes.
15///
16/// Recognised forms:
17/// - `#[lock_protected_by = "mu"]`
18/// - `#[protected_by = "mu"]`
19/// - `#[guarded_by("mu")]` or `#[guarded_by(mu)]`
20/// - `#[pt_guarded_by("mu")]` or `#[pt_guarded_by(mu)]` (pointer variant)
21pub fn extract_guard_from_attrs(attrs: &[syn::Attribute]) -> Option<String> {
22    for attr in attrs {
23        let path = attr.path();
24        // Name-value form: #[lock_protected_by = "mu"] / #[protected_by = "mu"]
25        if (path.is_ident("lock_protected_by") || path.is_ident("protected_by"))
26            && let syn::Meta::NameValue(nv) = &attr.meta
27            && let syn::Expr::Lit(syn::ExprLit {
28                lit: syn::Lit::Str(s),
29                ..
30            }) = &nv.value
31        {
32            return Some(s.value());
33        }
34        // List form: #[guarded_by("mu")] / #[guarded_by(mu)] / #[pt_guarded_by(...)]
35        if path.is_ident("guarded_by") || path.is_ident("pt_guarded_by") {
36            // Try string literal first
37            if let Ok(s) = attr.parse_args::<syn::LitStr>() {
38                return Some(s.value());
39            }
40            // Fall back to bare identifier
41            if let Ok(id) = attr.parse_args::<syn::Ident>() {
42                return Some(id.to_string());
43            }
44        }
45    }
46    None
47}
48
49// ── type resolution ───────────────────────────────────────────────────────────
50
51fn rust_type_size_align(ty: &Type, arch: &'static ArchConfig) -> (usize, usize, TypeInfo) {
52    match ty {
53        Type::Path(tp) => {
54            let name = tp
55                .path
56                .segments
57                .last()
58                .map(|s| s.ident.to_string())
59                .unwrap_or_default();
60            let (size, align) = primitive_size_align(&name, arch);
61            (size, align, TypeInfo::Primitive { name, size, align })
62        }
63        Type::Ptr(_) | Type::Reference(_) => {
64            let s = arch.pointer_size;
65            (s, s, TypeInfo::Pointer { size: s, align: s })
66        }
67        Type::Array(arr) => {
68            let (elem_size, elem_align, elem_ty) = rust_type_size_align(&arr.elem, arch);
69            let count = array_len_from_expr(&arr.len);
70            let size = elem_size * count;
71            (
72                size,
73                elem_align,
74                TypeInfo::Array {
75                    element: Box::new(elem_ty),
76                    count,
77                    size,
78                    align: elem_align,
79                },
80            )
81        }
82        _ => {
83            let s = arch.pointer_size;
84            (
85                s,
86                s,
87                TypeInfo::Opaque {
88                    name: "(unknown)".into(),
89                    size: s,
90                    align: s,
91                },
92            )
93        }
94    }
95}
96
97fn primitive_size_align(name: &str, arch: &'static ArchConfig) -> (usize, usize) {
98    let ps = arch.pointer_size;
99    match name {
100        // ── language primitives ───────────────────────────────────────────────
101        "bool" | "u8" | "i8" => (1, 1),
102        "u16" | "i16" => (2, 2),
103        "u32" | "i32" | "f32" => (4, 4),
104        "u64" | "i64" | "f64" => (8, 8),
105        "u128" | "i128" => (16, 16),
106        "usize" | "isize" => (ps, ps),
107        "char" => (4, 4), // Rust char is a Unicode scalar (4 bytes)
108
109        // ── std atomics ───────────────────────────────────────────────────────
110        "AtomicBool" | "AtomicU8" | "AtomicI8" => (1, 1),
111        "AtomicU16" | "AtomicI16" => (2, 2),
112        "AtomicU32" | "AtomicI32" => (4, 4),
113        "AtomicU64" | "AtomicI64" => (8, 8),
114        "AtomicUsize" | "AtomicIsize" | "AtomicPtr" => (ps, ps),
115
116        // ── heap-allocated collections: ptr + len + cap (3 words) ────────────
117        // Size is independent of the element type T (generic arg already stripped).
118        "Vec" | "String" | "OsString" | "CString" | "PathBuf" => (3 * ps, ps),
119        "VecDeque" | "LinkedList" | "BinaryHeap" => (3 * ps, ps),
120        "HashMap" | "HashSet" | "BTreeMap" | "BTreeSet" => (3 * ps, ps),
121
122        // ── single-pointer smart pointers ─────────────────────────────────────
123        "Box" | "Rc" | "Arc" | "Weak" | "NonNull" | "Cell" => (ps, ps),
124
125        // ── interior-mutability / sync wrappers ───────────────────────────────
126        // Size depends on T but pointer-size is a reasonable approximation for
127        // display purposes; use binary analysis for precise results.
128        "RefCell" | "Mutex" | "RwLock" => (ps, ps),
129
130        // ── channels ─────────────────────────────────────────────────────────
131        "Sender" | "Receiver" | "SyncSender" => (ps, ps),
132
133        // ── zero-sized types ──────────────────────────────────────────────────
134        "PhantomData" | "PhantomPinned" => (0, 1),
135
136        // ── common fixed-size stdlib types ────────────────────────────────────
137        // Duration: u64 secs (8B) + u32 nanos (4B) → 12B + 4B trailing = 16B
138        "Duration" => (16, 8),
139        "Instant" | "SystemTime" => (16, 8),
140
141        // ── Pin<T> wraps T, pointer-size approximation ────────────────────────
142        "Pin" => (ps, ps),
143
144        // ── x86 SSE / AVX / AVX-512 SIMD types ───────────────────────────────
145        "__m64" => (8, 8),
146        "__m128" | "__m128d" | "__m128i" => (16, 16),
147        "__m256" | "__m256d" | "__m256i" => (32, 32),
148        "__m512" | "__m512d" | "__m512i" => (64, 64),
149
150        // ── Rust portable SIMD / packed_simd types ────────────────────────────
151        "f32x4" | "i32x4" | "u32x4" => (16, 16),
152        "f64x2" | "i64x2" | "u64x2" => (16, 16),
153        "f32x8" | "i32x8" | "u32x8" => (32, 32),
154        "f64x4" | "i64x4" | "u64x4" => (32, 32),
155        "f32x16" | "i32x16" | "u32x16" => (64, 64),
156
157        // ── unknown / third-party / generic type params (T, E, …) ────────────
158        _ => (ps, ps),
159    }
160}
161
162fn array_len_from_expr(expr: &syn::Expr) -> usize {
163    if let syn::Expr::Lit(syn::ExprLit {
164        lit: syn::Lit::Int(n),
165        ..
166    }) = expr
167    {
168        n.base10_parse::<usize>().unwrap_or(0)
169    } else {
170        0
171    }
172}
173
174// ── struct repr detection ─────────────────────────────────────────────────────
175
176fn is_packed(attrs: &[syn::Attribute]) -> bool {
177    attrs
178        .iter()
179        .any(|a| a.path().is_ident("repr") && a.to_token_stream().to_string().contains("packed"))
180}
181
182fn simulate_rust_layout(
183    name: String,
184    fields: &[(String, Type)],
185    packed: bool,
186    arch: &'static ArchConfig,
187) -> StructLayout {
188    let mut offset = 0usize;
189    let mut struct_align = 1usize;
190    let mut out_fields: Vec<Field> = Vec::new();
191
192    for (fname, ty) in fields {
193        let (size, align, type_info) = rust_type_size_align(ty, arch);
194        let effective_align = if packed { 1 } else { align };
195
196        if effective_align > 0 {
197            offset = offset.next_multiple_of(effective_align);
198        }
199        struct_align = struct_align.max(effective_align);
200
201        out_fields.push(Field {
202            name: fname.clone(),
203            ty: type_info,
204            offset,
205            size,
206            align: effective_align,
207            source_file: None,
208            source_line: None,
209            access: AccessPattern::Unknown,
210        });
211        offset += size;
212    }
213
214    if !packed && struct_align > 0 {
215        offset = offset.next_multiple_of(struct_align);
216    }
217
218    StructLayout {
219        name,
220        total_size: offset,
221        align: struct_align,
222        fields: out_fields,
223        source_file: None,
224        source_line: None,
225        arch,
226        is_packed: packed,
227        is_union: false,
228    }
229}
230
231// ── visitor ───────────────────────────────────────────────────────────────────
232
233struct StructVisitor {
234    arch: &'static ArchConfig,
235    layouts: Vec<StructLayout>,
236}
237
238impl<'ast> Visit<'ast> for StructVisitor {
239    fn visit_item_struct(&mut self, node: &'ast ItemStruct) {
240        syn::visit::visit_item_struct(self, node); // recurse into nested items
241
242        // Generic structs (e.g. `struct Foo<T>`) cannot be accurately laid out
243        // without knowing the concrete type arguments. Skip them rather than
244        // producing wrong field sizes for the type parameters.
245        if !node.generics.params.is_empty() {
246            return;
247        }
248
249        let name = node.ident.to_string();
250        let packed = is_packed(&node.attrs);
251
252        // Collect (field_name, type, optional_guard)
253        let fields: Vec<(String, Type, Option<String>)> = match &node.fields {
254            Fields::Named(nf) => nf
255                .named
256                .iter()
257                .map(|f| {
258                    let fname = f.ident.as_ref().map(|i| i.to_string()).unwrap_or_default();
259                    let guard = extract_guard_from_attrs(&f.attrs);
260                    (fname, f.ty.clone(), guard)
261                })
262                .collect(),
263            Fields::Unnamed(uf) => uf
264                .unnamed
265                .iter()
266                .enumerate()
267                .map(|(i, f)| {
268                    let guard = extract_guard_from_attrs(&f.attrs);
269                    (format!("_{i}"), f.ty.clone(), guard)
270                })
271                .collect(),
272            Fields::Unit => vec![],
273        };
274
275        let name_ty: Vec<(String, Type)> = fields
276            .iter()
277            .map(|(n, t, _)| (n.clone(), t.clone()))
278            .collect();
279        let mut layout = simulate_rust_layout(name, &name_ty, packed, self.arch);
280        layout.source_line = Some(node.ident.span().start().line as u32);
281
282        // Apply explicit guard annotations; these take precedence over the
283        // heuristic type-name pass in concurrency.rs (which skips non-Unknown fields).
284        for (i, (_, _, guard)) in fields.iter().enumerate() {
285            if let Some(g) = guard {
286                layout.fields[i].access = AccessPattern::Concurrent {
287                    guard: Some(g.clone()),
288                    is_atomic: false,
289                };
290            }
291        }
292
293        self.layouts.push(layout);
294    }
295}
296
297// ── public API ────────────────────────────────────────────────────────────────
298
299pub fn parse_rust(source: &str, arch: &'static ArchConfig) -> anyhow::Result<Vec<StructLayout>> {
300    let file: syn::File = syn::parse_str(source)?;
301    let mut visitor = StructVisitor {
302        arch,
303        layouts: Vec::new(),
304    };
305    visitor.visit_file(&file);
306    Ok(visitor.layouts)
307}
308
309// ── tests ─────────────────────────────────────────────────────────────────────
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314    use padlock_core::arch::X86_64_SYSV;
315
316    #[test]
317    fn parse_simple_struct() {
318        let src = "struct Foo { a: u8, b: u64, c: u32 }";
319        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
320        assert_eq!(layouts.len(), 1);
321        let l = &layouts[0];
322        assert_eq!(l.name, "Foo");
323        assert_eq!(l.fields.len(), 3);
324        assert_eq!(l.fields[0].size, 1); // u8
325        assert_eq!(l.fields[1].size, 8); // u64
326        assert_eq!(l.fields[2].size, 4); // u32
327    }
328
329    #[test]
330    fn layout_includes_padding() {
331        // u8 then u64: 7 bytes padding inserted
332        let src = "struct T { a: u8, b: u64 }";
333        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
334        let l = &layouts[0];
335        assert_eq!(l.fields[0].offset, 0);
336        assert_eq!(l.fields[1].offset, 8); // u64 aligned to 8
337        assert_eq!(l.total_size, 16);
338        let gaps = padlock_core::ir::find_padding(l);
339        assert_eq!(gaps[0].bytes, 7);
340    }
341
342    #[test]
343    fn multiple_structs_parsed() {
344        let src = "struct A { x: u32 } struct B { y: u64 }";
345        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
346        assert_eq!(layouts.len(), 2);
347    }
348
349    #[test]
350    fn packed_struct_no_padding() {
351        let src = "#[repr(packed)] struct P { a: u8, b: u64 }";
352        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
353        let l = &layouts[0];
354        assert!(l.is_packed);
355        assert_eq!(l.fields[1].offset, 1); // no padding, b immediately after a
356        let gaps = padlock_core::ir::find_padding(l);
357        assert!(gaps.is_empty());
358    }
359
360    #[test]
361    fn pointer_field_uses_arch_size() {
362        let src = "struct S { p: *const u8 }";
363        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
364        assert_eq!(layouts[0].fields[0].size, 8); // 64-bit pointer
365    }
366
367    // ── attribute guard extraction ─────────────────────────────────────────────
368
369    #[test]
370    fn lock_protected_by_attr_sets_guard() {
371        let src = r#"
372struct Cache {
373    #[lock_protected_by = "mu"]
374    readers: u64,
375    mu: u64,
376}
377"#;
378        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
379        let readers = &layouts[0].fields[0];
380        assert_eq!(readers.name, "readers");
381        if let AccessPattern::Concurrent { guard, .. } = &readers.access {
382            assert_eq!(guard.as_deref(), Some("mu"));
383        } else {
384            panic!("expected Concurrent, got {:?}", readers.access);
385        }
386    }
387
388    #[test]
389    fn guarded_by_string_attr_sets_guard() {
390        let src = r#"
391struct S {
392    #[guarded_by("lock")]
393    value: u32,
394}
395"#;
396        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
397        if let AccessPattern::Concurrent { guard, .. } = &layouts[0].fields[0].access {
398            assert_eq!(guard.as_deref(), Some("lock"));
399        } else {
400            panic!("expected Concurrent");
401        }
402    }
403
404    #[test]
405    fn guarded_by_ident_attr_sets_guard() {
406        let src = r#"
407struct S {
408    #[guarded_by(mu)]
409    count: u64,
410}
411"#;
412        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
413        if let AccessPattern::Concurrent { guard, .. } = &layouts[0].fields[0].access {
414            assert_eq!(guard.as_deref(), Some("mu"));
415        } else {
416            panic!("expected Concurrent");
417        }
418    }
419
420    #[test]
421    fn protected_by_attr_sets_guard() {
422        let src = r#"
423struct S {
424    #[protected_by = "lock_a"]
425    x: u64,
426}
427"#;
428        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
429        if let AccessPattern::Concurrent { guard, .. } = &layouts[0].fields[0].access {
430            assert_eq!(guard.as_deref(), Some("lock_a"));
431        } else {
432            panic!("expected Concurrent");
433        }
434    }
435
436    #[test]
437    fn different_guards_on_same_cache_line_is_false_sharing() {
438        // readers and writers are at offsets 0 and 8 — same cache line (line 0).
439        // They have different explicit guards → confirmed false sharing.
440        let src = r#"
441struct HotPath {
442    #[lock_protected_by = "mu_a"]
443    readers: u64,
444    #[lock_protected_by = "mu_b"]
445    writers: u64,
446}
447"#;
448        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
449        assert!(padlock_core::analysis::false_sharing::has_false_sharing(
450            &layouts[0]
451        ));
452    }
453
454    #[test]
455    fn same_guard_on_same_cache_line_is_not_false_sharing() {
456        let src = r#"
457struct Safe {
458    #[lock_protected_by = "mu"]
459    a: u64,
460    #[lock_protected_by = "mu"]
461    b: u64,
462}
463"#;
464        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
465        assert!(!padlock_core::analysis::false_sharing::has_false_sharing(
466            &layouts[0]
467        ));
468    }
469
470    #[test]
471    fn unannotated_field_stays_unknown() {
472        let src = "struct S { x: u64 }";
473        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
474        assert!(matches!(
475            layouts[0].fields[0].access,
476            AccessPattern::Unknown
477        ));
478    }
479
480    // ── stdlib type sizes ─────────────────────────────────────────────────────
481
482    #[test]
483    fn vec_field_has_three_pointer_size() {
484        // Vec<T> is always ptr + len + cap regardless of T
485        let src = "struct S { items: Vec<u64> }";
486        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
487        assert_eq!(layouts[0].fields[0].size, 24); // 3 × 8 on x86-64
488    }
489
490    #[test]
491    fn string_field_has_three_pointer_size() {
492        let src = "struct S { name: String }";
493        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
494        assert_eq!(layouts[0].fields[0].size, 24);
495    }
496
497    #[test]
498    fn box_field_has_pointer_size() {
499        let src = "struct S { inner: Box<u64> }";
500        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
501        assert_eq!(layouts[0].fields[0].size, 8);
502    }
503
504    #[test]
505    fn arc_field_has_pointer_size() {
506        let src = "struct S { shared: Arc<Vec<u8>> }";
507        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
508        assert_eq!(layouts[0].fields[0].size, 8);
509    }
510
511    #[test]
512    fn phantom_data_is_zero_sized() {
513        let src = "struct S { a: u64, _marker: PhantomData<u8> }";
514        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
515        let marker = layouts[0]
516            .fields
517            .iter()
518            .find(|f| f.name == "_marker")
519            .unwrap();
520        assert_eq!(marker.size, 0);
521    }
522
523    #[test]
524    fn duration_field_is_16_bytes() {
525        let src = "struct S { timeout: Duration }";
526        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
527        assert_eq!(layouts[0].fields[0].size, 16);
528    }
529
530    #[test]
531    fn atomic_u64_has_correct_size() {
532        let src = "struct S { counter: AtomicU64 }";
533        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
534        assert_eq!(layouts[0].fields[0].size, 8);
535    }
536
537    #[test]
538    fn atomic_bool_has_correct_size() {
539        let src = "struct S { flag: AtomicBool }";
540        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
541        assert_eq!(layouts[0].fields[0].size, 1);
542    }
543
544    // ── generic struct skipping ───────────────────────────────────────────────
545
546    #[test]
547    fn generic_struct_is_skipped() {
548        // Cannot accurately lay out struct Foo<T> without knowing T.
549        let src = "struct Wrapper<T> { value: T, count: usize }";
550        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
551        assert!(
552            layouts.is_empty(),
553            "generic structs should be skipped; got {:?}",
554            layouts.iter().map(|l| &l.name).collect::<Vec<_>>()
555        );
556    }
557
558    #[test]
559    fn generic_struct_with_multiple_params_is_skipped() {
560        let src = "struct Pair<A, B> { first: A, second: B }";
561        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
562        assert!(layouts.is_empty());
563    }
564
565    #[test]
566    fn non_generic_struct_still_parsed_when_generic_sibling_exists() {
567        let src = r#"
568struct Generic<T> { value: T }
569struct Concrete { a: u32, b: u64 }
570"#;
571        let layouts = parse_rust(src, &X86_64_SYSV).unwrap();
572        assert_eq!(layouts.len(), 1);
573        assert_eq!(layouts[0].name, "Concrete");
574    }
575}