Skip to main content

padlock_source/
lib.rs

1// padlock-source/src/lib.rs
2
3pub mod concurrency;
4pub mod fixgen;
5pub mod frontends;
6
7use std::collections::HashMap;
8use std::path::Path;
9
10use padlock_core::arch::ArchConfig;
11use padlock_core::ir::{StructLayout, TypeInfo};
12
13#[derive(Debug, Clone, PartialEq)]
14pub enum SourceLanguage {
15    C,
16    Cpp,
17    Rust,
18    Go,
19}
20
21/// Detect language from file extension.
22pub fn detect_language(path: &Path) -> Option<SourceLanguage> {
23    match path.extension().and_then(|e| e.to_str()) {
24        Some("c") | Some("h") => Some(SourceLanguage::C),
25        Some("cpp") | Some("cc") | Some("cxx") | Some("hpp") => Some(SourceLanguage::Cpp),
26        Some("rs") => Some(SourceLanguage::Rust),
27        Some("go") => Some(SourceLanguage::Go),
28        _ => None,
29    }
30}
31
32/// Parse a source file and return struct layouts.
33pub fn parse_source(path: &Path, arch: &'static ArchConfig) -> anyhow::Result<Vec<StructLayout>> {
34    let lang = detect_language(path)
35        .ok_or_else(|| anyhow::anyhow!("unsupported file type: {}", path.display()))?;
36    let source = std::fs::read_to_string(path)?;
37    let mut layouts = parse_source_str(&source, &lang, arch)?;
38    let file_str = path.to_string_lossy().into_owned();
39    for layout in &mut layouts {
40        layout.source_file = Some(file_str.clone());
41    }
42    Ok(layouts)
43}
44
45/// Parse source text directly (useful for tests and piped input).
46pub fn parse_source_str(
47    source: &str,
48    lang: &SourceLanguage,
49    arch: &'static ArchConfig,
50) -> anyhow::Result<Vec<StructLayout>> {
51    let mut layouts = match lang {
52        SourceLanguage::C => frontends::c_cpp::parse_c(source, arch)?,
53        SourceLanguage::Cpp => frontends::c_cpp::parse_cpp(source, arch)?,
54        SourceLanguage::Rust => frontends::rust::parse_rust(source, arch)?,
55        SourceLanguage::Go => frontends::go::parse_go(source, arch)?,
56    };
57
58    // Resolve fields whose type names match other structs in this file.
59    // This makes nested struct sizes accurate (instead of defaulting to pointer size).
60    resolve_nested_structs(&mut layouts);
61
62    // Annotate concurrency patterns
63    for layout in &mut layouts {
64        concurrency::annotate_concurrency(layout, lang);
65    }
66
67    // Remove structs explicitly opted out via `// padlock:ignore`
68    layouts.retain(|layout| !is_padlock_ignored(source, &layout.name));
69
70    Ok(layouts)
71}
72
73// ── nested struct resolution ──────────────────────────────────────────────────
74
75/// Returns true if `name` is a well-known primitive type name in any supported
76/// language. These must never be shadowed by a user-defined struct name.
77fn is_known_primitive(name: &str) -> bool {
78    matches!(
79        name,
80        // Rust primitives
81        "bool" | "u8" | "i8" | "u16" | "i16" | "u32" | "i32" | "f32" | "u64" | "i64" | "f64"
82            | "u128" | "i128" | "usize" | "isize" | "char" | "str"
83            // C/C++ primitives
84            | "int" | "long" | "short" | "float" | "double" | "void"
85            | "int8_t" | "uint8_t" | "int16_t" | "uint16_t" | "int32_t" | "uint32_t"
86            | "int64_t" | "uint64_t" | "size_t" | "ssize_t" | "ptrdiff_t"
87            | "intptr_t" | "uintptr_t" | "_Bool"
88            // Go primitives
89            | "int8" | "uint8" | "byte" | "int16" | "uint16" | "int32" | "uint32"
90            | "int64" | "uint64" | "float32" | "float64" | "complex64" | "complex128"
91            | "rune" | "string" | "error"
92            // SIMD
93            | "__m64" | "__m128" | "__m128d" | "__m128i"
94            | "__m256" | "__m256d" | "__m256i"
95            | "__m512" | "__m512d" | "__m512i"
96    )
97}
98
99/// Resolve fields whose type name matches another parsed struct.
100///
101/// Runs in a loop until stable to handle transitive nesting (struct A contains
102/// B which contains C). In practice, 2–3 iterations suffice for typical code.
103fn resolve_nested_structs(layouts: &mut [StructLayout]) {
104    loop {
105        // Build name → (total_size, align) from whatever we have so far.
106        let known: HashMap<String, (usize, usize)> = layouts
107            .iter()
108            .map(|l| (l.name.clone(), (l.total_size, l.align)))
109            .collect();
110
111        let mut changed_any = false;
112
113        for layout in layouts.iter_mut() {
114            let mut changed = false;
115
116            for field in layout.fields.iter_mut() {
117                // Extract the type name from Primitive or Opaque variants.
118                // Struct/Pointer/Array variants are already correctly sized.
119                let type_name: String = match &field.ty {
120                    TypeInfo::Primitive { name, .. } | TypeInfo::Opaque { name, .. } => {
121                        name.clone()
122                    }
123                    _ => continue,
124                };
125
126                // Never shadow built-in primitives.
127                if is_known_primitive(&type_name) {
128                    continue;
129                }
130
131                // Don't resolve a struct to itself (circular).
132                if type_name == layout.name {
133                    continue;
134                }
135
136                if let Some(&(struct_size, struct_align)) = known.get(&type_name) {
137                    // Only update if the size would change — avoids infinite loops
138                    // for pointer-sized structs that already have the right size.
139                    if field.size == struct_size && field.align == struct_align {
140                        continue;
141                    }
142                    let eff_align = if layout.is_packed { 1 } else { struct_align };
143                    field.ty = TypeInfo::Opaque {
144                        name: type_name,
145                        size: struct_size,
146                        align: struct_align,
147                    };
148                    field.size = struct_size;
149                    field.align = eff_align;
150                    changed = true;
151                }
152            }
153
154            if changed {
155                resimulate_layout(layout);
156                changed_any = true;
157            }
158        }
159
160        if !changed_any {
161            break;
162        }
163    }
164}
165
166/// Re-simulate field offsets and total_size after field sizes have been updated.
167fn resimulate_layout(layout: &mut StructLayout) {
168    if layout.is_union {
169        for field in layout.fields.iter_mut() {
170            field.offset = 0;
171        }
172        let max_size = layout.fields.iter().map(|f| f.size).max().unwrap_or(0);
173        let max_align = layout.fields.iter().map(|f| f.align).max().unwrap_or(1);
174        layout.total_size = if max_align > 0 {
175            max_size.next_multiple_of(max_align)
176        } else {
177            max_size
178        };
179        layout.align = max_align;
180        return;
181    }
182
183    let packed = layout.is_packed;
184    let mut offset = 0usize;
185    let mut struct_align = 1usize;
186
187    for field in layout.fields.iter_mut() {
188        let eff_align = if packed { 1 } else { field.align };
189        if eff_align > 0 {
190            offset = offset.next_multiple_of(eff_align);
191        }
192        field.offset = offset;
193        offset += field.size;
194        struct_align = struct_align.max(eff_align);
195    }
196
197    if !packed && struct_align > 0 {
198        offset = offset.next_multiple_of(struct_align);
199    }
200
201    layout.total_size = offset;
202    layout.align = struct_align;
203}
204
205/// Returns `true` if a `// padlock:ignore` comment appears on the line
206/// immediately before (or inline on the same line as) the struct/union/type
207/// declaration for `struct_name`.
208///
209/// This allows callers to suppress analysis for a specific struct by writing:
210/// ```c
211/// // padlock:ignore
212/// struct MySpecialLayout { ... };
213/// ```
214fn is_padlock_ignored(source: &str, struct_name: &str) -> bool {
215    // Keywords that introduce named type definitions across all supported languages
216    for keyword in &["struct", "union", "type"] {
217        let needle = format!("{keyword} {struct_name}");
218        let mut search = 0usize;
219        while let Some(rel) = source[search..].find(&needle) {
220            let abs = search + rel;
221            // Ensure the character after the name is a word boundary (not part of a longer name)
222            let after_name = abs + needle.len();
223            let is_boundary = source[after_name..]
224                .chars()
225                .next()
226                .is_none_or(|c| !c.is_alphanumeric() && c != '_');
227            if is_boundary {
228                let line_start = source[..abs].rfind('\n').map(|i| i + 1).unwrap_or(0);
229                // Check the line containing the struct keyword for an inline annotation
230                let line_end = source[abs..]
231                    .find('\n')
232                    .map(|i| abs + i)
233                    .unwrap_or(source.len());
234                if source[line_start..line_end].contains("padlock:ignore") {
235                    return true;
236                }
237                // Check the immediately preceding line for an annotation comment.
238                // Only accept it if the preceding line is a pure comment (starts with `//`
239                // after trimming), so that an inline annotation on a prior struct's closing
240                // line doesn't accidentally suppress the following struct.
241                if line_start > 0 {
242                    let prev_end = line_start - 1;
243                    let prev_start = source[..prev_end].rfind('\n').map(|i| i + 1).unwrap_or(0);
244                    let prev_trimmed = source[prev_start..prev_end].trim();
245                    if prev_trimmed.starts_with("//") && prev_trimmed.contains("padlock:ignore") {
246                        return true;
247                    }
248                }
249            }
250            search = abs + 1;
251        }
252    }
253    false
254}
255
256// ── tests ─────────────────────────────────────────────────────────────────────
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261    use padlock_core::arch::X86_64_SYSV;
262
263    #[test]
264    fn detect_c_extensions() {
265        assert_eq!(detect_language(Path::new("foo.c")), Some(SourceLanguage::C));
266        assert_eq!(detect_language(Path::new("foo.h")), Some(SourceLanguage::C));
267    }
268
269    #[test]
270    fn detect_cpp_extensions() {
271        assert_eq!(
272            detect_language(Path::new("foo.cpp")),
273            Some(SourceLanguage::Cpp)
274        );
275        assert_eq!(
276            detect_language(Path::new("foo.cc")),
277            Some(SourceLanguage::Cpp)
278        );
279        assert_eq!(
280            detect_language(Path::new("foo.hpp")),
281            Some(SourceLanguage::Cpp)
282        );
283    }
284
285    #[test]
286    fn detect_rust_extension() {
287        assert_eq!(
288            detect_language(Path::new("foo.rs")),
289            Some(SourceLanguage::Rust)
290        );
291    }
292
293    #[test]
294    fn detect_go_extension() {
295        assert_eq!(
296            detect_language(Path::new("foo.go")),
297            Some(SourceLanguage::Go)
298        );
299    }
300
301    #[test]
302    fn detect_unknown_is_none() {
303        assert_eq!(detect_language(Path::new("foo.py")), None);
304        assert_eq!(detect_language(Path::new("foo")), None);
305    }
306
307    #[test]
308    fn parse_source_str_c_roundtrip() {
309        let src = "struct Point { int x; int y; };";
310        let layouts = parse_source_str(src, &SourceLanguage::C, &X86_64_SYSV).unwrap();
311        assert_eq!(layouts.len(), 1);
312        assert_eq!(layouts[0].name, "Point");
313    }
314
315    #[test]
316    fn parse_source_str_rust_roundtrip() {
317        let src = "struct Foo { x: u32, y: u64 }";
318        let layouts = parse_source_str(src, &SourceLanguage::Rust, &X86_64_SYSV).unwrap();
319        assert_eq!(layouts.len(), 1);
320        assert_eq!(layouts[0].name, "Foo");
321    }
322
323    #[test]
324    fn padlock_ignore_suppresses_c_struct() {
325        let src = "// padlock:ignore\nstruct Hidden { int x; int y; };\nstruct Visible { int a; };";
326        let layouts = parse_source_str(src, &SourceLanguage::C, &X86_64_SYSV).unwrap();
327        assert_eq!(layouts.len(), 1);
328        assert_eq!(layouts[0].name, "Visible");
329    }
330
331    #[test]
332    fn padlock_ignore_inline_suppresses_c_struct() {
333        // Inline annotation on the struct's own line suppresses it, but must NOT
334        // suppress the struct that follows (the next struct's preceding line is a
335        // code line with a trailing comment, not a pure `//` comment line).
336        let src = "struct Hidden { int x; }; // padlock:ignore\nstruct Visible { int a; };";
337        let layouts = parse_source_str(src, &SourceLanguage::C, &X86_64_SYSV).unwrap();
338        assert_eq!(layouts.len(), 1, "only Visible should remain");
339        assert_eq!(layouts[0].name, "Visible");
340    }
341
342    #[test]
343    fn padlock_ignore_suppresses_rust_struct() {
344        let src = "// padlock:ignore\nstruct Hidden { x: u32 }\nstruct Visible { a: u32 }";
345        let layouts = parse_source_str(src, &SourceLanguage::Rust, &X86_64_SYSV).unwrap();
346        assert_eq!(layouts.len(), 1);
347        assert_eq!(layouts[0].name, "Visible");
348    }
349
350    #[test]
351    fn padlock_ignore_without_annotation_keeps_struct() {
352        let src = "struct Visible { int x; int y; };";
353        let layouts = parse_source_str(src, &SourceLanguage::C, &X86_64_SYSV).unwrap();
354        assert_eq!(layouts.len(), 1);
355        assert_eq!(layouts[0].name, "Visible");
356    }
357
358    // ── nested struct resolution ───────────────────────────────────────────────
359
360    #[test]
361    fn nested_rust_struct_size_resolved() {
362        // Inner is 8 bytes. Outer has a field of type Inner.
363        // Without resolution, Inner's field size would be pointer_size (8) — coincidentally
364        // correct here, but offset placement still validates the pass runs.
365        let src = "struct Inner { x: u64 }\nstruct Outer { a: u8, b: Inner }";
366        let layouts = parse_source_str(src, &SourceLanguage::Rust, &X86_64_SYSV).unwrap();
367        let outer = layouts.iter().find(|l| l.name == "Outer").unwrap();
368        let b = outer.fields.iter().find(|f| f.name == "b").unwrap();
369        assert_eq!(b.size, 8, "Inner is 8 bytes");
370        assert_eq!(b.align, 8, "Inner aligns to 8");
371        // Outer: u8 at 0, [7 pad], Inner at 8 → total 16
372        assert_eq!(outer.total_size, 16);
373    }
374
375    #[test]
376    fn nested_rust_struct_non_pointer_size_resolved() {
377        // Point is 8 bytes (two i32). Line contains two Points — should be 16 bytes, not
378        // 2 * pointer_size = 16 (same here, but alignment is distinct).
379        let src = "struct Point { x: i32, y: i32 }\nstruct Line { a: Point, b: Point }";
380        let layouts = parse_source_str(src, &SourceLanguage::Rust, &X86_64_SYSV).unwrap();
381        let line = layouts.iter().find(|l| l.name == "Line").unwrap();
382        assert_eq!(line.total_size, 16);
383        assert_eq!(line.fields[0].size, 8);
384        assert_eq!(line.fields[1].size, 8);
385        assert_eq!(line.fields[1].offset, 8);
386    }
387
388    #[test]
389    fn nested_rust_struct_large_inner_triggers_padding() {
390        // SmallHeader: bool (1 byte). BigPayload: [u64; 4] = 32 bytes.
391        // Wrapper { flag: SmallHeader, data: BigPayload }
392        // Without resolution: SmallHeader is pointer-sized (8), total 8+32=40 → wrong.
393        // With resolution: SmallHeader is 1 byte, then 7 pad, then BigPayload at 8 → total 40.
394        // Actually u64 array: [u64;4] parsed as Array of 4 u64 = 32 bytes, align 8.
395        let src = "struct SmallHeader { flag: bool }\nstruct Wrapper { h: SmallHeader, data: u64 }";
396        let layouts = parse_source_str(src, &SourceLanguage::Rust, &X86_64_SYSV).unwrap();
397        let wrapper = layouts.iter().find(|l| l.name == "Wrapper").unwrap();
398        let h = wrapper.fields.iter().find(|f| f.name == "h").unwrap();
399        // SmallHeader has total_size=1, align=1
400        assert_eq!(h.size, 1, "SmallHeader resolved to 1 byte");
401        assert_eq!(h.align, 1);
402        // data (u64, align 8) should be at offset 8 (7 bytes padding after SmallHeader)
403        let data = wrapper.fields.iter().find(|f| f.name == "data").unwrap();
404        assert_eq!(data.offset, 8);
405        assert_eq!(wrapper.total_size, 16);
406    }
407
408    #[test]
409    fn nested_c_struct_resolved() {
410        let src =
411            "struct Vec2 { float x; float y; };\nstruct Rect { struct Vec2 tl; struct Vec2 br; };";
412        let layouts = parse_source_str(src, &SourceLanguage::C, &X86_64_SYSV).unwrap();
413        let rect = layouts.iter().find(|l| l.name == "Rect").unwrap();
414        // Each Vec2 is 8 bytes (two floats). Rect = 16 bytes, no padding.
415        assert_eq!(rect.total_size, 16, "Rect should be 16 bytes");
416        assert_eq!(rect.fields[0].size, 8);
417        assert_eq!(rect.fields[1].size, 8);
418        assert_eq!(rect.fields[1].offset, 8);
419    }
420
421    #[test]
422    fn nested_go_struct_resolved() {
423        let src = "package p\ntype Vec2 struct { X float32; Y float32 }\ntype Rect struct { TL Vec2; BR Vec2 }";
424        let layouts = parse_source_str(src, &SourceLanguage::Go, &X86_64_SYSV).unwrap();
425        let rect = layouts.iter().find(|l| l.name == "Rect").unwrap();
426        assert_eq!(rect.total_size, 16);
427        assert_eq!(rect.fields[0].size, 8);
428        assert_eq!(rect.fields[1].size, 8);
429        assert_eq!(rect.fields[1].offset, 8);
430    }
431
432    #[test]
433    fn primitive_types_not_shadowed_by_struct_resolution() {
434        // A struct named "u64" would be very unusual, but primitives must not be overwritten.
435        let src = "struct Wrapper { x: u64, y: bool }";
436        let layouts = parse_source_str(src, &SourceLanguage::Rust, &X86_64_SYSV).unwrap();
437        let w = &layouts[0];
438        let x = w.fields.iter().find(|f| f.name == "x").unwrap();
439        assert_eq!(x.size, 8, "u64 must stay 8 bytes");
440    }
441
442    #[test]
443    fn is_padlock_ignored_does_not_match_partial_names() {
444        // "struct Foo" annotation must not suppress "struct FooBar"
445        assert!(!is_padlock_ignored(
446            "// padlock:ignore\nstruct FooBar { int x; };",
447            "Foo"
448        ));
449    }
450}