use std::ops::{Add, Div, Mul, Neg};
use rand::{Rng, RngExt};
use crate::{
Expression, Type, TypeError, Val,
grammar::{BooleanExpr, IntegerExpr, NaturalExpr},
};
pub type Float = f64;
#[derive(Debug, Clone)]
pub enum FloatExpr<V>
where
V: Clone,
{
Const(Float),
Var(V),
Nat(NaturalExpr<V>),
Int(IntegerExpr<V>),
Rand(Box<(FloatExpr<V>, FloatExpr<V>)>),
Opposite(Box<FloatExpr<V>>),
Sum(Vec<FloatExpr<V>>),
Product(Vec<FloatExpr<V>>),
Div(Box<(FloatExpr<V>, FloatExpr<V>)>),
Ite(Box<(BooleanExpr<V>, FloatExpr<V>, FloatExpr<V>)>),
}
impl<V> FloatExpr<V>
where
V: Clone,
{
pub fn is_constant(&self) -> bool {
match self {
FloatExpr::Const(_) => true,
FloatExpr::Var(_) | FloatExpr::Rand(_) => false,
FloatExpr::Nat(natural_expr) => natural_expr.is_constant(),
FloatExpr::Int(integer_expr) => integer_expr.is_constant(),
FloatExpr::Opposite(float_expr) => float_expr.is_constant(),
FloatExpr::Sum(float_exprs) | FloatExpr::Product(float_exprs) => {
float_exprs.iter().all(FloatExpr::is_constant)
}
FloatExpr::Div(args) => {
let (lhs, rhs) = args.as_ref();
lhs.is_constant() && rhs.is_constant()
}
FloatExpr::Ite(args) => {
let (ite, lhs, rhs) = args.as_ref();
ite.is_constant() && lhs.is_constant() && rhs.is_constant()
}
}
}
pub fn eval<R: Rng>(&self, vars: &dyn Fn(&V) -> Val, rng: &mut R) -> Float {
match self {
FloatExpr::Const(float) => *float,
FloatExpr::Var(var) => {
if let Val::Float(float) = vars(var) {
float
} else {
panic!("type mismatch: expected float variable")
}
}
FloatExpr::Nat(natural_expr) => natural_expr.eval(vars, rng) as f64,
FloatExpr::Int(integer_expr) => integer_expr.eval(vars, rng) as f64,
FloatExpr::Rand(bounds) => {
let (lower_bound_expr, upper_bound_expr) = bounds.as_ref();
let lower_bound = lower_bound_expr.eval(vars, rng);
let upper_bound = upper_bound_expr.eval(vars, rng);
rng.random_range(lower_bound..upper_bound)
}
FloatExpr::Opposite(float_expr) => -float_expr.eval(vars, rng),
FloatExpr::Sum(float_exprs) => {
float_exprs.iter().map(|expr| expr.eval(vars, rng)).sum()
}
FloatExpr::Product(float_exprs) => float_exprs
.iter()
.map(|expr| expr.eval(vars, rng))
.product(),
FloatExpr::Div(args) => {
let (lhs_expr, rhs_expr) = args.as_ref();
let lhs = lhs_expr.eval(vars, rng);
let rhs = rhs_expr.eval(vars, rng);
lhs / rhs
}
FloatExpr::Ite(args) => {
let (ite, lhs, rhs) = args.as_ref();
if ite.eval(vars, rng) {
lhs.eval(vars, rng)
} else {
rhs.eval(vars, rng)
}
}
}
}
pub(crate) fn map<W: Clone>(self, map: &dyn Fn(V) -> W) -> FloatExpr<W> {
match self {
FloatExpr::Const(float) => FloatExpr::Const(float),
FloatExpr::Var(var) => FloatExpr::Var(map(var)),
FloatExpr::Nat(natural_expr) => FloatExpr::Nat(natural_expr.map(map)),
FloatExpr::Int(integer_expr) => FloatExpr::Int(integer_expr.map(map)),
FloatExpr::Rand(bounds) => {
let (lower_bound, upper_bound) = *bounds;
FloatExpr::Rand(Box::new((lower_bound.map(map), upper_bound.map(map))))
}
FloatExpr::Opposite(float_expr) => FloatExpr::Opposite(Box::new(float_expr.map(map))),
FloatExpr::Sum(float_exprs) => {
FloatExpr::Sum(float_exprs.into_iter().map(|expr| expr.map(map)).collect())
}
FloatExpr::Product(float_exprs) => {
FloatExpr::Product(float_exprs.into_iter().map(|expr| expr.map(map)).collect())
}
FloatExpr::Div(args) => {
let (lhs, rhs) = *args;
FloatExpr::Div(Box::new((lhs.map(map), rhs.map(map))))
}
FloatExpr::Ite(args) => {
let (r#if, then, r#else) = *args;
FloatExpr::Ite(Box::new((r#if.map(map), then.map(map), r#else.map(map))))
}
}
}
pub(crate) fn context(&self, vars: &dyn Fn(V) -> Option<Type>) -> Result<(), TypeError> {
match self {
FloatExpr::Const(_) => Ok(()),
FloatExpr::Var(v) => matches!(vars(v.clone()), Some(Type::Float))
.then_some(())
.ok_or(TypeError::TypeMismatch),
FloatExpr::Nat(natural_expr) => natural_expr.context(vars),
FloatExpr::Int(integer_expr) => integer_expr.context(vars),
FloatExpr::Rand(exprs) | FloatExpr::Div(exprs) => {
exprs.0.context(vars).and_then(|()| exprs.1.context(vars))
}
FloatExpr::Opposite(integer_expr) => integer_expr.context(vars),
FloatExpr::Sum(integer_exprs) | FloatExpr::Product(integer_exprs) => {
integer_exprs.iter().try_for_each(|expr| expr.context(vars))
}
FloatExpr::Ite(exprs) => exprs
.0
.context(vars)
.and_then(|()| exprs.1.context(vars))
.and_then(|()| exprs.2.context(vars)),
}
}
}
impl<V> From<Float> for FloatExpr<V>
where
V: Clone,
{
fn from(value: Float) -> Self {
Self::Const(value)
}
}
impl<V> TryFrom<Expression<V>> for FloatExpr<V>
where
V: Clone,
{
type Error = TypeError;
fn try_from(value: Expression<V>) -> Result<Self, Self::Error> {
match value {
Expression::Boolean(_) => Err(TypeError::TypeMismatch),
Expression::Natural(natural_expr) => Ok(FloatExpr::Nat(natural_expr)),
Expression::Integer(integer_expr) => Ok(FloatExpr::Int(integer_expr)),
Expression::Float(float_expr) => Ok(float_expr),
}
}
}
impl<V> From<NaturalExpr<V>> for FloatExpr<V>
where
V: Clone,
{
fn from(value: NaturalExpr<V>) -> Self {
Self::Nat(value)
}
}
impl<V> From<IntegerExpr<V>> for FloatExpr<V>
where
V: Clone,
{
fn from(value: IntegerExpr<V>) -> Self {
if let IntegerExpr::Nat(nat_expr) = value {
Self::Nat(nat_expr)
} else {
Self::Int(value)
}
}
}
impl<V> Add for FloatExpr<V>
where
V: Clone,
{
type Output = Self;
fn add(mut self, mut rhs: Self) -> Self::Output {
if let FloatExpr::Sum(ref mut exprs) = self {
if let FloatExpr::Sum(rhs_exprs) = rhs {
exprs.extend(rhs_exprs);
} else {
exprs.push(rhs);
}
self
} else if let FloatExpr::Sum(ref mut rhs_exprs) = rhs {
rhs_exprs.push(self);
rhs
} else {
FloatExpr::Sum(vec![self, rhs])
}
}
}
impl<V> Mul for FloatExpr<V>
where
V: Clone,
{
type Output = Self;
fn mul(mut self, mut rhs: Self) -> Self::Output {
if let FloatExpr::Product(ref mut exprs) = self {
if let FloatExpr::Product(rhs_exprs) = rhs {
exprs.extend(rhs_exprs);
} else {
exprs.push(rhs);
}
self
} else if let FloatExpr::Product(ref mut rhs_exprs) = rhs {
rhs_exprs.push(self);
rhs
} else {
FloatExpr::Product(vec![self, rhs])
}
}
}
impl<V> Neg for FloatExpr<V>
where
V: Clone,
{
type Output = Self;
fn neg(self) -> Self::Output {
if let FloatExpr::Opposite(expr) = self {
*expr
} else {
FloatExpr::Opposite(Box::new(self))
}
}
}
impl<V> Div for FloatExpr<V>
where
V: Clone,
{
type Output = Self;
fn div(self, rhs: Self) -> Self::Output {
FloatExpr::Div(Box::new((self, rhs)))
}
}