Skip to main content

chipi_core/
codegen_cpp.rs

1//! C++ code generation from validated definitions and decision trees.
2//!
3//! Generates a single-header C++ file with instruction enum, decode function,
4//! and disassembly formatting.
5
6use std::collections::HashMap;
7use std::fmt::Write;
8
9use crate::backend::cpp::CppOptions;
10use crate::backend::cpp::GuardStyle;
11use crate::tree::DecodeNode;
12use crate::types::*;
13
14/// Generate a complete C++ header file.
15pub fn generate_cpp_code(
16    def: &ValidatedDef,
17    tree: &DecodeNode,
18    opts: &CppOptions,
19    type_maps: &HashMap<String, String>,
20) -> String {
21    let mut out = String::new();
22
23    let default_ns = to_snake_case(&def.config.name);
24    let ns = opts.namespace.as_deref().unwrap_or(&default_ns);
25    let guard_name = format!("CHIPI_{}_HPP", ns.to_ascii_uppercase());
26    let unit_bytes = def.config.width / 8;
27    let word_type = cpp_word_type(def.config.width);
28    let endian = &def.config.endian;
29    let variable_length = def.instructions.iter().any(|i| i.unit_count() > 1);
30
31    // Header guard
32    match opts.guard_style {
33        GuardStyle::Pragma => writeln!(out, "#pragma once").unwrap(),
34        GuardStyle::Ifndef => {
35            writeln!(out, "#ifndef {}", guard_name).unwrap();
36            writeln!(out, "#define {}", guard_name).unwrap();
37        }
38    }
39    writeln!(out).unwrap();
40    writeln!(out, "// Auto-generated by https://github.com/ioncodes/chipi").unwrap();
41    writeln!(out, "// Do not edit.").unwrap();
42    writeln!(out).unwrap();
43
44    // Includes
45    writeln!(out, "#include <cstdint>").unwrap();
46    writeln!(out, "#include <cstddef>").unwrap();
47    writeln!(out, "#include <cstring>").unwrap();
48    writeln!(out, "#include <string>").unwrap();
49    writeln!(out, "#include <optional>").unwrap();
50    writeln!(out, "#include <format>").unwrap();
51    for inc in &opts.includes {
52        writeln!(out, "#include \"{}\"", inc).unwrap();
53    }
54    writeln!(out).unwrap();
55
56    writeln!(out, "namespace {} {{", ns).unwrap();
57    writeln!(out).unwrap();
58
59    // Built-in display wrapper types for type aliases with display hints
60    emit_display_types(&mut out, def, type_maps);
61
62    // Opcode enum
63    emit_opcode_enum(&mut out, def);
64
65    // Instruction struct
66    emit_instruction_struct(&mut out, def, type_maps);
67
68    // Sub-decoder types and functions
69    for sd in &def.sub_decoders {
70        emit_subdecoder(&mut out, sd, def.config.width);
71    }
72
73    // Map functions
74    emit_map_functions(&mut out, def);
75
76    // Decode function
77    emit_decode_function(
78        &mut out,
79        def,
80        tree,
81        word_type,
82        unit_bytes,
83        endian,
84        variable_length,
85        type_maps,
86    );
87
88    // Format function
89    emit_format_function(&mut out, def, type_maps, opts);
90
91    writeln!(out, "}} // namespace {}", ns).unwrap();
92
93    if opts.guard_style == GuardStyle::Ifndef {
94        writeln!(out).unwrap();
95        writeln!(out, "#endif // {}", guard_name).unwrap();
96    }
97
98    out
99}
100
101fn cpp_word_type(width: u32) -> &'static str {
102    match width {
103        8 => "uint8_t",
104        16 => "uint16_t",
105        32 => "uint32_t",
106        _ => "uint32_t",
107    }
108}
109
110fn cpp_signed_type(base: &str) -> &'static str {
111    match base {
112        "u8" | "i8" => "int8_t",
113        "u16" | "i16" => "int16_t",
114        "u32" | "i32" => "int32_t",
115        _ => "int32_t",
116    }
117}
118
119fn cpp_type_for(base: &str) -> &'static str {
120    match base {
121        "bool" => "bool",
122        "u1" | "u2" | "u3" | "u4" | "u5" | "u6" | "u7" | "u8" => "uint8_t",
123        "i8" => "int8_t",
124        "u16" => "uint16_t",
125        "i16" => "int16_t",
126        "u32" => "uint32_t",
127        "i32" => "int32_t",
128        _ => "uint32_t",
129    }
130}
131
132fn type_bits(base: &str) -> u32 {
133    match base {
134        "u8" | "i8" => 8,
135        "u16" | "i16" => 16,
136        "u32" | "i32" => 32,
137        _ => 32,
138    }
139}
140
141fn to_snake_case(name: &str) -> String {
142    let mut result = String::new();
143    for (i, ch) in name.chars().enumerate() {
144        if ch.is_ascii_uppercase() && i > 0 {
145            result.push('_');
146        }
147        result.push(ch.to_ascii_lowercase());
148    }
149    result
150}
151
152fn to_pascal_case(name: &str) -> String {
153    let mut result = String::new();
154    let mut cap_next = true;
155    for ch in name.chars() {
156        if ch == '_' {
157            cap_next = true;
158        } else if cap_next {
159            result.push(ch.to_ascii_uppercase());
160            cap_next = false;
161        } else {
162            result.push(ch.to_ascii_lowercase());
163        }
164    }
165    result
166}
167
168
169/// Emit built-in wrapper types for type aliases with `display(hex)` or `display(signed_hex)`.
170/// Only emits types that aren't already overridden via `type_map`.
171fn emit_display_types(
172    out: &mut String,
173    def: &ValidatedDef,
174    type_maps: &HashMap<String, String>,
175) {
176    let mut need_signed_hex = false;
177    let mut need_hex = false;
178
179    for alias in &def.type_aliases {
180        // Skip if the user provided a type_map override
181        if type_maps.contains_key(&alias.name) {
182            continue;
183        }
184        match alias.display_format {
185            Some(DisplayFormat::SignedHex) => need_signed_hex = true,
186            Some(DisplayFormat::Hex) => need_hex = true,
187            None => {}
188        }
189    }
190
191    if need_signed_hex {
192        writeln!(out, "struct SignedHex {{").unwrap();
193        writeln!(out, "    int32_t value;").unwrap();
194        writeln!(out, "    SignedHex() = default;").unwrap();
195        writeln!(out, "    constexpr SignedHex(int32_t v) : value(v) {{}}").unwrap();
196        writeln!(out, "    bool operator==(const SignedHex&) const = default;").unwrap();
197        writeln!(out, "    bool operator==(int other) const {{ return value == other; }}").unwrap();
198        writeln!(out, "    bool operator!=(int other) const {{ return value != other; }}").unwrap();
199        writeln!(out, "    bool operator<(int other) const {{ return value < other; }}").unwrap();
200        writeln!(out, "    bool operator<=(int other) const {{ return value <= other; }}").unwrap();
201        writeln!(out, "    bool operator>(int other) const {{ return value > other; }}").unwrap();
202        writeln!(out, "    bool operator>=(int other) const {{ return value >= other; }}").unwrap();
203        writeln!(out, "    SignedHex operator-() const {{ return SignedHex(-value); }}").unwrap();
204        writeln!(out, "    friend SignedHex operator-(int lhs, SignedHex rhs) {{ return SignedHex(lhs - rhs.value); }}").unwrap();
205        writeln!(out, "    friend SignedHex operator+(int lhs, SignedHex rhs) {{ return SignedHex(lhs + rhs.value); }}").unwrap();
206        writeln!(out, "    friend SignedHex operator+(SignedHex lhs, int rhs) {{ return SignedHex(lhs.value + rhs); }}").unwrap();
207        writeln!(out, "    friend SignedHex operator-(SignedHex lhs, int rhs) {{ return SignedHex(lhs.value - rhs); }}").unwrap();
208        writeln!(out, "    friend SignedHex operator*(SignedHex lhs, int rhs) {{ return SignedHex(lhs.value * rhs); }}").unwrap();
209        writeln!(out, "}};").unwrap();
210        writeln!(out).unwrap();
211    }
212
213    if need_hex {
214        writeln!(out, "struct Hex {{").unwrap();
215        writeln!(out, "    uint32_t value;").unwrap();
216        writeln!(out, "    Hex() = default;").unwrap();
217        writeln!(out, "    constexpr Hex(uint32_t v) : value(v) {{}}").unwrap();
218        writeln!(out, "    bool operator==(const Hex&) const = default;").unwrap();
219        writeln!(out, "    bool operator==(unsigned other) const {{ return value == other; }}").unwrap();
220        writeln!(out, "}};").unwrap();
221        writeln!(out).unwrap();
222    }
223
224    writeln!(out, "}} // namespace {}", to_snake_case(&def.config.name)).unwrap();
225    writeln!(out).unwrap();
226
227    // std::formatter specializations must be outside the namespace
228    if need_signed_hex {
229        let ns = to_snake_case(&def.config.name);
230        writeln!(out, "template <> struct std::formatter<{}::SignedHex> : std::formatter<std::string> {{", ns).unwrap();
231        writeln!(out, "    auto format({}::SignedHex v, auto& ctx) const {{", ns).unwrap();
232        writeln!(out, "        if (v.value < 0)").unwrap();
233        writeln!(out, "            return std::formatter<std::string>::format(std::format(\"-0x{{:x}}\", static_cast<unsigned>(-v.value)), ctx);").unwrap();
234        writeln!(out, "        return std::formatter<std::string>::format(std::format(\"0x{{:x}}\", static_cast<unsigned>(v.value)), ctx);").unwrap();
235        writeln!(out, "    }}").unwrap();
236        writeln!(out, "}};").unwrap();
237        writeln!(out).unwrap();
238    }
239
240    if need_hex {
241        let ns = to_snake_case(&def.config.name);
242        writeln!(out, "template <> struct std::formatter<{}::Hex> : std::formatter<std::string> {{", ns).unwrap();
243        writeln!(out, "    auto format({}::Hex v, auto& ctx) const {{", ns).unwrap();
244        writeln!(out, "        return std::formatter<std::string>::format(std::format(\"0x{{:x}}\", v.value), ctx);").unwrap();
245        writeln!(out, "    }}").unwrap();
246        writeln!(out, "}};").unwrap();
247        writeln!(out).unwrap();
248    }
249
250    // Re-open the namespace
251    if need_signed_hex || need_hex {
252        writeln!(out, "namespace {} {{", to_snake_case(&def.config.name)).unwrap();
253        writeln!(out).unwrap();
254    }
255}
256
257/// Emit the Opcode enum.
258fn emit_opcode_enum(out: &mut String, def: &ValidatedDef) {
259    writeln!(out, "enum class Opcode : uint32_t {{").unwrap();
260    for (i, instr) in def.instructions.iter().enumerate() {
261        writeln!(out, "    {} = {},", to_pascal_case(&instr.name), i).unwrap();
262    }
263    writeln!(out, "}};").unwrap();
264    writeln!(out).unwrap();
265}
266
267/// Emit the Instruction struct with a tagged union for fields.
268fn emit_instruction_struct(
269    out: &mut String,
270    def: &ValidatedDef,
271    type_maps: &HashMap<String, String>,
272) {
273    writeln!(out, "struct Instruction {{").unwrap();
274    writeln!(out, "    Opcode opcode;").unwrap();
275    writeln!(out, "    uint32_t size; // bytes consumed").unwrap();
276    writeln!(out).unwrap();
277
278    // Generate a union with per-instruction field structs
279    let has_fields = def.instructions.iter().any(|i| !i.resolved_fields.is_empty());
280    if has_fields {
281        writeln!(out, "    union {{").unwrap();
282        for instr in &def.instructions {
283            if instr.resolved_fields.is_empty() {
284                continue;
285            }
286            writeln!(out, "        struct {{").unwrap();
287            for field in &instr.resolved_fields {
288                let cpp_type = field_cpp_type(field, type_maps);
289                writeln!(out, "            {} {};", cpp_type, field.name).unwrap();
290            }
291            writeln!(out, "        }} {};", instr.name).unwrap();
292        }
293        writeln!(out, "    }};").unwrap();
294    }
295
296    writeln!(out, "}};").unwrap();
297    writeln!(out).unwrap();
298}
299
300/// Get the C++ type for a field.
301fn field_cpp_type(field: &ResolvedField, type_maps: &HashMap<String, String>) -> String {
302    // Check type map first (user override)
303    if let Some(alias) = &field.resolved_type.alias_name {
304        if let Some(mapped) = type_maps.get(alias) {
305            return mapped.clone();
306        }
307    }
308
309    // Check sub-decoder
310    if let Some(ref sd_name) = field.resolved_type.sub_decoder {
311        return format!("{}Insn", to_pascal_case(sd_name));
312    }
313
314    // Check display format -> built-in wrapper types
315    match field.resolved_type.display_format {
316        Some(DisplayFormat::SignedHex) => return "SignedHex".to_string(),
317        Some(DisplayFormat::Hex) => return "Hex".to_string(),
318        None => {}
319    }
320
321    cpp_type_for(&field.resolved_type.base_type).to_string()
322}
323
324/// Emit sub-decoder struct and dispatch function.
325fn emit_subdecoder(out: &mut String, sd: &ValidatedSubDecoder, _parent_width: u32) {
326    let type_name = format!("{}Insn", to_pascal_case(&sd.name));
327    let word_type = cpp_word_type(sd.width);
328
329    // Fragment struct
330    writeln!(out, "struct {} {{", type_name).unwrap();
331    for frag_name in &sd.fragment_names {
332        writeln!(out, "    const char* {};", frag_name).unwrap();
333    }
334    writeln!(out, "}};").unwrap();
335    writeln!(out).unwrap();
336
337    // Pre-baked fragment strings for each instruction
338    // Then dispatch function
339    let fn_name = format!("decode_{}", to_snake_case(&sd.name));
340    writeln!(
341        out,
342        "inline std::optional<{}> {}({} val) {{",
343        type_name, fn_name, word_type
344    )
345    .unwrap();
346
347    for (i, instr) in sd.instructions.iter().enumerate() {
348        let (mask, value) = compute_instruction_mask_value_sub(instr);
349        let keyword = if i == 0 { "if" } else { "} else if" };
350        writeln!(
351            out,
352            "    {} ((val & {:#x}) == {:#x}) {{",
353            keyword, mask, value
354        )
355        .unwrap();
356
357        // Build fragment values
358        // For simplicity, use string literals where possible; for field-dependent
359        // fragments, generate inline formatting
360        for frag in &instr.fragments {
361            let all_literal = frag.pieces.iter().all(|p| matches!(p, FormatPiece::Literal(_)));
362            if all_literal {
363                let s: String = frag.pieces.iter().map(|p| {
364                    if let FormatPiece::Literal(lit) = p { lit.as_str() } else { "" }
365                }).collect();
366                writeln!(out, "        // {}.{} = \"{}\"", instr.name, frag.name, s).unwrap();
367            }
368        }
369
370        // Return struct with fragment values
371        let frag_values: Vec<String> = instr.fragments.iter().map(|frag| {
372            let all_literal = frag.pieces.iter().all(|p| matches!(p, FormatPiece::Literal(_)));
373            if all_literal {
374                let s: String = frag.pieces.iter().map(|p| {
375                    if let FormatPiece::Literal(lit) = p { lit.as_str() } else { "" }
376                }).collect();
377                format!("\"{}\"", s)
378            } else {
379                // For dynamic fragments, we'd need snprintf; use empty for now
380                "\"\"".to_string()
381            }
382        }).collect();
383
384        writeln!(
385            out,
386            "        return {} {{ {} }};",
387            type_name,
388            frag_values.join(", ")
389        )
390        .unwrap();
391    }
392
393    if !sd.instructions.is_empty() {
394        writeln!(out, "    }}").unwrap();
395    }
396    writeln!(out, "    return std::nullopt;").unwrap();
397    writeln!(out, "}}").unwrap();
398    writeln!(out).unwrap();
399}
400
401fn compute_instruction_mask_value_sub(instr: &ValidatedSubInstruction) -> (u64, u64) {
402    let mut mask: u64 = 0;
403    let mut value: u64 = 0;
404    for seg in &instr.segments {
405        if let Segment::Fixed { ranges, pattern, .. } = seg {
406            let mut bit_idx = 0;
407            for range in ranges {
408                for i in 0..range.width() {
409                    if bit_idx < pattern.len() {
410                        let bit = pattern[bit_idx];
411                        if bit != Bit::Wildcard {
412                            let hw_bit = range.start - i;
413                            mask |= 1u64 << hw_bit;
414                            if bit == Bit::One {
415                                value |= 1u64 << hw_bit;
416                            }
417                        }
418                        bit_idx += 1;
419                    }
420                }
421            }
422        }
423    }
424    (mask, value)
425}
426
427/// Emit map (lookup) functions.
428fn emit_map_functions(out: &mut String, def: &ValidatedDef) {
429    for map_def in &def.maps {
430        let params: Vec<String> = map_def
431            .params
432            .iter()
433            .map(|p| format!("int {}", p))
434            .collect();
435        writeln!(
436            out,
437            "inline const char* {}({}) {{",
438            map_def.name,
439            params.join(", ")
440        )
441        .unwrap();
442
443        let key_var = if map_def.params.len() == 1 {
444            map_def.params[0].clone()
445        } else {
446            // Multi-param: won't use switch
447            String::new()
448        };
449
450        if map_def.params.len() == 1 {
451            writeln!(out, "    switch ({}) {{", key_var).unwrap();
452            for entry in &map_def.entries {
453                if entry.keys.len() == 1 && entry.keys[0] == MapKey::Wildcard {
454                    continue;
455                }
456                if let Some(MapKey::Value(v)) = entry.keys.first() {
457                    let s = pieces_to_str(&entry.output);
458                    writeln!(out, "    case {}: return \"{}\";", v, s).unwrap();
459                }
460            }
461
462            let default_str = map_def
463                .entries
464                .iter()
465                .find(|e| e.keys.len() == 1 && e.keys[0] == MapKey::Wildcard)
466                .map(|e| pieces_to_str(&e.output))
467                .unwrap_or_else(|| "???".to_string());
468            writeln!(out, "    default: return \"{}\";", default_str).unwrap();
469            writeln!(out, "    }}").unwrap();
470        } else {
471            // Fallback for multi-param maps: if/else chain
472            for entry in &map_def.entries {
473                if entry.keys.iter().any(|k| *k == MapKey::Wildcard) {
474                    continue;
475                }
476                let conds: Vec<String> = entry
477                    .keys
478                    .iter()
479                    .zip(map_def.params.iter())
480                    .map(|(k, p)| {
481                        if let MapKey::Value(v) = k {
482                            format!("{} == {}", p, v)
483                        } else {
484                            "true".to_string()
485                        }
486                    })
487                    .collect();
488                let s = pieces_to_str(&entry.output);
489                writeln!(out, "    if ({}) return \"{}\";", conds.join(" && "), s).unwrap();
490            }
491            writeln!(out, "    return \"???\";").unwrap();
492        }
493
494        writeln!(out, "}}").unwrap();
495        writeln!(out).unwrap();
496    }
497
498    // Also emit sub-decoder maps
499    for sd in &def.sub_decoders {
500        for map_def in &sd.maps {
501            let params: Vec<String> = map_def
502                .params
503                .iter()
504                .map(|p| format!("int {}", p))
505                .collect();
506            writeln!(
507                out,
508                "inline const char* {}({}) {{",
509                map_def.name,
510                params.join(", ")
511            )
512            .unwrap();
513            writeln!(out, "    switch ({}) {{", map_def.params[0]).unwrap();
514            for entry in &map_def.entries {
515                if entry.keys.len() == 1 && entry.keys[0] == MapKey::Wildcard {
516                    continue;
517                }
518                if let Some(MapKey::Value(v)) = entry.keys.first() {
519                    let s = pieces_to_str(&entry.output);
520                    writeln!(out, "    case {}: return \"{}\";", v, s).unwrap();
521                }
522            }
523            let default_str = map_def
524                .entries
525                .iter()
526                .find(|e| e.keys.len() == 1 && e.keys[0] == MapKey::Wildcard)
527                .map(|e| pieces_to_str(&e.output))
528                .unwrap_or_else(|| "???".to_string());
529            writeln!(out, "    default: return \"{}\";", default_str).unwrap();
530            writeln!(out, "    }}").unwrap();
531            writeln!(out, "}}").unwrap();
532            writeln!(out).unwrap();
533        }
534    }
535}
536
537fn pieces_to_str(pieces: &[FormatPiece]) -> String {
538    let mut s = String::new();
539    for piece in pieces {
540        if let FormatPiece::Literal(lit) = piece {
541            s.push_str(lit);
542        }
543    }
544    s
545}
546
547/// Emit the decode function.
548fn emit_decode_function(
549    out: &mut String,
550    def: &ValidatedDef,
551    tree: &DecodeNode,
552    word_type: &str,
553    unit_bytes: u32,
554    endian: &ByteEndian,
555    variable_length: bool,
556    type_maps: &HashMap<String, String>,
557) {
558    writeln!(
559        out,
560        "inline std::optional<Instruction> decode(const uint8_t* data, size_t len) {{"
561    )
562    .unwrap();
563    writeln!(out, "    if (len < {}) return std::nullopt;", unit_bytes).unwrap();
564
565    // Read first unit
566    emit_word_read(out, "opcode", word_type, 0, unit_bytes, endian, 1);
567
568    emit_tree_cpp(
569        out, tree, def, 1, word_type, unit_bytes, endian, variable_length, type_maps,
570    );
571
572    writeln!(out, "}}").unwrap();
573    writeln!(out).unwrap();
574}
575
576/// Emit a word read from the data buffer.
577fn emit_word_read(
578    out: &mut String,
579    var_name: &str,
580    word_type: &str,
581    offset: u32,
582    unit_bytes: u32,
583    endian: &ByteEndian,
584    indent: usize,
585) {
586    let pad = "    ".repeat(indent);
587    match (unit_bytes, endian) {
588        (1, _) => {
589            writeln!(out, "{}{} {} = data[{}];", pad, word_type, var_name, offset).unwrap();
590        }
591        (2, ByteEndian::Big) => {
592            writeln!(
593                out,
594                "{}{} {} = (static_cast<uint16_t>(data[{}]) << 8) | data[{}];",
595                pad,
596                word_type,
597                var_name,
598                offset,
599                offset + 1
600            )
601            .unwrap();
602        }
603        (2, ByteEndian::Little) => {
604            writeln!(
605                out,
606                "{}{} {} = data[{}] | (static_cast<uint16_t>(data[{}]) << 8);",
607                pad,
608                word_type,
609                var_name,
610                offset,
611                offset + 1
612            )
613            .unwrap();
614        }
615        (4, ByteEndian::Big) => {
616            writeln!(
617                out,
618                "{}{} {} = (static_cast<uint32_t>(data[{}]) << 24) | (static_cast<uint32_t>(data[{}]) << 16) | (static_cast<uint32_t>(data[{}]) << 8) | data[{}];",
619                pad, word_type, var_name, offset, offset + 1, offset + 2, offset + 3
620            ).unwrap();
621        }
622        (4, ByteEndian::Little) => {
623            writeln!(
624                out,
625                "{}{} {} = data[{}] | (static_cast<uint32_t>(data[{}]) << 8) | (static_cast<uint32_t>(data[{}]) << 16) | (static_cast<uint32_t>(data[{}]) << 24);",
626                pad, word_type, var_name, offset, offset + 1, offset + 2, offset + 3
627            ).unwrap();
628        }
629        _ => {
630            writeln!(out, "{}// unsupported width", pad).unwrap();
631        }
632    }
633}
634
635/// Unit read expression (inline).
636fn unit_read_expr(unit: u32, _word_type: &str, unit_bytes: u32, endian: &ByteEndian) -> String {
637    if unit == 0 {
638        return "opcode".to_string();
639    }
640    let offset = unit * unit_bytes;
641    match (unit_bytes, endian) {
642        (1, _) => format!("data[{}]", offset),
643        (2, ByteEndian::Big) => format!(
644            "(static_cast<uint16_t>(data[{}]) << 8 | data[{}])",
645            offset,
646            offset + 1
647        ),
648        (2, ByteEndian::Little) => format!(
649            "(data[{}] | static_cast<uint16_t>(data[{}]) << 8)",
650            offset,
651            offset + 1
652        ),
653        (4, ByteEndian::Big) => format!(
654            "(static_cast<uint32_t>(data[{}]) << 24 | static_cast<uint32_t>(data[{}]) << 16 | static_cast<uint32_t>(data[{}]) << 8 | data[{}])",
655            offset, offset + 1, offset + 2, offset + 3
656        ),
657        (4, ByteEndian::Little) => format!(
658            "(data[{}] | static_cast<uint32_t>(data[{}]) << 8 | static_cast<uint32_t>(data[{}]) << 16 | static_cast<uint32_t>(data[{}]) << 24)",
659            offset, offset + 1, offset + 2, offset + 3
660        ),
661        _ => "0".to_string(),
662    }
663}
664
665/// Extract bits expression.
666fn extract_expr(
667    var: &str,
668    ranges: &[BitRange],
669    word_type: &str,
670    unit_bytes: u32,
671    endian: &ByteEndian,
672) -> String {
673    if ranges.is_empty() {
674        return "0".to_string();
675    }
676
677    if ranges.len() == 1 {
678        let range = ranges[0];
679        let source = if range.unit == 0 {
680            var.to_string()
681        } else {
682            unit_read_expr(range.unit, word_type, unit_bytes, endian)
683        };
684        let width = range.width();
685        let shift = range.end;
686        let mask = (1u64 << width) - 1;
687        if shift == 0 {
688            format!("({} & {:#x})", source, mask)
689        } else {
690            format!("(({} >> {}) & {:#x})", source, shift, mask)
691        }
692    } else {
693        let mut parts = Vec::new();
694        let mut accumulated = 0u32;
695        for range in ranges {
696            let source = if range.unit == 0 {
697                var.to_string()
698            } else {
699                unit_read_expr(range.unit, word_type, unit_bytes, endian)
700            };
701            let width = range.width();
702            let shift = range.end;
703            let mask = (1u64 << width) - 1;
704            let extracted = if shift == 0 {
705                format!("({} & {:#x})", source, mask)
706            } else {
707                format!("(({} >> {}) & {:#x})", source, shift, mask)
708            };
709            if accumulated > 0 {
710                parts.push(format!("({} << {})", extracted, accumulated));
711            } else {
712                parts.push(extracted);
713            }
714            accumulated += width;
715        }
716        parts.join(" | ")
717    }
718}
719
720/// Compute mask/value for leaf guard.
721fn leaf_guard(
722    instr: &ValidatedInstruction,
723    word_type: &str,
724    unit_bytes: u32,
725    endian: &ByteEndian,
726) -> Option<String> {
727    let fixed_bits = instr.fixed_bits();
728    if fixed_bits.is_empty() {
729        return None;
730    }
731
732    let mut units_map: HashMap<u32, Vec<(u32, Bit)>> = HashMap::new();
733    for (unit, hw_bit, bit) in fixed_bits {
734        units_map.entry(unit).or_default().push((hw_bit, bit));
735    }
736
737    let mut conditions = Vec::new();
738    for (unit, bits) in &units_map {
739        let (mask, value) = compute_mask_value(bits);
740        if mask != 0 {
741            let source = if *unit == 0 {
742                "opcode".to_string()
743            } else {
744                unit_read_expr(*unit, word_type, unit_bytes, endian)
745            };
746            conditions.push(format!("({} & {:#x}) == {:#x}", source, mask, value));
747        }
748    }
749
750    if conditions.is_empty() {
751        None
752    } else {
753        Some(conditions.join(" && "))
754    }
755}
756
757fn compute_mask_value(fixed_bits: &[(u32, Bit)]) -> (u64, u64) {
758    let mut mask: u64 = 0;
759    let mut value: u64 = 0;
760    for &(bit_pos, bit_val) in fixed_bits {
761        if bit_val == Bit::Wildcard {
762            continue;
763        }
764        mask |= 1u64 << bit_pos;
765        if bit_val == Bit::One {
766            value |= 1u64 << bit_pos;
767        }
768    }
769    (mask, value)
770}
771
772/// Apply transforms to an extraction expression.
773fn apply_transforms(extract: &str, resolved: &ResolvedFieldType, type_maps: &HashMap<String, String>) -> String {
774    let mut expr = extract.to_string();
775
776    for transform in &resolved.transforms {
777        match transform {
778            Transform::SignExtend(n) => {
779                let signed = cpp_signed_type(&resolved.base_type);
780                let bits = type_bits(&resolved.base_type);
781                expr = format!(
782                    "static_cast<{}>(static_cast<{}>(({}) << ({} - {})) >> ({} - {}))",
783                    cpp_type_for(&resolved.base_type), signed, expr, bits, n, bits, n
784                );
785            }
786            Transform::ZeroExtend(_) => {}
787            Transform::ShiftLeft(n) => {
788                expr = format!("(({}) << {})", expr, n);
789            }
790        }
791    }
792
793    // Sub-decoder dispatch
794    if let Some(ref sd_name) = resolved.sub_decoder {
795        let decode_fn = format!("decode_{}", to_snake_case(sd_name));
796        return format!("{}(static_cast<{}>({})).value()", decode_fn, cpp_word_type(type_bits(&resolved.base_type).min(32)), expr);
797    }
798
799    // Type map wrapper (user override)
800    if let Some(alias) = &resolved.alias_name {
801        if let Some(mapped) = type_maps.get(alias) {
802            return format!("static_cast<{}>({})", mapped, expr);
803        }
804    }
805
806    // Display format -> built-in wrapper types
807    match resolved.display_format {
808        Some(DisplayFormat::SignedHex) => {
809            return format!("SignedHex(static_cast<int32_t>({}))", expr);
810        }
811        Some(DisplayFormat::Hex) => {
812            return format!("Hex(static_cast<uint32_t>({}))", expr);
813        }
814        None => {}
815    }
816
817    if resolved.base_type == "bool" {
818        format!("({}) != 0", expr)
819    } else {
820        format!("static_cast<{}>({})", cpp_type_for(&resolved.base_type), expr)
821    }
822}
823
824/// Emit the decision tree as C++ switch/if-else.
825fn emit_tree_cpp(
826    out: &mut String,
827    node: &DecodeNode,
828    def: &ValidatedDef,
829    indent: usize,
830    word_type: &str,
831    unit_bytes: u32,
832    endian: &ByteEndian,
833    variable_length: bool,
834    type_maps: &HashMap<String, String>,
835) {
836    let pad = "    ".repeat(indent);
837    match node {
838        DecodeNode::Leaf { instruction_index } => {
839            let instr = &def.instructions[*instruction_index];
840            if let Some(guard) = leaf_guard(instr, word_type, unit_bytes, endian) {
841                writeln!(out, "{}if ({}) {{", pad, guard).unwrap();
842                emit_return_instruction(out, instr, indent + 1, word_type, unit_bytes, endian, variable_length, type_maps);
843                writeln!(out, "{}}} else {{", pad).unwrap();
844                writeln!(out, "{}    return std::nullopt;", pad).unwrap();
845                writeln!(out, "{}}}", pad).unwrap();
846            } else {
847                emit_return_instruction(out, instr, indent, word_type, unit_bytes, endian, variable_length, type_maps);
848            }
849        }
850        DecodeNode::PriorityLeaves { candidates } => {
851            for (i, &idx) in candidates.iter().enumerate() {
852                let instr = &def.instructions[idx];
853                let guard = leaf_guard(instr, word_type, unit_bytes, endian);
854                if i == 0 {
855                    if let Some(g) = guard {
856                        writeln!(out, "{}if ({}) {{", pad, g).unwrap();
857                        emit_return_instruction(out, instr, indent + 1, word_type, unit_bytes, endian, variable_length, type_maps);
858                    } else {
859                        emit_return_instruction(out, instr, indent, word_type, unit_bytes, endian, variable_length, type_maps);
860                        break;
861                    }
862                } else if i == candidates.len() - 1 {
863                    if let Some(g) = guard {
864                        writeln!(out, "{}}} else if ({}) {{", pad, g).unwrap();
865                        emit_return_instruction(out, instr, indent + 1, word_type, unit_bytes, endian, variable_length, type_maps);
866                        writeln!(out, "{}}} else {{", pad).unwrap();
867                        writeln!(out, "{}    return std::nullopt;", pad).unwrap();
868                        writeln!(out, "{}}}", pad).unwrap();
869                    } else {
870                        writeln!(out, "{}}} else {{", pad).unwrap();
871                        emit_return_instruction(out, instr, indent + 1, word_type, unit_bytes, endian, variable_length, type_maps);
872                        writeln!(out, "{}}}", pad).unwrap();
873                    }
874                } else {
875                    let g = guard.unwrap_or_else(|| "true".to_string());
876                    writeln!(out, "{}}} else if ({}) {{", pad, g).unwrap();
877                    emit_return_instruction(out, instr, indent + 1, word_type, unit_bytes, endian, variable_length, type_maps);
878                }
879            }
880        }
881        DecodeNode::Fail => {
882            writeln!(out, "{}return std::nullopt;", pad).unwrap();
883        }
884        DecodeNode::Branch { range, arms, default } => {
885            let ext = extract_expr("opcode", &[*range], word_type, unit_bytes, endian);
886            writeln!(out, "{}switch ({}) {{", pad, ext).unwrap();
887            for (value, child) in arms {
888                writeln!(out, "{}case {:#x}: {{", pad, value).unwrap();
889                emit_tree_cpp(out, child, def, indent + 1, word_type, unit_bytes, endian, variable_length, type_maps);
890                writeln!(out, "{}    break;", pad).unwrap();
891                writeln!(out, "{}}}", pad).unwrap();
892            }
893            writeln!(out, "{}default: {{", pad).unwrap();
894            emit_tree_cpp(out, default, def, indent + 1, word_type, unit_bytes, endian, variable_length, type_maps);
895            writeln!(out, "{}    break;", pad).unwrap();
896            writeln!(out, "{}}}", pad).unwrap();
897            writeln!(out, "{}}}", pad).unwrap();
898        }
899    }
900}
901
902/// Emit code to return a decoded Instruction.
903fn emit_return_instruction(
904    out: &mut String,
905    instr: &ValidatedInstruction,
906    indent: usize,
907    word_type: &str,
908    unit_bytes: u32,
909    endian: &ByteEndian,
910    variable_length: bool,
911    type_maps: &HashMap<String, String>,
912) {
913    let pad = "    ".repeat(indent);
914    let unit_count = instr.unit_count();
915    let bytes_consumed = unit_count * unit_bytes;
916    let variant = to_pascal_case(&instr.name);
917
918    if variable_length && unit_count > 1 {
919        writeln!(out, "{}if (len < {}) return std::nullopt;", pad, bytes_consumed).unwrap();
920    }
921
922    if instr.resolved_fields.is_empty() {
923        writeln!(
924            out,
925            "{}return Instruction {{ Opcode::{}, {} }};",
926            pad, variant, bytes_consumed
927        )
928        .unwrap();
929    } else {
930        writeln!(out, "{}{{", pad).unwrap();
931        writeln!(out, "{}    Instruction insn{{}};", pad).unwrap();
932        writeln!(out, "{}    insn.opcode = Opcode::{};", pad, variant).unwrap();
933        writeln!(out, "{}    insn.size = {};", pad, bytes_consumed).unwrap();
934        for field in &instr.resolved_fields {
935            let ext = extract_expr("opcode", &field.ranges, word_type, unit_bytes, endian);
936            let expr = apply_transforms(&ext, &field.resolved_type, type_maps);
937            writeln!(out, "{}    insn.{}.{} = {};", pad, instr.name, field.name, expr).unwrap();
938        }
939        writeln!(out, "{}    return insn;", pad).unwrap();
940        writeln!(out, "{}}}", pad).unwrap();
941    }
942}
943
944/// Emit the format/disassemble function using std::format.
945fn emit_format_function(
946    out: &mut String,
947    def: &ValidatedDef,
948    _type_maps: &HashMap<String, String>,
949    _opts: &CppOptions,
950) {
951    writeln!(out, "inline std::string format(const Instruction& insn) {{").unwrap();
952    writeln!(out, "    switch (insn.opcode) {{").unwrap();
953
954    for instr in &def.instructions {
955        let variant = to_pascal_case(&instr.name);
956        writeln!(out, "    case Opcode::{}: {{", variant).unwrap();
957
958        if instr.format_lines.is_empty() {
959            writeln!(out, "        return \"{}\";", instr.name).unwrap();
960        } else {
961            emit_format_lines_cpp(out, instr, 2);
962        }
963
964        writeln!(out, "    }}").unwrap();
965    }
966
967    writeln!(out, "    default: return \"???\";").unwrap();
968    writeln!(out, "    }}").unwrap();
969    writeln!(out, "}}").unwrap();
970    writeln!(out).unwrap();
971}
972
973/// Emit format lines for a single instruction using std::format.
974fn emit_format_lines_cpp(
975    out: &mut String,
976    instr: &ValidatedInstruction,
977    indent: usize,
978) {
979    let pad = "    ".repeat(indent);
980
981    if instr.format_lines.len() == 1 && instr.format_lines[0].guard.is_none() {
982        let fl = &instr.format_lines[0];
983        emit_std_format_call(out, &fl.pieces, instr, &pad);
984        return;
985    }
986
987    for (i, fl) in instr.format_lines.iter().enumerate() {
988        if let Some(guard) = &fl.guard {
989            let guard_code = guard_to_cpp(guard, instr);
990            if i == 0 {
991                writeln!(out, "{}if ({}) {{", pad, guard_code).unwrap();
992            } else {
993                writeln!(out, "{}}} else if ({}) {{", pad, guard_code).unwrap();
994            }
995            emit_std_format_call(out, &fl.pieces, instr, &format!("{}    ", pad));
996        } else {
997            if i > 0 {
998                writeln!(out, "{}}} else {{", pad).unwrap();
999            }
1000            emit_std_format_call(out, &fl.pieces, instr, &format!("{}    ", pad));
1001        }
1002    }
1003
1004    if instr.format_lines.len() > 1
1005        || instr.format_lines.first().map_or(false, |fl| fl.guard.is_some())
1006    {
1007        writeln!(out, "{}}}", pad).unwrap();
1008    }
1009}
1010
1011/// Emit a return std::format(...) call for format pieces.
1012/// Fields are passed as arguments; user types participate via std::formatter specializations.
1013fn emit_std_format_call(
1014    out: &mut String,
1015    pieces: &[FormatPiece],
1016    instr: &ValidatedInstruction,
1017    pad: &str,
1018) {
1019    let mut fmt_str = String::new();
1020    let mut args: Vec<String> = Vec::new();
1021
1022    for piece in pieces {
1023        match piece {
1024            FormatPiece::Literal(lit) => {
1025                // Escape { and } for std::format
1026                for ch in lit.chars() {
1027                    match ch {
1028                        '{' => fmt_str.push_str("{{"),
1029                        '}' => fmt_str.push_str("}}"),
1030                        _ => fmt_str.push(ch),
1031                    }
1032                }
1033            }
1034            FormatPiece::FieldRef { expr, spec } => {
1035                let cpp_expr = expr_to_cpp(expr, instr);
1036                // Build std::format placeholder
1037                if let Some(spec) = spec {
1038                    fmt_str.push_str(&format!("{{:{}}}", translate_std_format_spec(spec)));
1039                } else {
1040                    fmt_str.push_str("{}");
1041                }
1042                args.push(cpp_expr);
1043            }
1044        }
1045    }
1046
1047    if args.is_empty() {
1048        writeln!(out, "{}return \"{}\";", pad, fmt_str).unwrap();
1049    } else {
1050        writeln!(
1051            out,
1052            "{}return std::format(\"{}\", {});",
1053            pad,
1054            fmt_str,
1055            args.join(", ")
1056        )
1057        .unwrap();
1058    }
1059}
1060
1061/// Translate chipi/Rust format specs to std::format specs.
1062/// Most are identical since std::format uses Python-style specs.
1063fn translate_std_format_spec(spec: &str) -> String {
1064    // chipi specs like "04x", "#06x", "#x" are already valid std::format specs
1065    spec.to_string()
1066}
1067
1068
1069/// Convert a FormatExpr to a C++ expression string.
1070fn expr_to_cpp(expr: &FormatExpr, instr: &ValidatedInstruction) -> String {
1071    match expr {
1072        FormatExpr::Field(name) => {
1073            format!("insn.{}.{}", instr.name, name)
1074        }
1075        FormatExpr::Ternary {
1076            field,
1077            if_nonzero,
1078            if_zero,
1079        } => {
1080            let else_val = if_zero.as_deref().unwrap_or("");
1081            format!(
1082                "(insn.{}.{} ? \"{}\" : \"{}\")",
1083                instr.name, field, if_nonzero, else_val
1084            )
1085        }
1086        FormatExpr::Arithmetic { left, op, right } => {
1087            let l = expr_to_cpp(left, instr);
1088            let r = expr_to_cpp(right, instr);
1089            let op_str = match op {
1090                ArithOp::Add => "+",
1091                ArithOp::Sub => "-",
1092                ArithOp::Mul => "*",
1093                ArithOp::Div => "/",
1094                ArithOp::Mod => "%",
1095            };
1096            format!("({} {} {})", l, op_str, r)
1097        }
1098        FormatExpr::IntLiteral(val) => format!("{}", val),
1099        FormatExpr::MapCall { map_name, args } => {
1100            let arg_strs: Vec<String> = args.iter().map(|a| expr_to_cpp(a, instr)).collect();
1101            format!("{}({})", map_name, arg_strs.join(", "))
1102        }
1103        FormatExpr::BuiltinCall { func, args } => {
1104            let arg_strs: Vec<String> = args.iter().map(|a| expr_to_cpp(a, instr)).collect();
1105            match func {
1106                BuiltinFunc::RotateRight => {
1107                    format!(
1108                        "((static_cast<uint32_t>({}) >> {}) | (static_cast<uint32_t>({}) << (32 - {})))",
1109                        arg_strs.first().map(|s| s.as_str()).unwrap_or("0"),
1110                        arg_strs.get(1).map(|s| s.as_str()).unwrap_or("0"),
1111                        arg_strs.first().map(|s| s.as_str()).unwrap_or("0"),
1112                        arg_strs.get(1).map(|s| s.as_str()).unwrap_or("0"),
1113                    )
1114                }
1115                BuiltinFunc::RotateLeft => {
1116                    format!(
1117                        "((static_cast<uint32_t>({}) << {}) | (static_cast<uint32_t>({}) >> (32 - {})))",
1118                        arg_strs.first().map(|s| s.as_str()).unwrap_or("0"),
1119                        arg_strs.get(1).map(|s| s.as_str()).unwrap_or("0"),
1120                        arg_strs.first().map(|s| s.as_str()).unwrap_or("0"),
1121                        arg_strs.get(1).map(|s| s.as_str()).unwrap_or("0"),
1122                    )
1123                }
1124            }
1125        }
1126        FormatExpr::SubDecoderAccess { field, fragment } => {
1127            format!("insn.{}.{}.{}", instr.name, field, fragment)
1128        }
1129    }
1130}
1131
1132/// Convert a guard condition to C++ code.
1133fn guard_to_cpp(guard: &Guard, instr: &ValidatedInstruction) -> String {
1134    let conditions: Vec<String> = guard
1135        .conditions
1136        .iter()
1137        .map(|cond| {
1138            let left = guard_operand_cpp(&cond.left, instr);
1139            let right = guard_operand_cpp(&cond.right, instr);
1140            let op = match cond.op {
1141                CompareOp::Eq => "==",
1142                CompareOp::Ne => "!=",
1143                CompareOp::Lt => "<",
1144                CompareOp::Le => "<=",
1145                CompareOp::Gt => ">",
1146                CompareOp::Ge => ">=",
1147            };
1148            format!("{} {} {}", left, op, right)
1149        })
1150        .collect();
1151    conditions.join(" && ")
1152}
1153
1154fn guard_operand_cpp(operand: &GuardOperand, instr: &ValidatedInstruction) -> String {
1155    match operand {
1156        GuardOperand::Field(name) => format!("insn.{}.{}", instr.name, name),
1157        GuardOperand::Literal(val) => format!("{}", val),
1158        GuardOperand::Expr { left, op, right } => {
1159            let l = guard_operand_cpp(left, instr);
1160            let r = guard_operand_cpp(right, instr);
1161            let op_str = match op {
1162                ArithOp::Add => "+",
1163                ArithOp::Sub => "-",
1164                ArithOp::Mul => "*",
1165                ArithOp::Div => "/",
1166                ArithOp::Mod => "%",
1167            };
1168            format!("({} {} {})", l, op_str, r)
1169        }
1170    }
1171}