use std::collections::HashMap;
use std::fmt::Write;
use crate::codegen::{signed_type_for, type_bits};
use crate::types::*;
fn hw_to_dsl(hw_bit: u32, width: u32, bit_order: BitOrder) -> u32 {
match bit_order {
BitOrder::Msb0 => width - 1 - hw_bit,
BitOrder::Lsb0 => hw_bit,
}
}
#[derive(Debug, Clone)]
struct FieldAccessor {
fn_name: String,
ranges: Vec<BitRange>,
base_type: String,
transforms: Vec<Transform>,
chipi_type: String,
}
fn collect_fields(def: &ValidatedDef) -> (Vec<FieldAccessor>, Vec<String>) {
let width = def.config.width;
let bit_order = def.config.bit_order;
let mut by_name: HashMap<String, Vec<&ResolvedField>> = HashMap::new();
for instr in &def.instructions {
for field in &instr.resolved_fields {
by_name.entry(field.name.clone()).or_default().push(field);
}
}
let mut accessors = Vec::new();
let mut warnings = Vec::new();
for (name, fields) in by_name {
let mut unique = Vec::new();
for field in fields {
if !unique.iter().any(|f: &&ResolvedField| {
f.ranges == field.ranges
&& f.resolved_type.base_type == field.resolved_type.base_type
&& f.resolved_type.transforms == field.resolved_type.transforms
}) {
unique.push(field);
}
}
if unique.len() == 1 {
let field = unique[0];
accessors.push(FieldAccessor {
fn_name: name.clone(),
ranges: field.ranges.clone(),
base_type: field.resolved_type.base_type.clone(),
transforms: field.resolved_type.transforms.clone(),
chipi_type: format_chipi_type(field, width, bit_order),
});
} else {
warnings.push(format!(
"field '{}' has {} conflicting definitions - generating {} variants with bit ranges in names",
name,
unique.len(),
unique.len()
));
for field in unique {
let range_suffix = format_range_suffix(&field.ranges, width, bit_order);
let fn_name = format!("{}_{}", name, range_suffix);
accessors.push(FieldAccessor {
fn_name,
ranges: field.ranges.clone(),
base_type: field.resolved_type.base_type.clone(),
transforms: field.resolved_type.transforms.clone(),
chipi_type: format_chipi_type(field, width, bit_order),
});
}
}
}
accessors.sort_by(|a, b| a.fn_name.cmp(&b.fn_name));
(accessors, warnings)
}
fn format_range_suffix(ranges: &[BitRange], width: u32, bit_order: BitOrder) -> String {
if ranges.len() == 1 {
let r = &ranges[0];
let dsl_start = hw_to_dsl(r.start, width, bit_order);
let dsl_end = hw_to_dsl(r.end, width, bit_order);
format!("{}_{}", dsl_start, dsl_end)
} else {
ranges
.iter()
.map(|r| {
let dsl_start = hw_to_dsl(r.start, width, bit_order);
let dsl_end = hw_to_dsl(r.end, width, bit_order);
format!("{}_{}", dsl_start, dsl_end)
})
.collect::<Vec<_>>()
.join("_")
}
}
fn format_chipi_type(field: &ResolvedField, width: u32, bit_order: BitOrder) -> String {
let ranges_str = field
.ranges
.iter()
.map(|r| {
let dsl_start = hw_to_dsl(r.start, width, bit_order);
let dsl_end = hw_to_dsl(r.end, width, bit_order);
if dsl_start == dsl_end {
format!("[{}]", dsl_start)
} else {
format!("[{}:{}]", dsl_start, dsl_end)
}
})
.collect::<Vec<_>>()
.join("");
let type_str = if let Some(ref wrapper) = field.resolved_type.wrapper_type {
format!("{} as {}", field.resolved_type.base_type, wrapper)
} else {
field.resolved_type.base_type.clone()
};
format!("{}: {}{}", field.name, type_str, ranges_str)
}
fn accessor_body(field: &FieldAccessor, raw_expr: &str) -> String {
let extract = if field.ranges.len() == 1 {
let range = &field.ranges[0];
let width = range.width();
let shift = range.end;
let mask = (1u64 << width) - 1;
if shift == 0 {
format!("{} & {:#x}", raw_expr, mask)
} else {
format!("({} >> {}) & {:#x}", raw_expr, shift, mask)
}
} else {
build_multi_range_extract(raw_expr, &field.ranges)
};
let mut expr = extract;
for transform in &field.transforms {
match transform {
Transform::SignExtend(n) => {
let signed = signed_type_for(&field.base_type);
let bits = type_bits(&field.base_type);
expr = format!(
"(((({}) as {}) << ({} - {})) >> ({} - {}))",
expr, signed, bits, n, bits, n
);
}
Transform::ZeroExtend(_) => {} Transform::ShiftLeft(n) => {
expr = format!("(({}) << {})", expr, n);
}
}
}
if field.base_type == "bool" {
format!("({}) != 0", expr)
} else {
format!("({}) as {}", expr, field.base_type)
}
}
fn build_multi_range_extract(raw_expr: &str, ranges: &[BitRange]) -> String {
let mut parts = Vec::new();
let mut accumulated_width = 0u32;
for range in ranges {
let width = range.width();
let shift = range.end;
let mask = (1u64 << width) - 1;
let extracted = if shift == 0 {
format!("({} & {:#x})", raw_expr, mask)
} else {
format!("(({} >> {}) & {:#x})", raw_expr, shift, mask)
};
if accumulated_width > 0 {
parts.push(format!("({} << {})", extracted, accumulated_width));
} else {
parts.push(extracted);
}
accumulated_width += width;
}
parts.join(" | ")
}
pub fn generate_instr_type(def: &ValidatedDef, struct_name: &str) -> (String, Vec<String>) {
let (fields, warnings) = collect_fields(def);
let mut out = String::new();
writeln!(out, "// Auto-generated by chipi. Do not edit.").unwrap();
writeln!(out).unwrap();
if !warnings.is_empty() {
writeln!(out, "// NOTES:").unwrap();
for warning in &warnings {
writeln!(out, "// {}", warning).unwrap();
}
writeln!(out).unwrap();
}
writeln!(out, "pub struct {}(pub u32);", struct_name).unwrap();
writeln!(out).unwrap();
writeln!(out, "#[rustfmt::skip]").unwrap();
writeln!(out, "impl {} {{", struct_name).unwrap();
for field in &fields {
let body = accessor_body(&field, "self.0");
writeln!(out, " /// Field: `{}`", field.chipi_type).unwrap();
writeln!(
out,
" #[inline] pub fn {}(&self) -> {} {{ {} }}",
field.fn_name, field.base_type, body
)
.unwrap();
}
writeln!(out, "}}").unwrap();
(out, warnings)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser;
use crate::validate;
fn validated(source: &str) -> ValidatedDef {
let def = parser::parse(source, "test.chipi").unwrap();
validate::validate(&def).unwrap()
}
#[test]
fn test_basic_field_generation() {
let def = validated(
r#"
decoder Test { width = 32 bit_order = msb0 }
addi [0:5]=001110 rd:u8[6:10] ra:u8[11:15] simm:i32[16:31]
"#,
);
let (code, _warnings) = generate_instr_type(&def, "Instruction");
assert!(code.contains("pub fn ra("));
assert!(code.contains("pub fn rd("));
assert!(code.contains("pub fn simm("));
assert!(code.contains("pub struct Instruction(pub u32);"));
assert!(code.contains("/// Field:"));
}
#[test]
fn test_deduplication() {
let def = validated(
r#"
decoder Test { width = 32 bit_order = msb0 }
addi [0:5]=001110 rd:u8[6:10] ra:u8[11:15] simm:i32[16:31]
addis [0:5]=001111 rd:u8[6:10] ra:u8[11:15] simm:i32[16:31]
"#,
);
let (code, _warnings) = generate_instr_type(&def, "Instruction");
assert_eq!(code.matches("pub fn rd(").count(), 1);
assert_eq!(code.matches("pub fn ra(").count(), 1);
assert_eq!(code.matches("pub fn simm(").count(), 1);
}
#[test]
fn test_bool_field() {
let def = validated(
r#"
decoder Test { width = 32 bit_order = msb0 }
bx [0:5]=010010 li:i32[6:29] aa:bool[30] lk:bool[31]
"#,
);
let (code, _warnings) = generate_instr_type(&def, "Instruction");
assert!(code.contains("-> bool"));
assert!(code.contains("!= 0"));
}
#[test]
fn test_conflicting_fields_generate_variants() {
let def = validated(
r#"
decoder Test { width = 32 bit_order = msb0 }
foo [0:5]=000001 rd:u8[6:10]
bar [0:5]=000010 rd:u8[11:15]
"#,
);
let (code, warnings) = generate_instr_type(&def, "Instruction");
assert!(!warnings.is_empty());
assert!(warnings[0].contains("rd"));
assert!(warnings[0].contains("conflicting"));
assert!(code.contains("rd_25_21") || code.contains("rd_21_25")); assert!(code.contains("rd_20_16") || code.contains("rd_16_20"));
}
#[test]
fn test_sign_extend_transform() {
let def = validated(
r#"
decoder Test { width = 32 bit_order = msb0 }
type simm16 = i32 { sign_extend(16) }
addi [0:5]=001110 rd:u8[6:10] simm:simm16[16:31]
"#,
);
let (code, _warnings) = generate_instr_type(&def, "Instruction");
assert!(code.contains("pub fn simm(&self) -> i32"));
assert!(code.contains("<<") && code.contains(">>"));
}
#[test]
fn test_shift_left_transform() {
let def = validated(
r#"
decoder Test { width = 32 bit_order = msb0 }
type addr = i32 { shift_left(2) }
bx [0:5]=010010 li:addr[6:29]
"#,
);
let (code, _warnings) = generate_instr_type(&def, "Instruction");
assert!(code.contains("pub fn li(&self) -> i32"));
assert!(code.contains("<< 2"));
}
}