use std::{collections::VecDeque, fmt};
use num_bigint::BigInt;
use num_rational::BigRational;
use num_traits::{One, Signed, ToPrimitive, Zero};
use proc_macro2::{Span, TokenStream};
use syn::{
BinOp, Expr as SynExpr, Expr, ExprBinary, ExprCast, ExprGroup, ExprLit, ExprMethodCall,
ExprParen, ExprUnary, Ident, Lit, LitFloat, LitInt, Type, UnOp, parse_quote,
punctuated::Punctuated, token::Comma,
};
#[derive(Clone, Debug, Default, PartialEq)]
pub struct ExprMetadata {
pub grade: Option<u8>,
}
pub type SymbolRef = String;
fn bigrat_to_syn(r: &BigRational, scalar_ty: Option<&Type>) -> SynExpr {
fn expr_from_int(value: i128) -> SynExpr {
if value >= 0 {
let literal = LitInt::new(&value.to_string(), Span::call_site());
SynExpr::Lit(ExprLit {
attrs: vec![],
lit: Lit::Int(literal),
})
} else {
let magnitude = (-value).to_string();
let literal = LitInt::new(&magnitude, Span::call_site());
let abs_expr = SynExpr::Lit(ExprLit {
attrs: vec![],
lit: Lit::Int(literal),
});
SynExpr::Unary(ExprUnary {
attrs: Vec::new(),
op: UnOp::Neg(Default::default()),
expr: Box::new(abs_expr),
})
}
}
fn expr_from_float(value: f64) -> SynExpr {
if value >= 0.0 {
let formatted = format!("{value:.12}");
let literal = LitFloat::new(&formatted, Span::call_site());
SynExpr::Lit(ExprLit {
attrs: vec![],
lit: Lit::Float(literal),
})
} else {
let formatted = format!("{:.12}", -value);
let literal = LitFloat::new(&formatted, Span::call_site());
let abs_expr = SynExpr::Lit(ExprLit {
attrs: vec![],
lit: Lit::Float(literal),
});
SynExpr::Unary(ExprUnary {
attrs: Vec::new(),
op: UnOp::Neg(Default::default()),
expr: Box::new(abs_expr),
})
}
}
let literal_expr = if r.denom().is_one() {
if let Some(value) = r.numer().to_i128() {
expr_from_int(value)
} else {
expr_from_float(r.to_f64().unwrap_or(0.0))
}
} else {
expr_from_float(r.to_f64().unwrap_or(0.0))
};
if let Some(ty) = scalar_ty {
match literal_expr {
SynExpr::Unary(_) | SynExpr::Binary(_) => parse_quote! { #literal_expr as #ty },
_ => parse_quote! { #literal_expr as #ty },
}
} else {
literal_expr
}
}
#[allow(dead_code)]
#[derive(Clone, Debug, PartialEq)]
pub enum CoeffExprKind {
Literal(BigRational),
Symbol(SymbolRef),
Sum(Vec<CoeffExpr>),
Product(Vec<CoeffExpr>),
Wedge(Vec<CoeffExpr>),
Neg(Box<CoeffExpr>),
Quotient(Box<CoeffExpr>, Box<CoeffExpr>),
Pow(Box<CoeffExpr>, u32),
Sin(Box<CoeffExpr>),
Cos(Box<CoeffExpr>),
Tan(Box<CoeffExpr>),
Sinh(Box<CoeffExpr>),
Cosh(Box<CoeffExpr>),
Tanh(Box<CoeffExpr>),
Asin(Box<CoeffExpr>),
Acos(Box<CoeffExpr>),
Atan(Box<CoeffExpr>),
Asinh(Box<CoeffExpr>),
Acosh(Box<CoeffExpr>),
Atanh(Box<CoeffExpr>),
Exp(Box<CoeffExpr>),
Ln(Box<CoeffExpr>),
Sqrt(Box<CoeffExpr>),
MulAdd(Box<CoeffExpr>, Box<CoeffExpr>, Box<CoeffExpr>),
Opaque(SynExpr),
}
#[allow(dead_code)]
#[derive(Clone, Debug, PartialEq)]
pub struct CoeffExpr {
kind: CoeffExprKind,
metadata: ExprMetadata,
}
#[allow(dead_code)]
impl CoeffExpr {
pub fn new(kind: CoeffExprKind) -> Self {
Self {
kind,
metadata: ExprMetadata::default(),
}
}
pub fn with_metadata(mut self, metadata: ExprMetadata) -> Self {
self.metadata = metadata;
self
}
pub fn metadata(&self) -> &ExprMetadata {
&self.metadata
}
pub fn metadata_mut(&mut self) -> &mut ExprMetadata {
&mut self.metadata
}
pub fn kind(&self) -> &CoeffExprKind {
&self.kind
}
pub fn into_kind(self) -> CoeffExprKind {
self.kind
}
pub fn zero() -> Self {
Self::literal(BigRational::zero())
}
pub fn one() -> Self {
Self::literal(BigRational::one())
}
pub fn literal(lit: BigRational) -> Self {
Self::new(CoeffExprKind::Literal(lit))
}
pub fn symbol(sym: SymbolRef) -> Self {
Self::new(CoeffExprKind::Symbol(sym))
}
pub fn neg(expr: CoeffExpr) -> Self {
let CoeffExpr { kind, metadata } = expr;
match kind {
CoeffExprKind::Literal(lit) => {
let mut result = CoeffExpr::literal(-lit);
result.metadata = metadata;
result
}
other => CoeffExpr {
kind: CoeffExprKind::Neg(Box::new(CoeffExpr {
kind: other,
metadata: metadata.clone(),
})),
metadata,
},
}
}
pub fn quotient(numerator: CoeffExpr, denominator: CoeffExpr) -> Self {
Self::new(CoeffExprKind::Quotient(
Box::new(numerator),
Box::new(denominator),
))
}
pub fn pow(base: CoeffExpr, exponent: u32) -> Self {
match exponent {
0 => CoeffExpr::one(),
1 => base,
exp => Self::new(CoeffExprKind::Pow(Box::new(base), exp)),
}
}
pub fn sin(arg: CoeffExpr) -> Self {
Self::new(CoeffExprKind::Sin(Box::new(arg)))
}
pub fn cos(arg: CoeffExpr) -> Self {
Self::new(CoeffExprKind::Cos(Box::new(arg)))
}
pub fn tan(arg: CoeffExpr) -> Self {
Self::new(CoeffExprKind::Tan(Box::new(arg)))
}
pub fn sinh(arg: CoeffExpr) -> Self {
Self::new(CoeffExprKind::Sinh(Box::new(arg)))
}
pub fn cosh(arg: CoeffExpr) -> Self {
Self::new(CoeffExprKind::Cosh(Box::new(arg)))
}
pub fn tanh(arg: CoeffExpr) -> Self {
Self::new(CoeffExprKind::Tanh(Box::new(arg)))
}
pub fn asin(arg: CoeffExpr) -> Self {
Self::new(CoeffExprKind::Asin(Box::new(arg)))
}
pub fn acos(arg: CoeffExpr) -> Self {
Self::new(CoeffExprKind::Acos(Box::new(arg)))
}
pub fn atan(arg: CoeffExpr) -> Self {
Self::new(CoeffExprKind::Atan(Box::new(arg)))
}
pub fn asinh(arg: CoeffExpr) -> Self {
Self::new(CoeffExprKind::Asinh(Box::new(arg)))
}
pub fn acosh(arg: CoeffExpr) -> Self {
Self::new(CoeffExprKind::Acosh(Box::new(arg)))
}
pub fn atanh(arg: CoeffExpr) -> Self {
Self::new(CoeffExprKind::Atanh(Box::new(arg)))
}
pub fn exp(arg: CoeffExpr) -> Self {
Self::new(CoeffExprKind::Exp(Box::new(arg)))
}
pub fn ln(arg: CoeffExpr) -> Self {
Self::new(CoeffExprKind::Ln(Box::new(arg)))
}
pub fn sqrt(arg: CoeffExpr) -> Self {
Self::new(CoeffExprKind::Sqrt(Box::new(arg)))
}
pub fn mul_add(lhs: CoeffExpr, rhs: CoeffExpr, addend: CoeffExpr) -> Self {
if lhs.is_zero() || rhs.is_zero() {
return addend;
}
if addend.is_zero() {
return CoeffExpr::product(vec![lhs, rhs]);
}
Self::new(CoeffExprKind::MulAdd(
Box::new(lhs),
Box::new(rhs),
Box::new(addend),
))
}
pub fn opaque(expr: SynExpr) -> Self {
Self::new(CoeffExprKind::Opaque(expr))
}
pub fn sum<T>(terms: T) -> Self
where
T: IntoIterator<Item = CoeffExpr>,
{
let mut pending: VecDeque<CoeffExpr> = terms.into_iter().collect();
if pending.is_empty() {
return CoeffExpr::zero();
}
let mut literal_acc: Option<BigRational> = None;
let mut result_terms: Vec<CoeffExpr> = Vec::with_capacity(pending.len());
while let Some(term) = pending.pop_front() {
let CoeffExpr { kind, metadata } = term;
match kind {
CoeffExprKind::Literal(lit) => {
literal_acc = Some(match literal_acc {
Some(acc) => acc + lit.clone(),
None => lit,
});
}
CoeffExprKind::Sum(nested) => {
pending.extend(nested);
}
other => result_terms.push(CoeffExpr {
kind: other,
metadata,
}),
}
}
if let Some(lit) = literal_acc {
if result_terms.is_empty() || !lit.is_zero() {
result_terms.push(CoeffExpr::literal(lit));
}
}
match result_terms.len() {
0 => CoeffExpr::zero(),
1 => result_terms.pop().unwrap(),
_ => CoeffExpr::new(CoeffExprKind::Sum(result_terms)),
}
}
pub fn product<T>(factors: T) -> Self
where
T: IntoIterator<Item = CoeffExpr>,
{
let mut pending: VecDeque<CoeffExpr> = factors.into_iter().collect();
if pending.is_empty() {
return CoeffExpr::one();
}
let mut literal_acc: Option<BigRational> = None;
let mut result_terms: Vec<CoeffExpr> = Vec::with_capacity(pending.len());
while let Some(term) = pending.pop_front() {
let CoeffExpr { kind, metadata } = term;
match kind {
CoeffExprKind::Literal(lit) => {
literal_acc = Some(match literal_acc {
Some(acc) => acc * lit.clone(),
None => lit,
});
if let Some(acc) = &literal_acc {
if acc.is_zero() {
return CoeffExpr::zero();
}
}
}
CoeffExprKind::Product(nested) => {
pending.extend(nested);
}
CoeffExprKind::Pow(base, exp) => {
let pow_expr =
CoeffExpr::new(CoeffExprKind::Pow(base, exp)).with_metadata(metadata);
result_terms.push(pow_expr);
}
other => result_terms.push(CoeffExpr {
kind: other,
metadata,
}),
}
}
if let Some(lit) = literal_acc {
if result_terms.is_empty() {
return CoeffExpr::literal(lit);
}
if !lit.is_one() {
result_terms.insert(0, CoeffExpr::literal(lit));
}
}
fn accumulate_pow(terms: &mut Vec<(CoeffExpr, u32)>, base: CoeffExpr, exp: u32) {
if exp == 0 {
return;
}
for (existing_base, existing_exp) in terms.iter_mut() {
if *existing_base == base {
*existing_exp += exp;
return;
}
}
terms.push((base, exp));
}
let mut combined: Vec<(CoeffExpr, u32)> = Vec::with_capacity(result_terms.len());
for term in result_terms.into_iter() {
let CoeffExpr { kind, metadata } = term;
match kind {
CoeffExprKind::Pow(base, exp) => accumulate_pow(&mut combined, *base, exp),
other => accumulate_pow(
&mut combined,
CoeffExpr {
kind: other,
metadata,
},
1,
),
}
}
let mut final_terms: Vec<CoeffExpr> = Vec::with_capacity(combined.len());
for (base, exp) in combined.into_iter() {
final_terms.push(CoeffExpr::pow(base, exp));
}
match final_terms.len() {
0 => CoeffExpr::one(),
1 => final_terms.pop().unwrap(),
_ => CoeffExpr::new(CoeffExprKind::Product(final_terms)),
}
}
pub fn wedge<T>(factors: T) -> Self
where
T: IntoIterator<Item = CoeffExpr>,
{
let mut collected: Vec<CoeffExpr> = factors.into_iter().collect();
if collected.len() == 1 {
return collected.remove(0);
}
Self::new(CoeffExprKind::Wedge(collected))
}
pub fn from_syn(expr: &SynExpr) -> Self {
match expr {
SynExpr::Lit(ExprLit { lit, .. }) => match lit {
Lit::Int(lit) => match lit.base10_parse::<i128>() {
Ok(value) => CoeffExpr::literal(BigRational::from_integer(value.into())),
Err(_) => CoeffExpr::opaque(expr.clone()),
},
Lit::Float(lit) => match lit.base10_parse::<f64>() {
Ok(value) => CoeffExpr::literal(BigRational::from_float(value).unwrap()),
Err(_) => CoeffExpr::opaque(expr.clone()),
},
_ => CoeffExpr::opaque(expr.clone()),
},
SynExpr::Unary(ExprUnary {
op, expr: inner, ..
}) if matches!(op, UnOp::Neg(_)) => {
let inner = CoeffExpr::from_syn(inner);
CoeffExpr::neg(inner)
}
SynExpr::Group(ExprGroup { expr: inner, .. })
| SynExpr::Paren(ExprParen { expr: inner, .. })
| SynExpr::Cast(ExprCast { expr: inner, .. }) => CoeffExpr::from_syn(inner),
SynExpr::Binary(ExprBinary {
left, op, right, ..
}) => {
let lhs = CoeffExpr::from_syn(left);
let rhs = CoeffExpr::from_syn(right);
match op {
BinOp::Add(_) => CoeffExpr::sum(vec![lhs, rhs]),
BinOp::Sub(_) => CoeffExpr::sum(vec![lhs, CoeffExpr::neg(rhs)]),
BinOp::Mul(_) => CoeffExpr::product(vec![lhs, rhs]),
BinOp::Div(_) => CoeffExpr::quotient(lhs, rhs),
_ => CoeffExpr::opaque(expr.clone()),
}
}
SynExpr::MethodCall(ExprMethodCall {
receiver,
method,
args,
..
}) if method == "pow" && args.len() == 1 => {
let base = CoeffExpr::from_syn(receiver);
let arg = &args[0];
if let SynExpr::Lit(ExprLit {
lit: Lit::Int(lit), ..
}) = arg
{
if let Ok(exp) = lit.base10_parse::<u32>() {
return CoeffExpr::pow(base, exp);
}
}
CoeffExpr::opaque(expr.clone())
}
SynExpr::MethodCall(ExprMethodCall {
receiver,
method,
args,
..
}) if method == "powi" && args.len() == 1 => {
let base = CoeffExpr::from_syn(receiver);
let arg = &args[0];
if let SynExpr::Lit(ExprLit {
lit: Lit::Int(lit), ..
}) = arg
{
if let Ok(exp) = lit.base10_parse::<i32>() {
if exp >= 0 {
return CoeffExpr::pow(base, exp as u32);
} else {
let positive = CoeffExpr::pow(base.clone(), (-exp) as u32);
return CoeffExpr::quotient(CoeffExpr::one(), positive);
}
}
}
CoeffExpr::opaque(expr.clone())
}
SynExpr::MethodCall(ExprMethodCall {
receiver,
method,
args,
..
}) if method == "powf" && args.len() == 1 => {
let base = CoeffExpr::from_syn(receiver);
let arg = &args[0];
if let SynExpr::Lit(ExprLit {
lit: Lit::Float(lit),
..
}) = arg
{
if let Ok(value) = lit.base10_parse::<f64>() {
if let Some(rational) = BigRational::from_float(value) {
if rational.denom().is_one() {
if rational.numer().is_negative() {
return CoeffExpr::opaque(expr.clone());
}
if let Some(exp) = rational.numer().to_u32() {
return CoeffExpr::pow(base, exp);
}
}
let half = BigRational::new(BigInt::from(1), BigInt::from(2));
if rational == half {
return CoeffExpr::sqrt(base);
}
}
}
}
CoeffExpr::opaque(expr.clone())
}
SynExpr::MethodCall(ExprMethodCall {
receiver,
method,
args,
..
}) if method == "mul_add" && args.len() == 2 => {
let lhs = CoeffExpr::from_syn(receiver);
let rhs = CoeffExpr::from_syn(&args[0]);
let addend = CoeffExpr::from_syn(&args[1]);
CoeffExpr::mul_add(lhs, rhs, addend)
}
SynExpr::MethodCall(ExprMethodCall {
receiver,
method,
args,
..
}) if method == "sin" && args.is_empty() => {
let arg = CoeffExpr::from_syn(receiver);
CoeffExpr::sin(arg)
}
SynExpr::MethodCall(ExprMethodCall {
receiver,
method,
args,
..
}) if method == "sinh" && args.is_empty() => {
let arg = CoeffExpr::from_syn(receiver);
CoeffExpr::sinh(arg)
}
SynExpr::MethodCall(ExprMethodCall {
receiver,
method,
args,
..
}) if method == "asin" && args.is_empty() => {
let arg = CoeffExpr::from_syn(receiver);
CoeffExpr::asin(arg)
}
SynExpr::MethodCall(ExprMethodCall {
receiver,
method,
args,
..
}) if method == "acos" && args.is_empty() => {
let arg = CoeffExpr::from_syn(receiver);
CoeffExpr::acos(arg)
}
SynExpr::MethodCall(ExprMethodCall {
receiver,
method,
args,
..
}) if method == "atan" && args.is_empty() => {
let arg = CoeffExpr::from_syn(receiver);
CoeffExpr::atan(arg)
}
SynExpr::MethodCall(ExprMethodCall {
receiver,
method,
args,
..
}) if method == "asinh" && args.is_empty() => {
let arg = CoeffExpr::from_syn(receiver);
CoeffExpr::asinh(arg)
}
SynExpr::MethodCall(ExprMethodCall {
receiver,
method,
args,
..
}) if method == "acosh" && args.is_empty() => {
let arg = CoeffExpr::from_syn(receiver);
CoeffExpr::acosh(arg)
}
SynExpr::MethodCall(ExprMethodCall {
receiver,
method,
args,
..
}) if method == "atanh" && args.is_empty() => {
let arg = CoeffExpr::from_syn(receiver);
CoeffExpr::atanh(arg)
}
SynExpr::MethodCall(ExprMethodCall {
receiver,
method,
args,
..
}) if method == "cos" && args.is_empty() => {
let arg = CoeffExpr::from_syn(receiver);
CoeffExpr::cos(arg)
}
SynExpr::MethodCall(ExprMethodCall {
receiver,
method,
args,
..
}) if method == "tan" && args.is_empty() => {
let arg = CoeffExpr::from_syn(receiver);
CoeffExpr::tan(arg)
}
SynExpr::MethodCall(ExprMethodCall {
receiver,
method,
args,
..
}) if method == "cosh" && args.is_empty() => {
let arg = CoeffExpr::from_syn(receiver);
CoeffExpr::cosh(arg)
}
SynExpr::MethodCall(ExprMethodCall {
receiver,
method,
args,
..
}) if method == "exp" && args.is_empty() => {
let arg = CoeffExpr::from_syn(receiver);
CoeffExpr::exp(arg)
}
SynExpr::MethodCall(ExprMethodCall {
receiver,
method,
args,
..
}) if method == "ln" && args.is_empty() => {
let arg = CoeffExpr::from_syn(receiver);
CoeffExpr::ln(arg)
}
SynExpr::MethodCall(ExprMethodCall {
receiver,
method,
args,
..
}) if method == "sqrt" && args.is_empty() => {
let arg = CoeffExpr::from_syn(receiver);
CoeffExpr::sqrt(arg)
}
SynExpr::MethodCall(ExprMethodCall {
receiver,
method,
args,
..
}) if method == "tanh" && args.is_empty() => {
let arg = CoeffExpr::from_syn(receiver);
CoeffExpr::tanh(arg)
}
_ => CoeffExpr::opaque(expr.clone()),
}
}
pub fn is_zero(&self) -> bool {
match &self.kind {
CoeffExprKind::Literal(lit) => lit.is_zero(),
CoeffExprKind::Quotient(numerator, _) => numerator.is_zero(),
CoeffExprKind::Pow(base, exp) => {
if *exp == 0 {
false
} else {
base.is_zero()
}
}
_ => false,
}
}
pub fn is_one(&self) -> bool {
match &self.kind {
CoeffExprKind::Literal(lit) => lit.is_one(),
CoeffExprKind::Pow(base, _) => base.is_one(),
_ => false,
}
}
pub fn canonicalize(self) -> Self {
self
}
pub fn to_syn(expr: &CoeffExpr, scalar_ty: Option<&Type>) -> SynExpr {
match &expr.kind {
CoeffExprKind::Literal(lit) => bigrat_to_syn(lit, scalar_ty),
CoeffExprKind::Symbol(name) => syn::parse_str::<SynExpr>(name).unwrap_or_else(|_| {
SynExpr::Verbatim(name.parse::<TokenStream>().unwrap_or_default())
}),
CoeffExprKind::Sum(terms) => {
fn wrap_simple(expr: syn::Expr) -> syn::Expr {
match expr {
syn::Expr::Path(_) | syn::Expr::Lit(_) => expr,
other => parse_quote! { #other },
}
}
let mut iter = terms.iter();
if let Some(first) = iter.next() {
let mut acc = wrap_simple(CoeffExpr::to_syn(first, scalar_ty));
for term in iter {
let rhs = wrap_simple(CoeffExpr::to_syn(term, scalar_ty));
acc = method_call_expr_with_args(acc, "add", vec![rhs], scalar_ty);
}
acc
} else {
parse_quote! { 0 }
}
}
CoeffExprKind::Product(factors) => {
fn wrap_for_product(expr: syn::Expr) -> syn::Expr {
match expr {
syn::Expr::Path(_) | syn::Expr::Lit(_) => expr,
other => parse_quote! { #other },
}
}
let mut iter = factors.iter();
if let Some(first) = iter.next() {
let mut acc = wrap_for_product(CoeffExpr::to_syn(first, scalar_ty));
for factor in iter {
let rhs = wrap_for_product(CoeffExpr::to_syn(factor, scalar_ty));
acc = method_call_expr_with_args(acc, "mul", vec![rhs], scalar_ty);
}
acc
} else {
parse_quote! { 1 }
}
}
CoeffExprKind::Wedge(factors) => {
let mut iter = factors.iter();
if let Some(first) = iter.next() {
let mut acc = CoeffExpr::to_syn(first, scalar_ty);
for factor in iter {
let rhs = CoeffExpr::to_syn(factor, scalar_ty);
acc = method_call_expr_with_args(acc, "wedge", vec![rhs], scalar_ty);
}
acc
} else {
parse_quote! { 0 }
}
}
CoeffExprKind::Neg(inner) => {
let inner = wrap_unary_operand(CoeffExpr::to_syn(inner, scalar_ty));
SynExpr::Unary(ExprUnary {
attrs: Vec::new(),
op: UnOp::Neg(Default::default()),
expr: Box::new(inner),
})
}
CoeffExprKind::Quotient(numerator, denominator) => {
let num = CoeffExpr::to_syn(numerator, scalar_ty);
let denom = CoeffExpr::to_syn(denominator, scalar_ty);
method_call_expr_with_args(num, "div", vec![denom], scalar_ty)
}
CoeffExprKind::Sin(arg) => emit_odd_method(arg, scalar_ty, "sin"),
CoeffExprKind::Cos(arg) => emit_even_method(arg, scalar_ty, "cos"),
CoeffExprKind::Tan(arg) => emit_odd_method(arg, scalar_ty, "tan"),
CoeffExprKind::Sinh(arg) => emit_odd_method(arg, scalar_ty, "sinh"),
CoeffExprKind::Cosh(arg) => emit_even_method(arg, scalar_ty, "cosh"),
CoeffExprKind::Tanh(arg) => emit_odd_method(arg, scalar_ty, "tanh"),
CoeffExprKind::Asin(arg) => emit_odd_method(arg, scalar_ty, "asin"),
CoeffExprKind::Acos(arg) => emit_plain_method(arg, scalar_ty, "acos"),
CoeffExprKind::Atan(arg) => emit_odd_method(arg, scalar_ty, "atan"),
CoeffExprKind::Asinh(arg) => emit_odd_method(arg, scalar_ty, "asinh"),
CoeffExprKind::Acosh(arg) => emit_plain_method(arg, scalar_ty, "acosh"),
CoeffExprKind::Atanh(arg) => emit_odd_method(arg, scalar_ty, "atanh"),
CoeffExprKind::Exp(arg) => emit_plain_method(arg, scalar_ty, "exp"),
CoeffExprKind::Ln(arg) => emit_plain_method(arg, scalar_ty, "ln"),
CoeffExprKind::Sqrt(arg) => emit_plain_method(arg, scalar_ty, "sqrt"),
CoeffExprKind::Pow(base, exponent) => {
let base_expr = wrap_method_receiver(CoeffExpr::to_syn(base, scalar_ty));
let lit = LitInt::new(&exponent.to_string(), Span::call_site());
let exp_expr: SynExpr = parse_quote! { #lit };
if scalar_ty
.and_then(|ty| CoeffExpr::float_scalar_ident(ty))
.is_some()
{
parse_quote! { #base_expr.powi(#exp_expr) }
} else {
parse_quote! { #base_expr.pow(#exp_expr) }
}
}
CoeffExprKind::MulAdd(lhs, rhs, addend) => {
let lhs_expr = wrap_method_receiver(CoeffExpr::to_syn(lhs, scalar_ty));
let rhs_expr = CoeffExpr::to_syn(rhs, scalar_ty);
let add_expr = CoeffExpr::to_syn(addend, scalar_ty);
if scalar_ty
.and_then(|ty| CoeffExpr::float_scalar_ident(ty))
.is_some()
{
parse_quote! { #lhs_expr.mul_add(#rhs_expr, #add_expr) }
} else {
let mul =
method_call_expr_with_args(lhs_expr, "mul", vec![rhs_expr], scalar_ty);
method_call_expr_with_args(mul, "add", vec![add_expr], scalar_ty)
}
}
CoeffExprKind::Opaque(expr) => expr.clone(),
}
}
fn float_scalar_ident(ty: &Type) -> Option<&syn::Ident> {
if let Type::Path(type_path) = ty {
if type_path.qself.is_none() {
let mut segments = type_path.path.segments.iter();
if let Some(segment) = segments.next() {
if segments.next().is_none()
&& (segment.ident == "f32" || segment.ident == "f64")
{
return Some(&segment.ident);
}
}
}
}
None
}
}
fn emit_plain_method(arg: &CoeffExpr, scalar_ty: Option<&Type>, method: &str) -> SynExpr {
let receiver = CoeffExpr::to_syn(arg, scalar_ty);
method_call_expr(receiver, method)
}
fn emit_odd_method(arg: &CoeffExpr, scalar_ty: Option<&Type>, method: &str) -> SynExpr {
if let CoeffExprKind::Neg(inner) = arg.kind() {
let call = method_call_expr(CoeffExpr::to_syn(inner.as_ref(), scalar_ty), method);
negate_expr(call)
} else {
emit_plain_method(arg, scalar_ty, method)
}
}
fn emit_even_method(arg: &CoeffExpr, scalar_ty: Option<&Type>, method: &str) -> SynExpr {
if let CoeffExprKind::Neg(inner) = arg.kind() {
method_call_expr(CoeffExpr::to_syn(inner.as_ref(), scalar_ty), method)
} else {
emit_plain_method(arg, scalar_ty, method)
}
}
fn negate_expr(expr: SynExpr) -> SynExpr {
SynExpr::Unary(ExprUnary {
attrs: Vec::new(),
op: UnOp::Neg(Default::default()),
expr: Box::new(expr),
})
}
fn method_call_expr(receiver: SynExpr, method: &str) -> SynExpr {
let method_ident = Ident::new(method, Span::call_site());
let receiver = wrap_method_receiver(receiver);
Expr::MethodCall(ExprMethodCall {
attrs: Vec::new(),
receiver: Box::new(receiver),
dot_token: Default::default(),
method: method_ident,
turbofish: None,
paren_token: Default::default(),
args: Punctuated::new(),
})
}
fn method_call_expr_with_args(
receiver: SynExpr,
method: &str,
args: Vec<SynExpr>,
_scalar_ty: Option<&Type>,
) -> SynExpr {
let method_ident = Ident::new(method, Span::call_site());
let mut punct = Punctuated::<SynExpr, Comma>::new();
for arg in args.into_iter() {
punct.push(arg);
}
let wrapped_receiver = wrap_method_receiver(receiver);
Expr::MethodCall(ExprMethodCall {
attrs: Vec::new(),
receiver: Box::new(wrapped_receiver),
dot_token: Default::default(),
method: method_ident,
turbofish: None,
paren_token: Default::default(),
args: punct,
})
}
fn wrap_method_receiver(expr: SynExpr) -> SynExpr {
match expr {
SynExpr::Paren(_) => expr,
other => SynExpr::Paren(ExprParen {
attrs: Vec::new(),
paren_token: Default::default(),
expr: Box::new(other),
}),
}
}
fn wrap_unary_operand(expr: SynExpr) -> SynExpr {
match expr {
SynExpr::Paren(_) => expr,
expr @ SynExpr::MethodCall(_) | expr @ SynExpr::Binary(_) => SynExpr::Paren(ExprParen {
attrs: Vec::new(),
paren_token: Default::default(),
expr: Box::new(expr),
}),
other => other,
}
}
impl fmt::Display for CoeffExpr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.kind {
CoeffExprKind::Literal(r) => write!(f, "{r}"),
CoeffExprKind::Symbol(name) => write!(f, "{name}"),
CoeffExprKind::Sum(terms) => {
let mut iter = terms.iter();
if let Some(first) = iter.next() {
write!(f, "{first}")?;
for term in iter {
write!(f, ".add({term}")?;
}
write!(f, ")")
} else {
write!(f, "0")
}
}
CoeffExprKind::Product(factors) => {
let mut iter = factors.iter();
if let Some(first) = iter.next() {
write!(f, "{first}")?;
for term in iter {
write!(f, ".mul({term}")?;
}
write!(f, ")")
} else {
write!(f, "1")
}
}
CoeffExprKind::Wedge(factors) => {
let mut iter = factors.iter();
if let Some(first) = iter.next() {
write!(f, "{first}")?;
for term in iter {
write!(f, ".wedge({term}")?;
}
write!(f, ")")
} else {
write!(f, "0")
}
}
CoeffExprKind::Neg(inner) => write!(f, "{inner}.neg()"),
CoeffExprKind::Quotient(num, denom) => write!(f, "{num}.div({denom})"),
CoeffExprKind::Sin(arg) => write!(f, "sin({arg})"),
CoeffExprKind::Cos(arg) => write!(f, "cos({arg})"),
CoeffExprKind::Tan(arg) => write!(f, "tan({arg})"),
CoeffExprKind::Exp(arg) => write!(f, "exp({arg})"),
CoeffExprKind::Ln(arg) => write!(f, "ln({arg})"),
CoeffExprKind::Sqrt(arg) => write!(f, "sqrt({arg})"),
CoeffExprKind::Sinh(arg) => write!(f, "sinh({arg})"),
CoeffExprKind::Cosh(arg) => write!(f, "cosh({arg})"),
CoeffExprKind::Tanh(arg) => write!(f, "tanh({arg})"),
CoeffExprKind::Asin(arg) => write!(f, "asin({arg})"),
CoeffExprKind::Acos(arg) => write!(f, "acos({arg})"),
CoeffExprKind::Atan(arg) => write!(f, "atan({arg})"),
CoeffExprKind::Asinh(arg) => write!(f, "asinh({arg})"),
CoeffExprKind::Acosh(arg) => write!(f, "acosh({arg})"),
CoeffExprKind::Atanh(arg) => write!(f, "atanh({arg})"),
CoeffExprKind::Pow(base, exp) => write!(f, "{base} ^ {exp}"),
CoeffExprKind::MulAdd(lhs, rhs, addend) => write!(f, "{lhs}.mul({rhs}).add({addend})"),
CoeffExprKind::Opaque(_) => write!(f, "<opaque>"), }
}
}