Skip to main content

padlock_source/frontends/
go.rs

1// padlock-source/src/frontends/go.rs
2//
3// Extracts struct layouts from Go source using tree-sitter-go.
4// Sizes use Go's platform-native alignment rules (same as C on the target arch).
5
6use padlock_core::arch::ArchConfig;
7use padlock_core::ir::{AccessPattern, Field, StructLayout, TypeInfo};
8use tree_sitter::{Node, Parser};
9
10// ── type resolution ───────────────────────────────────────────────────────────
11
12fn go_type_size_align(ty: &str, arch: &'static ArchConfig) -> (usize, usize) {
13    match ty.trim() {
14        "bool" => (1, 1),
15        "int8" | "uint8" | "byte" => (1, 1),
16        "int16" | "uint16" => (2, 2),
17        "int32" | "uint32" | "rune" | "float32" => (4, 4),
18        "int64" | "uint64" | "float64" | "complex64" => (8, 8),
19        "complex128" => (16, 16),
20        "int" | "uint" => (arch.pointer_size, arch.pointer_size),
21        "uintptr" => (arch.pointer_size, arch.pointer_size),
22        "string" => (arch.pointer_size * 2, arch.pointer_size), // ptr + len
23        ty if ty.starts_with("[]") => (arch.pointer_size * 3, arch.pointer_size), // ptr+len+cap
24        ty if ty.starts_with("map[") || ty.starts_with("chan ") => {
25            (arch.pointer_size, arch.pointer_size)
26        }
27        ty if ty.starts_with('*') => (arch.pointer_size, arch.pointer_size),
28        // Interface types: two-word fat pointer (type pointer + data pointer)
29        "error" | "interface{}" | "any" => (arch.pointer_size * 2, arch.pointer_size),
30        _ => (arch.pointer_size, arch.pointer_size),
31    }
32}
33
34// ── tree-sitter walker ────────────────────────────────────────────────────────
35
36fn extract_structs(source: &str, root: Node<'_>, arch: &'static ArchConfig) -> Vec<StructLayout> {
37    let mut layouts = Vec::new();
38    let mut stack = vec![root];
39
40    while let Some(node) = stack.pop() {
41        for i in (0..node.child_count()).rev() {
42            if let Some(c) = node.child(i) {
43                stack.push(c);
44            }
45        }
46
47        // type_declaration → type_spec → struct_type
48        if node.kind() == "type_declaration"
49            && let Some(layout) = parse_type_declaration(source, node, arch)
50        {
51            layouts.push(layout);
52        }
53    }
54    layouts
55}
56
57fn parse_type_declaration(
58    source: &str,
59    node: Node<'_>,
60    arch: &'static ArchConfig,
61) -> Option<StructLayout> {
62    let source_line = node.start_position().row as u32 + 1;
63    let decl_start_byte = node.start_byte();
64    // type_declaration has a type_spec child
65    for i in 0..node.child_count() {
66        let child = node.child(i)?;
67        if child.kind() == "type_spec" {
68            return parse_type_spec(source, child, arch, source_line, decl_start_byte);
69        }
70    }
71    None
72}
73
74fn parse_type_spec(
75    source: &str,
76    node: Node<'_>,
77    arch: &'static ArchConfig,
78    source_line: u32,
79    decl_start_byte: usize,
80) -> Option<StructLayout> {
81    let mut name: Option<String> = None;
82    let mut struct_node: Option<Node> = None;
83
84    for i in 0..node.child_count() {
85        let child = node.child(i)?;
86        match child.kind() {
87            "type_identifier" => name = Some(source[child.byte_range()].to_string()),
88            "struct_type" => struct_node = Some(child),
89            _ => {}
90        }
91    }
92
93    let name = name?;
94    let struct_node = struct_node?;
95    parse_struct_type(
96        source,
97        struct_node,
98        name,
99        arch,
100        source_line,
101        decl_start_byte,
102    )
103}
104
105fn parse_struct_type(
106    source: &str,
107    node: Node<'_>,
108    name: String,
109    arch: &'static ArchConfig,
110    source_line: u32,
111    decl_start_byte: usize,
112) -> Option<StructLayout> {
113    let mut raw_fields: Vec<(String, String, Option<String>, u32)> = Vec::new();
114
115    for i in 0..node.child_count() {
116        let child = node.child(i)?;
117        if child.kind() == "field_declaration_list" {
118            for j in 0..child.child_count() {
119                let field_node = child.child(j)?;
120                if field_node.kind() == "field_declaration" {
121                    collect_field_declarations(source, field_node, &mut raw_fields);
122                }
123            }
124        }
125    }
126
127    if raw_fields.is_empty() {
128        return None;
129    }
130
131    // Simulate layout
132    let mut offset = 0usize;
133    let mut struct_align = 1usize;
134    let mut fields: Vec<Field> = Vec::new();
135
136    for (fname, ty_name, guard, field_line) in raw_fields {
137        let (size, align) = go_type_size_align(&ty_name, arch);
138        if align > 0 {
139            offset = offset.next_multiple_of(align);
140        }
141        struct_align = struct_align.max(align);
142        let access = if let Some(g) = guard {
143            AccessPattern::Concurrent {
144                guard: Some(g),
145                is_atomic: false,
146                is_annotated: true,
147            }
148        } else {
149            AccessPattern::Unknown
150        };
151        fields.push(Field {
152            name: fname,
153            ty: TypeInfo::Primitive {
154                name: ty_name,
155                size,
156                align,
157            },
158            offset,
159            size,
160            align,
161            source_file: None,
162            source_line: Some(field_line),
163            access,
164        });
165        offset += size;
166    }
167    if struct_align > 0 {
168        offset = offset.next_multiple_of(struct_align);
169    }
170
171    Some(StructLayout {
172        name,
173        total_size: offset,
174        align: struct_align,
175        fields,
176        source_file: None,
177        source_line: Some(source_line),
178        arch,
179        is_packed: false,
180        is_union: false,
181        is_repr_rust: false,
182        suppressed_findings: super::suppress::suppressed_from_preceding_source(
183            source,
184            decl_start_byte,
185        ),
186    })
187}
188
189/// Extract a guard name from a Go field's trailing line comment.
190///
191/// Recognised forms (must appear after the field type on the same line):
192/// - `// padlock:guard=mu`
193/// - `// guarded_by: mu`
194/// - `// +checklocksprotects:mu` (gVisor-style)
195pub fn extract_guard_from_go_comment(comment: &str) -> Option<String> {
196    let c = comment.trim();
197    // Strip leading `//` and optional whitespace
198    let body = c.strip_prefix("//").map(str::trim)?;
199
200    // padlock:guard=mu
201    if let Some(rest) = body.strip_prefix("padlock:guard=") {
202        let guard = rest.trim();
203        if !guard.is_empty() {
204            return Some(guard.to_string());
205        }
206    }
207    // guarded_by: mu
208    if let Some(rest) = body
209        .strip_prefix("guarded_by:")
210        .or_else(|| body.strip_prefix("guarded_by ="))
211    {
212        let guard = rest.trim();
213        if !guard.is_empty() {
214            return Some(guard.to_string());
215        }
216    }
217    // +checklocksprotects:mu (gVisor)
218    if let Some(rest) = body.strip_prefix("+checklocksprotects:") {
219        let guard = rest.trim();
220        if !guard.is_empty() {
221            return Some(guard.to_string());
222        }
223    }
224    None
225}
226
227/// Find the trailing line comment on the same source line as `node`.
228fn trailing_comment_on_line(source: &str, node: Node<'_>) -> Option<String> {
229    // The node's end byte is just past the last token on the field line.
230    // Read the rest of that line from the source.
231    let end = node.end_byte();
232    if end >= source.len() {
233        return None;
234    }
235    let rest = &source[end..];
236    // Take only up to the next newline
237    let line = rest.lines().next().unwrap_or("");
238    // Look for `//` in that remainder
239    line.find("//").map(|pos| line[pos..].to_string())
240}
241
242fn collect_field_declarations(
243    source: &str,
244    node: Node<'_>,
245    out: &mut Vec<(String, String, Option<String>, u32)>,
246) {
247    // field_declaration: field_identifier+ type [comment]
248    // OR embedded type (anonymous field): TypeName [comment]
249    let mut field_names: Vec<String> = Vec::new();
250    let mut ty_text: Option<String> = None;
251    let field_line = node.start_position().row as u32 + 1;
252
253    for i in 0..node.child_count() {
254        if let Some(child) = node.child(i) {
255            match child.kind() {
256                "field_identifier" => field_names.push(source[child.byte_range()].to_string()),
257                "type_identifier" | "pointer_type" | "qualified_type" | "slice_type"
258                | "map_type" | "channel_type" | "array_type" | "interface_type" => {
259                    ty_text = Some(source[child.byte_range()].trim().to_string());
260                }
261                _ => {}
262            }
263        }
264    }
265
266    let guard =
267        trailing_comment_on_line(source, node).and_then(|c| extract_guard_from_go_comment(&c));
268
269    if !field_names.is_empty() {
270        if let Some(ty) = ty_text {
271            // Normal named fields
272            for name in field_names {
273                out.push((name, ty.clone(), guard.clone(), field_line));
274            }
275        }
276    } else if let Some(ty) = ty_text {
277        // Embedded (anonymous) field: `sync.Mutex` or `Base`.
278        // Go field name is the unqualified type name.
279        // The nested-struct resolution pass in lib.rs will later fill in
280        // the correct size/align from other parsed struct layouts.
281        let simple_name = ty.split('.').next_back().unwrap_or(&ty).to_string();
282        out.push((simple_name, ty, guard, field_line));
283    }
284}
285
286// ── public API ────────────────────────────────────────────────────────────────
287
288pub fn parse_go(source: &str, arch: &'static ArchConfig) -> anyhow::Result<Vec<StructLayout>> {
289    let mut parser = Parser::new();
290    parser.set_language(&tree_sitter_go::LANGUAGE.into())?;
291    let tree = parser
292        .parse(source, None)
293        .ok_or_else(|| anyhow::anyhow!("tree-sitter-go parse failed"))?;
294    Ok(extract_structs(source, tree.root_node(), arch))
295}
296
297// ── tests ─────────────────────────────────────────────────────────────────────
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302    use padlock_core::arch::X86_64_SYSV;
303
304    #[test]
305    fn parse_simple_go_struct() {
306        let src = r#"
307package main
308type Point struct {
309    X int32
310    Y int32
311}
312"#;
313        let layouts = parse_go(src, &X86_64_SYSV).unwrap();
314        assert_eq!(layouts.len(), 1);
315        assert_eq!(layouts[0].name, "Point");
316        assert_eq!(layouts[0].fields.len(), 2);
317    }
318
319    #[test]
320    fn go_layout_with_padding() {
321        let src = "package p\ntype T struct { A bool; B int64 }";
322        let layouts = parse_go(src, &X86_64_SYSV).unwrap();
323        assert_eq!(layouts.len(), 1);
324        let l = &layouts[0];
325        assert_eq!(l.fields[0].offset, 0);
326        assert_eq!(l.fields[1].offset, 8); // bool (1) + 7 pad → 8
327    }
328
329    #[test]
330    fn go_string_is_two_words() {
331        let src = "package p\ntype S struct { Name string }";
332        let layouts = parse_go(src, &X86_64_SYSV).unwrap();
333        assert_eq!(layouts[0].fields[0].size, 16); // ptr + len
334    }
335
336    // ── Go guard comment extraction ────────────────────────────────────────────
337
338    #[test]
339    fn extract_guard_padlock_form() {
340        assert_eq!(
341            extract_guard_from_go_comment("// padlock:guard=mu"),
342            Some("mu".to_string())
343        );
344    }
345
346    #[test]
347    fn extract_guard_guarded_by_form() {
348        assert_eq!(
349            extract_guard_from_go_comment("// guarded_by: counter_lock"),
350            Some("counter_lock".to_string())
351        );
352    }
353
354    #[test]
355    fn extract_guard_checklocksprotects_form() {
356        assert_eq!(
357            extract_guard_from_go_comment("// +checklocksprotects:mu"),
358            Some("mu".to_string())
359        );
360    }
361
362    #[test]
363    fn extract_guard_no_match_returns_none() {
364        assert!(extract_guard_from_go_comment("// just a comment").is_none());
365        assert!(extract_guard_from_go_comment("// TODO: fix this").is_none());
366    }
367
368    #[test]
369    fn go_struct_padlock_guard_annotation_sets_concurrent() {
370        let src = r#"package p
371type Cache struct {
372    Readers int64 // padlock:guard=mu
373    Writers int64 // padlock:guard=other_mu
374    Mu      sync.Mutex
375}
376"#;
377        let layouts = parse_go(src, &X86_64_SYSV).unwrap();
378        let l = &layouts[0];
379        // Readers and Writers should be Concurrent with different guards
380        if let AccessPattern::Concurrent { guard, .. } = &l.fields[0].access {
381            assert_eq!(guard.as_deref(), Some("mu"));
382        } else {
383            panic!(
384                "expected Concurrent for Readers, got {:?}",
385                l.fields[0].access
386            );
387        }
388        if let AccessPattern::Concurrent { guard, .. } = &l.fields[1].access {
389            assert_eq!(guard.as_deref(), Some("other_mu"));
390        } else {
391            panic!(
392                "expected Concurrent for Writers, got {:?}",
393                l.fields[1].access
394            );
395        }
396    }
397
398    #[test]
399    fn go_struct_different_guards_same_cache_line_is_false_sharing() {
400        let src = r#"package p
401type HotPath struct {
402    Readers int64 // padlock:guard=lock_a
403    Writers int64 // padlock:guard=lock_b
404}
405"#;
406        let layouts = parse_go(src, &X86_64_SYSV).unwrap();
407        assert!(padlock_core::analysis::false_sharing::has_false_sharing(
408            &layouts[0]
409        ));
410    }
411
412    #[test]
413    fn go_struct_same_guard_is_not_false_sharing() {
414        let src = r#"package p
415type Safe struct {
416    A int64 // padlock:guard=mu
417    B int64 // padlock:guard=mu
418}
419"#;
420        let layouts = parse_go(src, &X86_64_SYSV).unwrap();
421        assert!(!padlock_core::analysis::false_sharing::has_false_sharing(
422            &layouts[0]
423        ));
424    }
425
426    // ── interface{} / any sizing ───────────────────────────────────────────────
427
428    #[test]
429    fn interface_field_is_two_words() {
430        // interface{} is a fat pointer: (type pointer, data pointer) = 2×pointer
431        let src = "package p\ntype S struct { V interface{} }";
432        let layouts = parse_go(src, &X86_64_SYSV).unwrap();
433        assert_eq!(layouts[0].fields[0].size, 16); // 2 × 8B on x86-64
434        assert_eq!(layouts[0].fields[0].align, 8);
435    }
436
437    #[test]
438    fn any_field_is_two_words() {
439        // `any` is an alias for `interface{}` since Go 1.18
440        let src = "package p\ntype S struct { V any }";
441        let layouts = parse_go(src, &X86_64_SYSV).unwrap();
442        assert_eq!(layouts[0].fields[0].size, 16); // 2 × 8B on x86-64
443        assert_eq!(layouts[0].fields[0].align, 8);
444    }
445
446    #[test]
447    fn interface_field_same_size_as_error() {
448        // `error` was already two-word; interface{} must match
449        let src_iface = "package p\ntype S struct { V interface{} }";
450        let src_err = "package p\ntype S struct { V error }";
451        let iface = parse_go(src_iface, &X86_64_SYSV).unwrap();
452        let err = parse_go(src_err, &X86_64_SYSV).unwrap();
453        assert_eq!(iface[0].fields[0].size, err[0].fields[0].size);
454    }
455
456    #[test]
457    fn struct_with_mixed_interface_and_ints_has_correct_layout() {
458        // interface{} at offset 0 (size 16, align 8) then int64 at offset 16
459        let src = "package p\ntype S struct { V interface{}; N int64 }";
460        let layouts = parse_go(src, &X86_64_SYSV).unwrap();
461        let l = &layouts[0];
462        assert_eq!(l.fields[0].offset, 0);
463        assert_eq!(l.fields[0].size, 16);
464        assert_eq!(l.fields[1].offset, 16);
465        assert_eq!(l.total_size, 24);
466    }
467
468    // ── embedded struct support ───────────────────────────────────────────────
469
470    #[test]
471    fn embedded_struct_field_uses_type_name_as_field_name() {
472        // `Base` is an embedded field — Go uses the type name as the field name.
473        let src = r#"package p
474type Base struct { X int32 }
475type Derived struct {
476    Base
477    Y int32
478}
479"#;
480        let layouts = parse_go(src, &X86_64_SYSV).unwrap();
481        let derived = layouts
482            .iter()
483            .find(|l| l.name == "Derived")
484            .expect("Derived");
485        // Must have a field named "Base"
486        assert!(
487            derived.fields.iter().any(|f| f.name == "Base"),
488            "embedded field should be named 'Base'"
489        );
490    }
491
492    #[test]
493    fn embedded_qualified_type_uses_unqualified_name() {
494        // `sync.Mutex` embedded — field name should be "Mutex"
495        let src = r#"package p
496type Safe struct {
497    sync.Mutex
498    Value int64
499}
500"#;
501        let layouts = parse_go(src, &X86_64_SYSV).unwrap();
502        let l = layouts.iter().find(|l| l.name == "Safe").expect("Safe");
503        assert!(
504            l.fields.iter().any(|f| f.name == "Mutex"),
505            "embedded sync.Mutex should produce field named 'Mutex'"
506        );
507    }
508
509    #[test]
510    fn embedded_field_has_non_zero_size_from_resolution() {
511        // After lib.rs nested-struct resolution, Base's size should be filled in.
512        // We test via parse_source_str which triggers resolution.
513        let src = r#"package p
514type Inner struct { A int64; B int64 }
515type Outer struct {
516    Inner
517    C int32
518}
519"#;
520        use crate::{SourceLanguage, parse_source_str};
521        let layouts = parse_source_str(src, &SourceLanguage::Go, &X86_64_SYSV).unwrap();
522        let outer = layouts.iter().find(|l| l.name == "Outer").expect("Outer");
523        let inner_field = outer
524            .fields
525            .iter()
526            .find(|f| f.name == "Inner")
527            .expect("Inner field");
528        // Inner struct is 16 bytes (two int64s)
529        assert_eq!(
530            inner_field.size, 16,
531            "embedded Inner field should be resolved to 16 bytes"
532        );
533    }
534
535    #[test]
536    fn struct_with_no_embedded_fields_unaffected() {
537        let src = "package p\ntype S struct { A int32; B int64 }";
538        let layouts = parse_go(src, &X86_64_SYSV).unwrap();
539        let l = &layouts[0];
540        assert_eq!(l.fields.len(), 2);
541        assert_eq!(l.fields[0].name, "A");
542        assert_eq!(l.fields[1].name, "B");
543    }
544
545    // ── bad weather: embedded fields ──────────────────────────────────────────
546
547    #[test]
548    fn embedded_unknown_type_falls_back_to_pointer_size() {
549        // If the embedded type is not defined in the file, size = pointer_size
550        let src = "package p\ntype S struct { external.Type\nX int32 }";
551        let layouts = parse_go(src, &X86_64_SYSV).unwrap();
552        let l = layouts.iter().find(|l| l.name == "S").expect("S");
553        let emb = l
554            .fields
555            .iter()
556            .find(|f| f.name == "Type")
557            .expect("Type field");
558        // Falls back to pointer size (8 on x86_64) since type is unknown
559        assert_eq!(emb.size, 8);
560    }
561}