use crate::{N, T};
use std::cmp::{max, min};
use std::ops::Div;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ArithmeticConstraint {
#[allow(dead_code)]
CommutativeConstraint(CommutativeOperator, T),
NonCommutativeConstraint(NonCommutativeOperator, T),
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub enum CommutativeOperator {
Add,
Multiply,
}
impl CommutativeOperator {
#[must_use]
pub fn apply_to_tuple(self, ns: &[N]) -> T {
match self {
Self::Add => ns.iter().map(|&v| T::from(v)).sum(),
Self::Multiply => ns.iter().map(|&v| T::from(v)).product(),
}
}
#[must_use]
pub const fn apply_to_pair(self, x: T, y: T) -> T {
match self {
Self::Add => x + y,
Self::Multiply => x * y,
}
}
#[must_use]
pub const fn identity(self) -> T {
match self {
Self::Add => 0,
Self::Multiply => 1,
}
}
#[must_use]
pub const fn dual(self) -> Self {
match self {
Self::Add => Self::Multiply,
Self::Multiply => Self::Add,
}
}
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub enum NonCommutativeOperator {
Subtract,
Divide,
}
impl NonCommutativeOperator {
#[must_use]
pub fn apply(self, a: N, b: N) -> T {
match self {
Self::Subtract => T::from(a.abs_diff(b)),
Self::Divide => T::from(max(a, b).div(min(a, b))),
}
}
}