use super::CoefficientField;
use super::coeff_expr::{CoeffExpr, CoeffExprKind};
use abstalg::{
AbelianGroup, CommuntativeMonoid, Domain, Field, Monoid, SemiRing, Semigroup, UnitaryRing,
};
use cas_compute::primitive::int;
use cas_compute::symbolic::{expr::Primary, expr::SymExpr, simplify};
use num_bigint::BigInt;
use num_rational::BigRational;
use num_traits::Zero;
use proc_macro2::Span;
use quote::{ToTokens, quote};
use rug::{Float, Integer};
use std::{
fmt,
panic::{self, AssertUnwindSafe},
};
use syn::{
Expr as SynExpr, ExprCall, ExprLit, ExprMethodCall, ExprParen, ExprUnary, Ident, Lit, LitFloat,
LitInt, Type, UnOp, parse_quote, punctuated::Punctuated, token::Comma,
};
#[derive(Clone, Debug)]
pub struct CasRsExpr {
expr: SymExpr,
scalar_ty: Type,
}
#[derive(Clone, Debug)]
pub struct CasRsField {
scalar_ty: Type,
}
impl CasRsField {
pub fn new(scalar_ty: Type) -> Self {
Self { scalar_ty }
}
pub fn wrap_expr(&self, expr: SynExpr) -> CasRsExpr {
let coeff = CoeffExpr::from_syn(&expr);
let sym = coeff_expr_to_sym(&coeff);
self.wrap_sym(sym)
}
pub fn wrap_numeric_expr(&self, expr: SynExpr) -> CasRsExpr {
self.wrap_expr(expr)
}
fn wrap_sym(&self, expr: SymExpr) -> CasRsExpr {
let simplified = panic::catch_unwind(AssertUnwindSafe(|| simplify(&expr)))
.unwrap_or_else(|_| expr.clone());
CasRsExpr {
expr: simplified,
scalar_ty: self.scalar_ty.clone(),
}
}
fn zero_literal(&self) -> CasRsExpr {
self.wrap_sym(SymExpr::Primary(Primary::Integer(int(0))))
}
fn one_literal(&self) -> CasRsExpr {
self.wrap_sym(SymExpr::Primary(Primary::Integer(int(1))))
}
fn make_literal(&self, value: i64) -> CasRsExpr {
self.wrap_sym(SymExpr::Primary(Primary::Integer(int(value))))
}
fn add_inner(&self, lhs: &CasRsExpr, rhs: &CasRsExpr) -> CasRsExpr {
self.wrap_sym(lhs.expr.clone() + rhs.expr.clone())
}
fn mul_inner(&self, lhs: &CasRsExpr, rhs: &CasRsExpr) -> CasRsExpr {
self.wrap_sym(lhs.expr.clone() * rhs.expr.clone())
}
fn neg_inner(&self, elem: &CasRsExpr) -> CasRsExpr {
self.wrap_sym(-elem.expr.clone())
}
fn repeat_mul(&self, factor: i128, elem: &CasRsExpr) -> CasRsExpr {
if factor == 0 {
return self.zero_literal();
}
let coeff = SymExpr::Primary(Primary::Integer(int(factor as i64)));
self.wrap_sym(coeff * elem.expr.clone())
}
fn is_zero_expr(expr: &SymExpr) -> bool {
matches!(expr, SymExpr::Primary(Primary::Integer(value)) if value == &int(0))
}
}
impl CasRsExpr {
fn to_syn(&self) -> syn::Result<SynExpr> {
sym_expr_to_syn(&self.expr, &self.scalar_ty)
}
fn is_zero(&self) -> bool {
CasRsField::is_zero_expr(&self.expr)
}
}
fn coeff_expr_to_sym(expr: &CoeffExpr) -> SymExpr {
match expr.kind() {
CoeffExprKind::Literal(lit) => rational_to_sym(lit),
CoeffExprKind::Symbol(sym) => SymExpr::Primary(Primary::Symbol(sym.clone())),
CoeffExprKind::Sum(terms) => {
let mut iter = terms.iter().map(coeff_expr_to_sym);
if let Some(first) = iter.next() {
iter.fold(first, |acc, term| acc + term)
} else {
SymExpr::Primary(Primary::Integer(int(0)))
}
}
CoeffExprKind::Product(factors) => {
let mut iter = factors.iter().map(coeff_expr_to_sym);
if let Some(first) = iter.next() {
iter.fold(first, |acc, factor| acc * factor)
} else {
SymExpr::Primary(Primary::Integer(int(1)))
}
}
CoeffExprKind::Neg(inner) => -coeff_expr_to_sym(inner),
CoeffExprKind::Quotient(numerator, denominator) => {
let num = coeff_expr_to_sym(numerator);
let denom = coeff_expr_to_sym(denominator);
num * SymExpr::Exp(
Box::new(denom),
Box::new(SymExpr::Primary(Primary::Integer(int(-1)))),
)
}
CoeffExprKind::Pow(base, exponent) => {
let base_expr = coeff_expr_to_sym(base);
let exp_expr = SymExpr::Primary(Primary::Integer(int(*exponent as i64)));
SymExpr::Exp(Box::new(base_expr), Box::new(exp_expr))
}
CoeffExprKind::Sin(arg) => call_sym("sin", vec![coeff_expr_to_sym(arg)]),
CoeffExprKind::Cos(arg) => call_sym("cos", vec![coeff_expr_to_sym(arg)]),
CoeffExprKind::Tan(arg) => call_sym("tan", vec![coeff_expr_to_sym(arg)]),
CoeffExprKind::Sinh(arg) => call_sym("sinh", vec![coeff_expr_to_sym(arg)]),
CoeffExprKind::Cosh(arg) => call_sym("cosh", vec![coeff_expr_to_sym(arg)]),
CoeffExprKind::Tanh(arg) => call_sym("tanh", vec![coeff_expr_to_sym(arg)]),
CoeffExprKind::Asin(arg) => call_sym("asin", vec![coeff_expr_to_sym(arg)]),
CoeffExprKind::Acos(arg) => call_sym("acos", vec![coeff_expr_to_sym(arg)]),
CoeffExprKind::Atan(arg) => call_sym("atan", vec![coeff_expr_to_sym(arg)]),
CoeffExprKind::Asinh(arg) => call_sym("asinh", vec![coeff_expr_to_sym(arg)]),
CoeffExprKind::Acosh(arg) => call_sym("acosh", vec![coeff_expr_to_sym(arg)]),
CoeffExprKind::Atanh(arg) => call_sym("atanh", vec![coeff_expr_to_sym(arg)]),
CoeffExprKind::Exp(arg) => call_sym("exp", vec![coeff_expr_to_sym(arg)]),
CoeffExprKind::Ln(arg) => call_sym("ln", vec![coeff_expr_to_sym(arg)]),
CoeffExprKind::Sqrt(arg) => coeff_expr_to_sym(arg).sqrt(),
CoeffExprKind::MulAdd(lhs, rhs, addend) => {
let mul = coeff_expr_to_sym(lhs) * coeff_expr_to_sym(rhs);
mul + coeff_expr_to_sym(addend)
}
CoeffExprKind::Wedge(args) => {
let converted: Vec<_> = args.iter().map(coeff_expr_to_sym).collect();
call_sym("wedge", converted)
}
CoeffExprKind::Opaque(expr) => {
SymExpr::Primary(Primary::Symbol(expr.to_token_stream().to_string()))
}
}
}
fn call_sym(name: &str, args: Vec<SymExpr>) -> SymExpr {
SymExpr::Primary(Primary::Call(name.to_string(), args))
}
fn rational_to_sym(r: &BigRational) -> SymExpr {
if r.is_zero() {
return SymExpr::Primary(Primary::Integer(int(0)));
}
let numer = bigint_to_integer(r.numer());
let denom = bigint_to_integer(r.denom());
let numer_expr = SymExpr::Primary(Primary::Integer(numer));
if denom == int(1) {
numer_expr
} else {
numer_expr
* SymExpr::Exp(
Box::new(SymExpr::Primary(Primary::Integer(denom))),
Box::new(SymExpr::Primary(Primary::Integer(int(-1)))),
)
}
}
fn bigint_to_integer(value: &BigInt) -> Integer {
Integer::from_str_radix(&value.to_string(), 10).expect("invalid bigint conversion")
}
fn sym_expr_to_syn(expr: &SymExpr, scalar_ty: &Type) -> syn::Result<SynExpr> {
match expr {
SymExpr::Primary(primary) => primary_to_syn(primary, scalar_ty),
SymExpr::Add(terms) => add_terms_to_syn(terms, scalar_ty),
SymExpr::Mul(factors) => mul_terms_to_syn(factors, scalar_ty),
SymExpr::Exp(base, exponent) => exp_to_syn(base, exponent, scalar_ty),
}
}
fn primary_to_syn(primary: &Primary, scalar_ty: &Type) -> syn::Result<SynExpr> {
Ok(match primary {
Primary::Integer(value) => integer_to_syn(value, scalar_ty),
Primary::Float(value) => float_to_syn(value, scalar_ty),
Primary::Symbol(sym) => syn::parse_str(sym)?,
Primary::Call(name, args) => call_primary_to_syn(name, args, scalar_ty)?,
})
}
fn call_primary_to_syn(name: &str, args: &[SymExpr], scalar_ty: &Type) -> syn::Result<SynExpr> {
if let Some(method_name) = method_call_name(name) {
if let Some((first, rest)) = args.split_first() {
let receiver = wrap_for_method(sym_expr_to_syn(first, scalar_ty)?);
let method_ident = Ident::new(method_name, Span::call_site());
let mut call_args = Punctuated::<SynExpr, Comma>::new();
for arg in rest {
call_args.push(sym_expr_to_syn(arg, scalar_ty)?);
}
return Ok(SynExpr::MethodCall(ExprMethodCall {
attrs: vec![],
receiver: Box::new(receiver),
dot_token: Default::default(),
method: method_ident,
turbofish: None,
paren_token: Default::default(),
args: call_args,
}));
}
}
let func: SynExpr = syn::parse_str(name)?;
let mut call_args = Punctuated::<SynExpr, Comma>::new();
for arg in args {
call_args.push(sym_expr_to_syn(arg, scalar_ty)?);
}
Ok(SynExpr::Call(ExprCall {
attrs: vec![],
func: Box::new(func),
paren_token: Default::default(),
args: call_args,
}))
}
fn method_call_name(name: &str) -> Option<&str> {
match name {
"sin" | "cos" | "tan" | "sinh" | "cosh" | "tanh" | "asin" | "acos" | "atan" | "asinh"
| "acosh" | "atanh" | "exp" | "ln" | "sqrt" => Some(name),
_ => None,
}
}
fn add_terms_to_syn(terms: &[SymExpr], scalar_ty: &Type) -> syn::Result<SynExpr> {
let mut iter = terms.iter();
let mut acc = match iter.next() {
Some(expr) => sym_expr_to_syn(expr, scalar_ty)?,
None => return Ok(zero_literal_syn(scalar_ty)),
};
for term in iter {
let rhs = sym_expr_to_syn(term, scalar_ty)?;
acc = method_call_expr(acc, "add", vec![rhs]);
}
Ok(acc)
}
fn mul_terms_to_syn(factors: &[SymExpr], scalar_ty: &Type) -> syn::Result<SynExpr> {
let mut iter = factors.iter();
let mut acc = match iter.next() {
Some(expr) => sym_expr_to_syn(expr, scalar_ty)?,
None => return Ok(one_literal_syn(scalar_ty)),
};
for factor in iter {
let rhs = sym_expr_to_syn(factor, scalar_ty)?;
acc = method_call_expr(acc, "mul", vec![rhs]);
}
Ok(acc)
}
fn exp_to_syn(base: &SymExpr, exponent: &SymExpr, scalar_ty: &Type) -> syn::Result<SynExpr> {
let base_expr = wrap_for_method(sym_expr_to_syn(base, scalar_ty)?);
if let Some(exp_int) = exponent_as_i32(exponent) {
match scalar_kind(scalar_ty) {
ScalarKind::Float(_) => {
let lit = LitInt::new(&exp_int.to_string(), Span::call_site());
return Ok(parse_quote! { #base_expr.powi(#lit) });
}
ScalarKind::Other if exp_int >= 0 => {
let lit = LitInt::new(&(exp_int as u32).to_string(), Span::call_site());
return Ok(parse_quote! { #base_expr.pow(#lit) });
}
_ => {}
}
}
let exp_expr = sym_expr_to_syn(exponent, scalar_ty)?;
match scalar_kind(scalar_ty) {
ScalarKind::Float(ident) => Ok(parse_quote! { #base_expr.powf(#exp_expr as #ident) }),
ScalarKind::Other => Ok(parse_quote! { #base_expr.pow(#exp_expr) }),
}
}
fn wrap_for_method(expr: SynExpr) -> SynExpr {
match expr {
SynExpr::Paren(_) => expr,
other => SynExpr::Paren(ExprParen {
attrs: vec![],
paren_token: Default::default(),
expr: Box::new(other),
}),
}
}
fn method_call_expr(receiver: SynExpr, method: &str, args: Vec<SynExpr>) -> SynExpr {
let receiver = wrap_for_method(receiver);
let method_ident = Ident::new(method, Span::call_site());
let mut punct = Punctuated::<SynExpr, Comma>::new();
for arg in args {
punct.push(arg);
}
SynExpr::MethodCall(ExprMethodCall {
attrs: vec![],
receiver: Box::new(receiver),
dot_token: Default::default(),
method: method_ident,
turbofish: None,
paren_token: Default::default(),
args: punct,
})
}
fn integer_to_syn(value: &Integer, scalar_ty: &Type) -> SynExpr {
if *value == 0 {
cast_literal_if_needed(literal_from_int("0"), scalar_ty)
} else if *value < 0 {
let magnitude = (-value.clone()).to_string();
let negated = SynExpr::Unary(ExprUnary {
attrs: vec![],
op: UnOp::Neg(Default::default()),
expr: Box::new(literal_from_int(&magnitude)),
});
cast_literal_if_needed(negated, scalar_ty)
} else {
cast_literal_if_needed(literal_from_int(&value.to_string()), scalar_ty)
}
}
fn cast_literal_if_needed(expr: SynExpr, scalar_ty: &Type) -> SynExpr {
match scalar_kind(scalar_ty) {
ScalarKind::Float(ident) => parse_quote! { #expr as #ident },
ScalarKind::Other => expr,
}
}
fn literal_from_int(digits: &str) -> SynExpr {
let literal = LitInt::new(digits, Span::call_site());
SynExpr::Lit(ExprLit {
attrs: vec![],
lit: Lit::Int(literal),
})
}
fn float_to_syn(value: &Float, scalar_ty: &Type) -> SynExpr {
let formatted = value.to_f64();
let literal = LitFloat::new(&format!("{formatted:.12}"), Span::call_site());
let expr = SynExpr::Lit(ExprLit {
attrs: vec![],
lit: Lit::Float(literal),
});
match scalar_kind(scalar_ty) {
ScalarKind::Float(ident) => parse_quote! { #expr as #ident },
ScalarKind::Other => expr,
}
}
fn exponent_as_i32(expr: &SymExpr) -> Option<i32> {
if let SymExpr::Primary(Primary::Integer(value)) = expr {
value.to_i32()
} else {
None
}
}
fn zero_literal_syn(_ty: &Type) -> SynExpr {
literal_from_int("0")
}
fn one_literal_syn(_ty: &Type) -> SynExpr {
literal_from_int("1")
}
#[derive(Clone, Copy)]
enum ScalarKind<'a> {
Float(&'a Ident),
Other,
}
fn scalar_kind(ty: &Type) -> ScalarKind<'_> {
if let Type::Path(type_path) = ty {
if type_path.qself.is_none() {
let mut segments = type_path.path.segments.iter();
if let Some(segment) = segments.next() {
if segments.next().is_none() && (segment.ident == "f32" || segment.ident == "f64") {
return ScalarKind::Float(&segment.ident);
}
}
}
}
ScalarKind::Other
}
impl fmt::Display for CasRsExpr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.to_syn() {
Ok(expr) => write!(f, "{}", quote! { #expr }),
Err(_) => write!(f, "<invalid expr>"),
}
}
}
impl ToTokens for CasRsExpr {
fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
if let Ok(expr) = self.to_syn() {
expr.to_tokens(tokens);
}
}
}
impl Domain for CasRsField {
type Elem = CasRsExpr;
fn equals(&self, elem1: &Self::Elem, elem2: &Self::Elem) -> bool {
let diff_expr = elem1.expr.clone() + (-elem2.expr.clone());
let diff =
panic::catch_unwind(AssertUnwindSafe(|| simplify(&diff_expr))).unwrap_or(diff_expr);
if CasRsField::is_zero_expr(&diff) {
return true;
}
if let (Ok(lhs), Ok(rhs)) = (elem1.to_syn(), elem2.to_syn()) {
let lhs_tokens = quote! { #lhs }.to_string();
let rhs_tokens = quote! { #rhs }.to_string();
return lhs_tokens == rhs_tokens;
}
false
}
fn contains(&self, _elem: &Self::Elem) -> bool {
true
}
}
impl CommuntativeMonoid for CasRsField {
fn zero(&self) -> Self::Elem {
self.zero_literal()
}
fn add(&self, elem1: &Self::Elem, elem2: &Self::Elem) -> Self::Elem {
self.add_inner(elem1, elem2)
}
fn is_zero(&self, elem: &Self::Elem) -> bool {
elem.is_zero()
}
fn add_assign(&self, elem1: &mut Self::Elem, elem2: &Self::Elem) {
*elem1 = self.add(elem1, elem2);
}
fn double(&self, elem: &mut Self::Elem) {
*elem = self.add(elem, elem);
}
fn times(&self, num: usize, elem: &Self::Elem) -> Self::Elem {
self.repeat_mul(num as i128, elem)
}
}
impl AbelianGroup for CasRsField {
fn neg(&self, elem: &Self::Elem) -> Self::Elem {
self.neg_inner(elem)
}
fn neg_assign(&self, elem: &mut Self::Elem) {
*elem = self.neg(elem);
}
fn sub(&self, elem1: &Self::Elem, elem2: &Self::Elem) -> Self::Elem {
self.add_inner(elem1, &self.neg(elem2))
}
fn sub_assign(&self, elem1: &mut Self::Elem, elem2: &Self::Elem) {
*elem1 = self.sub(elem1, elem2);
}
fn times(&self, num: isize, elem: &Self::Elem) -> Self::Elem {
self.repeat_mul(num as i128, elem)
}
}
impl Semigroup for CasRsField {
fn mul(&self, elem1: &Self::Elem, elem2: &Self::Elem) -> Self::Elem {
self.mul_inner(elem1, elem2)
}
fn mul_assign(&self, elem1: &mut Self::Elem, elem2: &Self::Elem) {
*elem1 = self.mul(elem1, elem2);
}
fn square(&self, elem: &mut Self::Elem) {
*elem = self.mul(elem, elem);
}
}
impl Monoid for CasRsField {
fn one(&self) -> Self::Elem {
self.one_literal()
}
fn try_inv(&self, elem: &Self::Elem) -> Option<Self::Elem> {
if elem.is_zero() {
None
} else {
Some(self.wrap_sym(SymExpr::Exp(
Box::new(elem.expr.clone()),
Box::new(SymExpr::Primary(Primary::Integer(int(-1)))),
)))
}
}
}
impl SemiRing for CasRsField {}
impl UnitaryRing for CasRsField {
fn int(&self, elem: isize) -> Self::Elem {
self.make_literal(elem as i64)
}
}
impl Field for CasRsField {
fn inv(&self, elem: &Self::Elem) -> Self::Elem {
self.try_inv(elem)
.expect("attempted to invert zero element in CasRsField")
}
fn div(&self, elem1: &Self::Elem, elem2: &Self::Elem) -> Self::Elem {
if elem2.is_zero() {
panic!("attempted to divide by zero element in CasRsField");
}
self.wrap_sym(
elem1.expr.clone()
* SymExpr::Exp(
Box::new(elem2.expr.clone()),
Box::new(SymExpr::Primary(Primary::Integer(int(-1)))),
),
)
}
}
impl CoefficientField for CasRsField {
fn embed_expr(&self, expr: SynExpr) -> syn::Result<<Self as Domain>::Elem> {
Ok(self.wrap_expr(expr))
}
fn to_expr(&self, elem: &<Self as Domain>::Elem) -> syn::Result<SynExpr> {
elem.to_syn()
}
}
#[cfg(test)]
mod tests {
use super::*;
use quote::quote;
use syn::parse_quote;
#[test]
fn numeric_equality_holds() {
let field = CasRsField::new(parse_quote!(f64));
let lhs = field.embed_expr(parse_quote!(2.0)).unwrap();
let rhs = field.embed_expr(parse_quote!(1.0 + 1.0)).unwrap();
assert!(field.equals(&lhs, &rhs));
}
#[test]
fn division_produces_expected_result() {
let field = CasRsField::new(parse_quote!(f32));
let numerator = field.embed_expr(parse_quote!(6.0)).unwrap();
let denominator = field.embed_expr(parse_quote!(3.0)).unwrap();
let quotient = field.div(&numerator, &denominator);
let expected = field.embed_expr(parse_quote!(2.0)).unwrap();
assert!(field.equals("ient, &expected));
}
#[test]
fn to_expr_roundtrips_simple_expression() {
let field = CasRsField::new(parse_quote!(f32));
let elem = field.embed_expr(parse_quote!((x + x) * (y - y))).unwrap();
let expr = field.to_expr(&elem).unwrap();
let tokens = quote! { #expr }.to_string();
assert!(
tokens == "0" || tokens == "0 as f32",
"unexpected tokens: {tokens}"
);
}
}