use num_traits::NumOps;
use std::{fmt, str::FromStr};
use crate::{
error::{ErrorKind, ErrorLocation, OpErrors},
PrimitiveType, Type,
};
use arithmetic_parser::{BinaryOp, UnaryOp};
mod constraints;
mod substitutions;
pub(crate) use self::constraints::CompleteConstraints;
pub use self::constraints::{
Constraint, ConstraintSet, LinearType, Linearity, ObjectSafeConstraint, Ops, StructConstraint,
};
pub use self::substitutions::Substitutions;
pub trait MapPrimitiveType<Val> {
type Prim: PrimitiveType;
fn type_of_literal(&self, lit: &Val) -> Self::Prim;
}
pub trait TypeArithmetic<Prim: PrimitiveType> {
fn process_unary_op(
&self,
substitutions: &mut Substitutions<Prim>,
context: &UnaryOpContext<Prim>,
errors: OpErrors<'_, Prim>,
) -> Type<Prim>;
fn process_binary_op(
&self,
substitutions: &mut Substitutions<Prim>,
context: &BinaryOpContext<Prim>,
errors: OpErrors<'_, Prim>,
) -> Type<Prim>;
}
#[derive(Debug, Clone)]
pub struct UnaryOpContext<Prim: PrimitiveType> {
pub op: UnaryOp,
pub arg: Type<Prim>,
}
#[derive(Debug, Clone)]
pub struct BinaryOpContext<Prim: PrimitiveType> {
pub op: BinaryOp,
pub lhs: Type<Prim>,
pub rhs: Type<Prim>,
}
pub trait WithBoolean: PrimitiveType {
const BOOL: Self;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct BoolArithmetic;
impl<Prim: WithBoolean> TypeArithmetic<Prim> for BoolArithmetic {
fn process_unary_op<'a>(
&self,
substitutions: &mut Substitutions<Prim>,
context: &UnaryOpContext<Prim>,
mut errors: OpErrors<'_, Prim>,
) -> Type<Prim> {
let op = context.op;
if op == UnaryOp::Not {
substitutions.unify(&Type::BOOL, &context.arg, errors);
Type::BOOL
} else {
let err = ErrorKind::unsupported(op);
errors.push(err);
substitutions.new_type_var()
}
}
fn process_binary_op(
&self,
substitutions: &mut Substitutions<Prim>,
context: &BinaryOpContext<Prim>,
mut errors: OpErrors<'_, Prim>,
) -> Type<Prim> {
match context.op {
BinaryOp::Eq | BinaryOp::NotEq => {
substitutions.unify(&context.lhs, &context.rhs, errors);
Type::BOOL
}
BinaryOp::And | BinaryOp::Or => {
substitutions.unify(
&Type::BOOL,
&context.lhs,
errors.with_location(ErrorLocation::Lhs),
);
substitutions.unify(
&Type::BOOL,
&context.rhs,
errors.with_location(ErrorLocation::Rhs),
);
Type::BOOL
}
_ => {
errors.push(ErrorKind::unsupported(context.op));
substitutions.new_type_var()
}
}
}
}
#[derive(Debug)]
pub struct OpConstraintSettings<'a, Prim: PrimitiveType> {
pub lin: &'a dyn Constraint<Prim>,
pub ops: &'a dyn Constraint<Prim>,
}
impl<Prim: PrimitiveType> Clone for OpConstraintSettings<'_, Prim> {
fn clone(&self) -> Self {
Self {
lin: self.lin,
ops: self.ops,
}
}
}
impl<Prim: PrimitiveType> Copy for OpConstraintSettings<'_, Prim> {}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Num {
Num,
Bool,
}
impl fmt::Display for Num {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str(match self {
Self::Num => "Num",
Self::Bool => "Bool",
})
}
}
impl FromStr for Num {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"Num" => Ok(Self::Num),
"Bool" => Ok(Self::Bool),
_ => Err(anyhow::anyhow!("Expected `Num` or `Bool`")),
}
}
}
impl PrimitiveType for Num {
fn well_known_constraints() -> ConstraintSet<Self> {
let mut constraints = ConstraintSet::default();
constraints.insert_object_safe(Linearity);
constraints.insert(Ops);
constraints
}
}
impl WithBoolean for Num {
const BOOL: Self = Self::Bool;
}
impl LinearType for Num {
fn is_linear(&self) -> bool {
matches!(self, Self::Num)
}
}
#[derive(Debug, Clone)]
pub struct NumArithmetic {
comparisons_enabled: bool,
}
impl NumArithmetic {
pub const fn without_comparisons() -> Self {
Self {
comparisons_enabled: false,
}
}
pub const fn with_comparisons() -> Self {
Self {
comparisons_enabled: true,
}
}
pub fn unify_binary_op<Prim: PrimitiveType>(
substitutions: &mut Substitutions<Prim>,
context: &BinaryOpContext<Prim>,
mut errors: OpErrors<'_, Prim>,
settings: OpConstraintSettings<'_, Prim>,
) -> Type<Prim> {
let lhs_ty = &context.lhs;
let rhs_ty = &context.rhs;
let resolved_lhs_ty = substitutions.fast_resolve(lhs_ty);
let resolved_rhs_ty = substitutions.fast_resolve(rhs_ty);
match (
resolved_lhs_ty.is_primitive(),
resolved_rhs_ty.is_primitive(),
) {
(Some(true), Some(false)) => {
let resolved_rhs_ty = resolved_rhs_ty.clone();
settings
.lin
.visitor(substitutions, errors.with_location(ErrorLocation::Lhs))
.visit_type(lhs_ty);
settings
.lin
.visitor(substitutions, errors.with_location(ErrorLocation::Rhs))
.visit_type(rhs_ty);
resolved_rhs_ty
}
(Some(false), Some(true)) => {
let resolved_lhs_ty = resolved_lhs_ty.clone();
settings
.lin
.visitor(substitutions, errors.with_location(ErrorLocation::Lhs))
.visit_type(lhs_ty);
settings
.lin
.visitor(substitutions, errors.with_location(ErrorLocation::Rhs))
.visit_type(rhs_ty);
resolved_lhs_ty
}
_ => {
let lhs_is_valid = errors.with_location(ErrorLocation::Lhs).check(|errors| {
settings
.ops
.visitor(substitutions, errors)
.visit_type(lhs_ty);
});
let rhs_is_valid = errors.with_location(ErrorLocation::Rhs).check(|errors| {
settings
.ops
.visitor(substitutions, errors)
.visit_type(rhs_ty);
});
if lhs_is_valid && rhs_is_valid {
substitutions.unify(lhs_ty, rhs_ty, errors);
}
if lhs_is_valid {
lhs_ty.clone()
} else {
rhs_ty.clone()
}
}
}
}
pub fn process_unary_op<Prim: WithBoolean>(
substitutions: &mut Substitutions<Prim>,
context: &UnaryOpContext<Prim>,
mut errors: OpErrors<'_, Prim>,
constraints: &impl Constraint<Prim>,
) -> Type<Prim> {
match context.op {
UnaryOp::Not => BoolArithmetic.process_unary_op(substitutions, context, errors),
UnaryOp::Neg => {
constraints
.visitor(substitutions, errors)
.visit_type(&context.arg);
context.arg.clone()
}
_ => {
errors.push(ErrorKind::unsupported(context.op));
substitutions.new_type_var()
}
}
}
pub fn process_binary_op<Prim: WithBoolean>(
substitutions: &mut Substitutions<Prim>,
context: &BinaryOpContext<Prim>,
mut errors: OpErrors<'_, Prim>,
comparable_type: Option<Prim>,
settings: OpConstraintSettings<'_, Prim>,
) -> Type<Prim> {
match context.op {
BinaryOp::And | BinaryOp::Or | BinaryOp::Eq | BinaryOp::NotEq => {
BoolArithmetic.process_binary_op(substitutions, context, errors)
}
BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div | BinaryOp::Power => {
Self::unify_binary_op(substitutions, context, errors, settings)
}
BinaryOp::Ge | BinaryOp::Le | BinaryOp::Lt | BinaryOp::Gt => {
if let Some(ty) = comparable_type {
let ty = Type::Prim(ty);
substitutions.unify(
&ty,
&context.lhs,
errors.with_location(ErrorLocation::Lhs),
);
substitutions.unify(
&ty,
&context.rhs,
errors.with_location(ErrorLocation::Rhs),
);
} else {
let err = ErrorKind::unsupported(context.op);
errors.push(err);
}
Type::BOOL
}
_ => {
errors.push(ErrorKind::unsupported(context.op));
substitutions.new_type_var()
}
}
}
}
impl<Val> MapPrimitiveType<Val> for NumArithmetic
where
Val: Clone + NumOps + PartialEq,
{
type Prim = Num;
fn type_of_literal(&self, _: &Val) -> Self::Prim {
Num::Num
}
}
impl TypeArithmetic<Num> for NumArithmetic {
fn process_unary_op<'a>(
&self,
substitutions: &mut Substitutions<Num>,
context: &UnaryOpContext<Num>,
errors: OpErrors<'_, Num>,
) -> Type<Num> {
Self::process_unary_op(substitutions, context, errors, &Linearity)
}
fn process_binary_op<'a>(
&self,
substitutions: &mut Substitutions<Num>,
context: &BinaryOpContext<Num>,
errors: OpErrors<'_, Num>,
) -> Type<Num> {
const OP_SETTINGS: OpConstraintSettings<'static, Num> = OpConstraintSettings {
lin: &Linearity,
ops: &Ops,
};
let comparable_type = if self.comparisons_enabled {
Some(Num::Num)
} else {
None
};
Self::process_binary_op(substitutions, context, errors, comparable_type, OP_SETTINGS)
}
}