1use padlock_core::ir::{StructLayout, optimal_order};
6use similar::{ChangeTag, TextDiff};
7
8pub fn generate_c_fix(layout: &StructLayout) -> String {
13 let optimal = optimal_order(layout);
14 let mut out = format!("struct {} {{\n", layout.name);
15 for field in &optimal {
16 let ty = field_type_name(field);
17 out.push_str(&format!(" {ty} {};\n", field.name));
18 }
19 out.push_str("};\n");
20 out
21}
22
23pub fn generate_rust_fix(layout: &StructLayout) -> String {
25 let optimal = optimal_order(layout);
26 let mut out = format!("struct {} {{\n", layout.name);
27 for field in &optimal {
28 let ty = field_type_name(field);
29 out.push_str(&format!(" {}: {ty},\n", field.name));
30 }
31 out.push_str("}\n");
32 out
33}
34
35pub fn generate_go_fix(layout: &StructLayout) -> String {
37 let optimal = optimal_order(layout);
38 let mut out = format!("type {} struct {{\n", layout.name);
39 for field in &optimal {
40 let ty = field_type_name(field);
41 out.push_str(&format!("\t{}\t{ty}\n", field.name));
42 }
43 out.push_str("}\n");
44 out
45}
46
47pub fn unified_diff(original: &str, fixed: &str, context_lines: usize) -> String {
49 if original == fixed {
50 return String::from("(no changes)\n");
51 }
52 let diff = TextDiff::from_lines(original, fixed);
53 let mut out = String::new();
54 for (idx, group) in diff.grouped_ops(context_lines).iter().enumerate() {
55 if idx > 0 {
56 out.push_str("...\n");
57 }
58 for op in group {
59 for change in diff.iter_changes(op) {
60 let prefix = match change.tag() {
61 ChangeTag::Delete => "-",
62 ChangeTag::Insert => "+",
63 ChangeTag::Equal => " ",
64 };
65 out.push_str(&format!("{prefix} {}", change.value()));
66 if !change.value().ends_with('\n') {
67 out.push('\n');
68 }
69 }
70 }
71 }
72 out
73}
74
75fn match_braces(s: &str) -> Option<usize> {
80 let mut depth = 0usize;
81 for (i, c) in s.char_indices() {
82 match c {
83 '{' => depth += 1,
84 '}' => {
85 depth -= 1;
86 if depth == 0 {
87 return Some(i + 1);
88 }
89 }
90 _ => {}
91 }
92 }
93 None
94}
95
96fn consume_semicolon(source: &str, pos: usize) -> usize {
98 let rest = &source[pos..];
99 let ws = rest.len()
100 - rest
101 .trim_start_matches(|c: char| c.is_whitespace() && c != '\n')
102 .len();
103 let after_ws = &rest[ws..];
104 if after_ws.starts_with(';') {
105 pos + ws + 1
106 } else {
107 pos
108 }
109}
110
111pub fn find_c_struct_span(source: &str, struct_name: &str) -> Option<std::ops::Range<usize>> {
114 for kw in &["struct", "union"] {
115 let needle = format!("{kw} {struct_name}");
116 let mut search_from = 0usize;
117 while let Some(rel) = source[search_from..].find(&needle) {
118 let start = search_from + rel;
119 let after_name = start + needle.len();
120 let boundary = source[after_name..].chars().next();
122 if matches!(
123 boundary,
124 Some('{') | Some('\n') | Some('\r') | Some(' ') | Some('\t') | None
125 ) {
126 if let Some(brace_rel) = source[after_name..].find('{') {
128 let brace_start = after_name + brace_rel;
129 if source[after_name..brace_start]
131 .chars()
132 .all(|c| c.is_whitespace())
133 && let Some(body_len) = match_braces(&source[brace_start..])
134 {
135 let end = consume_semicolon(source, brace_start + body_len);
136 return Some(start..end);
137 }
138 }
139 }
140 search_from = start + 1;
141 }
142 }
143 None
144}
145
146pub fn find_rust_struct_span(source: &str, struct_name: &str) -> Option<std::ops::Range<usize>> {
148 let needle = format!("struct {struct_name}");
149 let mut search_from = 0usize;
150 while let Some(rel) = source[search_from..].find(&needle) {
151 let start = search_from + rel;
152 let after_name = start + needle.len();
153 let boundary = source[after_name..].chars().next();
154 if matches!(
155 boundary,
156 Some('{') | Some('\n') | Some('\r') | Some(' ') | Some('\t') | None
157 ) && let Some(brace_rel) = source[after_name..].find('{')
158 {
159 let brace_start = after_name + brace_rel;
160 if source[after_name..brace_start]
161 .chars()
162 .all(|c| c.is_whitespace())
163 && let Some(body_len) = match_braces(&source[brace_start..])
164 {
165 return Some(start..brace_start + body_len);
167 }
168 }
169 search_from = start + 1;
170 }
171 None
172}
173
174pub fn find_go_struct_span(source: &str, struct_name: &str) -> Option<std::ops::Range<usize>> {
176 let needle = format!("type {struct_name} struct");
177 let mut search_from = 0usize;
178 while let Some(rel) = source[search_from..].find(&needle) {
179 let start = search_from + rel;
180 let after_kw = start + needle.len();
181 if let Some(brace_rel) = source[after_kw..].find('{') {
182 let brace_start = after_kw + brace_rel;
183 if source[after_kw..brace_start]
184 .chars()
185 .all(|c| c.is_whitespace())
186 && let Some(body_len) = match_braces(&source[brace_start..])
187 {
188 return Some(start..brace_start + body_len);
189 }
190 }
191 search_from = start + 1;
192 }
193 None
194}
195
196pub fn apply_fixes_c(source: &str, layouts: &[&StructLayout]) -> String {
203 apply_fixes(source, layouts, find_c_struct_span, generate_c_fix)
204}
205
206pub fn apply_fixes_rust(source: &str, layouts: &[&StructLayout]) -> String {
208 apply_fixes(source, layouts, find_rust_struct_span, generate_rust_fix)
209}
210
211pub fn apply_fixes_go(source: &str, layouts: &[&StructLayout]) -> String {
213 apply_fixes(source, layouts, find_go_struct_span, generate_go_fix)
214}
215
216pub fn generate_zig_fix(layout: &StructLayout) -> String {
220 let optimal = optimal_order(layout);
221 let qualifier = if layout.is_packed { "packed " } else { "" };
222 let mut out = format!("const {} = {}struct {{\n", layout.name, qualifier);
223 for field in &optimal {
224 let ty = field_type_name(field);
225 out.push_str(&format!(" {}: {ty},\n", field.name));
226 }
227 out.push_str("};\n");
228 out
229}
230
231pub fn find_zig_struct_span(source: &str, struct_name: &str) -> Option<std::ops::Range<usize>> {
234 let needle = format!("const {struct_name}");
236 let mut search_from = 0usize;
237 while let Some(rel) = source[search_from..].find(&needle) {
238 let start = search_from + rel;
239 let after_name = start + needle.len();
240 let rest = source[after_name..].trim_start();
242 if !rest.starts_with('=') {
243 search_from = start + 1;
244 continue;
245 }
246 let after_eq = after_name + source[after_name..].find('=')? + 1;
248 let after_eq_rest = &source[after_eq..];
249 if let Some(struct_rel) = after_eq_rest.find("struct") {
251 let prefix = &after_eq_rest[..struct_rel];
254 let prefix_clean = prefix.trim();
255 if prefix_clean.is_empty() || prefix_clean == "packed" || prefix_clean == "extern" {
256 let struct_kw_end = after_eq + struct_rel + "struct".len();
257 if let Some(brace_rel) = source[struct_kw_end..].find('{') {
258 let brace_start = struct_kw_end + brace_rel;
259 if source[struct_kw_end..brace_start]
260 .chars()
261 .all(|c| c.is_whitespace())
262 && let Some(body_len) = match_braces(&source[brace_start..])
263 {
264 let end = consume_semicolon(source, brace_start + body_len);
265 return Some(start..end);
266 }
267 }
268 }
269 }
270 search_from = start + 1;
271 }
272 None
273}
274
275pub fn apply_fixes_zig(source: &str, layouts: &[&StructLayout]) -> String {
277 apply_fixes(source, layouts, find_zig_struct_span, generate_zig_fix)
278}
279
280fn apply_fixes(
281 source: &str,
282 layouts: &[&StructLayout],
283 find_span: fn(&str, &str) -> Option<std::ops::Range<usize>>,
284 generate: fn(&StructLayout) -> String,
285) -> String {
286 let mut replacements: Vec<(usize, usize, String)> = layouts
288 .iter()
289 .filter_map(|layout| {
290 let span = find_span(source, &layout.name)?;
291 let fixed = generate(layout);
292 Some((span.start, span.end, fixed))
293 })
294 .collect();
295
296 replacements.sort_by_key(|(start, _, _)| *start);
298
299 let mut result = source.to_string();
300 for (start, end, fixed) in replacements.into_iter().rev() {
301 result.replace_range(start..end, &fixed);
302 }
303 result
304}
305
306fn field_type_name(field: &padlock_core::ir::Field) -> &str {
307 match &field.ty {
308 padlock_core::ir::TypeInfo::Primitive { name, .. }
309 | padlock_core::ir::TypeInfo::Opaque { name, .. } => name.as_str(),
310 padlock_core::ir::TypeInfo::Pointer { .. } => "void*",
311 padlock_core::ir::TypeInfo::Array { .. } => "/* array */",
312 padlock_core::ir::TypeInfo::Struct(l) => l.name.as_str(),
313 }
314}
315
316#[cfg(test)]
319mod tests {
320 use super::*;
321 use padlock_core::ir::test_fixtures::connection_layout;
322
323 #[test]
324 fn c_fix_starts_with_struct() {
325 let out = generate_c_fix(&connection_layout());
326 assert!(out.starts_with("struct Connection {"));
327 }
328
329 #[test]
330 fn c_fix_contains_all_fields() {
331 let out = generate_c_fix(&connection_layout());
332 assert!(out.contains("timeout"));
333 assert!(out.contains("port"));
334 assert!(out.contains("is_active"));
335 assert!(out.contains("is_tls"));
336 }
337
338 #[test]
339 fn c_fix_puts_largest_align_first() {
340 let out = generate_c_fix(&connection_layout());
341 let timeout_pos = out.find("timeout").unwrap();
342 let is_active_pos = out.find("is_active").unwrap();
343 assert!(timeout_pos < is_active_pos);
344 }
345
346 #[test]
347 fn rust_fix_uses_colon_syntax() {
348 let out = generate_rust_fix(&connection_layout());
349 assert!(out.contains(": f64"));
350 }
351
352 #[test]
353 fn unified_diff_marks_changes() {
354 let orig = "struct T { char a; double b; };\n";
355 let fixed = "struct T { double b; char a; };\n";
356 let diff = unified_diff(orig, fixed, 1);
357 assert!(diff.contains('-') || diff.contains('+'));
358 }
359
360 #[test]
361 fn unified_diff_identical_is_no_changes() {
362 assert_eq!(unified_diff("x\n", "x\n", 3), "(no changes)\n");
363 }
364
365 #[test]
368 fn find_c_struct_span_basic() {
369 let src = "struct Foo { int x; char y; };\nstruct Bar { double z; };\n";
370 let span = find_c_struct_span(src, "Foo").unwrap();
371 let text = &src[span];
372 assert!(text.starts_with("struct Foo"));
373 assert!(!text.contains("Bar"));
374 }
375
376 #[test]
377 fn find_c_struct_span_missing_returns_none() {
378 let src = "struct Other { int x; };";
379 assert!(find_c_struct_span(src, "Missing").is_none());
380 }
381
382 #[test]
383 fn find_rust_struct_span_basic() {
384 let src = "struct Foo {\n x: u32,\n y: u8,\n}\n";
385 let span = find_rust_struct_span(src, "Foo").unwrap();
386 assert!(src[span].starts_with("struct Foo"));
387 }
388
389 #[test]
390 fn find_go_struct_span_basic() {
391 let src = "type Foo struct {\n\tX int32\n\tY bool\n}\n";
392 let span = find_go_struct_span(src, "Foo").unwrap();
393 assert!(src[span].starts_with("type Foo struct"));
394 }
395
396 #[test]
399 fn apply_fixes_c_reorders_in_place() {
400 let src = "struct Connection { bool is_active; double timeout; bool is_tls; int port; };\n";
402 let layout = connection_layout();
403 let fixed = apply_fixes_c(src, &[&layout]);
404 let timeout_pos = fixed.find("timeout").unwrap();
405 let is_active_pos = fixed.find("is_active").unwrap();
406 assert!(
407 timeout_pos < is_active_pos,
408 "double should appear before bool after reorder"
409 );
410 }
411
412 #[test]
413 fn apply_fixes_rust_reorders_in_place() {
414 let src = "struct Connection {\n is_active: bool,\n timeout: f64,\n is_tls: bool,\n port: i32,\n}\n";
415 let layout = connection_layout();
416 let fixed = apply_fixes_rust(src, &[&layout]);
417 let timeout_pos = fixed.find("timeout").unwrap();
418 let is_active_pos = fixed.find("is_active").unwrap();
419 assert!(timeout_pos < is_active_pos);
420 }
421
422 #[test]
423 fn go_fix_uses_tab_syntax() {
424 let layout = connection_layout();
425 let out = generate_go_fix(&layout);
426 assert!(out.starts_with("type Connection struct"));
427 assert!(out.contains('\t'));
428 }
429
430 #[test]
431 fn zig_fix_uses_const_struct_syntax() {
432 let out = generate_zig_fix(&connection_layout());
433 assert!(out.starts_with("const Connection = struct {"));
434 assert!(out.ends_with("};\n"));
435 }
436
437 #[test]
438 fn find_zig_struct_span_basic() {
439 let src = "const S = struct {\n x: u32,\n y: u8,\n};\n";
440 let span = find_zig_struct_span(src, "S").unwrap();
441 assert!(src[span].starts_with("const S = struct"));
442 }
443
444 #[test]
445 fn find_zig_struct_span_packed() {
446 let src = "const S = packed struct {\n x: u32,\n y: u8,\n};\n";
447 let span = find_zig_struct_span(src, "S").unwrap();
448 assert!(src[span].contains("packed struct"));
449 }
450
451 #[test]
452 fn find_zig_struct_span_missing_returns_none() {
453 let src = "const Other = struct { x: u8 };\n";
454 assert!(find_zig_struct_span(src, "Missing").is_none());
455 }
456
457 #[test]
458 fn apply_fixes_zig_reorders_in_place() {
459 use crate::parse_source_str;
460 use padlock_core::arch::X86_64_SYSV;
461 let src = "const S = struct {\n a: u8,\n b: u64,\n};\n";
462 let layouts = parse_source_str(src, &crate::SourceLanguage::Zig, &X86_64_SYSV).unwrap();
463 let layout = &layouts[0];
464 let fixed = apply_fixes_zig(src, &[layout]);
465 let b_pos = fixed.find("b:").unwrap();
467 let a_pos = fixed.find("a:").unwrap();
468 assert!(
469 b_pos < a_pos,
470 "u64 field should come before u8 after reorder"
471 );
472 }
473}