Skip to main content

padlock_source/
fixgen.rs

1// padlock-source/src/fixgen.rs
2//
3// Generate reordered struct source text, unified diffs, and in-place rewrites.
4
5use padlock_core::ir::{StructLayout, optimal_order};
6use similar::{ChangeTag, TextDiff};
7
8/// Render a reordered C/C++ struct definition as source text.
9///
10/// Uses the field names already present in the layout — type names come from
11/// the `TypeInfo::Primitive/Opaque` name stored during source parsing.
12pub fn generate_c_fix(layout: &StructLayout) -> String {
13    let optimal = optimal_order(layout);
14    let mut out = format!("struct {} {{\n", layout.name);
15    for field in &optimal {
16        let ty = field_type_name(field);
17        out.push_str(&format!("    {ty} {};\n", field.name));
18    }
19    out.push_str("};\n");
20    out
21}
22
23/// Render a reordered Rust struct definition as source text.
24pub fn generate_rust_fix(layout: &StructLayout) -> String {
25    let optimal = optimal_order(layout);
26    let mut out = format!("struct {} {{\n", layout.name);
27    for field in &optimal {
28        let ty = field_type_name(field);
29        out.push_str(&format!("    {}: {ty},\n", field.name));
30    }
31    out.push_str("}\n");
32    out
33}
34
35/// Render a reordered Go struct definition as source text.
36pub fn generate_go_fix(layout: &StructLayout) -> String {
37    let optimal = optimal_order(layout);
38    let mut out = format!("type {} struct {{\n", layout.name);
39    for field in &optimal {
40        let ty = field_type_name(field);
41        out.push_str(&format!("\t{}\t{ty}\n", field.name));
42    }
43    out.push_str("}\n");
44    out
45}
46
47/// Produce a unified diff between `original` and `fixed` source text.
48pub fn unified_diff(original: &str, fixed: &str, context_lines: usize) -> String {
49    if original == fixed {
50        return String::from("(no changes)\n");
51    }
52    let diff = TextDiff::from_lines(original, fixed);
53    let mut out = String::new();
54    for (idx, group) in diff.grouped_ops(context_lines).iter().enumerate() {
55        if idx > 0 {
56            out.push_str("...\n");
57        }
58        for op in group {
59            for change in diff.iter_changes(op) {
60                let prefix = match change.tag() {
61                    ChangeTag::Delete => "-",
62                    ChangeTag::Insert => "+",
63                    ChangeTag::Equal => " ",
64                };
65                out.push_str(&format!("{prefix} {}", change.value()));
66                if !change.value().ends_with('\n') {
67                    out.push('\n');
68                }
69            }
70        }
71    }
72    out
73}
74
75// ── span finders ──────────────────────────────────────────────────────────────
76
77/// Count matching braces from the start of `s` (which must begin with `{`).
78/// Returns the byte index one past the matching `}`.
79fn match_braces(s: &str) -> Option<usize> {
80    let mut depth = 0usize;
81    for (i, c) in s.char_indices() {
82        match c {
83            '{' => depth += 1,
84            '}' => {
85                depth -= 1;
86                if depth == 0 {
87                    return Some(i + 1);
88                }
89            }
90            _ => {}
91        }
92    }
93    None
94}
95
96/// Consume an optional trailing semicolon (after optional whitespace) at `pos`.
97fn consume_semicolon(source: &str, pos: usize) -> usize {
98    let rest = &source[pos..];
99    let ws = rest.len()
100        - rest
101            .trim_start_matches(|c: char| c.is_whitespace() && c != '\n')
102            .len();
103    let after_ws = &rest[ws..];
104    if after_ws.starts_with(';') {
105        pos + ws + 1
106    } else {
107        pos
108    }
109}
110
111/// Find the byte range of a named struct/union in C/C++ source.
112/// The range covers from `struct/union Name` through the closing `};`.
113pub fn find_c_struct_span(source: &str, struct_name: &str) -> Option<std::ops::Range<usize>> {
114    for kw in &["struct", "union"] {
115        let needle = format!("{kw} {struct_name}");
116        let mut search_from = 0usize;
117        while let Some(rel) = source[search_from..].find(&needle) {
118            let start = search_from + rel;
119            let after_name = start + needle.len();
120            // Ensure the character after the name is a boundary (space, `{`, newline)
121            let boundary = source[after_name..].chars().next();
122            if matches!(
123                boundary,
124                Some('{') | Some('\n') | Some('\r') | Some(' ') | Some('\t') | None
125            ) {
126                // Find the opening brace (may have whitespace between name and `{`)
127                if let Some(brace_rel) = source[after_name..].find('{') {
128                    let brace_start = after_name + brace_rel;
129                    // Verify no word characters between name end and brace
130                    if source[after_name..brace_start]
131                        .chars()
132                        .all(|c| c.is_whitespace())
133                        && let Some(body_len) = match_braces(&source[brace_start..])
134                    {
135                        let end = consume_semicolon(source, brace_start + body_len);
136                        return Some(start..end);
137                    }
138                }
139            }
140            search_from = start + 1;
141        }
142    }
143    None
144}
145
146/// Find the byte range of a named struct in Rust source (`struct Name { ... }`).
147pub fn find_rust_struct_span(source: &str, struct_name: &str) -> Option<std::ops::Range<usize>> {
148    let needle = format!("struct {struct_name}");
149    let mut search_from = 0usize;
150    while let Some(rel) = source[search_from..].find(&needle) {
151        let start = search_from + rel;
152        let after_name = start + needle.len();
153        let boundary = source[after_name..].chars().next();
154        if matches!(
155            boundary,
156            Some('{') | Some('\n') | Some('\r') | Some(' ') | Some('\t') | None
157        ) && let Some(brace_rel) = source[after_name..].find('{')
158        {
159            let brace_start = after_name + brace_rel;
160            if source[after_name..brace_start]
161                .chars()
162                .all(|c| c.is_whitespace())
163                && let Some(body_len) = match_braces(&source[brace_start..])
164            {
165                // Rust structs have no trailing `;` (unit structs do, but we skip those)
166                return Some(start..brace_start + body_len);
167            }
168        }
169        search_from = start + 1;
170    }
171    None
172}
173
174/// Find the byte range of a named struct in Go source (`type Name struct { ... }`).
175pub fn find_go_struct_span(source: &str, struct_name: &str) -> Option<std::ops::Range<usize>> {
176    let needle = format!("type {struct_name} struct");
177    let mut search_from = 0usize;
178    while let Some(rel) = source[search_from..].find(&needle) {
179        let start = search_from + rel;
180        let after_kw = start + needle.len();
181        if let Some(brace_rel) = source[after_kw..].find('{') {
182            let brace_start = after_kw + brace_rel;
183            if source[after_kw..brace_start]
184                .chars()
185                .all(|c| c.is_whitespace())
186                && let Some(body_len) = match_braces(&source[brace_start..])
187            {
188                return Some(start..brace_start + body_len);
189            }
190        }
191        search_from = start + 1;
192    }
193    None
194}
195
196// ── in-place rewriters ────────────────────────────────────────────────────────
197
198/// Apply C/C++ struct reorderings in-place, returning the modified source.
199/// Each layout in `layouts` is looked up by name; matched structs are replaced
200/// with the optimally-ordered definition. Replacements are applied back-to-front
201/// so byte offsets remain valid.
202pub fn apply_fixes_c(source: &str, layouts: &[&StructLayout]) -> String {
203    apply_fixes(source, layouts, find_c_struct_span, generate_c_fix)
204}
205
206/// Apply Rust struct reorderings in-place, returning the modified source.
207pub fn apply_fixes_rust(source: &str, layouts: &[&StructLayout]) -> String {
208    apply_fixes(source, layouts, find_rust_struct_span, generate_rust_fix)
209}
210
211/// Apply Go struct reorderings in-place, returning the modified source.
212pub fn apply_fixes_go(source: &str, layouts: &[&StructLayout]) -> String {
213    apply_fixes(source, layouts, find_go_struct_span, generate_go_fix)
214}
215
216/// Render a reordered Zig struct definition as source text.
217/// Zig structs are declared as `const Name = struct { ... };`.
218/// If the layout is packed, the output uses `packed struct`.
219pub fn generate_zig_fix(layout: &StructLayout) -> String {
220    let optimal = optimal_order(layout);
221    let qualifier = if layout.is_packed { "packed " } else { "" };
222    let mut out = format!("const {} = {}struct {{\n", layout.name, qualifier);
223    for field in &optimal {
224        let ty = field_type_name(field);
225        out.push_str(&format!("    {}: {ty},\n", field.name));
226    }
227    out.push_str("};\n");
228    out
229}
230
231/// Find the byte range of a named Zig struct in source.
232/// Matches `const Name = [packed|extern ]struct { ... };`.
233pub fn find_zig_struct_span(source: &str, struct_name: &str) -> Option<std::ops::Range<usize>> {
234    // Match `const Name =` (with optional whitespace variations)
235    let needle = format!("const {struct_name}");
236    let mut search_from = 0usize;
237    while let Some(rel) = source[search_from..].find(&needle) {
238        let start = search_from + rel;
239        let after_name = start + needle.len();
240        // Must be followed by whitespace then `=`
241        let rest = source[after_name..].trim_start();
242        if !rest.starts_with('=') {
243            search_from = start + 1;
244            continue;
245        }
246        // Find `struct` keyword after `=`
247        let after_eq = after_name + source[after_name..].find('=')? + 1;
248        let after_eq_rest = &source[after_eq..];
249        // Skip optional `packed` or `extern` modifiers
250        if let Some(struct_rel) = after_eq_rest.find("struct") {
251            // Check no non-whitespace/identifier characters between = and struct
252            // (beyond optional packed/extern modifiers)
253            let prefix = &after_eq_rest[..struct_rel];
254            let prefix_clean = prefix.trim();
255            if prefix_clean.is_empty() || prefix_clean == "packed" || prefix_clean == "extern" {
256                let struct_kw_end = after_eq + struct_rel + "struct".len();
257                if let Some(brace_rel) = source[struct_kw_end..].find('{') {
258                    let brace_start = struct_kw_end + brace_rel;
259                    if source[struct_kw_end..brace_start]
260                        .chars()
261                        .all(|c| c.is_whitespace())
262                        && let Some(body_len) = match_braces(&source[brace_start..])
263                    {
264                        let end = consume_semicolon(source, brace_start + body_len);
265                        return Some(start..end);
266                    }
267                }
268            }
269        }
270        search_from = start + 1;
271    }
272    None
273}
274
275/// Apply Zig struct reorderings in-place, returning the modified source.
276pub fn apply_fixes_zig(source: &str, layouts: &[&StructLayout]) -> String {
277    apply_fixes(source, layouts, find_zig_struct_span, generate_zig_fix)
278}
279
280fn apply_fixes(
281    source: &str,
282    layouts: &[&StructLayout],
283    find_span: fn(&str, &str) -> Option<std::ops::Range<usize>>,
284    generate: fn(&StructLayout) -> String,
285) -> String {
286    // Collect (start, end, replacement) for each matching layout
287    let mut replacements: Vec<(usize, usize, String)> = layouts
288        .iter()
289        .filter_map(|layout| {
290            let span = find_span(source, &layout.name)?;
291            let fixed = generate(layout);
292            Some((span.start, span.end, fixed))
293        })
294        .collect();
295
296    // Sort by start offset ascending, then apply in reverse so offsets stay valid
297    replacements.sort_by_key(|(start, _, _)| *start);
298
299    let mut result = source.to_string();
300    for (start, end, fixed) in replacements.into_iter().rev() {
301        result.replace_range(start..end, &fixed);
302    }
303    result
304}
305
306fn field_type_name(field: &padlock_core::ir::Field) -> &str {
307    match &field.ty {
308        padlock_core::ir::TypeInfo::Primitive { name, .. }
309        | padlock_core::ir::TypeInfo::Opaque { name, .. } => name.as_str(),
310        padlock_core::ir::TypeInfo::Pointer { .. } => "void*",
311        padlock_core::ir::TypeInfo::Array { .. } => "/* array */",
312        padlock_core::ir::TypeInfo::Struct(l) => l.name.as_str(),
313    }
314}
315
316// ── tests ─────────────────────────────────────────────────────────────────────
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321    use padlock_core::ir::test_fixtures::connection_layout;
322
323    #[test]
324    fn c_fix_starts_with_struct() {
325        let out = generate_c_fix(&connection_layout());
326        assert!(out.starts_with("struct Connection {"));
327    }
328
329    #[test]
330    fn c_fix_contains_all_fields() {
331        let out = generate_c_fix(&connection_layout());
332        assert!(out.contains("timeout"));
333        assert!(out.contains("port"));
334        assert!(out.contains("is_active"));
335        assert!(out.contains("is_tls"));
336    }
337
338    #[test]
339    fn c_fix_puts_largest_align_first() {
340        let out = generate_c_fix(&connection_layout());
341        let timeout_pos = out.find("timeout").unwrap();
342        let is_active_pos = out.find("is_active").unwrap();
343        assert!(timeout_pos < is_active_pos);
344    }
345
346    #[test]
347    fn rust_fix_uses_colon_syntax() {
348        let out = generate_rust_fix(&connection_layout());
349        assert!(out.contains(": f64"));
350    }
351
352    #[test]
353    fn unified_diff_marks_changes() {
354        let orig = "struct T { char a; double b; };\n";
355        let fixed = "struct T { double b; char a; };\n";
356        let diff = unified_diff(orig, fixed, 1);
357        assert!(diff.contains('-') || diff.contains('+'));
358    }
359
360    #[test]
361    fn unified_diff_identical_is_no_changes() {
362        assert_eq!(unified_diff("x\n", "x\n", 3), "(no changes)\n");
363    }
364
365    // ── span finders ──────────────────────────────────────────────────────────
366
367    #[test]
368    fn find_c_struct_span_basic() {
369        let src = "struct Foo { int x; char y; };\nstruct Bar { double z; };\n";
370        let span = find_c_struct_span(src, "Foo").unwrap();
371        let text = &src[span];
372        assert!(text.starts_with("struct Foo"));
373        assert!(!text.contains("Bar"));
374    }
375
376    #[test]
377    fn find_c_struct_span_missing_returns_none() {
378        let src = "struct Other { int x; };";
379        assert!(find_c_struct_span(src, "Missing").is_none());
380    }
381
382    #[test]
383    fn find_rust_struct_span_basic() {
384        let src = "struct Foo {\n    x: u32,\n    y: u8,\n}\n";
385        let span = find_rust_struct_span(src, "Foo").unwrap();
386        assert!(src[span].starts_with("struct Foo"));
387    }
388
389    #[test]
390    fn find_go_struct_span_basic() {
391        let src = "type Foo struct {\n\tX int32\n\tY bool\n}\n";
392        let span = find_go_struct_span(src, "Foo").unwrap();
393        assert!(src[span].starts_with("type Foo struct"));
394    }
395
396    // ── apply_fixes ───────────────────────────────────────────────────────────
397
398    #[test]
399    fn apply_fixes_c_reorders_in_place() {
400        // Connection has char/double/char/int — after fix, double should come first
401        let src = "struct Connection { bool is_active; double timeout; bool is_tls; int port; };\n";
402        let layout = connection_layout();
403        let fixed = apply_fixes_c(src, &[&layout]);
404        let timeout_pos = fixed.find("timeout").unwrap();
405        let is_active_pos = fixed.find("is_active").unwrap();
406        assert!(
407            timeout_pos < is_active_pos,
408            "double should appear before bool after reorder"
409        );
410    }
411
412    #[test]
413    fn apply_fixes_rust_reorders_in_place() {
414        let src = "struct Connection {\n    is_active: bool,\n    timeout: f64,\n    is_tls: bool,\n    port: i32,\n}\n";
415        let layout = connection_layout();
416        let fixed = apply_fixes_rust(src, &[&layout]);
417        let timeout_pos = fixed.find("timeout").unwrap();
418        let is_active_pos = fixed.find("is_active").unwrap();
419        assert!(timeout_pos < is_active_pos);
420    }
421
422    #[test]
423    fn go_fix_uses_tab_syntax() {
424        let layout = connection_layout();
425        let out = generate_go_fix(&layout);
426        assert!(out.starts_with("type Connection struct"));
427        assert!(out.contains('\t'));
428    }
429
430    #[test]
431    fn zig_fix_uses_const_struct_syntax() {
432        let out = generate_zig_fix(&connection_layout());
433        assert!(out.starts_with("const Connection = struct {"));
434        assert!(out.ends_with("};\n"));
435    }
436
437    #[test]
438    fn find_zig_struct_span_basic() {
439        let src = "const S = struct {\n    x: u32,\n    y: u8,\n};\n";
440        let span = find_zig_struct_span(src, "S").unwrap();
441        assert!(src[span].starts_with("const S = struct"));
442    }
443
444    #[test]
445    fn find_zig_struct_span_packed() {
446        let src = "const S = packed struct {\n    x: u32,\n    y: u8,\n};\n";
447        let span = find_zig_struct_span(src, "S").unwrap();
448        assert!(src[span].contains("packed struct"));
449    }
450
451    #[test]
452    fn find_zig_struct_span_missing_returns_none() {
453        let src = "const Other = struct { x: u8 };\n";
454        assert!(find_zig_struct_span(src, "Missing").is_none());
455    }
456
457    #[test]
458    fn apply_fixes_zig_reorders_in_place() {
459        use crate::parse_source_str;
460        use padlock_core::arch::X86_64_SYSV;
461        let src = "const S = struct {\n    a: u8,\n    b: u64,\n};\n";
462        let layouts = parse_source_str(src, &crate::SourceLanguage::Zig, &X86_64_SYSV).unwrap();
463        let layout = &layouts[0];
464        let fixed = apply_fixes_zig(src, &[layout]);
465        // b (u64, align 8) should come before a (u8)
466        let b_pos = fixed.find("b:").unwrap();
467        let a_pos = fixed.find("a:").unwrap();
468        assert!(
469            b_pos < a_pos,
470            "u64 field should come before u8 after reorder"
471        );
472    }
473}