Skip to main content

padlock_source/frontends/
zig.rs

1// padlock-source/src/frontends/zig.rs
2//
3// Extracts struct layouts from Zig source using tree-sitter-zig.
4// Handles regular, extern, and packed struct variants.
5// Sizes use Zig's platform-native alignment rules (same as C on the target arch).
6
7use padlock_core::arch::ArchConfig;
8use padlock_core::ir::{Field, StructLayout, TypeInfo};
9use tree_sitter::{Node, Parser};
10
11// ── type resolution ───────────────────────────────────────────────────────────
12
13fn zig_type_size_align(ty: &str, arch: &'static ArchConfig) -> (usize, usize) {
14    match ty.trim() {
15        "bool" => (1, 1),
16        "u8" | "i8" => (1, 1),
17        "u16" | "i16" | "f16" => (2, 2),
18        "u32" | "i32" | "f32" => (4, 4),
19        "u64" | "i64" | "f64" => (8, 8),
20        "u128" | "i128" | "f128" => (16, 16),
21        // f80 is the x87 80-bit float; stored as 10 bytes, aligned to 16 on x86-64
22        "f80" => (10, 16),
23        "usize" | "isize" => (arch.pointer_size, arch.pointer_size),
24        "void" | "anyopaque" => (0, 1),
25        // comptime-only or type-erased — treat as pointer-sized
26        "type" | "anytype" | "comptime_int" | "comptime_float" => {
27            (arch.pointer_size, arch.pointer_size)
28        }
29        _ => (arch.pointer_size, arch.pointer_size),
30    }
31}
32
33/// Determine size/align of a type node, dispatching by node kind.
34fn type_node_size_align(source: &str, node: Node<'_>, arch: &'static ArchConfig) -> (usize, usize) {
35    match node.kind() {
36        "builtin_type" | "identifier" => {
37            let text = source[node.byte_range()].trim();
38            zig_type_size_align(text, arch)
39        }
40        // *T — single pointer
41        "pointer_type" => (arch.pointer_size, arch.pointer_size),
42        // ?T — optional; if T is a pointer the optional is pointer-sized (null = 0),
43        // otherwise it is T + 1 byte tag, rounded up. Approximate as pointer-sized.
44        "nullable_type" => {
45            // Check if the inner type is a pointer — if so, null-pointer optimisation applies
46            if let Some(inner) = find_child_by_kinds(node, &["pointer_type"]) {
47                let _ = inner; // pointer optionals are pointer-sized
48                (arch.pointer_size, arch.pointer_size)
49            } else if let Some(inner) = find_first_type_child(source, node) {
50                let (sz, al) = type_node_size_align(source, inner, arch);
51                // Add 1 byte tag, round up to alignment
52                let tagged = (sz + 1).next_multiple_of(al.max(1));
53                (tagged, al.max(1))
54            } else {
55                (arch.pointer_size, arch.pointer_size)
56            }
57        }
58        // []T — slice = (ptr, len)
59        "slice_type" => (arch.pointer_size * 2, arch.pointer_size),
60        // [N]T — array; try to parse N and recursively get element size
61        "array_type" => {
62            if let Some((count, elem_sz, elem_al)) = parse_array_type(source, node, arch) {
63                (elem_sz * count, elem_al)
64            } else {
65                (arch.pointer_size, arch.pointer_size)
66            }
67        }
68        // error union E!T — approximate as two words
69        "error_union" => (arch.pointer_size * 2, arch.pointer_size),
70        _ => (arch.pointer_size, arch.pointer_size),
71    }
72}
73
74/// For `[N]T` nodes, return `Some((count, elem_size, elem_align))`.
75fn parse_array_type(
76    source: &str,
77    node: Node<'_>,
78    arch: &'static ArchConfig,
79) -> Option<(usize, usize, usize)> {
80    // array_type children: [ integer_literal ] type_expr
81    let mut count: Option<usize> = None;
82    let mut elem: Option<(usize, usize)> = None;
83
84    for i in 0..node.child_count() {
85        let child = node.child(i)?;
86        match child.kind() {
87            "integer" | "integer_literal" => {
88                let text = source[child.byte_range()].trim();
89                count = text.parse::<usize>().ok();
90            }
91            "builtin_type" | "identifier" | "pointer_type" | "slice_type" | "array_type"
92            | "nullable_type" => {
93                elem = Some(type_node_size_align(source, child, arch));
94            }
95            _ => {}
96        }
97    }
98
99    let count = count?;
100    let (esz, eal) = elem.unwrap_or((arch.pointer_size, arch.pointer_size));
101    Some((count, esz, eal))
102}
103
104fn find_child_by_kinds<'a>(node: Node<'a>, kinds: &[&str]) -> Option<Node<'a>> {
105    for i in 0..node.child_count() {
106        if let Some(c) = node.child(i)
107            && kinds.contains(&c.kind())
108        {
109            return Some(c);
110        }
111    }
112    None
113}
114
115fn find_first_type_child<'a>(source: &str, node: Node<'a>) -> Option<Node<'a>> {
116    let _ = source;
117    for i in 0..node.child_count() {
118        if let Some(c) = node.child(i) {
119            match c.kind() {
120                "builtin_type" | "identifier" | "pointer_type" | "slice_type" | "array_type"
121                | "nullable_type" | "error_union" => return Some(c),
122                _ => {}
123            }
124        }
125    }
126    None
127}
128
129// ── tree-sitter walker ────────────────────────────────────────────────────────
130
131fn extract_structs(source: &str, root: Node<'_>, arch: &'static ArchConfig) -> Vec<StructLayout> {
132    let mut layouts = Vec::new();
133    let mut stack = vec![root];
134
135    while let Some(node) = stack.pop() {
136        for i in (0..node.child_count()).rev() {
137            if let Some(c) = node.child(i) {
138                stack.push(c);
139            }
140        }
141
142        if node.kind() == "variable_declaration"
143            && let Some(layout) = parse_variable_declaration(source, node, arch)
144        {
145            layouts.push(layout);
146        }
147    }
148    layouts
149}
150
151fn parse_variable_declaration(
152    source: &str,
153    node: Node<'_>,
154    arch: &'static ArchConfig,
155) -> Option<StructLayout> {
156    let source_line = node.start_position().row as u32 + 1;
157    let mut name: Option<String> = None;
158    let mut struct_node: Option<Node> = None;
159    let mut union_node: Option<Node> = None;
160
161    for i in 0..node.child_count() {
162        let child = node.child(i)?;
163        match child.kind() {
164            "identifier" => {
165                // The first identifier after `const`/`var` is the name
166                if name.is_none() {
167                    name = Some(source[child.byte_range()].to_string());
168                }
169            }
170            "struct_declaration" => struct_node = Some(child),
171            "union_declaration" => union_node = Some(child),
172            _ => {}
173        }
174    }
175
176    let name = name?;
177    if let Some(sn) = struct_node {
178        parse_struct_declaration(source, sn, name, arch, source_line)
179    } else if let Some(un) = union_node {
180        parse_union_declaration(source, un, name, arch, source_line)
181    } else {
182        None
183    }
184}
185
186/// Parse a Zig `union { ... }` or `union(enum) { ... }` declaration.
187///
188/// Layout rules:
189/// - All fields share the same storage (offset 0), total = max(field sizes).
190/// - Tagged unions add a synthetic `__tag` discriminant field; its size is
191///   the smallest integer that covers the variant count.
192/// - The struct is emitted with `is_union = true`.
193fn parse_union_declaration(
194    source: &str,
195    node: Node<'_>,
196    name: String,
197    arch: &'static ArchConfig,
198    source_line: u32,
199) -> Option<StructLayout> {
200    let mut is_tagged = false;
201    let mut raw_fields: Vec<(String, String, usize, usize)> = Vec::new();
202
203    for i in 0..node.child_count() {
204        let child = node.child(i)?;
205        match child.kind() {
206            // `union(enum)` — `enum` keyword is a direct child
207            "enum" => is_tagged = true,
208            // `union(SomeEnum)` — identifier naming the explicit tag type
209            // We detect this by seeing an identifier inside the `(...)` group.
210            // Mark it as tagged regardless of the tag type.
211            "container_field" => {
212                if let Some(f) = parse_container_field(source, child, arch, false) {
213                    raw_fields.push(f);
214                }
215            }
216            _ => {}
217        }
218    }
219
220    if raw_fields.is_empty() {
221        return None;
222    }
223
224    // Union layout: all fields at offset 0; total = max field size rounded to alignment.
225    let max_size = raw_fields
226        .iter()
227        .map(|(_, _, sz, _)| *sz)
228        .max()
229        .unwrap_or(0);
230    let max_align = raw_fields
231        .iter()
232        .map(|(_, _, _, al)| *al)
233        .max()
234        .unwrap_or(1);
235    let total_size = if max_align > 0 {
236        max_size.next_multiple_of(max_align)
237    } else {
238        max_size
239    };
240
241    let mut fields: Vec<Field> = raw_fields
242        .into_iter()
243        .map(|(fname, type_text, size, align)| Field {
244            name: fname,
245            ty: TypeInfo::Primitive {
246                name: type_text,
247                size,
248                align,
249            },
250            offset: 0,
251            size,
252            align,
253            source_file: None,
254            source_line: None,
255            access: padlock_core::ir::AccessPattern::Unknown,
256        })
257        .collect();
258
259    // Tagged union: add a synthetic `__tag` discriminant field.
260    // Its size is the smallest integer type that holds all variant indices.
261    if is_tagged {
262        let n = fields.len();
263        let tag_size: usize = if n <= 256 {
264            1
265        } else if n <= 65536 {
266            2
267        } else {
268            4
269        };
270        fields.push(Field {
271            name: "__tag".to_string(),
272            ty: TypeInfo::Primitive {
273                name: format!("u{}", tag_size * 8),
274                size: tag_size,
275                align: tag_size,
276            },
277            offset: total_size, // tag lives after the union payload
278            size: tag_size,
279            align: tag_size,
280            source_file: None,
281            source_line: None,
282            access: padlock_core::ir::AccessPattern::Unknown,
283        });
284    }
285
286    let struct_align = max_align; // tag alignment is usually smaller than payload
287
288    let final_size = if is_tagged {
289        let tag_size = fields.last().map(|f| f.size).unwrap_or(0);
290        (total_size + tag_size).next_multiple_of(struct_align.max(1))
291    } else {
292        total_size
293    };
294
295    Some(StructLayout {
296        name,
297        total_size: final_size,
298        align: struct_align,
299        fields,
300        source_file: None,
301        source_line: Some(source_line),
302        arch,
303        is_packed: false,
304        is_union: true,
305    })
306}
307
308fn parse_struct_declaration(
309    source: &str,
310    node: Node<'_>,
311    name: String,
312    arch: &'static ArchConfig,
313    source_line: u32,
314) -> Option<StructLayout> {
315    let mut is_packed = false;
316    let mut is_extern = false;
317    // (field_name, type_text, size, align)
318    let mut raw_fields: Vec<(String, String, usize, usize)> = Vec::new();
319
320    for i in 0..node.child_count() {
321        let child = node.child(i)?;
322        match child.kind() {
323            "packed" => is_packed = true,
324            "extern" => is_extern = true,
325            "container_field" => {
326                if let Some(f) = parse_container_field(source, child, arch, is_packed) {
327                    raw_fields.push(f);
328                }
329            }
330            _ => {}
331        }
332    }
333
334    if raw_fields.is_empty() {
335        return None;
336    }
337
338    // Regular Zig structs have implementation-defined layout (reordering allowed).
339    // Only extern and packed structs have stable C-compatible / bit-exact layout.
340    // For analysis purposes we simulate the declared order for all variants,
341    // since that is what the developer sees and intends to reason about.
342    let mut offset = 0usize;
343    let mut struct_align = 1usize;
344    let mut fields: Vec<Field> = Vec::new();
345
346    for (fname, type_text, size, align) in raw_fields {
347        let eff_align = if is_packed { 1 } else { align };
348        if eff_align > 0 {
349            offset = offset.next_multiple_of(eff_align);
350        }
351        struct_align = struct_align.max(eff_align);
352        fields.push(Field {
353            name: fname,
354            ty: TypeInfo::Primitive {
355                name: type_text,
356                size,
357                align,
358            },
359            offset,
360            size,
361            align: eff_align,
362            source_file: None,
363            source_line: None,
364            access: padlock_core::ir::AccessPattern::Unknown,
365        });
366        offset += size;
367    }
368
369    if !is_packed && struct_align > 0 {
370        offset = offset.next_multiple_of(struct_align);
371    }
372
373    let _ = is_extern; // affects ABI guarantees, not layout simulation
374
375    Some(StructLayout {
376        name,
377        total_size: offset,
378        align: struct_align,
379        fields,
380        source_file: None,
381        source_line: Some(source_line),
382        arch,
383        is_packed,
384        is_union: false,
385    })
386}
387
388/// Parse a `container_field` node and return `(name, type_text, size, align)`.
389fn parse_container_field(
390    source: &str,
391    node: Node<'_>,
392    arch: &'static ArchConfig,
393    is_packed: bool,
394) -> Option<(String, String, usize, usize)> {
395    let mut field_name: Option<String> = None;
396    let mut type_text: Option<String> = None;
397    let mut size_align: Option<(usize, usize)> = None;
398
399    for i in 0..node.child_count() {
400        let child = node.child(i)?;
401        match child.kind() {
402            "identifier" if field_name.is_none() => {
403                field_name = Some(source[child.byte_range()].to_string());
404            }
405            "builtin_type" | "pointer_type" | "nullable_type" | "slice_type" | "array_type"
406            | "error_union" => {
407                let text = source[child.byte_range()].to_string();
408                size_align = Some(type_node_size_align(source, child, arch));
409                type_text = Some(text);
410            }
411            "identifier" => {
412                // Second identifier = type name (e.g. a named struct type)
413                let text = source[child.byte_range()].trim().to_string();
414                size_align = Some(zig_type_size_align(&text, arch));
415                type_text = Some(text);
416            }
417            _ => {}
418        }
419    }
420
421    // Discard fields with empty names — tree-sitter-zig emits a zero-length
422    // identifier node for `union {}` (empty union body), which is not a real field.
423    let name = field_name.filter(|n| !n.is_empty())?;
424    let ty = type_text.unwrap_or_else(|| "anyopaque".to_string());
425    let (mut size, align) = size_align.unwrap_or((arch.pointer_size, arch.pointer_size));
426
427    if is_packed && size == 0 {
428        size = 0; // void fields in packed structs stay 0
429    }
430
431    Some((name, ty, size, align))
432}
433
434// ── public API ────────────────────────────────────────────────────────────────
435
436pub fn parse_zig(source: &str, arch: &'static ArchConfig) -> anyhow::Result<Vec<StructLayout>> {
437    let mut parser = Parser::new();
438    parser.set_language(&tree_sitter_zig::LANGUAGE.into())?;
439    let tree = parser
440        .parse(source, None)
441        .ok_or_else(|| anyhow::anyhow!("tree-sitter-zig parse failed"))?;
442    Ok(extract_structs(source, tree.root_node(), arch))
443}
444
445// ── tests ─────────────────────────────────────────────────────────────────────
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450    use padlock_core::arch::X86_64_SYSV;
451
452    #[test]
453    fn parse_simple_zig_struct() {
454        let src = "const Point = struct { x: u32, y: u32 };";
455        let layouts = parse_zig(src, &X86_64_SYSV).unwrap();
456        assert_eq!(layouts.len(), 1);
457        assert_eq!(layouts[0].name, "Point");
458        assert_eq!(layouts[0].fields.len(), 2);
459        assert_eq!(layouts[0].total_size, 8);
460    }
461
462    #[test]
463    fn zig_layout_with_padding() {
464        let src = "const T = struct { a: bool, b: u64 };";
465        let layouts = parse_zig(src, &X86_64_SYSV).unwrap();
466        assert_eq!(layouts.len(), 1);
467        let l = &layouts[0];
468        assert_eq!(l.fields[0].offset, 0); // bool at 0
469        assert_eq!(l.fields[1].offset, 8); // u64 at 8 (7 bytes padding)
470        assert_eq!(l.total_size, 16);
471    }
472
473    #[test]
474    fn zig_packed_struct_no_padding() {
475        let src = "const Packed = packed struct { a: u8, b: u32 };";
476        let layouts = parse_zig(src, &X86_64_SYSV).unwrap();
477        assert_eq!(layouts.len(), 1);
478        let l = &layouts[0];
479        assert!(l.is_packed);
480        assert_eq!(l.fields[0].offset, 0);
481        assert_eq!(l.fields[1].offset, 1); // immediately after u8, no padding
482        assert_eq!(l.total_size, 5);
483    }
484
485    #[test]
486    fn zig_extern_struct_detected() {
487        let src = "const Extern = extern struct { x: i32, y: f64 };";
488        let layouts = parse_zig(src, &X86_64_SYSV).unwrap();
489        assert_eq!(layouts.len(), 1);
490        let l = &layouts[0];
491        // extern struct has C layout: x at 0 (4B), 4B pad, y at 8 (8B)
492        assert_eq!(l.fields[0].offset, 0);
493        assert_eq!(l.fields[1].offset, 8);
494        assert_eq!(l.total_size, 16);
495    }
496
497    #[test]
498    fn zig_pointer_field_is_pointer_sized() {
499        let src = "const S = struct { ptr: *u8 };";
500        let layouts = parse_zig(src, &X86_64_SYSV).unwrap();
501        assert_eq!(layouts[0].fields[0].size, 8);
502        assert_eq!(layouts[0].fields[0].align, 8);
503    }
504
505    #[test]
506    fn zig_optional_pointer_is_pointer_sized() {
507        let src = "const S = struct { opt: ?*u8 };";
508        let layouts = parse_zig(src, &X86_64_SYSV).unwrap();
509        assert_eq!(layouts[0].fields[0].size, 8);
510    }
511
512    #[test]
513    fn zig_slice_is_two_words() {
514        let src = "const S = struct { buf: []u8 };";
515        let layouts = parse_zig(src, &X86_64_SYSV).unwrap();
516        assert_eq!(layouts[0].fields[0].size, 16); // ptr + len
517    }
518
519    #[test]
520    fn zig_usize_follows_arch() {
521        let src = "const S = struct { n: usize };";
522        let layouts = parse_zig(src, &X86_64_SYSV).unwrap();
523        assert_eq!(layouts[0].fields[0].size, 8);
524    }
525
526    #[test]
527    fn zig_multiple_structs_parsed() {
528        let src = "const A = struct { x: u8 };\nconst B = struct { y: u64 };";
529        let layouts = parse_zig(src, &X86_64_SYSV).unwrap();
530        assert_eq!(layouts.len(), 2);
531        assert!(layouts.iter().any(|l| l.name == "A"));
532        assert!(layouts.iter().any(|l| l.name == "B"));
533    }
534
535    #[test]
536    fn zig_array_field_size() {
537        let src = "const S = struct { buf: [4]u32 };";
538        let layouts = parse_zig(src, &X86_64_SYSV).unwrap();
539        assert_eq!(layouts[0].fields[0].size, 16); // 4 * 4
540    }
541
542    // ── union / tagged union ──────────────────────────────────────────────────
543
544    #[test]
545    fn zig_bare_union_parsed_as_union() {
546        let src = "const U = union { a: u8, b: u32 };";
547        let layouts = parse_zig(src, &X86_64_SYSV).unwrap();
548        assert_eq!(layouts.len(), 1);
549        let l = &layouts[0];
550        assert_eq!(l.name, "U");
551        assert!(l.is_union, "union should have is_union=true");
552    }
553
554    #[test]
555    fn zig_bare_union_total_size_is_max_field() {
556        // a: u8 (1B), b: u32 (4B) → max = 4B, aligned to 4
557        let src = "const U = union { a: u8, b: u32 };";
558        let layouts = parse_zig(src, &X86_64_SYSV).unwrap();
559        let l = &layouts[0];
560        assert_eq!(l.total_size, 4);
561    }
562
563    #[test]
564    fn zig_union_all_fields_at_offset_zero() {
565        let src = "const U = union { a: u8, b: u64 };";
566        let layouts = parse_zig(src, &X86_64_SYSV).unwrap();
567        let l = &layouts[0];
568        for field in &l.fields {
569            assert_eq!(
570                field.offset, 0,
571                "union field '{}' should be at offset 0",
572                field.name
573            );
574        }
575    }
576
577    #[test]
578    fn zig_tagged_union_has_tag_field() {
579        let src = "const T = union(enum) { ok: u32, err: void };";
580        let layouts = parse_zig(src, &X86_64_SYSV).unwrap();
581        let l = &layouts[0];
582        assert!(
583            l.fields.iter().any(|f| f.name == "__tag"),
584            "tagged union should have a synthetic __tag field"
585        );
586    }
587
588    #[test]
589    fn zig_tagged_union_size_includes_tag() {
590        // ok: u32 (4B), err: void (0B) → payload = 4B, tag = 1B (2 variants ≤ 256)
591        // total = (4 + 1).next_multiple_of(4) = 8B
592        let src = "const T = union(enum) { ok: u32, err: void };";
593        let layouts = parse_zig(src, &X86_64_SYSV).unwrap();
594        let l = &layouts[0];
595        // payload (4B) + tag (1B) → 5B → rounded to align 4 = 8B
596        assert_eq!(l.total_size, 8);
597    }
598
599    #[test]
600    fn zig_union_with_largest_field_u64() {
601        // a: u8 (1B), b: u64 (8B), c: u32 (4B) → max = 8B, align = 8
602        let src = "const U = union { a: u8, b: u64, c: u32 };";
603        let layouts = parse_zig(src, &X86_64_SYSV).unwrap();
604        let l = &layouts[0];
605        assert_eq!(l.total_size, 8);
606        assert_eq!(l.align, 8);
607    }
608
609    #[test]
610    fn zig_struct_and_union_in_same_file() {
611        let src = "const S = struct { x: u32 };\nconst U = union { a: u8, b: u32 };";
612        let layouts = parse_zig(src, &X86_64_SYSV).unwrap();
613        assert_eq!(layouts.len(), 2);
614        assert!(layouts.iter().any(|l| l.name == "S" && !l.is_union));
615        assert!(layouts.iter().any(|l| l.name == "U" && l.is_union));
616    }
617
618    // ── bad weather: unions ───────────────────────────────────────────────────
619
620    #[test]
621    fn zig_empty_union_returns_none() {
622        // Empty union body → no layout produced
623        let src = "const E = union {};";
624        let layouts = parse_zig(src, &X86_64_SYSV).unwrap();
625        assert!(layouts.is_empty(), "empty union should produce no layout");
626    }
627
628    #[test]
629    fn zig_union_no_padding_finding() {
630        // Unions should never report inter-field padding (all fields at offset 0)
631        let src = "const U = union { a: u8, b: u64 };";
632        let layouts = parse_zig(src, &X86_64_SYSV).unwrap();
633        let gaps = padlock_core::ir::find_padding(&layouts[0]);
634        assert!(
635            gaps.is_empty(),
636            "unions should have no padding gaps: {:?}",
637            gaps
638        );
639    }
640}