use crate::{
array::{Array, ArrayType},
bundle::{Bundle, BundleField, BundleType},
enum_::{Enum, EnumType, EnumVariant},
expr::{
ops::{self, EnumLiteral},
CastBitsTo, CastTo, CastToBits, Expr, ExprEnum, HdlPartialEq, ToExpr,
},
hdl,
int::UInt,
intern::{Intern, Interned, Memoize},
memory::{DynPortType, Mem, MemPort},
module::{
transform::visit::{Fold, Folder},
Block, Id, Module, NameId, ScopedNameId, Stmt, StmtConnect, StmtIf, StmtMatch, StmtWire,
},
source_location::SourceLocation,
ty::{CanonicalType, Type},
wire::Wire,
};
use core::fmt;
use hashbrown::HashMap;
#[derive(Debug)]
pub enum SimplifyEnumsError {
EnumIsNotCastableFromBits { enum_type: Enum },
}
impl fmt::Display for SimplifyEnumsError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SimplifyEnumsError::EnumIsNotCastableFromBits { enum_type } => write!(
f,
"simplify_enums failed: enum type is not castable from bits: {enum_type:?}"
),
}
}
}
impl std::error::Error for SimplifyEnumsError {}
impl From<SimplifyEnumsError> for std::io::Error {
fn from(value: SimplifyEnumsError) -> Self {
std::io::Error::new(std::io::ErrorKind::Other, value)
}
}
fn contains_any_enum_types(ty: CanonicalType) -> bool {
#[derive(Copy, Clone, PartialEq, Eq, Hash)]
struct TheMemoize;
impl Memoize for TheMemoize {
type Input = CanonicalType;
type InputOwned = CanonicalType;
type Output = bool;
fn inner(self, ty: &Self::Input) -> Self::Output {
match *ty {
CanonicalType::Array(array_type) => contains_any_enum_types(array_type.element()),
CanonicalType::Enum(_) => true,
CanonicalType::Bundle(bundle) => bundle
.fields()
.iter()
.any(|field| contains_any_enum_types(field.ty)),
CanonicalType::UInt(_)
| CanonicalType::SInt(_)
| CanonicalType::Bool(_)
| CanonicalType::AsyncReset(_)
| CanonicalType::SyncReset(_)
| CanonicalType::Reset(_)
| CanonicalType::Clock(_) => false,
}
}
}
TheMemoize.get_owned(ty)
}
#[hdl]
struct TagAndBody<Tag, Body> {
tag: Tag,
body: Body,
}
#[derive(Clone, Debug)]
enum EnumTypeState {
TagEnumAndBody(TagAndBody<Enum, UInt>),
TagUIntAndBody(TagAndBody<UInt, UInt>),
UInt(UInt),
Unchanged,
}
struct ModuleState {
module_name: NameId,
}
impl ModuleState {
fn gen_name(&mut self, name: &str) -> ScopedNameId {
ScopedNameId(self.module_name, NameId(name.intern(), Id::new()))
}
}
struct State {
enum_types: HashMap<Enum, EnumTypeState>,
replacement_mem_ports: HashMap<MemPort<DynPortType>, Wire<CanonicalType>>,
kind: SimplifyEnumsKind,
module_state_stack: Vec<ModuleState>,
}
impl State {
fn get_or_make_enum_type_state(
&mut self,
enum_type: Enum,
) -> Result<EnumTypeState, SimplifyEnumsError> {
if let Some(retval) = self.enum_types.get(&enum_type) {
return Ok(retval.clone());
}
if !enum_type.type_properties().is_castable_from_bits {
return Err(SimplifyEnumsError::EnumIsNotCastableFromBits { enum_type });
}
let has_body = enum_type
.variants()
.iter()
.any(|variant| variant.ty.is_some());
let retval = match (self.kind, has_body) {
(SimplifyEnumsKind::SimplifyToEnumsWithNoBody, true) => {
EnumTypeState::TagEnumAndBody(TagAndBody {
tag: Enum::new(Interned::from_iter(enum_type.variants().iter().map(|v| {
EnumVariant {
name: v.name,
ty: None,
}
}))),
body: UInt::new_dyn(
enum_type.type_properties().bit_width - enum_type.discriminant_bit_width(),
),
})
}
(SimplifyEnumsKind::SimplifyToEnumsWithNoBody, false) => EnumTypeState::Unchanged,
(SimplifyEnumsKind::ReplaceWithBundleOfUInts, _) => {
EnumTypeState::TagUIntAndBody(TagAndBody {
tag: UInt::new_dyn(enum_type.discriminant_bit_width()),
body: UInt::new_dyn(
enum_type.type_properties().bit_width - enum_type.discriminant_bit_width(),
),
})
}
(SimplifyEnumsKind::ReplaceWithUInt, _) => {
EnumTypeState::UInt(UInt::new_dyn(enum_type.type_properties().bit_width))
}
};
self.enum_types.insert(enum_type, retval.clone());
Ok(retval)
}
#[hdl]
fn handle_enum_literal(
&mut self,
unfolded_enum_type: Enum,
variant_index: usize,
folded_variant_value: Option<Expr<CanonicalType>>,
) -> Result<Expr<CanonicalType>, SimplifyEnumsError> {
Ok(
match self.get_or_make_enum_type_state(unfolded_enum_type)? {
EnumTypeState::TagEnumAndBody(TagAndBody { tag, body }) => Expr::canonical(
#[hdl]
TagAndBody {
tag: EnumLiteral::new_by_index(tag, variant_index, None),
body: match folded_variant_value {
Some(variant_value) => variant_value.cast_to_bits().cast_to(body),
None => body.zero().to_expr(),
},
},
),
EnumTypeState::TagUIntAndBody(TagAndBody { tag, body }) => Expr::canonical(
#[hdl]
TagAndBody {
tag: tag.from_int_wrapping(variant_index),
body: match folded_variant_value {
Some(folded_variant_value) => {
folded_variant_value.cast_to_bits().cast_to(body)
}
None => body.zero().to_expr(),
},
},
),
EnumTypeState::UInt(_) => {
let tag = UInt[unfolded_enum_type.discriminant_bit_width()];
let body = UInt[unfolded_enum_type.type_properties().bit_width - tag.width()];
Expr::canonical(
(#[hdl]
TagAndBody {
tag: tag.from_int_wrapping(variant_index),
body: match folded_variant_value {
Some(folded_variant_value) => {
folded_variant_value.cast_to_bits().cast_to(body)
}
None => body.zero().to_expr(),
},
})
.cast_to_bits(),
)
}
EnumTypeState::Unchanged => Expr::canonical(
ops::EnumLiteral::new_by_index(
unfolded_enum_type,
variant_index,
folded_variant_value,
)
.to_expr(),
),
},
)
}
fn handle_variant_access(
&mut self,
unfolded_enum_type: Enum,
folded_base_expr: Expr<CanonicalType>,
variant_index: usize,
) -> Result<Expr<CanonicalType>, SimplifyEnumsError> {
let unfolded_variant_type = unfolded_enum_type.variants()[variant_index].ty;
Ok(
match self.get_or_make_enum_type_state(unfolded_enum_type)? {
EnumTypeState::TagEnumAndBody(_) | EnumTypeState::TagUIntAndBody(_) => {
match unfolded_variant_type {
Some(variant_type) => Expr::canonical(
Expr::<TagAndBody<CanonicalType, UInt>>::from_canonical(
folded_base_expr,
)
.body[..variant_type.bit_width()]
.cast_bits_to(variant_type.fold(self)?),
),
None => Expr::canonical(().to_expr()),
}
}
EnumTypeState::UInt(_) => match unfolded_variant_type {
Some(variant_type) => {
let base_int = Expr::<UInt>::from_canonical(folded_base_expr);
let variant_type_bit_width = variant_type.bit_width();
Expr::canonical(
base_int[unfolded_enum_type.discriminant_bit_width()..]
[..variant_type_bit_width]
.cast_bits_to(variant_type.fold(self)?),
)
}
None => Expr::canonical(().to_expr()),
},
EnumTypeState::Unchanged => match unfolded_variant_type {
Some(_) => ops::VariantAccess::new_by_index(
Expr::from_canonical(folded_base_expr),
variant_index,
)
.to_expr(),
None => Expr::canonical(().to_expr()),
},
},
)
}
fn handle_match(
&mut self,
unfolded_enum_type: Enum,
folded_expr: Expr<CanonicalType>,
source_location: SourceLocation,
folded_blocks: &[Block],
) -> Result<Stmt, SimplifyEnumsError> {
match self.get_or_make_enum_type_state(unfolded_enum_type)? {
EnumTypeState::TagEnumAndBody(_) => Ok(StmtMatch {
expr: Expr::<TagAndBody<Enum, UInt>>::from_canonical(folded_expr).tag,
source_location,
blocks: folded_blocks.intern(),
}
.into()),
EnumTypeState::TagUIntAndBody(_) => {
let int_tag_expr = Expr::<TagAndBody<UInt, UInt>>::from_canonical(folded_expr).tag;
Ok(match_int_tag(int_tag_expr, source_location, folded_blocks).into())
}
EnumTypeState::UInt(_) => {
let int_tag_expr = Expr::<UInt>::from_canonical(folded_expr)
[..unfolded_enum_type.discriminant_bit_width()];
Ok(match_int_tag(int_tag_expr, source_location, folded_blocks).into())
}
EnumTypeState::Unchanged => Ok(StmtMatch {
expr: Expr::from_canonical(folded_expr),
source_location,
blocks: folded_blocks.intern(),
}
.into()),
}
}
fn handle_stmt_connect_array(
&mut self,
unfolded_lhs_ty: Array,
unfolded_rhs_ty: Array,
folded_lhs: Expr<Array>,
folded_rhs: Expr<Array>,
source_location: SourceLocation,
output_stmts: &mut Vec<Stmt>,
) -> Result<(), SimplifyEnumsError> {
assert_eq!(unfolded_lhs_ty.len(), unfolded_rhs_ty.len());
let unfolded_lhs_element_ty = unfolded_lhs_ty.element();
let unfolded_rhs_element_ty = unfolded_rhs_ty.element();
for array_index in 0..unfolded_lhs_ty.len() {
self.handle_stmt_connect(
unfolded_lhs_element_ty,
unfolded_rhs_element_ty,
folded_lhs[array_index],
folded_rhs[array_index],
source_location,
output_stmts,
)?;
}
Ok(())
}
fn handle_stmt_connect_bundle(
&mut self,
unfolded_lhs_ty: Bundle,
unfolded_rhs_ty: Bundle,
folded_lhs: Expr<Bundle>,
folded_rhs: Expr<Bundle>,
source_location: SourceLocation,
output_stmts: &mut Vec<Stmt>,
) -> Result<(), SimplifyEnumsError> {
let unfolded_lhs_fields = unfolded_lhs_ty.fields();
let unfolded_rhs_fields = unfolded_rhs_ty.fields();
assert_eq!(unfolded_lhs_fields.len(), unfolded_rhs_fields.len());
for (
field_index,
(
&BundleField {
name,
flipped,
ty: unfolded_lhs_field_ty,
},
unfolded_rhs_field,
),
) in unfolded_lhs_fields
.iter()
.zip(&unfolded_rhs_fields)
.enumerate()
{
assert_eq!(name, unfolded_rhs_field.name);
assert_eq!(flipped, unfolded_rhs_field.flipped);
let folded_lhs_field =
ops::FieldAccess::new_by_index(folded_lhs, field_index).to_expr();
let folded_rhs_field =
ops::FieldAccess::new_by_index(folded_rhs, field_index).to_expr();
if flipped {
self.handle_stmt_connect(
unfolded_rhs_field.ty,
unfolded_lhs_field_ty,
folded_rhs_field,
folded_lhs_field,
source_location,
output_stmts,
)?;
} else {
self.handle_stmt_connect(
unfolded_lhs_field_ty,
unfolded_rhs_field.ty,
folded_lhs_field,
folded_rhs_field,
source_location,
output_stmts,
)?;
}
}
Ok(())
}
fn handle_stmt_connect_enum(
&mut self,
unfolded_lhs_ty: Enum,
unfolded_rhs_ty: Enum,
folded_lhs: Expr<CanonicalType>,
folded_rhs: Expr<CanonicalType>,
source_location: SourceLocation,
output_stmts: &mut Vec<Stmt>,
) -> Result<(), SimplifyEnumsError> {
let unfolded_lhs_variants = unfolded_lhs_ty.variants();
let unfolded_rhs_variants = unfolded_rhs_ty.variants();
assert_eq!(unfolded_lhs_variants.len(), unfolded_rhs_variants.len());
let mut folded_blocks = vec![];
for (
variant_index,
(
&EnumVariant {
name,
ty: unfolded_lhs_variant_ty,
},
unfolded_rhs_variant,
),
) in unfolded_lhs_variants
.iter()
.zip(&unfolded_rhs_variants)
.enumerate()
{
let mut output_stmts = vec![];
assert_eq!(name, unfolded_rhs_variant.name);
assert_eq!(
unfolded_lhs_variant_ty.is_some(),
unfolded_rhs_variant.ty.is_some()
);
let folded_variant_value =
if let (Some(unfolded_lhs_variant_ty), Some(unfolded_rhs_variant_ty)) =
(unfolded_lhs_variant_ty, unfolded_rhs_variant.ty)
{
let lhs_wire = Wire::new_unchecked(
self.module_state_stack
.last_mut()
.unwrap()
.gen_name("__connect_variant_body"),
source_location,
unfolded_lhs_variant_ty.fold(self)?,
);
output_stmts.push(
StmtWire {
annotations: Interned::default(),
wire: lhs_wire,
}
.into(),
);
let lhs_wire = lhs_wire.to_expr();
let folded_rhs_variant =
self.handle_variant_access(unfolded_rhs_ty, folded_rhs, variant_index)?;
self.handle_stmt_connect(
unfolded_lhs_variant_ty,
unfolded_rhs_variant_ty,
lhs_wire,
folded_rhs_variant,
source_location,
&mut output_stmts,
)?;
Some(lhs_wire)
} else {
None
};
output_stmts.push(
StmtConnect {
lhs: folded_lhs,
rhs: self.handle_enum_literal(
unfolded_lhs_ty,
variant_index,
folded_variant_value,
)?,
source_location,
}
.into(),
);
folded_blocks.push(Block {
memories: Interned::default(),
stmts: Intern::intern_owned(output_stmts),
});
}
output_stmts.push(self.handle_match(
unfolded_rhs_ty,
folded_rhs,
source_location,
&folded_blocks,
)?);
Ok(())
}
fn handle_stmt_connect(
&mut self,
unfolded_lhs_ty: CanonicalType,
unfolded_rhs_ty: CanonicalType,
folded_lhs: Expr<CanonicalType>,
folded_rhs: Expr<CanonicalType>,
source_location: SourceLocation,
output_stmts: &mut Vec<Stmt>,
) -> Result<(), SimplifyEnumsError> {
let needs_expansion = unfolded_lhs_ty != unfolded_rhs_ty
&& (contains_any_enum_types(unfolded_lhs_ty)
|| contains_any_enum_types(unfolded_rhs_ty));
if !needs_expansion {
output_stmts.push(
StmtConnect {
lhs: folded_lhs,
rhs: folded_rhs,
source_location,
}
.into(),
);
return Ok(());
}
match unfolded_lhs_ty {
CanonicalType::Array(unfolded_lhs_ty) => self.handle_stmt_connect_array(
unfolded_lhs_ty,
Array::from_canonical(unfolded_rhs_ty),
Expr::from_canonical(folded_lhs),
Expr::from_canonical(folded_rhs),
source_location,
output_stmts,
),
CanonicalType::Enum(unfolded_lhs_ty) => self.handle_stmt_connect_enum(
unfolded_lhs_ty,
Enum::from_canonical(unfolded_rhs_ty),
folded_lhs,
folded_rhs,
source_location,
output_stmts,
),
CanonicalType::Bundle(unfolded_lhs_ty) => self.handle_stmt_connect_bundle(
unfolded_lhs_ty,
Bundle::from_canonical(unfolded_rhs_ty),
Expr::from_canonical(folded_lhs),
Expr::from_canonical(folded_rhs),
source_location,
output_stmts,
),
CanonicalType::UInt(_)
| CanonicalType::SInt(_)
| CanonicalType::Bool(_)
| CanonicalType::AsyncReset(_)
| CanonicalType::SyncReset(_)
| CanonicalType::Reset(_)
| CanonicalType::Clock(_) => unreachable!(),
}
}
}
fn connect_port(
stmts: &mut Vec<Stmt>,
lhs: Expr<CanonicalType>,
rhs: Expr<CanonicalType>,
source_location: SourceLocation,
) {
if Expr::ty(lhs) == Expr::ty(rhs) {
stmts.push(
StmtConnect {
lhs,
rhs,
source_location,
}
.into(),
);
return;
}
match (Expr::ty(lhs), Expr::ty(rhs)) {
(CanonicalType::Bundle(lhs_type), CanonicalType::UInt(_) | CanonicalType::Bool(_)) => {
let lhs = Expr::<Bundle>::from_canonical(lhs);
for field in lhs_type.fields() {
assert!(!field.flipped);
connect_port(stmts, Expr::field(lhs, &field.name), rhs, source_location);
}
}
(CanonicalType::UInt(_) | CanonicalType::Bool(_), CanonicalType::Bundle(rhs_type)) => {
let rhs = Expr::<Bundle>::from_canonical(rhs);
for field in rhs_type.fields() {
assert!(!field.flipped);
connect_port(stmts, lhs, Expr::field(rhs, &field.name), source_location);
}
}
(CanonicalType::Bundle(lhs_type), CanonicalType::Bundle(_)) => {
let lhs = Expr::<Bundle>::from_canonical(lhs);
let rhs = Expr::<Bundle>::from_canonical(rhs);
for field in lhs_type.fields() {
let (lhs_field, rhs_field) = if field.flipped {
(Expr::field(rhs, &field.name), Expr::field(lhs, &field.name))
} else {
(Expr::field(lhs, &field.name), Expr::field(rhs, &field.name))
};
connect_port(stmts, lhs_field, rhs_field, source_location);
}
}
(CanonicalType::Array(lhs_type), CanonicalType::Array(_)) => {
let lhs = Expr::<Array>::from_canonical(lhs);
let rhs = Expr::<Array>::from_canonical(rhs);
for index in 0..lhs_type.len() {
connect_port(stmts, lhs[index], rhs[index], source_location);
}
}
(CanonicalType::Bundle(_), _)
| (CanonicalType::Enum(_), _)
| (CanonicalType::Array(_), _)
| (CanonicalType::UInt(_), _)
| (CanonicalType::SInt(_), _)
| (CanonicalType::Bool(_), _)
| (CanonicalType::Clock(_), _)
| (CanonicalType::AsyncReset(_), _)
| (CanonicalType::SyncReset(_), _)
| (CanonicalType::Reset(_), _) => unreachable!(
"trying to connect memory ports:\n{:?}\n{:?}",
Expr::ty(lhs),
Expr::ty(rhs),
),
}
}
fn match_int_tag(
int_tag_expr: Expr<UInt>,
source_location: SourceLocation,
folded_blocks: &[Block],
) -> StmtIf {
let mut blocks_iter = folded_blocks.iter().copied().enumerate();
let (_, last_block) = blocks_iter.next_back().unwrap_or_default();
let Some((next_to_last_variant_index, next_to_last_block)) = blocks_iter.next_back() else {
return StmtIf {
cond: true.to_expr(),
source_location,
blocks: [last_block, Block::default()],
};
};
let mut retval = StmtIf {
cond: int_tag_expr
.cmp_eq(Expr::ty(int_tag_expr).from_int_wrapping(next_to_last_variant_index)),
source_location,
blocks: [next_to_last_block, last_block],
};
for (variant_index, block) in blocks_iter.rev() {
retval = StmtIf {
cond: int_tag_expr.cmp_eq(Expr::ty(int_tag_expr).from_int_wrapping(variant_index)),
source_location,
blocks: [
block,
Block {
memories: Default::default(),
stmts: [Stmt::from(retval)][..].intern(),
},
],
};
}
retval
}
impl Folder for State {
type Error = SimplifyEnumsError;
fn fold_enum(&mut self, _v: Enum) -> Result<Enum, Self::Error> {
unreachable!()
}
fn fold_module<T: BundleType>(&mut self, v: Module<T>) -> Result<Module<T>, Self::Error> {
self.module_state_stack.push(ModuleState {
module_name: v.name_id(),
});
let retval = Fold::default_fold(v, self);
self.module_state_stack.pop();
retval
}
fn fold_expr_enum(&mut self, op: ExprEnum) -> Result<ExprEnum, Self::Error> {
match op {
ExprEnum::EnumLiteral(op) => {
let folded_variant_value = op.variant_value().map(|v| v.fold(self)).transpose()?;
Ok(*Expr::expr_enum(self.handle_enum_literal(
op.ty(),
op.variant_index(),
folded_variant_value,
)?))
}
ExprEnum::VariantAccess(op) => {
let folded_base_expr = Expr::canonical(op.base()).fold(self)?;
Ok(*Expr::expr_enum(self.handle_variant_access(
Expr::ty(op.base()),
folded_base_expr,
op.variant_index(),
)?))
}
ExprEnum::MemPort(mem_port) => Ok(
if let Some(&wire) = self.replacement_mem_ports.get(&mem_port) {
ExprEnum::Wire(wire)
} else {
ExprEnum::MemPort(mem_port.fold(self)?)
},
),
ExprEnum::UIntLiteral(_)
| ExprEnum::SIntLiteral(_)
| ExprEnum::BoolLiteral(_)
| ExprEnum::BundleLiteral(_)
| ExprEnum::ArrayLiteral(_)
| ExprEnum::Uninit(_)
| ExprEnum::NotU(_)
| ExprEnum::NotS(_)
| ExprEnum::NotB(_)
| ExprEnum::Neg(_)
| ExprEnum::BitAndU(_)
| ExprEnum::BitAndS(_)
| ExprEnum::BitAndB(_)
| ExprEnum::BitOrU(_)
| ExprEnum::BitOrS(_)
| ExprEnum::BitOrB(_)
| ExprEnum::BitXorU(_)
| ExprEnum::BitXorS(_)
| ExprEnum::BitXorB(_)
| ExprEnum::AddU(_)
| ExprEnum::AddS(_)
| ExprEnum::SubU(_)
| ExprEnum::SubS(_)
| ExprEnum::MulU(_)
| ExprEnum::MulS(_)
| ExprEnum::DivU(_)
| ExprEnum::DivS(_)
| ExprEnum::RemU(_)
| ExprEnum::RemS(_)
| ExprEnum::DynShlU(_)
| ExprEnum::DynShlS(_)
| ExprEnum::DynShrU(_)
| ExprEnum::DynShrS(_)
| ExprEnum::FixedShlU(_)
| ExprEnum::FixedShlS(_)
| ExprEnum::FixedShrU(_)
| ExprEnum::FixedShrS(_)
| ExprEnum::CmpLtB(_)
| ExprEnum::CmpLeB(_)
| ExprEnum::CmpGtB(_)
| ExprEnum::CmpGeB(_)
| ExprEnum::CmpEqB(_)
| ExprEnum::CmpNeB(_)
| ExprEnum::CmpLtU(_)
| ExprEnum::CmpLeU(_)
| ExprEnum::CmpGtU(_)
| ExprEnum::CmpGeU(_)
| ExprEnum::CmpEqU(_)
| ExprEnum::CmpNeU(_)
| ExprEnum::CmpLtS(_)
| ExprEnum::CmpLeS(_)
| ExprEnum::CmpGtS(_)
| ExprEnum::CmpGeS(_)
| ExprEnum::CmpEqS(_)
| ExprEnum::CmpNeS(_)
| ExprEnum::CastUIntToUInt(_)
| ExprEnum::CastUIntToSInt(_)
| ExprEnum::CastSIntToUInt(_)
| ExprEnum::CastSIntToSInt(_)
| ExprEnum::CastBoolToUInt(_)
| ExprEnum::CastBoolToSInt(_)
| ExprEnum::CastUIntToBool(_)
| ExprEnum::CastSIntToBool(_)
| ExprEnum::CastBoolToSyncReset(_)
| ExprEnum::CastUIntToSyncReset(_)
| ExprEnum::CastSIntToSyncReset(_)
| ExprEnum::CastBoolToAsyncReset(_)
| ExprEnum::CastUIntToAsyncReset(_)
| ExprEnum::CastSIntToAsyncReset(_)
| ExprEnum::CastSyncResetToBool(_)
| ExprEnum::CastSyncResetToUInt(_)
| ExprEnum::CastSyncResetToSInt(_)
| ExprEnum::CastSyncResetToReset(_)
| ExprEnum::CastAsyncResetToBool(_)
| ExprEnum::CastAsyncResetToUInt(_)
| ExprEnum::CastAsyncResetToSInt(_)
| ExprEnum::CastAsyncResetToReset(_)
| ExprEnum::CastResetToBool(_)
| ExprEnum::CastResetToUInt(_)
| ExprEnum::CastResetToSInt(_)
| ExprEnum::CastBoolToClock(_)
| ExprEnum::CastUIntToClock(_)
| ExprEnum::CastSIntToClock(_)
| ExprEnum::CastClockToBool(_)
| ExprEnum::CastClockToUInt(_)
| ExprEnum::CastClockToSInt(_)
| ExprEnum::FieldAccess(_)
| ExprEnum::ArrayIndex(_)
| ExprEnum::DynArrayIndex(_)
| ExprEnum::ReduceBitAndU(_)
| ExprEnum::ReduceBitAndS(_)
| ExprEnum::ReduceBitOrU(_)
| ExprEnum::ReduceBitOrS(_)
| ExprEnum::ReduceBitXorU(_)
| ExprEnum::ReduceBitXorS(_)
| ExprEnum::SliceUInt(_)
| ExprEnum::SliceSInt(_)
| ExprEnum::CastToBits(_)
| ExprEnum::CastBitsTo(_)
| ExprEnum::ModuleIO(_)
| ExprEnum::Instance(_)
| ExprEnum::Wire(_)
| ExprEnum::Reg(_) => op.default_fold(self),
}
}
fn fold_block(&mut self, block: Block) -> Result<Block, Self::Error> {
let mut memories = vec![];
let mut stmts = vec![];
for memory in block.memories {
let old_element_ty = memory.array_type().element();
let new_element_ty = memory.array_type().element().fold(self)?;
if new_element_ty != old_element_ty {
let mut new_ports = vec![];
for port in memory.ports() {
let new_port = MemPort::<DynPortType>::new_unchecked(
port.mem_name(),
port.source_location(),
port.port_name(),
port.addr_type(),
new_element_ty,
);
new_ports.push(new_port);
let new_port_ty = new_port.ty();
let mut wire_ty_fields = Vec::from_iter(new_port_ty.fields());
if let Some(wmask_name) = new_port.port_kind().wmask_name() {
let index = *new_port_ty
.name_indexes()
.get(&wmask_name.intern())
.unwrap();
wire_ty_fields[index].ty = port.ty().fields()[index].ty;
}
let wire_ty = Bundle::new(Intern::intern_owned(wire_ty_fields));
if wire_ty == new_port_ty {
continue;
}
let wire = Wire::new_unchecked(
self.module_state_stack
.last_mut()
.unwrap()
.gen_name(&format!(
"{}_{}",
memory.scoped_name().1 .0,
port.port_name()
)),
port.source_location(),
wire_ty,
);
stmts.push(
StmtWire {
annotations: Default::default(),
wire: wire.canonical(),
}
.into(),
);
connect_port(
&mut stmts,
Expr::canonical(new_port.to_expr()),
Expr::canonical(wire.to_expr()),
port.source_location(),
);
self.replacement_mem_ports.insert(port, wire.canonical());
}
memories.push(Mem::new_unchecked(
memory.scoped_name(),
memory.source_location(),
ArrayType::new_dyn(new_element_ty, memory.array_type().len()),
memory.initial_value(),
Intern::intern_owned(new_ports),
memory.read_latency(),
memory.write_latency(),
memory.read_under_write(),
memory.port_annotations(),
memory.mem_annotations(),
));
} else {
memories.push(memory.fold(self)?);
}
}
stmts.extend_from_slice(&block.stmts.fold(self)?);
Ok(Block {
memories: Intern::intern_owned(memories),
stmts: Intern::intern_owned(stmts),
})
}
fn fold_stmt(&mut self, stmt: Stmt) -> Result<Stmt, Self::Error> {
match stmt {
Stmt::Match(StmtMatch {
expr,
source_location,
blocks,
}) => {
let folded_expr = Expr::canonical(expr).fold(self)?;
let folded_blocks = blocks.fold(self)?;
self.handle_match(Expr::ty(expr), folded_expr, source_location, &folded_blocks)
}
Stmt::Connect(StmtConnect {
lhs,
rhs,
source_location,
}) => {
let folded_lhs = lhs.fold(self)?;
let folded_rhs = rhs.fold(self)?;
let mut output_stmts = vec![];
self.handle_stmt_connect(
Expr::ty(lhs),
Expr::ty(rhs),
folded_lhs,
folded_rhs,
source_location,
&mut output_stmts,
)?;
if output_stmts.len() == 1 {
Ok(output_stmts.pop().unwrap())
} else {
Ok(StmtIf {
cond: true.to_expr(),
source_location,
blocks: [
Block {
memories: Interned::default(),
stmts: Intern::intern_owned(output_stmts),
},
Block::default(),
],
}
.into())
}
}
Stmt::Formal(_) | Stmt::If(_) | Stmt::Declaration(_) => stmt.default_fold(self),
}
}
fn fold_stmt_match(&mut self, _v: StmtMatch) -> Result<StmtMatch, Self::Error> {
unreachable!()
}
fn fold_canonical_type(
&mut self,
canonical_type: CanonicalType,
) -> Result<CanonicalType, Self::Error> {
match canonical_type {
CanonicalType::Enum(enum_type) => {
Ok(match self.get_or_make_enum_type_state(enum_type)? {
EnumTypeState::TagEnumAndBody(ty) => ty.canonical(),
EnumTypeState::TagUIntAndBody(ty) => ty.canonical(),
EnumTypeState::UInt(ty) => ty.canonical(),
EnumTypeState::Unchanged => enum_type.canonical(),
})
}
CanonicalType::Bundle(_)
| CanonicalType::Array(_)
| CanonicalType::UInt(_)
| CanonicalType::SInt(_)
| CanonicalType::Bool(_)
| CanonicalType::Clock(_)
| CanonicalType::AsyncReset(_)
| CanonicalType::SyncReset(_)
| CanonicalType::Reset(_) => canonical_type.default_fold(self),
}
}
fn fold_enum_variant(&mut self, _v: EnumVariant) -> Result<EnumVariant, Self::Error> {
unreachable!()
}
fn fold_enum_literal<T: EnumType + Fold<Self>>(
&mut self,
_v: ops::EnumLiteral<T>,
) -> Result<ops::EnumLiteral<T>, Self::Error> {
unreachable!()
}
fn fold_variant_access<VariantType: Type>(
&mut self,
_v: ops::VariantAccess<VariantType>,
) -> Result<ops::VariantAccess<VariantType>, Self::Error> {
unreachable!()
}
}
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, clap::ValueEnum)]
pub enum SimplifyEnumsKind {
SimplifyToEnumsWithNoBody,
#[clap(name = "replace-with-bundle-of-uints")]
ReplaceWithBundleOfUInts,
#[clap(name = "replace-with-uint")]
ReplaceWithUInt,
}
pub fn simplify_enums(
module: Interned<Module<Bundle>>,
kind: SimplifyEnumsKind,
) -> Result<Interned<Module<Bundle>>, SimplifyEnumsError> {
module.fold(&mut State {
enum_types: HashMap::new(),
replacement_mem_ports: HashMap::new(),
kind,
module_state_stack: vec![],
})
}