use std::str::FromStr;
use machine_check_common::{
ir_common::{IrReference, IrStdBinaryOp, IrStdUnaryOp, IrTypeArray},
Signedness,
};
use syn::{
punctuated::Punctuated, spanned::Spanned, token::Comma, Expr, ExprBinary, ExprCall, ExprField,
ExprIndex, ExprLit, ExprReference, ExprStruct, ExprUnary, GenericArgument, Lit, Member, Path,
PathArguments, PathSegment, UnOp,
};
use syn_path::path;
use crate::{
into_wir::{
from_syn::{path::fold_path, ty::fold_type},
Error, ErrorType,
},
util::{create_expr_call, create_expr_ident, create_expr_path, ArgType},
wir::{
WArrayBaseExpr, WBlock, WCall, WCallArg, WExpr, WExprField, WExprHighCall, WExprReference,
WExprStruct, WHighMckExt, WHighMckNew, WHighStdInto, WIdent, WIfCondition, WIndexedExpr,
WIndexedIdent, WMacroableStmt, WNoIfPolarity, WPartialBasicType, WSpan, WStdBinary,
WStdUnary, WStmtAssign, WStmtIf, WType, ZTac, MCK_HIGH_BITVECTOR_ARRAY_NEW,
MCK_HIGH_BITVECTOR_NEW, MCK_HIGH_EXT, MCK_HIGH_SIGNED_NEW, MCK_HIGH_UNSIGNED_NEW,
STD_CLONE, STD_INTO,
},
};
use super::FunctionFolder;
impl super::FunctionFolder {
pub fn fold_right_expr(
&mut self,
expr: Expr,
stmts: &mut Vec<WMacroableStmt<ZTac>>,
) -> Result<WIndexedExpr<WExprHighCall>, Error> {
RightExprFolder {
fn_folder: self,
stmts,
}
.fold_right_expr(expr)
}
pub fn force_right_expr_to_ident<'a>(
&'a mut self,
expr: Expr,
stmts: &'a mut Vec<WMacroableStmt<ZTac>>,
) -> Result<WIdent, Error> {
{
RightExprFolder {
fn_folder: self,
stmts,
}
.force_ident(expr)
}
}
}
struct RightExprFolder<'a> {
fn_folder: &'a mut FunctionFolder,
stmts: &'a mut Vec<WMacroableStmt<ZTac>>,
}
impl RightExprFolder<'_> {
pub fn fold_right_expr(&mut self, expr: Expr) -> Result<WIndexedExpr<WExprHighCall>, Error> {
Ok(match expr {
Expr::Call(expr_call) => {
WIndexedExpr::NonIndexed(WExpr::Call(self.fold_right_expr_call(expr_call)?))
}
Expr::Field(expr_field) => {
WIndexedExpr::NonIndexed(WExpr::Field(self.fold_right_expr_field(expr_field)?))
}
Expr::Path(_) => {
WIndexedExpr::NonIndexed(WExpr::Move(self.fn_folder.fold_expr_as_ident(expr)?))
}
Expr::Struct(expr_struct) => {
WIndexedExpr::NonIndexed(WExpr::Struct(self.fold_right_expr_struct(expr_struct)?))
}
Expr::Reference(expr_reference) => WIndexedExpr::NonIndexed(WExpr::Reference(
self.fold_right_expr_reference(expr_reference)?,
)),
Expr::Lit(expr_lit) => WIndexedExpr::NonIndexed(WExpr::Lit(expr_lit.lit, false)),
Expr::Index(expr_index) => self.fold_right_expr_index(expr_index)?,
Expr::Binary(expr_binary) => self.fold_binary(expr_binary)?,
Expr::Unary(expr_unary) => {
if let UnOp::Neg(_) = expr_unary.op {
if let Expr::Lit(expr_lit) = *expr_unary.expr {
return Ok(WIndexedExpr::NonIndexed(WExpr::Lit(expr_lit.lit, true)));
}
}
self.fold_unary(expr_unary)?
}
Expr::Paren(expr_paren) => {
self.fold_right_expr(*expr_paren.expr)?
}
Expr::Group(expr_group) => {
self.fold_right_expr(*expr_group.expr)?
}
_ => return Err(Error::unsupported_syn_construct("Expression kind", &expr)),
})
}
fn fold_right_expr_call(&mut self, expr_call: ExprCall) -> Result<WExprHighCall, Error> {
let Expr::Path(expr_path) = &*expr_call.func else {
return Err(Error::unsupported_syn_construct(
"Non-path function operand",
&expr_call,
));
};
if expr_path.qself.is_some() {
return Err(Error::unsupported_syn_construct(
"Qualified self in function operand",
&expr_path,
));
}
let fn_path = &expr_path.path;
let mut nongeneric_path_string = if fn_path.leading_colon.is_some() {
String::from("::")
} else {
String::new()
};
let mut first = true;
for pair in fn_path.segments.pairs() {
if first {
first = false;
} else {
nongeneric_path_string += "::";
}
let segment = pair.into_value();
nongeneric_path_string += &segment.ident.to_string();
}
if let Ok(unary_op) = IrStdUnaryOp::from_str(&nongeneric_path_string) {
return self.create_std_unary(unary_op, fn_path, expr_call.args);
}
if let Ok(binary_op) = IrStdBinaryOp::from_str(&nongeneric_path_string) {
return self.create_std_binary(binary_op, fn_path, expr_call.args);
}
match nongeneric_path_string.as_str() {
MCK_HIGH_EXT => {
return self.create_mck_ext(fn_path, expr_call.args);
}
MCK_HIGH_BITVECTOR_NEW
| MCK_HIGH_UNSIGNED_NEW
| MCK_HIGH_SIGNED_NEW
| MCK_HIGH_BITVECTOR_ARRAY_NEW => {
return self.create_mck_new(fn_path, expr_call.args);
}
STD_CLONE => {
return self.create_std_clone(fn_path, expr_call.args);
}
STD_INTO => {
return self.create_std_into(fn_path, expr_call.args);
}
_ => {}
}
let wir_fn_path = fold_path(fn_path.clone(), self.fn_folder.self_ty.as_ref())?;
if wir_fn_path.leading_colon.is_none() && wir_fn_path.segments.len() == 1 {
let ident = &wir_fn_path.segments[0].ident;
if self.fn_folder.lookup_local_ident(ident).is_some() {
return Err(Error::unsupported_syn_construct(
"Local ident as function operand",
&fn_path,
));
}
}
let mut args = Vec::new();
for arg in expr_call.args {
args.push(self.force_call_arg(arg)?);
}
Ok(WExprHighCall::Call(WCall {
fn_path: wir_fn_path,
args,
}))
}
fn create_std_unary(
&mut self,
op: IrStdUnaryOp,
fn_path: &Path,
args: Punctuated<Expr, Comma>,
) -> Result<WExprHighCall, Error> {
Self::assure_nongeneric_fn_path(fn_path)?;
let operand = self.parse_single_ident_arg(args)?;
Ok(WExprHighCall::StdUnary(WStdUnary { op, operand }))
}
fn create_std_binary(
&mut self,
op: IrStdBinaryOp,
fn_path: &Path,
args: Punctuated<Expr, Comma>,
) -> Result<WExprHighCall, Error> {
Self::assure_nongeneric_fn_path(fn_path)?;
let (a, b) = self.parse_two_ident_args(args)?;
Ok(WExprHighCall::StdBinary(WStdBinary { op, a, b }))
}
fn create_mck_ext(
&mut self,
fn_path: &Path,
args: Punctuated<Expr, Comma>,
) -> Result<WExprHighCall, Error> {
let mut fn_path = fn_path.clone();
let second_segment = &mut fn_path.segments[1];
let width = Self::parse_single_u32_generics_opt(second_segment)?;
second_segment.arguments = syn::PathArguments::None;
Self::assure_nongeneric_fn_path(&fn_path)?;
let from = self.parse_single_ident_arg(args)?;
Ok(WExprHighCall::MckExt(WHighMckExt { width, from }))
}
fn create_mck_new(
&mut self,
fn_path: &Path,
args: Punctuated<Expr, Comma>,
) -> Result<WExprHighCall, Error> {
let mut fn_path = fn_path.clone();
let second_segment = &mut fn_path.segments[1];
if second_segment.ident.to_string().as_str() == "BitvectorArray" {
let (index_width, element_width) = Self::parse_two_u32_generics(second_segment)?;
let fill_ident = self.parse_single_ident_arg(args)?;
return Ok(WExprHighCall::MckNew(WHighMckNew::BitvectorArray(
IrTypeArray {
index_width,
element_width,
},
fill_ident,
)));
}
let width = Self::parse_single_u32_generics_opt(second_segment)?;
second_segment.arguments = syn::PathArguments::None;
let value = self.parse_single_const_arg(args)?;
let kind = match second_segment.ident.to_string().as_str() {
"Bitvector" => WHighMckNew::Bitvector(Signedness::None, width, value),
"Unsigned" => WHighMckNew::Bitvector(Signedness::Unsigned, width, value),
"Signed" => WHighMckNew::Bitvector(Signedness::Signed, width, value),
_ => panic!("Unexpected function path here"),
};
Self::assure_nongeneric_fn_path(&fn_path)?;
Ok(WExprHighCall::MckNew(kind))
}
fn create_std_into(
&mut self,
fn_path: &Path,
args: Punctuated<Expr, Comma>,
) -> Result<WExprHighCall, Error> {
let mut fn_path = fn_path.clone();
let third_segment = &mut fn_path.segments[2];
let ty = self.parse_single_type_generics(third_segment)?;
third_segment.arguments = syn::PathArguments::None;
let IrReference::None = ty.reference else {
return Err(Error::unsupported_syn_construct(
"Reference type",
&third_segment,
));
};
let (signedness, width) = match ty.inner {
WPartialBasicType::Bitvector(signedness, width) => (signedness, width),
_ => {
return Err(Error::unsupported_syn_construct(
"Non-bitvector type",
&third_segment,
))
}
};
let from = self.parse_single_ident_arg(args)?;
Ok(WExprHighCall::StdInto(WHighStdInto {
signedness,
width,
from,
}))
}
fn parse_single_u32_generics_opt(segment: &PathSegment) -> Result<Option<u32>, Error> {
Ok(if matches!(segment.arguments, PathArguments::None) {
None
} else {
Some(Self::parse_single_u32_generics(segment)?)
})
}
fn parse_single_u32_generics(segment: &PathSegment) -> Result<u32, Error> {
let turbofished = Self::extract_turbofished(segment)?;
if turbofished.len() != 1 {
return Err(Error::new(
ErrorType::IllegalConstruct(String::from(
"Exactly one generic argument should be used here",
)),
WSpan::from_syn(segment),
));
}
Self::parse_u32_generic(&turbofished[0])
}
fn parse_two_u32_generics(segment: &PathSegment) -> Result<(u32, u32), Error> {
let turbofished = Self::extract_turbofished(segment)?;
if turbofished.len() != 2 {
return Err(Error::new(
ErrorType::IllegalConstruct(String::from(
"Exactly 2 generic arguments should be used here",
)),
WSpan::from_syn(&segment),
));
}
let first = Self::parse_u32_generic(&turbofished[0])?;
let second = Self::parse_u32_generic(&turbofished[1])?;
Ok((first, second))
}
fn parse_single_type_generics(
&self,
segment: &PathSegment,
) -> Result<WType<WPartialBasicType>, Error> {
let turbofished = Self::extract_turbofished(segment)?;
if turbofished.len() != 1 {
return Err(Error::new(
ErrorType::IllegalConstruct(String::from(
"Exactly one generic argument should be used here",
)),
WSpan::from_syn(segment),
));
}
let arg = &turbofished[0];
let GenericArgument::Type(arg) = arg else {
return Err(Error::unsupported_construct(
"Non-type generic argument",
WSpan::from_syn(segment),
));
};
let ty = fold_type(arg.clone(), self.fn_folder.self_ty.as_ref())?;
Ok(ty)
}
fn extract_turbofished(
segment: &PathSegment,
) -> Result<&Punctuated<GenericArgument, Comma>, Error> {
let PathArguments::AngleBracketed(generic_args) = &segment.arguments else {
return Err(Error::unsupported_construct(
"This call without generic argument",
WSpan::from_syn(segment),
));
};
if generic_args.colon2_token.is_none() {
return Err(Error::new(
ErrorType::IllegalConstruct(String::from("Turbofish should be used here")),
WSpan::from_syn(segment),
));
}
Ok(&generic_args.args)
}
fn parse_u32_generic(arg: &GenericArgument) -> Result<u32, Error> {
let GenericArgument::Const(Expr::Lit(ExprLit {
lit: Lit::Int(lit_int),
..
})) = arg
else {
return Err(Error::new(
ErrorType::IllegalConstruct(String::from(
"The generic argument here should be a literal",
)),
WSpan::from_syn(arg),
));
};
let result = lit_int.base10_parse();
let Ok(result) = result else {
return Err(Error::new(
ErrorType::IllegalConstruct(String::from(
"The generic argument here should be parseable as u32",
)),
WSpan::from_syn(arg),
));
};
Ok(result)
}
fn create_std_clone(
&mut self,
fn_path: &Path,
args: Punctuated<Expr, Comma>,
) -> Result<WExprHighCall, Error> {
Self::assure_nongeneric_fn_path(fn_path)?;
let ident = self.parse_single_ident_arg(args)?;
Ok(WExprHighCall::StdClone(ident))
}
fn parse_single_const_arg(&mut self, args: Punctuated<Expr, Comma>) -> Result<i128, Error> {
if args.len() != 1 {
return Err(Error::new(
ErrorType::IllegalConstruct(String::from("Exactly 1 argument expected")),
WSpan::from_syn(&args),
));
};
let mut arg = args.iter().next().unwrap();
let mut neg = false;
if let Expr::Unary(ExprUnary {
attrs: _,
op: UnOp::Neg(_),
expr,
}) = arg
{
neg = true;
arg = expr;
}
let Expr::Lit(ExprLit {
lit: Lit::Int(lit_int),
attrs: _attrs,
}) = arg
else {
return Err(Error::unsupported_construct(
"Non-integer-literal argument here",
WSpan::from_syn(&args),
));
};
let Ok(parsed) = lit_int.base10_parse::<u128>() else {
return Err(Error::new(
ErrorType::IllegalConstruct(String::from("Argument not parseable as constant")),
WSpan::from_syn(&lit_int),
));
};
Ok(if neg {
-(parsed as i128)
} else {
parsed as i128
})
}
fn parse_single_ident_arg(&mut self, args: Punctuated<Expr, Comma>) -> Result<WIdent, Error> {
if args.len() != 1 {
return Err(Error::new(
ErrorType::IllegalConstruct(String::from("Exactly 1 argument expected")),
WSpan::from_syn(&args),
));
};
self.force_ident(args.into_iter().next().unwrap())
}
fn parse_two_ident_args(
&mut self,
args: Punctuated<Expr, Comma>,
) -> Result<(WIdent, WIdent), Error> {
if args.len() != 2 {
return Err(Error::new(
ErrorType::IllegalConstruct(String::from("Exactly 2 arguments expected")),
WSpan::from_syn(&args),
));
};
let mut iter = args.into_iter();
let a = self.force_ident(iter.next().unwrap())?;
let b = self.force_ident(iter.next().unwrap())?;
Ok((a, b))
}
fn assure_nongeneric_fn_path(fn_path: &Path) -> Result<(), Error> {
for segment in &fn_path.segments {
if !segment.arguments.is_none() {
return Err(Error::unsupported_construct(
"Unexpected generics",
WSpan::from_syn(segment),
));
};
}
Ok(())
}
fn fold_right_expr_field(&mut self, expr_field: ExprField) -> Result<WExprField, Error> {
let base = self.fn_folder.fold_expr_as_ident(*expr_field.base)?;
let member = Self::extract_member(expr_field.member)?;
Ok(WExprField { base, member })
}
fn fold_right_expr_struct(&mut self, expr_struct: ExprStruct) -> Result<WExprStruct, Error> {
if expr_struct.qself.is_some() {
return Err(Error::unsupported_construct(
"Quantified self",
WSpan::from_syn(&expr_struct),
));
}
if expr_struct.rest.is_some() {
return Err(Error::unsupported_construct(
"Struct expressions with base",
WSpan::from_syn(&expr_struct.rest),
));
}
let mut args = Vec::new();
for field in expr_struct.fields {
let member_ident = Self::extract_member(field.member)?;
let member_value = self.force_ident(field.expr)?;
args.push((member_ident, member_value))
}
Ok(WExprStruct {
type_path: fold_path(expr_struct.path, self.fn_folder.self_ty.as_ref())?,
fields: args,
})
}
fn fold_right_expr_reference(
&mut self,
expr_reference: ExprReference,
) -> Result<WExprReference, Error> {
Ok(match *expr_reference.expr {
Expr::Path(expr_path) => {
WExprReference::Ident(self.fn_folder.fold_expr_as_ident(Expr::Path(expr_path))?)
}
Expr::Field(expr_field) => {
let member = Self::extract_member(expr_field.member)?;
WExprReference::Field(WExprField {
base: self.force_ident(*expr_field.base)?,
member,
})
}
_ => {
return Err(Error::unsupported_construct(
"Expression kind inside reference",
WSpan::from_syn(&expr_reference.expr),
))
}
})
}
fn fold_right_expr_index(
&mut self,
expr_index: ExprIndex,
) -> Result<WIndexedExpr<WExprHighCall>, Error> {
let array_base = match *expr_index.expr {
Expr::Path(expr_path) => {
WArrayBaseExpr::Ident(self.fn_folder.fold_expr_as_ident(Expr::Path(expr_path))?)
}
Expr::Field(expr_field) => {
let field_base = self.force_ident(*expr_field.base)?;
let member = Self::extract_member(expr_field.member)?;
WArrayBaseExpr::Field(WExprField {
base: field_base,
member,
})
}
_ => {
return Err(Error::unsupported_construct(
"Expression kind as array base",
WSpan::from_syn(&expr_index.expr),
))
}
};
let index_ident = self.force_ident(*expr_index.index)?;
Ok(WIndexedExpr::Indexed(array_base, index_ident))
}
fn extract_member(member: Member) -> Result<WIdent, Error> {
match member {
Member::Named(ident) => Ok(WIdent::from_syn_ident(ident)),
Member::Unnamed(index) => Err(Error::unsupported_construct(
"Unnamed members",
WSpan::from_syn(&index),
)),
}
}
fn force_call_arg(&mut self, expr: Expr) -> Result<WCallArg, Error> {
if let Expr::Lit(lit) = expr {
return Ok(WCallArg::Literal(lit.lit));
}
Ok(WCallArg::Ident(self.force_ident(expr)?))
}
fn force_ident(&mut self, expr: Expr) -> Result<WIdent, Error> {
if let Ok(ident) = self.fn_folder.fold_expr_as_ident(expr.clone()) {
return Ok(ident);
}
self.move_through_temp(expr)
}
fn move_through_temp(&mut self, expr: Expr) -> Result<WIdent, Error> {
let expr_span = expr.span();
let expr = match expr {
syn::Expr::Path(_) => {
return self.fn_folder.fold_expr_as_ident(expr);
}
syn::Expr::Paren(paren) => {
return self.move_through_temp(*paren.expr);
}
syn::Expr::Lit(ExprLit {
lit: Lit::Bool(lit),
..
}) => {
WIndexedExpr::NonIndexed(WExpr::Call(WExprHighCall::BooleanNew(lit.value)))
}
_ => {
self.fold_right_expr(expr)?
}
};
let tmp_ident = self
.fn_folder
.ident_creator
.create_temporary_ident(expr_span);
self.stmts.push(WMacroableStmt::Assign(WStmtAssign {
left: WIndexedIdent::NonIndexed(tmp_ident.clone()),
right: expr,
}));
Ok(tmp_ident)
}
fn force_assign_to_temp(&mut self, expr: Expr) -> Result<WIdent, Error> {
let tmp_ident = self
.fn_folder
.ident_creator
.create_temporary_ident(expr.span());
let expr = self.fold_right_expr(expr)?;
self.stmts.push(WMacroableStmt::Assign(WStmtAssign {
left: WIndexedIdent::NonIndexed(tmp_ident.clone()),
right: expr,
}));
Ok(tmp_ident)
}
fn fold_unary(&mut self, expr_unary: ExprUnary) -> Result<WIndexedExpr<WExprHighCall>, Error> {
let path = match expr_unary.op {
syn::UnOp::Deref(_) => {
return Err(Error::unsupported_syn_construct(
"Dereference",
&expr_unary.op,
))
}
syn::UnOp::Not(_) => path!(::std::ops::Not::not),
syn::UnOp::Neg(_) => path!(::std::ops::Neg::neg),
_ => {
return Err(Error::unsupported_syn_construct(
"Unary operator",
&expr_unary.op,
));
}
};
let call = create_expr_call(
create_expr_path(path),
vec![(ArgType::Normal, *expr_unary.expr)],
);
self.fold_right_expr(call)
}
fn fold_binary(
&mut self,
expr_binary: ExprBinary,
) -> Result<WIndexedExpr<WExprHighCall>, Error> {
let call_func = match expr_binary.op {
syn::BinOp::Add(_) => path!(::std::ops::Add::add),
syn::BinOp::Sub(_) => path!(::std::ops::Sub::sub),
syn::BinOp::Mul(_) => path!(::std::ops::Mul::mul),
syn::BinOp::Div(_) => path!(::std::ops::Div::div),
syn::BinOp::Rem(_) => path!(::std::ops::Rem::rem),
syn::BinOp::And(_) => {
return self.fold_short_circuiting(true, *expr_binary.left, *expr_binary.right);
}
syn::BinOp::Or(_) => {
return self.fold_short_circuiting(false, *expr_binary.left, *expr_binary.right);
}
syn::BinOp::BitAnd(_) => path!(::std::ops::BitAnd::bitand),
syn::BinOp::BitOr(_) => path!(::std::ops::BitOr::bitor),
syn::BinOp::BitXor(_) => path!(::std::ops::BitXor::bitxor),
syn::BinOp::Shl(_) => path!(::std::ops::Shl::shl),
syn::BinOp::Shr(_) => path!(::std::ops::Shr::shr),
syn::BinOp::Eq(_) => path!(::std::cmp::PartialEq::eq),
syn::BinOp::Ne(_) => path!(::std::cmp::PartialEq::ne),
syn::BinOp::Lt(_) => path!(::std::cmp::PartialOrd::lt),
syn::BinOp::Le(_) => path!(::std::cmp::PartialOrd::le),
syn::BinOp::Gt(_) => path!(::std::cmp::PartialOrd::gt),
syn::BinOp::Ge(_) => path!(::std::cmp::PartialOrd::ge),
syn::BinOp::AddAssign(_)
| syn::BinOp::SubAssign(_)
| syn::BinOp::MulAssign(_)
| syn::BinOp::DivAssign(_)
| syn::BinOp::RemAssign(_)
| syn::BinOp::BitXorAssign(_)
| syn::BinOp::BitAndAssign(_)
| syn::BinOp::BitOrAssign(_)
| syn::BinOp::ShlAssign(_)
| syn::BinOp::ShrAssign(_) => {
return Err(Error::unsupported_syn_construct(
"Assignment operators",
&expr_binary.op,
))
}
_ => {
return Err(Error::unsupported_syn_construct(
"Binary operator",
&expr_binary.op,
))
}
};
let call = create_expr_call(
create_expr_path(call_func),
vec![
(ArgType::Normal, *expr_binary.left),
(ArgType::Normal, *expr_binary.right),
],
);
self.fold_right_expr(call)
}
fn fold_short_circuiting(
&mut self,
is_and: bool,
left: Expr,
right: Expr,
) -> Result<WIndexedExpr<WExprHighCall>, Error> {
let left = self.force_ident(left)?;
let right = self.force_ident(right)?;
let tmp_result = self.force_assign_to_temp(create_expr_ident(left.to_syn_ident()))?;
let right_assign = WMacroableStmt::Assign(WStmtAssign {
left: WIndexedIdent::NonIndexed(tmp_result.clone()),
right: WIndexedExpr::NonIndexed(WExpr::Move(right)),
});
let right_assign_block = WBlock {
stmts: vec![right_assign],
};
let empty_block = WBlock { stmts: Vec::new() };
let (then_block, else_block) = if is_and {
(right_assign_block, empty_block)
} else {
(empty_block, right_assign_block)
};
self.stmts.push(WMacroableStmt::If(WStmtIf {
condition: WIfCondition {
polarity: WNoIfPolarity,
ident: left,
},
then_block,
else_block,
}));
Ok(WIndexedExpr::NonIndexed(WExpr::Move(tmp_result)))
}
}