use std::collections::HashMap;
use std::fmt::Write;
use crate::tree::DecodeNode;
use crate::types::*;
fn needs_variable_length_decode(def: &ValidatedDef) -> bool {
def.instructions.iter().any(|i| i.unit_count() > 1)
}
pub fn generate_code(def: &ValidatedDef, tree: &DecodeNode) -> String {
let mut out = String::new();
writeln!(out, "// Auto-generated by chipi. Do not edit.").unwrap();
writeln!(out).unwrap();
writeln!(out, "use std::fmt;").unwrap();
writeln!(out, "use std::marker::PhantomData;").unwrap();
writeln!(out).unwrap();
for imp in &def.imports {
writeln!(out, "use {};", imp.path).unwrap();
}
if !def.imports.is_empty() {
writeln!(out).unwrap();
}
let word_type = word_type_for_width(def.config.width);
let unit_bytes = def.config.width / 8;
let variable_length = needs_variable_length_decode(def);
let enum_name = format!("{}Instruction", def.config.name);
let trait_name = format!("{}Format", def.config.name);
let default_struct = format!("Default{}Format", def.config.name);
let display_with = "DisplayWith";
let endian_suffix = match def.config.endian {
ByteEndian::Big => "be",
ByteEndian::Little => "le",
};
generate_display_helpers(&mut out, def);
generate_map_functions(&mut out, def);
writeln!(out, "#[derive(Debug, Clone, Copy, PartialEq, Eq)]").unwrap();
writeln!(out, "pub enum {} {{", enum_name).unwrap();
for instr in &def.instructions {
let variant_name = to_pascal_case(&instr.name);
if instr.resolved_fields.is_empty() {
writeln!(out, " {},", variant_name).unwrap();
} else {
let fields: Vec<String> = instr
.resolved_fields
.iter()
.map(|f| {
let rust_type = field_rust_type(f);
format!("{}: {}", f.name, rust_type)
})
.collect();
writeln!(out, " {} {{ {} }},", variant_name, fields.join(", ")).unwrap();
}
}
writeln!(out, "}}").unwrap();
writeln!(out).unwrap();
generate_format_trait(&mut out, def, &enum_name, &trait_name);
writeln!(out, "impl {} {{", enum_name).unwrap();
writeln!(out, " #[inline]").unwrap();
writeln!(
out,
" pub fn decode(data: &[u8]) -> Option<(Self, usize)> {{"
)
.unwrap();
writeln!(out, " if data.len() < {} {{ return None; }}", unit_bytes).unwrap();
writeln!(
out,
" let opcode = {}::from_{}_bytes(data[0..{}].try_into().unwrap());",
word_type, endian_suffix, unit_bytes
)
.unwrap();
emit_tree(&mut out, tree, def, &enum_name, 2, variable_length, &word_type);
writeln!(out, " }}").unwrap();
writeln!(out).unwrap();
generate_write_asm(&mut out, def, &enum_name, &trait_name);
writeln!(out).unwrap();
writeln!(out, " #[allow(dead_code)]").unwrap();
writeln!(
out,
" pub fn display<F: {}>(&self) -> {}<'_, F> {{",
trait_name, display_with
)
.unwrap();
writeln!(
out,
" {} {{ insn: self, _phantom: PhantomData }}",
display_with
)
.unwrap();
writeln!(out, " }}").unwrap();
writeln!(out, "}}").unwrap();
writeln!(out).unwrap();
writeln!(out, "#[allow(dead_code)]").unwrap();
writeln!(
out,
"pub struct {}<'a, F: {}> {{",
display_with, trait_name
)
.unwrap();
writeln!(out, " insn: &'a {},", enum_name).unwrap();
writeln!(out, " _phantom: PhantomData<F>,").unwrap();
writeln!(out, "}}").unwrap();
writeln!(out).unwrap();
writeln!(
out,
"impl<F: {}> fmt::Display for {}<'_, F> {{",
trait_name, display_with
)
.unwrap();
writeln!(
out,
" fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {{"
)
.unwrap();
writeln!(out, " self.insn.write_asm::<F>(f)").unwrap();
writeln!(out, " }}").unwrap();
writeln!(out, "}}").unwrap();
writeln!(out).unwrap();
writeln!(out, "pub struct {};", default_struct).unwrap();
writeln!(out, "impl {} for {} {{}}", trait_name, default_struct).unwrap();
writeln!(out).unwrap();
writeln!(out, "impl fmt::Display for {} {{", enum_name).unwrap();
writeln!(
out,
" fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {{"
)
.unwrap();
writeln!(out, " self.write_asm::<{}>(f)", default_struct).unwrap();
writeln!(out, " }}").unwrap();
writeln!(out, "}}").unwrap();
out
}
fn emit_tree(
out: &mut String,
node: &DecodeNode,
def: &ValidatedDef,
enum_name: &str,
indent: usize,
variable_length: bool,
word_type: &str,
) {
let unit_bytes = def.config.width / 8;
let endian_suffix = match def.config.endian {
ByteEndian::Big => "be",
ByteEndian::Little => "le",
};
let pad = " ".repeat(indent);
match node {
DecodeNode::Leaf { instruction_index } => {
let instr = &def.instructions[*instruction_index];
if let Some(guard) = leaf_guard(instr, word_type, unit_bytes, endian_suffix) {
writeln!(out, "{}if {} {{", pad, guard).unwrap();
emit_some(out, instr, enum_name, &format!("{} ", pad), variable_length, word_type, unit_bytes, endian_suffix);
writeln!(out, "{}}} else {{", pad).unwrap();
writeln!(out, "{} None", pad).unwrap();
writeln!(out, "{}}}", pad).unwrap();
} else {
emit_some(out, instr, enum_name, &pad, variable_length, word_type, unit_bytes, endian_suffix);
}
}
DecodeNode::PriorityLeaves { candidates } => {
for (i, &idx) in candidates.iter().enumerate() {
let instr = &def.instructions[idx];
let guard = leaf_guard(instr, word_type, unit_bytes, endian_suffix);
if i == 0 {
if let Some(guard_expr) = guard {
writeln!(out, "{}if {} {{", pad, guard_expr).unwrap();
emit_some(out, instr, enum_name, &format!("{} ", pad), variable_length, word_type, unit_bytes, endian_suffix);
} else {
emit_some(out, instr, enum_name, &pad, variable_length, word_type, unit_bytes, endian_suffix);
break; }
} else if i == candidates.len() - 1 {
writeln!(out, "{}}} else {{", pad).unwrap();
if let Some(guard_expr) = guard {
writeln!(out, "{} if {} {{", pad, guard_expr).unwrap();
emit_some(out, instr, enum_name, &format!("{} ", pad), variable_length, word_type, unit_bytes, endian_suffix);
writeln!(out, "{} }} else {{", pad).unwrap();
writeln!(out, "{} None", pad).unwrap();
writeln!(out, "{} }}", pad).unwrap();
} else {
emit_some(out, instr, enum_name, &format!("{} ", pad), variable_length, word_type, unit_bytes, endian_suffix);
}
writeln!(out, "{}}}", pad).unwrap();
} else {
writeln!(out, "{}}} else if {} {{", pad, guard.unwrap_or_else(|| "true".to_string())).unwrap();
emit_some(out, instr, enum_name, &format!("{} ", pad), variable_length, word_type, unit_bytes, endian_suffix);
}
}
}
DecodeNode::Fail => {
writeln!(out, "{}None", pad).unwrap();
}
DecodeNode::Branch {
range,
arms,
default,
} => {
let extract_expr = extract_expression("opcode", &[*range], word_type, unit_bytes, endian_suffix);
writeln!(out, "{}match {} {{", pad, extract_expr).unwrap();
for (value, child) in arms {
emit_arm(out, child, def, enum_name, indent + 1, &format!("{:#x}", value), variable_length, word_type);
}
emit_arm(out, default, def, enum_name, indent + 1, "_", variable_length, word_type);
writeln!(out, "{}}}", pad).unwrap();
}
}
}
fn emit_arm(
out: &mut String,
node: &DecodeNode,
def: &ValidatedDef,
enum_name: &str,
indent: usize,
pattern: &str,
variable_length: bool,
word_type: &str,
) {
let unit_bytes = def.config.width / 8;
let endian_suffix = match def.config.endian {
ByteEndian::Big => "be",
ByteEndian::Little => "le",
};
let pad = " ".repeat(indent);
match node {
DecodeNode::Fail => {
writeln!(out, "{}{} => None,", pad, pattern).unwrap();
}
DecodeNode::Leaf { instruction_index } => {
let instr = &def.instructions[*instruction_index];
if let Some(guard) = leaf_guard(instr, word_type, unit_bytes, endian_suffix) {
if pattern == "_" {
write!(out, "{}{} if {} => ", pad, pattern, guard).unwrap();
emit_some_inline(out, instr, enum_name, indent, variable_length, word_type, unit_bytes, endian_suffix);
writeln!(out, "{}{} => None,", pad, pattern).unwrap();
} else {
write!(out, "{}{} if {} => ", pad, pattern, guard).unwrap();
emit_some_inline(out, instr, enum_name, indent, variable_length, word_type, unit_bytes, endian_suffix);
}
} else {
write!(out, "{}{} => ", pad, pattern).unwrap();
emit_some_inline(out, instr, enum_name, indent, variable_length, word_type, unit_bytes, endian_suffix);
}
}
DecodeNode::PriorityLeaves { candidates } => {
writeln!(out, "{}{} => {{", pad, pattern).unwrap();
let inner_pad = " ".repeat(indent + 1);
for (i, &idx) in candidates.iter().enumerate() {
let instr = &def.instructions[idx];
let guard = leaf_guard(instr, word_type, unit_bytes, endian_suffix);
if i == 0 {
if let Some(guard_expr) = guard {
writeln!(out, "{}if {} {{", inner_pad, guard_expr).unwrap();
emit_some(out, instr, enum_name, &format!("{} ", inner_pad), variable_length, word_type, unit_bytes, endian_suffix);
} else {
emit_some(out, instr, enum_name, &inner_pad, variable_length, word_type, unit_bytes, endian_suffix);
writeln!(out, "{}}}", pad).unwrap();
return;
}
} else if i == candidates.len() - 1 {
writeln!(out, "{}}} else {{", inner_pad).unwrap();
if let Some(guard_expr) = guard {
writeln!(out, "{} if {} {{", inner_pad, guard_expr).unwrap();
emit_some(out, instr, enum_name, &format!("{} ", inner_pad), variable_length, word_type, unit_bytes, endian_suffix);
writeln!(out, "{} }} else {{", inner_pad).unwrap();
writeln!(out, "{} None", inner_pad).unwrap();
writeln!(out, "{} }}", inner_pad).unwrap();
} else {
emit_some(out, instr, enum_name, &format!("{} ", inner_pad), variable_length, word_type, unit_bytes, endian_suffix);
}
writeln!(out, "{}}}", inner_pad).unwrap();
writeln!(out, "{}}}", pad).unwrap();
} else {
writeln!(out, "{}}} else if {} {{", inner_pad, guard.unwrap_or_else(|| "true".to_string())).unwrap();
emit_some(out, instr, enum_name, &format!("{} ", inner_pad), variable_length, word_type, unit_bytes, endian_suffix);
}
}
}
DecodeNode::Branch {
range,
arms,
default,
} => {
writeln!(out, "{}{} => {{", pad, pattern).unwrap();
let extract_expr = extract_expression("opcode", &[*range], word_type, unit_bytes, endian_suffix);
let inner_pad = " ".repeat(indent + 1);
writeln!(out, "{}match {} {{", inner_pad, extract_expr).unwrap();
for (value, child) in arms {
emit_arm(out, child, def, enum_name, indent + 2, &format!("{:#x}", value), variable_length, word_type);
}
emit_arm(out, default, def, enum_name, indent + 2, "_", variable_length, word_type);
writeln!(out, "{}}}", inner_pad).unwrap();
writeln!(out, "{}}}", pad).unwrap();
}
}
}
fn leaf_guard(instr: &ValidatedInstruction, word_type: &str, unit_bytes: u32, endian_suffix: &str) -> Option<String> {
let fixed_bits = instr.fixed_bits();
if fixed_bits.is_empty() {
return None;
}
let mut units_map: std::collections::HashMap<u32, Vec<(u32, Bit)>> = std::collections::HashMap::new();
for (unit, hw_bit, bit) in fixed_bits {
units_map.entry(unit).or_default().push((hw_bit, bit));
}
let mut conditions = Vec::new();
for (unit, bits) in units_map {
let (mask, value) = compute_mask_value(&bits);
if mask != 0 {
let source = unit_read_expr(unit, word_type, unit_bytes, endian_suffix);
conditions.push(format!("{} & {:#x} == {:#x}", source, mask, value));
}
}
if conditions.is_empty() {
None
} else {
Some(conditions.join(" && "))
}
}
fn emit_some_inline(
out: &mut String,
instr: &ValidatedInstruction,
enum_name: &str,
_indent: usize,
variable_length: bool,
word_type: &str,
unit_bytes: u32,
endian_suffix: &str,
) {
let variant_name = to_pascal_case(&instr.name);
let unit_count = instr.unit_count();
let bytes_consumed = unit_count * unit_bytes;
if variable_length && unit_count > 1 {
let cond = format!("data.len() >= {}", bytes_consumed);
write!(out, "if {} {{ ", cond).unwrap();
}
if instr.resolved_fields.is_empty() {
write!(out, "Some(({}::{}, {}))", enum_name, variant_name, bytes_consumed).unwrap();
} else {
let fields: Vec<String> = instr
.resolved_fields
.iter()
.map(|f| {
let extract = extract_expression("opcode", &f.ranges, word_type, unit_bytes, endian_suffix);
let expr = apply_transforms(&extract, &f.resolved_type);
format!("{}: {}", f.name, expr)
})
.collect();
write!(
out,
"Some(({}::{} {{ {} }}, {}))",
enum_name,
variant_name,
fields.join(", "),
bytes_consumed
)
.unwrap();
}
if variable_length && unit_count > 1 {
write!(out, " }} else {{ None }}").unwrap();
}
writeln!(out, ",").unwrap();
}
fn emit_some(
out: &mut String,
instr: &ValidatedInstruction,
enum_name: &str,
pad: &str,
variable_length: bool,
word_type: &str,
unit_bytes: u32,
endian_suffix: &str,
) {
let unit_count = instr.unit_count();
let bytes_consumed = unit_count * unit_bytes;
if variable_length && unit_count > 1 {
writeln!(out, "{}if data.len() >= {} {{", pad, bytes_consumed).unwrap();
let inner_pad = format!("{} ", pad);
emit_some_inner(out, instr, enum_name, &inner_pad, word_type, unit_bytes, endian_suffix, bytes_consumed);
writeln!(out, "{}}} else {{", pad).unwrap();
writeln!(out, "{} None", pad).unwrap();
writeln!(out, "{}}}", pad).unwrap();
} else {
emit_some_inner(out, instr, enum_name, pad, word_type, unit_bytes, endian_suffix, bytes_consumed);
}
}
fn emit_some_inner(
out: &mut String,
instr: &ValidatedInstruction,
enum_name: &str,
pad: &str,
word_type: &str,
unit_bytes: u32,
endian_suffix: &str,
bytes_consumed: u32,
) {
let variant_name = to_pascal_case(&instr.name);
if instr.resolved_fields.is_empty() {
writeln!(out, "{}Some(({}::{}, {}))", pad, enum_name, variant_name, bytes_consumed).unwrap();
} else {
writeln!(out, "{}Some(({}::{} {{", pad, enum_name, variant_name).unwrap();
for field in &instr.resolved_fields {
let extract = extract_expression("opcode", &field.ranges, word_type, unit_bytes, endian_suffix);
let expr = apply_transforms(&extract, &field.resolved_type);
writeln!(out, "{} {}: {},", pad, field.name, expr).unwrap();
}
writeln!(out, "{}}}, {}))", pad, bytes_consumed).unwrap();
}
}
fn compute_mask_value(fixed_bits: &[(u32, Bit)]) -> (u64, u64) {
let mut mask: u64 = 0;
let mut value: u64 = 0;
for &(bit_pos, bit_val) in fixed_bits {
if bit_val == Bit::Wildcard {
continue;
}
mask |= 1u64 << bit_pos;
if bit_val == Bit::One {
value |= 1u64 << bit_pos;
}
}
(mask, value)
}
fn unit_read_expr(unit: u32, word_type: &str, unit_bytes: u32, endian_suffix: &str) -> String {
if unit == 0 {
"opcode".to_string()
} else {
let start = unit * unit_bytes;
let end = start + unit_bytes;
format!(
"{}::from_{}_bytes(data[{}..{}].try_into().unwrap())",
word_type, endian_suffix, start, end
)
}
}
fn extract_expression(var: &str, ranges: &[BitRange], word_type: &str, unit_bytes: u32, endian_suffix: &str) -> String {
if ranges.is_empty() {
return "0".to_string();
}
if ranges.len() == 1 {
let range = ranges[0];
let source = if range.unit == 0 {
var.to_string()
} else {
unit_read_expr(range.unit, word_type, unit_bytes, endian_suffix)
};
let width = range.width();
let shift = range.end;
let mask = (1u64 << width) - 1;
if shift == 0 {
format!("{} & {:#x}", source, mask)
} else {
format!("({} >> {}) & {:#x}", source, shift, mask)
}
} else {
let mut parts = Vec::new();
let mut accumulated_width = 0u32;
for range in ranges {
let source = if range.unit == 0 {
var.to_string()
} else {
unit_read_expr(range.unit, word_type, unit_bytes, endian_suffix)
};
let width = range.width();
let shift = range.end;
let mask = (1u64 << width) - 1;
let extracted = if shift == 0 {
format!("({} & {:#x})", source, mask)
} else {
format!("(({} >> {}) & {:#x})", source, shift, mask)
};
if accumulated_width > 0 {
parts.push(format!("({} << {})", extracted, accumulated_width));
} else {
parts.push(extracted);
}
accumulated_width += width;
}
parts.join(" | ")
}
}
fn apply_transforms(extract_expr: &str, resolved: &ResolvedFieldType) -> String {
let mut expr = extract_expr.to_string();
for transform in &resolved.transforms {
match transform {
Transform::SignExtend(n) => {
let signed_type = signed_type_for(&resolved.base_type);
let bits = type_bits(&resolved.base_type);
expr = format!(
"(((({}) as {}) << ({} - {})) >> ({} - {}))",
expr, signed_type, bits, n, bits, n
);
}
Transform::ZeroExtend(_) => {
}
Transform::ShiftLeft(n) => {
expr = format!("(({}) << {})", expr, n);
}
}
}
if let Some(ref wrapper) = resolved.wrapper_type {
expr = format!("{}::from(({}) as {})", wrapper, expr, resolved.base_type);
} else if resolved.base_type == "bool" {
expr = format!("({}) != 0", expr);
} else {
expr = format!("({}) as {}", expr, resolved.base_type);
}
expr
}
fn generate_display_helpers(out: &mut String, def: &ValidatedDef) {
let mut need_signed_hex = false;
let mut need_hex = false;
for instr in &def.instructions {
for field in &instr.resolved_fields {
match field.resolved_type.display_format {
Some(DisplayFormat::SignedHex) => need_signed_hex = true,
Some(DisplayFormat::Hex) => need_hex = true,
None => {}
}
}
}
if need_signed_hex {
writeln!(out, "struct SignedHex<T>(T);").unwrap();
writeln!(out).unwrap();
for ty in &["i8", "i16", "i32"] {
writeln!(out, "impl fmt::Display for SignedHex<{}> {{", ty).unwrap();
writeln!(out, " fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {{").unwrap();
writeln!(out, " if self.0 == 0 {{ write!(f, \"0\") }}").unwrap();
writeln!(out, " else if self.0 > 0 {{ write!(f, \"0x{{:X}}\", self.0) }}").unwrap();
writeln!(out, " else {{ write!(f, \"-0x{{:X}}\", (self.0 as i64).wrapping_neg()) }}").unwrap();
writeln!(out, " }}").unwrap();
writeln!(out, "}}").unwrap();
writeln!(out).unwrap();
}
}
if need_hex {
if !need_signed_hex {
writeln!(out, "struct SignedHex<T>(T);").unwrap();
writeln!(out).unwrap();
}
for ty in &["u8", "u16", "u32"] {
writeln!(out, "impl fmt::Display for SignedHex<{}> {{", ty).unwrap();
writeln!(out, " fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {{").unwrap();
writeln!(out, " if self.0 == 0 {{ write!(f, \"0\") }}").unwrap();
writeln!(out, " else {{ write!(f, \"0x{{:X}}\", self.0) }}").unwrap();
writeln!(out, " }}").unwrap();
writeln!(out, "}}").unwrap();
writeln!(out).unwrap();
}
}
}
fn generate_map_functions(out: &mut String, def: &ValidatedDef) {
if def.maps.is_empty() {
return;
}
let map_param_types = infer_map_param_types(def);
for map in &def.maps {
let has_interpolation = map.entries.iter().any(|entry| {
entry
.output
.iter()
.any(|p| matches!(p, FormatPiece::FieldRef { .. }))
});
let param_types: Vec<String> = map
.params
.iter()
.enumerate()
.map(|(i, _)| {
map_param_types
.get(&map.name)
.and_then(|types| types.get(i))
.cloned()
.unwrap_or_else(|| "i64".to_string())
})
.collect();
let return_type = if has_interpolation {
"String"
} else {
"&'static str"
};
let params: Vec<String> = map
.params
.iter()
.zip(param_types.iter())
.map(|(name, ty)| format!("{}: {}", name, ty))
.collect();
writeln!(out, "fn {}({}) -> {} {{", map.name, params.join(", "), return_type).unwrap();
if map.params.len() == 1 {
writeln!(out, " match {} {{", map.params[0]).unwrap();
} else {
let tuple: Vec<&str> = map.params.iter().map(|s| s.as_str()).collect();
writeln!(out, " match ({}) {{", tuple.join(", ")).unwrap();
}
let mut default_entry = None;
for entry in &map.entries {
let is_all_wildcard = entry.keys.iter().all(|k| matches!(k, MapKey::Wildcard));
if is_all_wildcard {
default_entry = Some(entry);
continue;
}
let pattern = if map.params.len() == 1 {
format_map_key(&entry.keys[0])
} else {
let keys: Vec<String> = entry.keys.iter().map(|k| format_map_key(k)).collect();
format!("({})", keys.join(", "))
};
let output = format_map_output(&entry.output, has_interpolation);
writeln!(out, " {} => {},", pattern, output).unwrap();
}
if let Some(entry) = default_entry {
let output = format_map_output(&entry.output, has_interpolation);
writeln!(out, " _ => {},", output).unwrap();
} else {
if has_interpolation {
writeln!(out, " _ => String::from(\"???\"),").unwrap();
} else {
writeln!(out, " _ => \"???\",").unwrap();
}
}
writeln!(out, " }}").unwrap();
writeln!(out, "}}").unwrap();
writeln!(out).unwrap();
}
}
fn format_map_key(key: &MapKey) -> String {
match key {
MapKey::Value(v) => format!("{}", v),
MapKey::Wildcard => "_".to_string(),
}
}
fn format_map_output(pieces: &[FormatPiece], has_interpolation: bool) -> String {
if !has_interpolation {
let mut s = String::new();
for piece in pieces {
if let FormatPiece::Literal(lit) = piece {
s.push_str(lit);
}
}
return format!("\"{}\"", s);
}
let mut fmt_str = String::new();
let mut args = Vec::new();
for piece in pieces {
match piece {
FormatPiece::Literal(lit) => {
for ch in lit.chars() {
match ch {
'{' => fmt_str.push_str("{{"),
'}' => fmt_str.push_str("}}"),
_ => fmt_str.push(ch),
}
}
}
FormatPiece::FieldRef { expr, spec } => {
if let Some(spec) = spec {
fmt_str.push_str(&format!("{{:{}}}", spec));
} else {
fmt_str.push_str("{}");
}
args.push(expr_to_rust(expr, &[]));
}
}
}
if args.is_empty() {
format!("String::from(\"{}\")", fmt_str)
} else {
format!("format!(\"{}\", {})", fmt_str, args.join(", "))
}
}
fn infer_map_param_types(def: &ValidatedDef) -> HashMap<String, Vec<String>> {
let mut result: HashMap<String, Vec<String>> = HashMap::new();
for instr in &def.instructions {
let field_types: HashMap<&str, &ResolvedFieldType> = instr
.resolved_fields
.iter()
.map(|f| (f.name.as_str(), &f.resolved_type))
.collect();
for fl in &instr.format_lines {
for piece in &fl.pieces {
if let FormatPiece::FieldRef { expr, .. } = piece {
collect_map_call_types(expr, &field_types, &mut result);
}
}
}
}
result
}
fn infer_expr_type(
expr: &FormatExpr,
field_types: &HashMap<&str, &ResolvedFieldType>,
) -> Option<String> {
match expr {
FormatExpr::Field(name) => field_types.get(name.as_str()).map(|ft| {
if let Some(ref wrapper) = ft.wrapper_type {
wrapper.clone()
} else {
ft.base_type.clone()
}
}),
FormatExpr::Arithmetic { left, .. } => infer_expr_type(left, field_types),
FormatExpr::IntLiteral(_) => Some("i64".to_string()),
_ => None,
}
}
fn collect_map_call_types(
expr: &FormatExpr,
field_types: &HashMap<&str, &ResolvedFieldType>,
result: &mut HashMap<String, Vec<String>>,
) {
match expr {
FormatExpr::MapCall { map_name, args } => {
let entry = result
.entry(map_name.clone())
.or_insert_with(|| vec!["i64".to_string(); args.len()]);
for (i, arg) in args.iter().enumerate() {
if i < entry.len() {
if let Some(rust_type) = infer_expr_type(arg, field_types) {
entry[i] = rust_type;
}
}
}
for arg in args {
collect_map_call_types(arg, field_types, result);
}
}
FormatExpr::Arithmetic { left, right, .. } => {
collect_map_call_types(left, field_types, result);
collect_map_call_types(right, field_types, result);
}
FormatExpr::BuiltinCall { args, .. } => {
for arg in args {
collect_map_call_types(arg, field_types, result);
}
}
_ => {}
}
}
fn generate_format_trait(
out: &mut String,
def: &ValidatedDef,
_enum_name: &str,
trait_name: &str,
) {
writeln!(out, "pub trait {} {{", trait_name).unwrap();
for instr in &def.instructions {
let method_name = format!("fmt_{}", instr.name);
let params = trait_method_params(&instr.resolved_fields);
if params.is_empty() {
writeln!(
out,
" fn {}(f: &mut fmt::Formatter) -> fmt::Result {{",
method_name
)
.unwrap();
} else {
writeln!(
out,
" fn {}({}, f: &mut fmt::Formatter) -> fmt::Result {{",
method_name, params
)
.unwrap();
}
generate_format_body(out, instr, 2);
writeln!(out, " }}").unwrap();
}
writeln!(out, "}}").unwrap();
writeln!(out).unwrap();
}
fn trait_method_params(fields: &[ResolvedField]) -> String {
let mut params = Vec::new();
for field in fields {
let rust_type = field_rust_type(field);
if field.resolved_type.wrapper_type.is_some() {
params.push(format!("{}: &{}", field.name, rust_type));
} else {
params.push(format!("{}: {}", field.name, rust_type));
}
}
params.join(", ")
}
fn generate_format_body(out: &mut String, instr: &ValidatedInstruction, indent: usize) {
let pad = " ".repeat(indent);
if instr.format_lines.is_empty() {
if instr.resolved_fields.is_empty() {
writeln!(out, "{}write!(f, \"{}\")", pad, instr.name).unwrap();
} else {
let field_names: Vec<&str> = instr
.resolved_fields
.iter()
.map(|f| f.name.as_str())
.collect();
let placeholders: Vec<&str> = field_names.iter().map(|_| "{}").collect();
let fmt_str = format!("{} {}", instr.name, placeholders.join(", "));
let args: Vec<String> = instr
.resolved_fields
.iter()
.map(|f| f.name.clone())
.collect();
writeln!(
out,
"{}write!(f, \"{}\", {})",
pad, fmt_str, args.join(", ")
)
.unwrap();
}
return;
}
if instr.format_lines.len() == 1 && instr.format_lines[0].guard.is_none() {
emit_write_call(out, &instr.format_lines[0].pieces, &instr.resolved_fields, &pad);
return;
}
for (i, fl) in instr.format_lines.iter().enumerate() {
if let Some(guard) = &fl.guard {
let guard_code = generate_guard_code(guard, &instr.resolved_fields);
if i == 0 {
writeln!(out, "{}if {} {{", pad, guard_code).unwrap();
} else {
writeln!(out, "{}}} else if {} {{", pad, guard_code).unwrap();
}
emit_write_call(out, &fl.pieces, &instr.resolved_fields, &format!("{} ", pad));
} else {
if i > 0 {
writeln!(out, "{}}} else {{", pad).unwrap();
}
emit_write_call(out, &fl.pieces, &instr.resolved_fields, &format!("{} ", pad));
}
}
if instr.format_lines.len() > 1
|| instr.format_lines.first().map_or(false, |fl| fl.guard.is_some())
{
writeln!(out, "{}}}", pad).unwrap();
}
}
fn emit_write_call(
out: &mut String,
pieces: &[FormatPiece],
fields: &[ResolvedField],
pad: &str,
) {
let mut fmt_str = String::new();
let mut args = Vec::new();
for piece in pieces {
match piece {
FormatPiece::Literal(lit) => {
for ch in lit.chars() {
match ch {
'{' => fmt_str.push_str("{{"),
'}' => fmt_str.push_str("}}"),
_ => fmt_str.push(ch),
}
}
}
FormatPiece::FieldRef { expr, spec } => {
if let Some(spec) = spec {
fmt_str.push_str(&format!("{{:{}}}", spec));
args.push(expr_to_rust(expr, fields));
} else if let Some(display_fmt) = resolve_display_format(expr, fields) {
fmt_str.push_str("{}");
let rust_expr = expr_to_rust(expr, fields);
match display_fmt {
DisplayFormat::SignedHex | DisplayFormat::Hex => {
args.push(format!("SignedHex({})", rust_expr));
}
}
} else {
fmt_str.push_str("{}");
args.push(expr_to_rust(expr, fields));
}
}
}
}
if args.is_empty() {
writeln!(out, "{}write!(f, \"{}\")", pad, fmt_str).unwrap();
} else {
writeln!(
out,
"{}write!(f, \"{}\", {})",
pad,
fmt_str,
args.join(", ")
)
.unwrap();
}
}
fn resolve_display_format(expr: &FormatExpr, fields: &[ResolvedField]) -> Option<DisplayFormat> {
if let FormatExpr::Field(name) = expr {
fields
.iter()
.find(|f| f.name == *name)
.and_then(|f| f.resolved_type.display_format)
} else {
None
}
}
fn expr_to_rust(expr: &FormatExpr, fields: &[ResolvedField]) -> String {
match expr {
FormatExpr::Field(name) => {
let field = fields.iter().find(|f| f.name == *name);
if let Some(f) = field {
if f.resolved_type.base_type == "bool" {
format!("{} as u8", name)
} else {
name.clone()
}
} else {
name.clone()
}
}
FormatExpr::Ternary {
field,
if_nonzero,
if_zero,
} => {
let f = fields.iter().find(|f| f.name == *field);
let cond = if let Some(f) = f {
if f.resolved_type.base_type == "bool" {
field.clone()
} else if f.resolved_type.wrapper_type.is_some() {
format!("Into::<{}>::into(*{}) != 0", f.resolved_type.base_type, field)
} else {
format!("{} != 0", field)
}
} else {
format!("{} != 0", field)
};
let else_val = if_zero
.as_deref()
.map(|s| format!("\"{}\"", s))
.unwrap_or_else(|| "\"\"".to_string());
format!("if {} {{ \"{}\" }} else {{ {} }}", cond, if_nonzero, else_val)
}
FormatExpr::Arithmetic { left, op, right } => {
let l = expr_to_rust(left, fields);
let r = expr_to_rust(right, fields);
let op_str = match op {
ArithOp::Add => "+",
ArithOp::Sub => "-",
ArithOp::Mul => "*",
ArithOp::Div => "/",
ArithOp::Mod => "%",
};
format!("{} {} {}", l, op_str, r)
}
FormatExpr::IntLiteral(val) => format!("{}", val),
FormatExpr::MapCall { map_name, args } => {
let arg_strs: Vec<String> = args.iter().map(|a| expr_to_rust(a, fields)).collect();
format!("{}({})", map_name, arg_strs.join(", "))
}
FormatExpr::BuiltinCall { func, args } => {
let arg_strs: Vec<String> = args.iter().map(|a| expr_to_rust(a, fields)).collect();
match func {
BuiltinFunc::RotateRight => {
format!(
"({} as u32).rotate_right({} as u32)",
arg_strs.get(0).map(|s| s.as_str()).unwrap_or("0"),
arg_strs.get(1).map(|s| s.as_str()).unwrap_or("0")
)
}
BuiltinFunc::RotateLeft => {
format!(
"({} as u32).rotate_left({} as u32)",
arg_strs.get(0).map(|s| s.as_str()).unwrap_or("0"),
arg_strs.get(1).map(|s| s.as_str()).unwrap_or("0")
)
}
}
}
}
}
fn generate_guard_code(guard: &Guard, fields: &[ResolvedField]) -> String {
let conditions: Vec<String> = guard
.conditions
.iter()
.map(|cond| {
let left = guard_operand_to_rust(&cond.left, fields);
let right = guard_operand_to_rust(&cond.right, fields);
let op = match cond.op {
CompareOp::Eq => "==",
CompareOp::Ne => "!=",
CompareOp::Lt => "<",
CompareOp::Le => "<=",
CompareOp::Gt => ">",
CompareOp::Ge => ">=",
};
let left_field = match &cond.left {
GuardOperand::Field(name) => fields.iter().find(|f| f.name == *name),
_ => None,
};
if let Some(f) = left_field {
if let Some(ref wrapper) = f.resolved_type.wrapper_type {
if let GuardOperand::Literal(val) = &cond.right {
return format!(
"*{} {} {}::from({}{})",
left, op, wrapper, val, f.resolved_type.base_type
);
}
return format!("*{} {} *{}", left, op, right);
} else if f.resolved_type.base_type == "bool" {
if let GuardOperand::Literal(val) = &cond.right {
match (cond.op, *val) {
(CompareOp::Eq, 0) => return format!("!{}", left),
(CompareOp::Eq, _) => return left.clone(),
(CompareOp::Ne, 0) => return left.clone(),
(CompareOp::Ne, _) => return format!("!{}", left),
_ => {}
}
}
}
}
format!("{} {} {}", left, op, right)
})
.collect();
conditions.join(" && ")
}
fn guard_operand_to_rust(operand: &GuardOperand, fields: &[ResolvedField]) -> String {
match operand {
GuardOperand::Field(name) => {
let field = fields.iter().find(|f| f.name == *name);
if let Some(f) = field {
if f.resolved_type.wrapper_type.is_some() {
name.clone()
} else {
name.clone()
}
} else {
name.clone()
}
}
GuardOperand::Literal(val) => format!("{}", val),
GuardOperand::Expr { left, op, right } => {
let l = guard_operand_to_rust(left, fields);
let r = guard_operand_to_rust(right, fields);
let op_str = match op {
ArithOp::Add => "+",
ArithOp::Sub => "-",
ArithOp::Mul => "*",
ArithOp::Div => "/",
ArithOp::Mod => "%",
};
format!("({} {} {})", l, op_str, r)
}
}
}
fn generate_write_asm(
out: &mut String,
def: &ValidatedDef,
enum_name: &str,
trait_name: &str,
) {
writeln!(
out,
" pub fn write_asm<F: {}>(&self, f: &mut fmt::Formatter) -> fmt::Result {{",
trait_name
)
.unwrap();
writeln!(out, " match self {{").unwrap();
for instr in &def.instructions {
let variant_name = to_pascal_case(&instr.name);
let method_name = format!("fmt_{}", instr.name);
if instr.resolved_fields.is_empty() {
writeln!(
out,
" {}::{} => F::{}(f),",
enum_name, variant_name, method_name
)
.unwrap();
} else {
let field_names: Vec<String> =
instr.resolved_fields.iter().map(|f| f.name.clone()).collect();
let destructure = format!(
"{}::{} {{ {} }}",
enum_name,
variant_name,
field_names.join(", ")
);
let args: Vec<String> = instr
.resolved_fields
.iter()
.map(|f| {
if f.resolved_type.wrapper_type.is_some() {
f.name.clone()
} else {
format!("*{}", f.name)
}
})
.collect();
writeln!(
out,
" {} => F::{}({}, f),",
destructure,
method_name,
args.join(", ")
)
.unwrap();
}
}
writeln!(out, " }}").unwrap();
writeln!(out, " }}").unwrap();
}
fn word_type_for_width(width: u32) -> &'static str {
match width {
8 => "u8",
16 => "u16",
32 => "u32",
_ => "u32",
}
}
pub(crate) fn signed_type_for(base: &str) -> &'static str {
match base {
"u8" | "i8" => "i8",
"u16" | "i16" => "i16",
"u32" | "i32" => "i32",
_ => "i32",
}
}
pub(crate) fn type_bits(base: &str) -> u32 {
match base {
"u8" | "i8" => 8,
"u16" | "i16" => 16,
"u32" | "i32" => 32,
_ => 32,
}
}
fn field_rust_type(field: &ResolvedField) -> String {
if let Some(ref wrapper) = field.resolved_type.wrapper_type {
wrapper.clone()
} else {
field.resolved_type.base_type.clone()
}
}
pub fn to_pascal_case(name: &str) -> String {
let mut result = String::new();
let mut capitalize_next = true;
for ch in name.chars() {
if ch == '_' {
capitalize_next = true;
} else if capitalize_next {
result.push(ch.to_ascii_uppercase());
capitalize_next = false;
} else {
result.push(ch.to_ascii_lowercase());
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_to_pascal_case() {
assert_eq!(to_pascal_case("addi"), "Addi");
assert_eq!(to_pascal_case("ld_b_c"), "LdBC");
assert_eq!(to_pascal_case("ADD"), "Add");
assert_eq!(to_pascal_case("nop"), "Nop");
}
#[test]
fn test_extract_expression() {
let range = BitRange::new(5, 0);
assert_eq!(extract_expression("opcode", &[range], "u32", 4, "be"), "opcode & 0x3f");
let range = BitRange::new(31, 26);
assert_eq!(
extract_expression("opcode", &[range], "u32", 4, "be"),
"(opcode >> 26) & 0x3f"
);
let range = BitRange::new_in_unit(1, 15, 0);
assert_eq!(
extract_expression("opcode", &[range], "u16", 2, "be"),
"u16::from_be_bytes(data[2..4].try_into().unwrap()) & 0xffff"
);
let range0 = BitRange::new_in_unit(0, 7, 0); let range1 = BitRange::new_in_unit(1, 15, 8); let result = extract_expression("opcode", &[range0, range1], "u16", 2, "be");
assert!(result.contains("opcode & 0xff"));
assert!(result.contains("u16::from_be_bytes(data[2..4].try_into().unwrap())"));
}
}