Skip to main content

padlock_source/frontends/
go.rs

1// padlock-source/src/frontends/go.rs
2//
3// Extracts struct layouts from Go source using tree-sitter-go.
4// Sizes use Go's platform-native alignment rules (same as C on the target arch).
5
6use padlock_core::arch::ArchConfig;
7use padlock_core::ir::{AccessPattern, Field, StructLayout, TypeInfo};
8use tree_sitter::{Node, Parser};
9
10// ── type resolution ───────────────────────────────────────────────────────────
11
12fn go_type_size_align(ty: &str, arch: &'static ArchConfig) -> (usize, usize) {
13    match ty.trim() {
14        "bool" => (1, 1),
15        "int8" | "uint8" | "byte" => (1, 1),
16        "int16" | "uint16" => (2, 2),
17        "int32" | "uint32" | "rune" | "float32" => (4, 4),
18        "int64" | "uint64" | "float64" | "complex64" => (8, 8),
19        "complex128" => (16, 16),
20        "int" | "uint" => (arch.pointer_size, arch.pointer_size),
21        "uintptr" => (arch.pointer_size, arch.pointer_size),
22        "string" => (arch.pointer_size * 2, arch.pointer_size), // ptr + len
23        ty if ty.starts_with("[]") => (arch.pointer_size * 3, arch.pointer_size), // ptr+len+cap
24        ty if ty.starts_with("map[") || ty.starts_with("chan ") => {
25            (arch.pointer_size, arch.pointer_size)
26        }
27        ty if ty.starts_with('*') => (arch.pointer_size, arch.pointer_size),
28        // Interface types: two-word fat pointer (type pointer + data pointer)
29        "error" | "interface{}" | "any" => (arch.pointer_size * 2, arch.pointer_size),
30        _ => (arch.pointer_size, arch.pointer_size),
31    }
32}
33
34// ── tree-sitter walker ────────────────────────────────────────────────────────
35
36fn extract_structs(source: &str, root: Node<'_>, arch: &'static ArchConfig) -> Vec<StructLayout> {
37    let mut layouts = Vec::new();
38    let mut stack = vec![root];
39
40    while let Some(node) = stack.pop() {
41        for i in (0..node.child_count()).rev() {
42            if let Some(c) = node.child(i) {
43                stack.push(c);
44            }
45        }
46
47        // type_declaration → type_spec → struct_type
48        if node.kind() == "type_declaration"
49            && let Some(layout) = parse_type_declaration(source, node, arch)
50        {
51            layouts.push(layout);
52        }
53    }
54    layouts
55}
56
57fn parse_type_declaration(
58    source: &str,
59    node: Node<'_>,
60    arch: &'static ArchConfig,
61) -> Option<StructLayout> {
62    let source_line = node.start_position().row as u32 + 1;
63    // type_declaration has a type_spec child
64    for i in 0..node.child_count() {
65        let child = node.child(i)?;
66        if child.kind() == "type_spec" {
67            return parse_type_spec(source, child, arch, source_line);
68        }
69    }
70    None
71}
72
73fn parse_type_spec(
74    source: &str,
75    node: Node<'_>,
76    arch: &'static ArchConfig,
77    source_line: u32,
78) -> Option<StructLayout> {
79    let mut name: Option<String> = None;
80    let mut struct_node: Option<Node> = None;
81
82    for i in 0..node.child_count() {
83        let child = node.child(i)?;
84        match child.kind() {
85            "type_identifier" => name = Some(source[child.byte_range()].to_string()),
86            "struct_type" => struct_node = Some(child),
87            _ => {}
88        }
89    }
90
91    let name = name?;
92    let struct_node = struct_node?;
93    parse_struct_type(source, struct_node, name, arch, source_line)
94}
95
96fn parse_struct_type(
97    source: &str,
98    node: Node<'_>,
99    name: String,
100    arch: &'static ArchConfig,
101    source_line: u32,
102) -> Option<StructLayout> {
103    let mut raw_fields: Vec<(String, String, Option<String>)> = Vec::new();
104
105    for i in 0..node.child_count() {
106        let child = node.child(i)?;
107        if child.kind() == "field_declaration_list" {
108            for j in 0..child.child_count() {
109                let field_node = child.child(j)?;
110                if field_node.kind() == "field_declaration" {
111                    collect_field_declarations(source, field_node, &mut raw_fields);
112                }
113            }
114        }
115    }
116
117    if raw_fields.is_empty() {
118        return None;
119    }
120
121    // Simulate layout
122    let mut offset = 0usize;
123    let mut struct_align = 1usize;
124    let mut fields: Vec<Field> = Vec::new();
125
126    for (fname, ty_name, guard) in raw_fields {
127        let (size, align) = go_type_size_align(&ty_name, arch);
128        if align > 0 {
129            offset = offset.next_multiple_of(align);
130        }
131        struct_align = struct_align.max(align);
132        let access = if let Some(g) = guard {
133            AccessPattern::Concurrent {
134                guard: Some(g),
135                is_atomic: false,
136            }
137        } else {
138            AccessPattern::Unknown
139        };
140        fields.push(Field {
141            name: fname,
142            ty: TypeInfo::Primitive {
143                name: ty_name,
144                size,
145                align,
146            },
147            offset,
148            size,
149            align,
150            source_file: None,
151            source_line: None,
152            access,
153        });
154        offset += size;
155    }
156    if struct_align > 0 {
157        offset = offset.next_multiple_of(struct_align);
158    }
159
160    Some(StructLayout {
161        name,
162        total_size: offset,
163        align: struct_align,
164        fields,
165        source_file: None,
166        source_line: Some(source_line),
167        arch,
168        is_packed: false,
169        is_union: false,
170        is_repr_rust: false,
171    })
172}
173
174/// Extract a guard name from a Go field's trailing line comment.
175///
176/// Recognised forms (must appear after the field type on the same line):
177/// - `// padlock:guard=mu`
178/// - `// guarded_by: mu`
179/// - `// +checklocksprotects:mu` (gVisor-style)
180pub fn extract_guard_from_go_comment(comment: &str) -> Option<String> {
181    let c = comment.trim();
182    // Strip leading `//` and optional whitespace
183    let body = c.strip_prefix("//").map(str::trim)?;
184
185    // padlock:guard=mu
186    if let Some(rest) = body.strip_prefix("padlock:guard=") {
187        let guard = rest.trim();
188        if !guard.is_empty() {
189            return Some(guard.to_string());
190        }
191    }
192    // guarded_by: mu
193    if let Some(rest) = body
194        .strip_prefix("guarded_by:")
195        .or_else(|| body.strip_prefix("guarded_by ="))
196    {
197        let guard = rest.trim();
198        if !guard.is_empty() {
199            return Some(guard.to_string());
200        }
201    }
202    // +checklocksprotects:mu (gVisor)
203    if let Some(rest) = body.strip_prefix("+checklocksprotects:") {
204        let guard = rest.trim();
205        if !guard.is_empty() {
206            return Some(guard.to_string());
207        }
208    }
209    None
210}
211
212/// Find the trailing line comment on the same source line as `node`.
213fn trailing_comment_on_line(source: &str, node: Node<'_>) -> Option<String> {
214    // The node's end byte is just past the last token on the field line.
215    // Read the rest of that line from the source.
216    let end = node.end_byte();
217    if end >= source.len() {
218        return None;
219    }
220    let rest = &source[end..];
221    // Take only up to the next newline
222    let line = rest.lines().next().unwrap_or("");
223    // Look for `//` in that remainder
224    line.find("//").map(|pos| line[pos..].to_string())
225}
226
227fn collect_field_declarations(
228    source: &str,
229    node: Node<'_>,
230    out: &mut Vec<(String, String, Option<String>)>,
231) {
232    // field_declaration: field_identifier+ type [comment]
233    // OR embedded type (anonymous field): TypeName [comment]
234    let mut field_names: Vec<String> = Vec::new();
235    let mut ty_text: Option<String> = None;
236
237    for i in 0..node.child_count() {
238        if let Some(child) = node.child(i) {
239            match child.kind() {
240                "field_identifier" => field_names.push(source[child.byte_range()].to_string()),
241                "type_identifier" | "pointer_type" | "qualified_type" | "slice_type"
242                | "map_type" | "channel_type" | "array_type" | "interface_type" => {
243                    ty_text = Some(source[child.byte_range()].trim().to_string());
244                }
245                _ => {}
246            }
247        }
248    }
249
250    let guard =
251        trailing_comment_on_line(source, node).and_then(|c| extract_guard_from_go_comment(&c));
252
253    if !field_names.is_empty() {
254        if let Some(ty) = ty_text {
255            // Normal named fields
256            for name in field_names {
257                out.push((name, ty.clone(), guard.clone()));
258            }
259        }
260    } else if let Some(ty) = ty_text {
261        // Embedded (anonymous) field: `sync.Mutex` or `Base`.
262        // Go field name is the unqualified type name.
263        // The nested-struct resolution pass in lib.rs will later fill in
264        // the correct size/align from other parsed struct layouts.
265        let simple_name = ty.split('.').next_back().unwrap_or(&ty).to_string();
266        out.push((simple_name, ty, guard));
267    }
268}
269
270// ── public API ────────────────────────────────────────────────────────────────
271
272pub fn parse_go(source: &str, arch: &'static ArchConfig) -> anyhow::Result<Vec<StructLayout>> {
273    let mut parser = Parser::new();
274    parser.set_language(&tree_sitter_go::LANGUAGE.into())?;
275    let tree = parser
276        .parse(source, None)
277        .ok_or_else(|| anyhow::anyhow!("tree-sitter-go parse failed"))?;
278    Ok(extract_structs(source, tree.root_node(), arch))
279}
280
281// ── tests ─────────────────────────────────────────────────────────────────────
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286    use padlock_core::arch::X86_64_SYSV;
287
288    #[test]
289    fn parse_simple_go_struct() {
290        let src = r#"
291package main
292type Point struct {
293    X int32
294    Y int32
295}
296"#;
297        let layouts = parse_go(src, &X86_64_SYSV).unwrap();
298        assert_eq!(layouts.len(), 1);
299        assert_eq!(layouts[0].name, "Point");
300        assert_eq!(layouts[0].fields.len(), 2);
301    }
302
303    #[test]
304    fn go_layout_with_padding() {
305        let src = "package p\ntype T struct { A bool; B int64 }";
306        let layouts = parse_go(src, &X86_64_SYSV).unwrap();
307        assert_eq!(layouts.len(), 1);
308        let l = &layouts[0];
309        assert_eq!(l.fields[0].offset, 0);
310        assert_eq!(l.fields[1].offset, 8); // bool (1) + 7 pad → 8
311    }
312
313    #[test]
314    fn go_string_is_two_words() {
315        let src = "package p\ntype S struct { Name string }";
316        let layouts = parse_go(src, &X86_64_SYSV).unwrap();
317        assert_eq!(layouts[0].fields[0].size, 16); // ptr + len
318    }
319
320    // ── Go guard comment extraction ────────────────────────────────────────────
321
322    #[test]
323    fn extract_guard_padlock_form() {
324        assert_eq!(
325            extract_guard_from_go_comment("// padlock:guard=mu"),
326            Some("mu".to_string())
327        );
328    }
329
330    #[test]
331    fn extract_guard_guarded_by_form() {
332        assert_eq!(
333            extract_guard_from_go_comment("// guarded_by: counter_lock"),
334            Some("counter_lock".to_string())
335        );
336    }
337
338    #[test]
339    fn extract_guard_checklocksprotects_form() {
340        assert_eq!(
341            extract_guard_from_go_comment("// +checklocksprotects:mu"),
342            Some("mu".to_string())
343        );
344    }
345
346    #[test]
347    fn extract_guard_no_match_returns_none() {
348        assert!(extract_guard_from_go_comment("// just a comment").is_none());
349        assert!(extract_guard_from_go_comment("// TODO: fix this").is_none());
350    }
351
352    #[test]
353    fn go_struct_padlock_guard_annotation_sets_concurrent() {
354        let src = r#"package p
355type Cache struct {
356    Readers int64 // padlock:guard=mu
357    Writers int64 // padlock:guard=other_mu
358    Mu      sync.Mutex
359}
360"#;
361        let layouts = parse_go(src, &X86_64_SYSV).unwrap();
362        let l = &layouts[0];
363        // Readers and Writers should be Concurrent with different guards
364        if let AccessPattern::Concurrent { guard, .. } = &l.fields[0].access {
365            assert_eq!(guard.as_deref(), Some("mu"));
366        } else {
367            panic!(
368                "expected Concurrent for Readers, got {:?}",
369                l.fields[0].access
370            );
371        }
372        if let AccessPattern::Concurrent { guard, .. } = &l.fields[1].access {
373            assert_eq!(guard.as_deref(), Some("other_mu"));
374        } else {
375            panic!(
376                "expected Concurrent for Writers, got {:?}",
377                l.fields[1].access
378            );
379        }
380    }
381
382    #[test]
383    fn go_struct_different_guards_same_cache_line_is_false_sharing() {
384        let src = r#"package p
385type HotPath struct {
386    Readers int64 // padlock:guard=lock_a
387    Writers int64 // padlock:guard=lock_b
388}
389"#;
390        let layouts = parse_go(src, &X86_64_SYSV).unwrap();
391        assert!(padlock_core::analysis::false_sharing::has_false_sharing(
392            &layouts[0]
393        ));
394    }
395
396    #[test]
397    fn go_struct_same_guard_is_not_false_sharing() {
398        let src = r#"package p
399type Safe struct {
400    A int64 // padlock:guard=mu
401    B int64 // padlock:guard=mu
402}
403"#;
404        let layouts = parse_go(src, &X86_64_SYSV).unwrap();
405        assert!(!padlock_core::analysis::false_sharing::has_false_sharing(
406            &layouts[0]
407        ));
408    }
409
410    // ── interface{} / any sizing ───────────────────────────────────────────────
411
412    #[test]
413    fn interface_field_is_two_words() {
414        // interface{} is a fat pointer: (type pointer, data pointer) = 2×pointer
415        let src = "package p\ntype S struct { V interface{} }";
416        let layouts = parse_go(src, &X86_64_SYSV).unwrap();
417        assert_eq!(layouts[0].fields[0].size, 16); // 2 × 8B on x86-64
418        assert_eq!(layouts[0].fields[0].align, 8);
419    }
420
421    #[test]
422    fn any_field_is_two_words() {
423        // `any` is an alias for `interface{}` since Go 1.18
424        let src = "package p\ntype S struct { V any }";
425        let layouts = parse_go(src, &X86_64_SYSV).unwrap();
426        assert_eq!(layouts[0].fields[0].size, 16); // 2 × 8B on x86-64
427        assert_eq!(layouts[0].fields[0].align, 8);
428    }
429
430    #[test]
431    fn interface_field_same_size_as_error() {
432        // `error` was already two-word; interface{} must match
433        let src_iface = "package p\ntype S struct { V interface{} }";
434        let src_err = "package p\ntype S struct { V error }";
435        let iface = parse_go(src_iface, &X86_64_SYSV).unwrap();
436        let err = parse_go(src_err, &X86_64_SYSV).unwrap();
437        assert_eq!(iface[0].fields[0].size, err[0].fields[0].size);
438    }
439
440    #[test]
441    fn struct_with_mixed_interface_and_ints_has_correct_layout() {
442        // interface{} at offset 0 (size 16, align 8) then int64 at offset 16
443        let src = "package p\ntype S struct { V interface{}; N int64 }";
444        let layouts = parse_go(src, &X86_64_SYSV).unwrap();
445        let l = &layouts[0];
446        assert_eq!(l.fields[0].offset, 0);
447        assert_eq!(l.fields[0].size, 16);
448        assert_eq!(l.fields[1].offset, 16);
449        assert_eq!(l.total_size, 24);
450    }
451
452    // ── embedded struct support ───────────────────────────────────────────────
453
454    #[test]
455    fn embedded_struct_field_uses_type_name_as_field_name() {
456        // `Base` is an embedded field — Go uses the type name as the field name.
457        let src = r#"package p
458type Base struct { X int32 }
459type Derived struct {
460    Base
461    Y int32
462}
463"#;
464        let layouts = parse_go(src, &X86_64_SYSV).unwrap();
465        let derived = layouts
466            .iter()
467            .find(|l| l.name == "Derived")
468            .expect("Derived");
469        // Must have a field named "Base"
470        assert!(
471            derived.fields.iter().any(|f| f.name == "Base"),
472            "embedded field should be named 'Base'"
473        );
474    }
475
476    #[test]
477    fn embedded_qualified_type_uses_unqualified_name() {
478        // `sync.Mutex` embedded — field name should be "Mutex"
479        let src = r#"package p
480type Safe struct {
481    sync.Mutex
482    Value int64
483}
484"#;
485        let layouts = parse_go(src, &X86_64_SYSV).unwrap();
486        let l = layouts.iter().find(|l| l.name == "Safe").expect("Safe");
487        assert!(
488            l.fields.iter().any(|f| f.name == "Mutex"),
489            "embedded sync.Mutex should produce field named 'Mutex'"
490        );
491    }
492
493    #[test]
494    fn embedded_field_has_non_zero_size_from_resolution() {
495        // After lib.rs nested-struct resolution, Base's size should be filled in.
496        // We test via parse_source_str which triggers resolution.
497        let src = r#"package p
498type Inner struct { A int64; B int64 }
499type Outer struct {
500    Inner
501    C int32
502}
503"#;
504        use crate::{SourceLanguage, parse_source_str};
505        let layouts = parse_source_str(src, &SourceLanguage::Go, &X86_64_SYSV).unwrap();
506        let outer = layouts.iter().find(|l| l.name == "Outer").expect("Outer");
507        let inner_field = outer
508            .fields
509            .iter()
510            .find(|f| f.name == "Inner")
511            .expect("Inner field");
512        // Inner struct is 16 bytes (two int64s)
513        assert_eq!(
514            inner_field.size, 16,
515            "embedded Inner field should be resolved to 16 bytes"
516        );
517    }
518
519    #[test]
520    fn struct_with_no_embedded_fields_unaffected() {
521        let src = "package p\ntype S struct { A int32; B int64 }";
522        let layouts = parse_go(src, &X86_64_SYSV).unwrap();
523        let l = &layouts[0];
524        assert_eq!(l.fields.len(), 2);
525        assert_eq!(l.fields[0].name, "A");
526        assert_eq!(l.fields[1].name, "B");
527    }
528
529    // ── bad weather: embedded fields ──────────────────────────────────────────
530
531    #[test]
532    fn embedded_unknown_type_falls_back_to_pointer_size() {
533        // If the embedded type is not defined in the file, size = pointer_size
534        let src = "package p\ntype S struct { external.Type\nX int32 }";
535        let layouts = parse_go(src, &X86_64_SYSV).unwrap();
536        let l = layouts.iter().find(|l| l.name == "S").expect("S");
537        let emb = l
538            .fields
539            .iter()
540            .find(|f| f.name == "Type")
541            .expect("Type field");
542        // Falls back to pointer size (8 on x86_64) since type is unknown
543        assert_eq!(emb.size, 8);
544    }
545}