Skip to main content

chipi_core/
codegen_python.rs

1//! Shared Python code generation helpers.
2//!
3//! Provides reusable functions for emitting Python code from the chipi IR.
4//! Used by the IDA backend (and future Binary Ninja backend).
5
6use std::collections::HashMap;
7use std::fmt::Write;
8
9use crate::tree::DecodeNode;
10use crate::types::*;
11
12/// Configuration for display formatting in generated Python code.
13#[derive(Debug, Clone, Default)]
14pub struct DisplayConfig {
15    /// Maps type alias names to display prefixes (e.g., "gpr" -> "r", "fpr" -> "f").
16    pub type_prefixes: HashMap<String, String>,
17}
18
19/// Emit `_fmt_signed_hex` and `_fmt_hex` helper functions if needed.
20pub fn emit_display_format_helpers(out: &mut String, def: &ValidatedDef) {
21    if needs_display_format(def) {
22        writeln!(out, "def _fmt_signed_hex(v):").unwrap();
23        writeln!(out, "    if v < 0:").unwrap();
24        writeln!(out, "        return f\"-0x{{-v:x}}\"").unwrap();
25        writeln!(out, "    return f\"0x{{v:x}}\"").unwrap();
26        writeln!(out).unwrap();
27        writeln!(out).unwrap();
28        writeln!(out, "def _fmt_hex(v):").unwrap();
29        writeln!(out, "    return f\"0x{{v:x}}\"").unwrap();
30        writeln!(out).unwrap();
31        writeln!(out).unwrap();
32    }
33}
34
35/// Check if any field uses display(hex) or display(signed_hex).
36fn needs_display_format(def: &ValidatedDef) -> bool {
37    for instr in &def.instructions {
38        for field in &instr.resolved_fields {
39            if field.resolved_type.display_format.is_some() {
40                return true;
41            }
42        }
43    }
44    for sd in &def.sub_decoders {
45        for instr in &sd.instructions {
46            for field in &instr.resolved_fields {
47                if field.resolved_type.display_format.is_some() {
48                    return true;
49                }
50            }
51        }
52    }
53    false
54}
55
56/// Emit a `_sign_extend(val, bits)` Python helper function.
57pub fn emit_sign_extend_helper(out: &mut String) {
58    writeln!(out, "def _sign_extend(val, bits):").unwrap();
59    writeln!(out, "    mask = 1 << (bits - 1)").unwrap();
60    writeln!(out, "    return (val ^ mask) - mask").unwrap();
61    writeln!(out).unwrap();
62    writeln!(out).unwrap();
63}
64
65/// Emit a `_rotate_right(val, amt, width)` Python helper function.
66pub fn emit_rotate_helpers(out: &mut String) {
67    writeln!(out, "def _rotate_right(val, amt, width=32):").unwrap();
68    writeln!(out, "    amt = amt % width").unwrap();
69    writeln!(out, "    mask = (1 << width) - 1").unwrap();
70    writeln!(out, "    val = val & mask").unwrap();
71    writeln!(
72        out,
73        "    return ((val >> amt) | (val << (width - amt))) & mask"
74    )
75    .unwrap();
76    writeln!(out).unwrap();
77    writeln!(out).unwrap();
78    writeln!(out, "def _rotate_left(val, amt, width=32):").unwrap();
79    writeln!(out, "    amt = amt % width").unwrap();
80    writeln!(out, "    mask = (1 << width) - 1").unwrap();
81    writeln!(out, "    val = val & mask").unwrap();
82    writeln!(
83        out,
84        "    return ((val << amt) | (val >> (width - amt))) & mask"
85    )
86    .unwrap();
87    writeln!(out).unwrap();
88    writeln!(out).unwrap();
89}
90
91/// Generate Python dict-based map (lookup) functions.
92pub fn emit_map_functions_python(out: &mut String, maps: &[MapDef]) {
93    for map_def in maps {
94        let params: Vec<&str> = map_def.params.iter().map(|s| s.as_str()).collect();
95        writeln!(out, "def {}({}):", map_def.name, params.join(", ")).unwrap();
96
97        // Find the wildcard/default entry
98        let default_entry = map_def
99            .entries
100            .iter()
101            .find(|e| e.keys.len() == 1 && e.keys[0] == MapKey::Wildcard);
102
103        // Build the lookup dict
104        writeln!(out, "    _MAP = {{").unwrap();
105        for entry in &map_def.entries {
106            if entry.keys.iter().any(|k| *k == MapKey::Wildcard) {
107                continue; // skip wildcard, it's the default
108            }
109            let key = if entry.keys.len() == 1 {
110                format_map_key_python(&entry.keys[0])
111            } else {
112                let keys: Vec<String> = entry.keys.iter().map(format_map_key_python).collect();
113                format!("({})", keys.join(", "))
114            };
115            let value = format_pieces_to_python_str_simple(&entry.output);
116            writeln!(out, "        {}: {},", key, value).unwrap();
117        }
118        writeln!(out, "    }}").unwrap();
119
120        let lookup_key = if params.len() == 1 {
121            params[0].to_string()
122        } else {
123            format!("({})", params.join(", "))
124        };
125
126        let default_val = if let Some(entry) = default_entry {
127            format_pieces_to_python_str_simple(&entry.output)
128        } else {
129            "\"???\"".to_string()
130        };
131
132        writeln!(out, "    return _MAP.get({}, {})", lookup_key, default_val).unwrap();
133        writeln!(out).unwrap();
134        writeln!(out).unwrap();
135    }
136}
137
138fn format_map_key_python(key: &MapKey) -> String {
139    match key {
140        MapKey::Value(v) => format!("{}", v),
141        MapKey::Wildcard => "None".to_string(),
142    }
143}
144
145/// Convert format pieces to a simple Python string expression (no IDA-specific formatting).
146/// Returns a Python expression that evaluates to a string.
147fn format_pieces_to_python_str_simple(pieces: &[FormatPiece]) -> String {
148    if pieces.is_empty() {
149        return "\"\"".to_string();
150    }
151
152    // Check if all pieces are literals
153    let all_literal = pieces.iter().all(|p| matches!(p, FormatPiece::Literal(_)));
154    if all_literal {
155        let mut s = String::new();
156        for piece in pieces {
157            if let FormatPiece::Literal(lit) = piece {
158                s.push_str(lit);
159            }
160        }
161        return format!("\"{}\"", escape_python_str(&s));
162    }
163
164    // Build f-string
165    let mut parts = String::new();
166    parts.push_str("f\"");
167    for piece in pieces {
168        match piece {
169            FormatPiece::Literal(lit) => {
170                parts.push_str(&escape_python_fstr(lit));
171            }
172            FormatPiece::FieldRef { expr, spec } => {
173                parts.push('{');
174                parts.push_str(&expr_to_python_simple(expr));
175                if let Some(spec) = spec {
176                    parts.push(':');
177                    parts.push_str(spec);
178                }
179                parts.push('}');
180            }
181        }
182    }
183    parts.push('"');
184    parts
185}
186
187fn expr_to_python_simple(expr: &FormatExpr) -> String {
188    match expr {
189        FormatExpr::Field(name) => name.clone(),
190        FormatExpr::IntLiteral(val) => format!("{}", val),
191        FormatExpr::Arithmetic { left, op, right } => {
192            let l = expr_to_python_simple(left);
193            let r = expr_to_python_simple(right);
194            let op_str = arith_op_str(op);
195            format!("({} {} {})", l, op_str, r)
196        }
197        FormatExpr::MapCall { map_name, args } => {
198            let arg_strs: Vec<String> = args.iter().map(|a| expr_to_python_simple(a)).collect();
199            format!("{}({})", map_name, arg_strs.join(", "))
200        }
201        FormatExpr::Ternary {
202            field,
203            if_nonzero,
204            if_zero,
205        } => {
206            let else_val = if_zero.as_deref().unwrap_or("");
207            format!(
208                "(\"{}\" if {} else \"{}\")",
209                escape_python_str(if_nonzero),
210                field,
211                escape_python_str(else_val)
212            )
213        }
214        FormatExpr::BuiltinCall { func, args } => {
215            let arg_strs: Vec<String> = args.iter().map(|a| expr_to_python_simple(a)).collect();
216            match func {
217                BuiltinFunc::RotateRight => {
218                    format!(
219                        "_rotate_right({}, {})",
220                        arg_strs.first().map(|s| s.as_str()).unwrap_or("0"),
221                        arg_strs.get(1).map(|s| s.as_str()).unwrap_or("0")
222                    )
223                }
224                BuiltinFunc::RotateLeft => {
225                    format!(
226                        "_rotate_left({}, {})",
227                        arg_strs.first().map(|s| s.as_str()).unwrap_or("0"),
228                        arg_strs.get(1).map(|s| s.as_str()).unwrap_or("0")
229                    )
230                }
231            }
232        }
233        FormatExpr::SubDecoderAccess { field, fragment } => {
234            format!("{}[\"{}\"]", field, fragment)
235        }
236    }
237}
238
239fn arith_op_str(op: &ArithOp) -> &'static str {
240    match op {
241        ArithOp::Add => "+",
242        ArithOp::Sub => "-",
243        ArithOp::Mul => "*",
244        ArithOp::Div => "//",
245        ArithOp::Mod => "%",
246    }
247}
248
249/// Escape a string for use inside Python string literals.
250pub fn escape_python_str(s: &str) -> String {
251    s.replace('\\', "\\\\").replace('"', "\\\"")
252}
253
254/// Escape a string for use inside Python f-string literals.
255pub fn escape_python_fstr(s: &str) -> String {
256    s.replace('\\', "\\\\")
257        .replace('"', "\\\"")
258        .replace('{', "{{")
259        .replace('}', "}}")
260}
261
262/// Generate the Python `_decode(data)` function body using the decision tree.
263///
264/// Returns the complete function as a string. The function returns
265/// `(itype_const, fields_dict, byte_size)` or `None`.
266pub fn emit_decode_function(
267    out: &mut String,
268    def: &ValidatedDef,
269    tree: &DecodeNode,
270    itype_prefix: &str,
271) {
272    let unit_bytes = def.config.width / 8;
273    let endian = match def.config.endian {
274        ByteEndian::Big => "big",
275        ByteEndian::Little => "little",
276    };
277    let variable_length = def.instructions.iter().any(|i| i.unit_count() > 1);
278
279    writeln!(out, "def _decode(data):").unwrap();
280    writeln!(out, "    if len(data) < {}:", unit_bytes).unwrap();
281    writeln!(out, "        return None").unwrap();
282    writeln!(
283        out,
284        "    opcode = int.from_bytes(data[0:{}], byteorder=\"{}\")",
285        unit_bytes, endian
286    )
287    .unwrap();
288
289    emit_tree_python(
290        out,
291        tree,
292        def,
293        1,
294        itype_prefix,
295        variable_length,
296        unit_bytes,
297        endian,
298    );
299
300    writeln!(out).unwrap();
301    writeln!(out).unwrap();
302}
303
304/// Recursively emit Python decision tree code.
305fn emit_tree_python(
306    out: &mut String,
307    node: &DecodeNode,
308    def: &ValidatedDef,
309    indent: usize,
310    itype_prefix: &str,
311    variable_length: bool,
312    unit_bytes: u32,
313    endian: &str,
314) {
315    let pad = "    ".repeat(indent);
316    match node {
317        DecodeNode::Leaf { instruction_index } => {
318            let instr = &def.instructions[*instruction_index];
319            if let Some(guard) = leaf_guard_python(instr, unit_bytes, endian) {
320                writeln!(out, "{}if {}:", pad, guard).unwrap();
321                emit_return_decoded(
322                    out,
323                    instr,
324                    itype_prefix,
325                    indent + 1,
326                    variable_length,
327                    unit_bytes,
328                    endian,
329                );
330                writeln!(out, "{}else:", pad).unwrap();
331                writeln!(out, "{}    return None", pad).unwrap();
332            } else {
333                emit_return_decoded(
334                    out,
335                    instr,
336                    itype_prefix,
337                    indent,
338                    variable_length,
339                    unit_bytes,
340                    endian,
341                );
342            }
343        }
344        DecodeNode::PriorityLeaves { candidates } => {
345            for (i, &idx) in candidates.iter().enumerate() {
346                let instr = &def.instructions[idx];
347                let guard = leaf_guard_python(instr, unit_bytes, endian);
348
349                if i == 0 {
350                    if let Some(guard_expr) = guard {
351                        writeln!(out, "{}if {}:", pad, guard_expr).unwrap();
352                        emit_return_decoded(
353                            out,
354                            instr,
355                            itype_prefix,
356                            indent + 1,
357                            variable_length,
358                            unit_bytes,
359                            endian,
360                        );
361                    } else {
362                        emit_return_decoded(
363                            out,
364                            instr,
365                            itype_prefix,
366                            indent,
367                            variable_length,
368                            unit_bytes,
369                            endian,
370                        );
371                        break;
372                    }
373                } else if i == candidates.len() - 1 {
374                    // Last candidate
375                    if let Some(guard_expr) = guard {
376                        writeln!(out, "{}elif {}:", pad, guard_expr).unwrap();
377                        emit_return_decoded(
378                            out,
379                            instr,
380                            itype_prefix,
381                            indent + 1,
382                            variable_length,
383                            unit_bytes,
384                            endian,
385                        );
386                        writeln!(out, "{}else:", pad).unwrap();
387                        writeln!(out, "{}    return None", pad).unwrap();
388                    } else {
389                        writeln!(out, "{}else:", pad).unwrap();
390                        emit_return_decoded(
391                            out,
392                            instr,
393                            itype_prefix,
394                            indent + 1,
395                            variable_length,
396                            unit_bytes,
397                            endian,
398                        );
399                    }
400                } else {
401                    // Middle
402                    let guard_expr = guard.unwrap_or_else(|| "True".to_string());
403                    writeln!(out, "{}elif {}:", pad, guard_expr).unwrap();
404                    emit_return_decoded(
405                        out,
406                        instr,
407                        itype_prefix,
408                        indent + 1,
409                        variable_length,
410                        unit_bytes,
411                        endian,
412                    );
413                }
414            }
415        }
416        DecodeNode::Fail => {
417            writeln!(out, "{}return None", pad).unwrap();
418        }
419        DecodeNode::Branch {
420            range,
421            arms,
422            default,
423        } => {
424            let extract = extract_expr_python("opcode", &[*range], unit_bytes, endian);
425            let var_name = format!("_v{}", indent);
426            writeln!(out, "{}{} = {}", pad, var_name, extract).unwrap();
427
428            let mut first = true;
429            for (value, child) in arms {
430                if first {
431                    writeln!(out, "{}if {} == {:#x}:", pad, var_name, value).unwrap();
432                    first = false;
433                } else {
434                    writeln!(out, "{}elif {} == {:#x}:", pad, var_name, value).unwrap();
435                }
436                emit_tree_python(
437                    out,
438                    child,
439                    def,
440                    indent + 1,
441                    itype_prefix,
442                    variable_length,
443                    unit_bytes,
444                    endian,
445                );
446            }
447
448            // Default arm
449            if !arms.is_empty() {
450                writeln!(out, "{}else:", pad).unwrap();
451                emit_tree_python(
452                    out,
453                    default,
454                    def,
455                    indent + 1,
456                    itype_prefix,
457                    variable_length,
458                    unit_bytes,
459                    endian,
460                );
461            } else {
462                emit_tree_python(
463                    out,
464                    default,
465                    def,
466                    indent,
467                    itype_prefix,
468                    variable_length,
469                    unit_bytes,
470                    endian,
471                );
472            }
473        }
474    }
475}
476
477/// Emit Python code to return a decoded instruction tuple.
478fn emit_return_decoded(
479    out: &mut String,
480    instr: &ValidatedInstruction,
481    itype_prefix: &str,
482    indent: usize,
483    variable_length: bool,
484    unit_bytes: u32,
485    endian: &str,
486) {
487    let pad = "    ".repeat(indent);
488    let unit_count = instr.unit_count();
489    let bytes_consumed = unit_count * unit_bytes;
490    let itype_const = format!("{}_{}", itype_prefix, instr.name.to_ascii_uppercase());
491
492    if variable_length && unit_count > 1 {
493        writeln!(out, "{}if len(data) < {}:", pad, bytes_consumed).unwrap();
494        writeln!(out, "{}    return None", pad).unwrap();
495    }
496
497    if instr.resolved_fields.is_empty() {
498        writeln!(
499            out,
500            "{}return ({}, {{}}, {})",
501            pad, itype_const, bytes_consumed
502        )
503        .unwrap();
504    } else {
505        // Extract fields
506        for field in &instr.resolved_fields {
507            let extract = extract_expr_python("opcode", &field.ranges, unit_bytes, endian);
508            let expr = apply_transforms_python(&extract, &field.resolved_type);
509            writeln!(out, "{}{} = {}", pad, field.name, expr).unwrap();
510        }
511
512        // Build fields dict
513        let field_names: Vec<&str> = instr
514            .resolved_fields
515            .iter()
516            .map(|f| f.name.as_str())
517            .collect();
518        let dict_entries: Vec<String> = field_names
519            .iter()
520            .map(|n| format!("\"{}\": {}", n, n))
521            .collect();
522        writeln!(
523            out,
524            "{}return ({}, {{{}}}, {})",
525            pad,
526            itype_const,
527            dict_entries.join(", "),
528            bytes_consumed
529        )
530        .unwrap();
531    }
532}
533
534/// Generate a Python expression to extract bits from ranges.
535pub fn extract_expr_python(
536    var: &str,
537    ranges: &[BitRange],
538    unit_bytes: u32,
539    endian: &str,
540) -> String {
541    if ranges.is_empty() {
542        return "0".to_string();
543    }
544
545    if ranges.len() == 1 {
546        let range = ranges[0];
547        let source = unit_read_python(range.unit, unit_bytes, endian);
548        let source = if range.unit == 0 {
549            var.to_string()
550        } else {
551            source
552        };
553
554        let width = range.width();
555        let shift = range.end;
556        let mask = (1u64 << width) - 1;
557
558        if shift == 0 {
559            format!("({} & {:#x})", source, mask)
560        } else {
561            format!("(({} >> {}) & {:#x})", source, shift, mask)
562        }
563    } else {
564        let mut parts = Vec::new();
565        let mut accumulated_width = 0u32;
566
567        for range in ranges {
568            let source = if range.unit == 0 {
569                var.to_string()
570            } else {
571                unit_read_python(range.unit, unit_bytes, endian)
572            };
573
574            let width = range.width();
575            let shift = range.end;
576            let mask = (1u64 << width) - 1;
577
578            let extracted = if shift == 0 {
579                format!("({} & {:#x})", source, mask)
580            } else {
581                format!("(({} >> {}) & {:#x})", source, shift, mask)
582            };
583
584            if accumulated_width > 0 {
585                parts.push(format!("({} << {})", extracted, accumulated_width));
586            } else {
587                parts.push(extracted);
588            }
589
590            accumulated_width += width;
591        }
592
593        parts.join(" | ")
594    }
595}
596
597/// Generate a Python expression to read a unit from the data buffer.
598fn unit_read_python(unit: u32, unit_bytes: u32, endian: &str) -> String {
599    if unit == 0 {
600        "opcode".to_string()
601    } else {
602        let start = unit * unit_bytes;
603        let end = start + unit_bytes;
604        format!(
605            "int.from_bytes(data[{}:{}], byteorder=\"{}\")",
606            start, end, endian
607        )
608    }
609}
610
611/// Apply chipi transforms to a Python expression.
612pub fn apply_transforms_python(extract_expr: &str, resolved: &ResolvedFieldType) -> String {
613    let mut expr = extract_expr.to_string();
614
615    for transform in &resolved.transforms {
616        match transform {
617            Transform::SignExtend(n) => {
618                expr = format!("_sign_extend({}, {})", expr, n);
619            }
620            Transform::ZeroExtend(_) => {
621                // No-op in Python (arbitrary precision)
622            }
623            Transform::ShiftLeft(n) => {
624                expr = format!("(({}) << {})", expr, n);
625            }
626        }
627    }
628
629    // Handle sub-decoder fields
630    if let Some(ref sd_name) = resolved.sub_decoder {
631        let decode_fn = format!("_decode_{}", to_snake_case(sd_name));
632        return format!("{}({})", decode_fn, expr);
633    }
634
635    expr
636}
637
638/// Compute a guard condition for a leaf node in Python.
639/// Returns `None` if no guard is needed.
640pub fn leaf_guard_python(
641    instr: &ValidatedInstruction,
642    unit_bytes: u32,
643    endian: &str,
644) -> Option<String> {
645    let fixed_bits = instr.fixed_bits();
646    if fixed_bits.is_empty() {
647        return None;
648    }
649
650    let mut units_map: HashMap<u32, Vec<(u32, Bit)>> = HashMap::new();
651    for (unit, hw_bit, bit) in fixed_bits {
652        units_map.entry(unit).or_default().push((hw_bit, bit));
653    }
654
655    let mut conditions = Vec::new();
656
657    for (unit, bits) in &units_map {
658        let (mask, value) = compute_mask_value(bits);
659        if mask != 0 {
660            let source = if *unit == 0 {
661                "opcode".to_string()
662            } else {
663                unit_read_python(*unit, unit_bytes, endian)
664            };
665            conditions.push(format!("{} & {:#x} == {:#x}", source, mask, value));
666        }
667    }
668
669    if conditions.is_empty() {
670        None
671    } else {
672        Some(conditions.join(" and "))
673    }
674}
675
676/// Compute a bitmask and expected value from fixed bits.
677fn compute_mask_value(fixed_bits: &[(u32, Bit)]) -> (u64, u64) {
678    let mut mask: u64 = 0;
679    let mut value: u64 = 0;
680    for &(bit_pos, bit_val) in fixed_bits {
681        if bit_val == Bit::Wildcard {
682            continue;
683        }
684        mask |= 1u64 << bit_pos;
685        if bit_val == Bit::One {
686            value |= 1u64 << bit_pos;
687        }
688    }
689    (mask, value)
690}
691
692/// Convert format pieces to a Python expression (concatenated string).
693/// `fields_var` is the variable name for the fields dict (e.g., "fields").
694pub fn format_pieces_to_python_expr(
695    pieces: &[FormatPiece],
696    fields: &[ResolvedField],
697    fields_var: &str,
698    display: &DisplayConfig,
699) -> String {
700    if pieces.is_empty() {
701        return "\"\"".to_string();
702    }
703
704    let all_literal = pieces.iter().all(|p| matches!(p, FormatPiece::Literal(_)));
705    if all_literal {
706        let mut s = String::new();
707        for piece in pieces {
708            if let FormatPiece::Literal(lit) = piece {
709                s.push_str(lit);
710            }
711        }
712        return format!("\"{}\"", escape_python_str(&s));
713    }
714
715    // Build an f-string-style expression
716    let mut result = String::from("f\"");
717    for piece in pieces {
718        match piece {
719            FormatPiece::Literal(lit) => {
720                result.push_str(&escape_python_fstr(lit));
721            }
722            FormatPiece::FieldRef { expr, spec } => {
723                if spec.is_some() {
724                    // Explicit format spec overrides display hints
725                    result.push('{');
726                    result.push_str(&expr_to_python(expr, fields, fields_var, display));
727                    if let Some(spec) = spec {
728                        result.push(':');
729                        result.push_str(spec);
730                    }
731                    result.push('}');
732                } else if let Some(wrapper) = resolve_display_wrapper(expr, fields, display) {
733                    // Apply display format (hex, signed_hex, or type prefix)
734                    result.push('{');
735                    result.push_str(&wrapper);
736                    result.push('}');
737                } else {
738                    result.push('{');
739                    result.push_str(&expr_to_python(expr, fields, fields_var, display));
740                    result.push('}');
741                }
742            }
743        }
744    }
745    result.push('"');
746    result
747}
748
749/// Resolve display formatting for a format expression.
750/// Returns a Python expression string with the display wrapper applied, or None.
751fn resolve_display_wrapper(
752    expr: &FormatExpr,
753    fields: &[ResolvedField],
754    display: &DisplayConfig,
755) -> Option<String> {
756    // Find the display format from the expression's primary field
757    let (display_fmt, alias_name) = resolve_display_info(expr, fields)?;
758
759    let raw_expr = expr_to_python(expr, fields, "fields", display);
760
761    // Check type prefix first (e.g., gpr -> "r")
762    if let Some(alias) = &alias_name {
763        if let Some(prefix) = display.type_prefixes.get(alias) {
764            return Some(format!(
765                "\"{}\" + str({})",
766                escape_python_str(prefix),
767                raw_expr
768            ));
769        }
770    }
771
772    // Check display format hint
773    match display_fmt? {
774        DisplayFormat::SignedHex => Some(format!("_fmt_signed_hex({})", raw_expr)),
775        DisplayFormat::Hex => Some(format!("_fmt_hex({})", raw_expr)),
776    }
777}
778
779/// Extract display format and alias name from a format expression.
780/// Looks at the primary field in the expression.
781fn resolve_display_info(
782    expr: &FormatExpr,
783    fields: &[ResolvedField],
784) -> Option<(Option<DisplayFormat>, Option<String>)> {
785    match expr {
786        FormatExpr::Field(name) => {
787            let field = fields.iter().find(|f| f.name == *name)?;
788            Some((
789                field.resolved_type.display_format,
790                field.resolved_type.alias_name.clone(),
791            ))
792        }
793        FormatExpr::Arithmetic { left, right, .. } => {
794            // Try left first, then right
795            resolve_display_info(left, fields).or_else(|| resolve_display_info(right, fields))
796        }
797        _ => None,
798    }
799}
800
801/// Convert a FormatExpr to a Python expression string.
802pub fn expr_to_python(
803    expr: &FormatExpr,
804    fields: &[ResolvedField],
805    fields_var: &str,
806    display: &DisplayConfig,
807) -> String {
808    match expr {
809        FormatExpr::Field(name) => {
810            format!("{}[\"{}\"]", fields_var, name)
811        }
812        FormatExpr::Ternary {
813            field,
814            if_nonzero,
815            if_zero,
816        } => {
817            let else_val = if_zero.as_deref().unwrap_or("");
818            format!(
819                "(\"{}\" if {}[\"{}\"] else \"{}\")",
820                escape_python_str(if_nonzero),
821                fields_var,
822                field,
823                escape_python_str(else_val)
824            )
825        }
826        FormatExpr::Arithmetic { left, op, right } => {
827            let l = expr_to_python(left, fields, fields_var, display);
828            let r = expr_to_python(right, fields, fields_var, display);
829            let op_str = arith_op_str(op);
830            format!("({} {} {})", l, op_str, r)
831        }
832        FormatExpr::IntLiteral(val) => format!("{}", val),
833        FormatExpr::MapCall { map_name, args } => {
834            let arg_strs: Vec<String> = args
835                .iter()
836                .map(|a| expr_to_python(a, fields, fields_var, display))
837                .collect();
838            format!("{}({})", map_name, arg_strs.join(", "))
839        }
840        FormatExpr::BuiltinCall { func, args } => {
841            let arg_strs: Vec<String> = args
842                .iter()
843                .map(|a| expr_to_python(a, fields, fields_var, display))
844                .collect();
845            match func {
846                BuiltinFunc::RotateRight => {
847                    format!(
848                        "_rotate_right({}, {})",
849                        arg_strs.first().map(|s| s.as_str()).unwrap_or("0"),
850                        arg_strs.get(1).map(|s| s.as_str()).unwrap_or("0")
851                    )
852                }
853                BuiltinFunc::RotateLeft => {
854                    format!(
855                        "_rotate_left({}, {})",
856                        arg_strs.first().map(|s| s.as_str()).unwrap_or("0"),
857                        arg_strs.get(1).map(|s| s.as_str()).unwrap_or("0")
858                    )
859                }
860            }
861        }
862        FormatExpr::SubDecoderAccess { field, fragment } => {
863            format!("{}[\"{}\"][\"{}\"]", fields_var, field, fragment)
864        }
865    }
866}
867
868/// Generate a Python guard condition from a Guard.
869pub fn emit_guard_python(
870    guard: &Guard,
871    _fields: &[ResolvedField],
872    fields_var: &str,
873    _display: &DisplayConfig,
874) -> String {
875    let conditions: Vec<String> = guard
876        .conditions
877        .iter()
878        .map(|cond| {
879            let left = guard_operand_to_python(&cond.left, fields_var);
880            let right = guard_operand_to_python(&cond.right, fields_var);
881            let op = match cond.op {
882                CompareOp::Eq => "==",
883                CompareOp::Ne => "!=",
884                CompareOp::Lt => "<",
885                CompareOp::Le => "<=",
886                CompareOp::Gt => ">",
887                CompareOp::Ge => ">=",
888            };
889            format!("{} {} {}", left, op, right)
890        })
891        .collect();
892
893    conditions.join(" and ")
894}
895
896fn guard_operand_to_python(operand: &GuardOperand, fields_var: &str) -> String {
897    match operand {
898        GuardOperand::Field(name) => format!("{}[\"{}\"]", fields_var, name),
899        GuardOperand::Literal(val) => format!("{}", val),
900        GuardOperand::Expr { left, op, right } => {
901            let l = guard_operand_to_python(left, fields_var);
902            let r = guard_operand_to_python(right, fields_var);
903            let op_str = arith_op_str(op);
904            format!("({} {} {})", l, op_str, r)
905        }
906    }
907}
908
909/// Generate a Python sub-decoder dispatch function.
910pub fn emit_subdecoder_python(out: &mut String, sd: &ValidatedSubDecoder) {
911    let fn_name = format!("_decode_{}", to_snake_case(&sd.name));
912    let width = sd.width;
913    let _unit_bytes = width / 8;
914
915    // Generate map functions for sub-decoder-local maps
916    emit_map_functions_python(out, &sd.maps);
917
918    writeln!(out, "def {}(val):", fn_name).unwrap();
919
920    // Build a simple dispatch table using if/elif on masked values
921    for (i, instr) in sd.instructions.iter().enumerate() {
922        let (mask, value) = compute_instruction_mask_value(instr);
923
924        let keyword = if i == 0 { "if" } else { "elif" };
925        writeln!(out, "    {} val & {:#x} == {:#x}:", keyword, mask, value).unwrap();
926
927        // Extract fields
928        for field in &instr.resolved_fields {
929            let extract = extract_field_from_val(&field.ranges, width);
930            let expr = apply_transforms_python(&extract, &field.resolved_type);
931            writeln!(out, "        {} = {}", field.name, expr).unwrap();
932        }
933
934        // Build fragment dict
935        let mut frag_entries = Vec::new();
936        for frag in &instr.fragments {
937            let frag_expr =
938                format_pieces_to_python_subdecoder_str(&frag.pieces, &instr.resolved_fields);
939            frag_entries.push(format!("\"{}\": {}", frag.name, frag_expr));
940        }
941        writeln!(out, "        return {{{}}}", frag_entries.join(", ")).unwrap();
942    }
943
944    writeln!(out, "    return None").unwrap();
945    writeln!(out).unwrap();
946    writeln!(out).unwrap();
947}
948
949/// Compute mask/value for a sub-decoder instruction.
950fn compute_instruction_mask_value(instr: &ValidatedSubInstruction) -> (u64, u64) {
951    let mut mask: u64 = 0;
952    let mut value: u64 = 0;
953
954    for seg in &instr.segments {
955        if let Segment::Fixed {
956            ranges, pattern, ..
957        } = seg
958        {
959            let mut bit_idx = 0;
960            for range in ranges {
961                for i in 0..range.width() {
962                    if bit_idx < pattern.len() {
963                        let bit = pattern[bit_idx];
964                        if bit != Bit::Wildcard {
965                            let hw_bit = range.start - i;
966                            mask |= 1u64 << hw_bit;
967                            if bit == Bit::One {
968                                value |= 1u64 << hw_bit;
969                            }
970                        }
971                        bit_idx += 1;
972                    }
973                }
974            }
975        }
976    }
977
978    (mask, value)
979}
980
981/// Extract a field value from `val` for sub-decoder (single-unit, no data[] reads).
982fn extract_field_from_val(ranges: &[BitRange], _width: u32) -> String {
983    if ranges.is_empty() {
984        return "0".to_string();
985    }
986
987    if ranges.len() == 1 {
988        let range = ranges[0];
989        let width = range.width();
990        let shift = range.end;
991        let mask = (1u64 << width) - 1;
992
993        if shift == 0 {
994            format!("(val & {:#x})", mask)
995        } else {
996            format!("((val >> {}) & {:#x})", shift, mask)
997        }
998    } else {
999        let mut parts = Vec::new();
1000        let mut accumulated_width = 0u32;
1001
1002        for range in ranges {
1003            let width = range.width();
1004            let shift = range.end;
1005            let mask = (1u64 << width) - 1;
1006
1007            let extracted = if shift == 0 {
1008                format!("(val & {:#x})", mask)
1009            } else {
1010                format!("((val >> {}) & {:#x})", shift, mask)
1011            };
1012
1013            if accumulated_width > 0 {
1014                parts.push(format!("({} << {})", extracted, accumulated_width));
1015            } else {
1016                parts.push(extracted);
1017            }
1018
1019            accumulated_width += width;
1020        }
1021
1022        parts.join(" | ")
1023    }
1024}
1025
1026/// Convert format pieces for a sub-decoder fragment to a Python string expression.
1027/// Fields are referenced directly by name (not via a dict).
1028fn format_pieces_to_python_subdecoder_str(
1029    pieces: &[FormatPiece],
1030    fields: &[ResolvedField],
1031) -> String {
1032    if pieces.is_empty() {
1033        return "\"\"".to_string();
1034    }
1035
1036    let all_literal = pieces.iter().all(|p| matches!(p, FormatPiece::Literal(_)));
1037    if all_literal {
1038        let mut s = String::new();
1039        for piece in pieces {
1040            if let FormatPiece::Literal(lit) = piece {
1041                s.push_str(lit);
1042            }
1043        }
1044        return format!("\"{}\"", escape_python_str(&s));
1045    }
1046
1047    let mut result = String::from("f\"");
1048    for piece in pieces {
1049        match piece {
1050            FormatPiece::Literal(lit) => {
1051                result.push_str(&escape_python_fstr(lit));
1052            }
1053            FormatPiece::FieldRef { expr, spec } => {
1054                result.push('{');
1055                result.push_str(&expr_to_python_direct(expr, fields));
1056                if let Some(spec) = spec {
1057                    result.push(':');
1058                    result.push_str(spec);
1059                }
1060                result.push('}');
1061            }
1062        }
1063    }
1064    result.push('"');
1065    result
1066}
1067
1068/// Convert a FormatExpr to a Python expression where fields are local variables.
1069fn expr_to_python_direct(expr: &FormatExpr, _fields: &[ResolvedField]) -> String {
1070    match expr {
1071        FormatExpr::Field(name) => name.clone(),
1072        FormatExpr::Ternary {
1073            field,
1074            if_nonzero,
1075            if_zero,
1076        } => {
1077            let else_val = if_zero.as_deref().unwrap_or("");
1078            format!(
1079                "(\"{}\" if {} else \"{}\")",
1080                escape_python_str(if_nonzero),
1081                field,
1082                escape_python_str(else_val)
1083            )
1084        }
1085        FormatExpr::Arithmetic { left, op, right } => {
1086            let l = expr_to_python_direct(left, _fields);
1087            let r = expr_to_python_direct(right, _fields);
1088            let op_str = arith_op_str(op);
1089            format!("({} {} {})", l, op_str, r)
1090        }
1091        FormatExpr::IntLiteral(val) => format!("{}", val),
1092        FormatExpr::MapCall { map_name, args } => {
1093            let arg_strs: Vec<String> = args
1094                .iter()
1095                .map(|a| expr_to_python_direct(a, _fields))
1096                .collect();
1097            format!("{}({})", map_name, arg_strs.join(", "))
1098        }
1099        FormatExpr::BuiltinCall { func, args } => {
1100            let arg_strs: Vec<String> = args
1101                .iter()
1102                .map(|a| expr_to_python_direct(a, _fields))
1103                .collect();
1104            match func {
1105                BuiltinFunc::RotateRight => {
1106                    format!(
1107                        "_rotate_right({}, {})",
1108                        arg_strs.first().map(|s| s.as_str()).unwrap_or("0"),
1109                        arg_strs.get(1).map(|s| s.as_str()).unwrap_or("0")
1110                    )
1111                }
1112                BuiltinFunc::RotateLeft => {
1113                    format!(
1114                        "_rotate_left({}, {})",
1115                        arg_strs.first().map(|s| s.as_str()).unwrap_or("0"),
1116                        arg_strs.get(1).map(|s| s.as_str()).unwrap_or("0")
1117                    )
1118                }
1119            }
1120        }
1121        FormatExpr::SubDecoderAccess { field, fragment } => {
1122            format!("{}[\"{}\"]", field, fragment)
1123        }
1124    }
1125}
1126
1127/// Convert a name to snake_case.
1128pub fn to_snake_case(name: &str) -> String {
1129    let mut result = String::new();
1130    for (i, ch) in name.chars().enumerate() {
1131        if ch.is_ascii_uppercase() && i > 0 {
1132            result.push('_');
1133        }
1134        result.push(ch.to_ascii_lowercase());
1135    }
1136    result
1137}
1138
1139/// Check if any instruction uses rotate builtins.
1140pub fn needs_rotate_helpers(def: &ValidatedDef) -> bool {
1141    for instr in &def.instructions {
1142        for fl in &instr.format_lines {
1143            for piece in &fl.pieces {
1144                if let FormatPiece::FieldRef { expr, .. } = piece {
1145                    if expr_uses_rotate(expr) {
1146                        return true;
1147                    }
1148                }
1149            }
1150        }
1151    }
1152    false
1153}
1154
1155fn expr_uses_rotate(expr: &FormatExpr) -> bool {
1156    match expr {
1157        FormatExpr::BuiltinCall { func, .. } => {
1158            matches!(func, BuiltinFunc::RotateRight | BuiltinFunc::RotateLeft)
1159        }
1160        FormatExpr::Arithmetic { left, right, .. } => {
1161            expr_uses_rotate(left) || expr_uses_rotate(right)
1162        }
1163        _ => false,
1164    }
1165}
1166
1167/// Check if any instruction uses sign_extend transforms.
1168pub fn needs_sign_extend(def: &ValidatedDef) -> bool {
1169    for instr in &def.instructions {
1170        for field in &instr.resolved_fields {
1171            for transform in &field.resolved_type.transforms {
1172                if matches!(transform, Transform::SignExtend(_)) {
1173                    return true;
1174                }
1175            }
1176        }
1177    }
1178    // Also check sub-decoders
1179    for sd in &def.sub_decoders {
1180        for instr in &sd.instructions {
1181            for field in &instr.resolved_fields {
1182                for transform in &field.resolved_type.transforms {
1183                    if matches!(transform, Transform::SignExtend(_)) {
1184                        return true;
1185                    }
1186                }
1187            }
1188        }
1189    }
1190    false
1191}
1192
1193/// Generate the `_format_insn(itype, fields)` Python function.
1194/// Returns (mnemonic_str, operands_str).
1195pub fn emit_format_function(
1196    out: &mut String,
1197    def: &ValidatedDef,
1198    itype_prefix: &str,
1199    display: &DisplayConfig,
1200) {
1201    writeln!(out, "def _format_insn(itype, fields):").unwrap();
1202    writeln!(
1203        out,
1204        "    \"\"\"Format an instruction. Returns (mnemonic, operands) strings.\"\"\""
1205    )
1206    .unwrap();
1207
1208    for (i, instr) in def.instructions.iter().enumerate() {
1209        let itype_const = format!("{}_{}", itype_prefix, instr.name.to_ascii_uppercase());
1210        let keyword = if i == 0 { "if" } else { "elif" };
1211        writeln!(out, "    {} itype == {}:", keyword, itype_const).unwrap();
1212
1213        if instr.format_lines.is_empty() {
1214            // Fallback: mnemonic is instruction name, operands are field values
1215            if instr.resolved_fields.is_empty() {
1216                writeln!(out, "        return \"{}\", \"\"", instr.name).unwrap();
1217            } else {
1218                let field_strs: Vec<String> = instr
1219                    .resolved_fields
1220                    .iter()
1221                    .map(|f| format!("str(fields[\"{}\"])", f.name))
1222                    .collect();
1223                writeln!(
1224                    out,
1225                    "        return \"{}\", \", \".join([{}])",
1226                    instr.name,
1227                    field_strs.join(", ")
1228                )
1229                .unwrap();
1230            }
1231        } else {
1232            emit_format_lines_python(out, instr, 2, display);
1233        }
1234    }
1235
1236    writeln!(out, "    return \"???\", \"\"").unwrap();
1237    writeln!(out).unwrap();
1238    writeln!(out).unwrap();
1239}
1240
1241/// Emit format lines for a single instruction as Python code.
1242fn emit_format_lines_python(
1243    out: &mut String,
1244    instr: &ValidatedInstruction,
1245    indent: usize,
1246    display: &DisplayConfig,
1247) {
1248    let pad = "    ".repeat(indent);
1249
1250    if instr.format_lines.len() == 1 && instr.format_lines[0].guard.is_none() {
1251        let fl = &instr.format_lines[0];
1252        let (mnemonic, operands) = split_format_pieces(&fl.pieces);
1253        let mnemonic_expr =
1254            format_pieces_to_python_expr(&mnemonic, &instr.resolved_fields, "fields", display);
1255        let operands_expr =
1256            format_pieces_to_python_expr(&operands, &instr.resolved_fields, "fields", display);
1257        writeln!(out, "{}return {}, {}", pad, mnemonic_expr, operands_expr).unwrap();
1258        return;
1259    }
1260
1261    // Multiple format lines with guards
1262    for (i, fl) in instr.format_lines.iter().enumerate() {
1263        let (mnemonic, operands) = split_format_pieces(&fl.pieces);
1264        let mnemonic_expr =
1265            format_pieces_to_python_expr(&mnemonic, &instr.resolved_fields, "fields", display);
1266        let operands_expr =
1267            format_pieces_to_python_expr(&operands, &instr.resolved_fields, "fields", display);
1268
1269        if let Some(guard) = &fl.guard {
1270            let guard_code = emit_guard_python(guard, &instr.resolved_fields, "fields", display);
1271            if i == 0 {
1272                writeln!(out, "{}if {}:", pad, guard_code).unwrap();
1273            } else {
1274                writeln!(out, "{}elif {}:", pad, guard_code).unwrap();
1275            }
1276            writeln!(
1277                out,
1278                "{}    return {}, {}",
1279                pad, mnemonic_expr, operands_expr
1280            )
1281            .unwrap();
1282        } else {
1283            if i > 0 {
1284                writeln!(out, "{}else:", pad).unwrap();
1285                writeln!(
1286                    out,
1287                    "{}    return {}, {}",
1288                    pad, mnemonic_expr, operands_expr
1289                )
1290                .unwrap();
1291            } else {
1292                writeln!(out, "{}return {}, {}", pad, mnemonic_expr, operands_expr).unwrap();
1293            }
1294        }
1295    }
1296}
1297
1298/// Split format pieces into (mnemonic_pieces, operand_pieces).
1299/// The mnemonic is everything before the first space; operands are the rest.
1300fn split_format_pieces(pieces: &[FormatPiece]) -> (Vec<FormatPiece>, Vec<FormatPiece>) {
1301    let mut mnemonic = Vec::new();
1302    let mut operands = Vec::new();
1303    let mut found_space = false;
1304
1305    for piece in pieces {
1306        if found_space {
1307            operands.push(piece.clone());
1308        } else {
1309            match piece {
1310                FormatPiece::Literal(lit) => {
1311                    if let Some(pos) = lit.find(' ') {
1312                        // Split this literal at the first space
1313                        let before = &lit[..pos];
1314                        let after = &lit[pos + 1..];
1315                        if !before.is_empty() {
1316                            mnemonic.push(FormatPiece::Literal(before.to_string()));
1317                        }
1318                        if !after.is_empty() {
1319                            operands.push(FormatPiece::Literal(after.to_string()));
1320                        }
1321                        found_space = true;
1322                    } else {
1323                        mnemonic.push(piece.clone());
1324                    }
1325                }
1326                _ => {
1327                    // Field ref before any space - part of mnemonic (unusual but possible)
1328                    mnemonic.push(piece.clone());
1329                }
1330            }
1331        }
1332    }
1333
1334    (mnemonic, operands)
1335}