use std::num::Wrapping;
#[cfg(feature = "num-bigint")]
mod num;
pub trait Operation<N> {
fn combine(&self, a: &N, b: &N) -> N;
#[inline]
fn combine_mut(&self, a: &mut N, b: &N) {
let res = self.combine(&*a, b);
*a = res;
}
#[inline]
fn combine_mut2(&self, a: &N, b: &mut N) {
let res = self.combine(a, &*b);
*b = res;
}
#[inline]
fn combine_left(&self, mut a: N, b: &N) -> N {
self.combine_mut(&mut a, b); a
}
#[inline]
fn combine_right(&self, a: &N, mut b: N) -> N {
self.combine_mut2(a, &mut b); b
}
#[inline]
fn combine_both(&self, a: N, b: N) -> N {
self.combine_left(a, &b)
}
}
pub trait Commutative<N>: Operation<N> {}
pub trait Identity<N> {
fn identity(&self) -> N;
}
pub trait Invertible<N> {
fn uncombine(&self, a: &mut N, b: &N);
}
#[derive(Clone,Copy,Eq,PartialEq,Ord,PartialOrd,Debug,Default,Hash)]
pub struct Add;
#[derive(Clone,Copy,Eq,PartialEq,Ord,PartialOrd,Debug,Default,Hash)]
pub struct Mul;
#[derive(Clone,Copy,Eq,PartialEq,Ord,PartialOrd,Debug,Default,Hash)]
pub struct And;
#[derive(Clone,Copy,Eq,PartialEq,Ord,PartialOrd,Debug,Default,Hash)]
pub struct Or;
#[derive(Clone,Copy,Eq,PartialEq,Ord,PartialOrd,Debug,Default,Hash)]
pub struct Xor;
#[derive(Clone,Copy,Eq,PartialEq,Ord,PartialOrd,Debug,Default,Hash)]
pub struct Min;
#[derive(Clone,Copy,Eq,PartialEq,Ord,PartialOrd,Debug,Default,Hash)]
pub struct Max;
macro_rules! impl_operation_infix {
($op:ty, $ty:ty, $combineop:tt, $doc:expr) => {
impl Operation<$ty> for $op {
#[doc = $doc]
#[inline]
fn combine(&self, a: &$ty, b: &$ty) -> $ty {
*a $combineop *b
}
}
}
}
macro_rules! impl_operation_prefix {
($op:ty, $ty:ty, $combinef:expr, $doc:expr) => {
impl Operation<$ty> for $op {
#[doc = $doc]
#[inline]
fn combine(&self, a: &$ty, b: &$ty) -> $ty {
$combinef(*a, *b)
}
}
}
}
macro_rules! impl_identity {
($op:ty, $ty:ty, $iden:expr, $doc:expr) => {
impl Identity<$ty> for $op {
#[doc = $doc]
#[inline]
fn identity(&self) -> $ty {
$iden
}
}
}
}
macro_rules! impl_inverse {
($op:ty, $ty:ty, $uncombineop:tt, $doc:expr) => {
impl Invertible<$ty> for $op {
#[doc = $doc]
#[inline]
fn uncombine(&self, a: &mut $ty, b: &$ty) {
*a = *a $uncombineop *b;
}
}
}
}
macro_rules! impl_unsigned_primitive {
($ty:tt) => {
impl_operation_infix!(Add, $ty, +, "Returns the sum.");
impl_identity!(Add, $ty, 0, "Returns zero.");
impl Commutative<$ty> for Add {}
impl_operation_infix!(Add, Wrapping<$ty>, +, "Returns the sum.");
impl_identity!(Add, Wrapping<$ty>, Wrapping(0), "Returns zero.");
impl_inverse!(Add, Wrapping<$ty>, -, "Returns the difference.");
impl Commutative<Wrapping<$ty>> for Add {}
impl_operation_infix!(Xor, $ty, ^, "Returns the bitwise exclusive or.");
impl_identity!(Xor, $ty, 0, "Returns zero.");
impl_inverse!(Xor, $ty, ^, "Returns the bitwise exclusive or.");
impl Commutative<$ty> for Xor {}
impl_operation_infix!(Mul, $ty, *, "Returns the product.");
impl_identity!(Mul, $ty, 1, "Returns one.");
impl Commutative<$ty> for Mul {}
impl_operation_infix!(Mul, Wrapping<$ty>, *, "Returns the product.");
impl_identity!(Mul, Wrapping<$ty>, Wrapping(1), "Returns one.");
impl Commutative<Wrapping<$ty>> for Mul {}
impl_operation_infix!(And, $ty, &, "Returns the bitwise and.");
impl_identity!(And, $ty, std::$ty::MAX, "Returns the largest possible value.");
impl Commutative<$ty> for And {}
impl_operation_infix!(Or, $ty, &, "Returns the bitwise or.");
impl_identity!(Or, $ty, 0, "Returns zero.");
impl Commutative<$ty> for Or {}
impl_operation_prefix!(Min, $ty, std::cmp::min, "Returns the minimum.");
impl_identity!(Min, $ty, std::$ty::MAX, "Returns the largest possible value.");
impl Commutative<$ty> for Min {}
impl_operation_prefix!(Max, $ty, std::cmp::max, "Returns the maximum.");
impl_identity!(Max, $ty, 0, "Returns zero.");
impl Commutative<$ty> for Max {}
}
}
impl_unsigned_primitive!(u8);
impl_unsigned_primitive!(u16);
impl_unsigned_primitive!(u32);
impl_unsigned_primitive!(u64);
impl_unsigned_primitive!(u128);
impl_unsigned_primitive!(usize);
macro_rules! impl_signed_primitive {
($ty:tt) => {
impl_operation_infix!(Add, $ty, +, "Returns the sum.");
impl_identity!(Add, $ty, 0, "Returns zero.");
impl_inverse!(Add, $ty, -, "Returns the difference.");
impl Commutative<$ty> for Add {}
impl_operation_infix!(Add, Wrapping<$ty>, +, "Returns the sum.");
impl_identity!(Add, Wrapping<$ty>, Wrapping(0), "Returns zero.");
impl_inverse!(Add, Wrapping<$ty>, -, "Returns the difference.");
impl Commutative<Wrapping<$ty>> for Add {}
impl_operation_infix!(Xor, $ty, ^, "Returns the bitwise exclusive or.");
impl_identity!(Xor, $ty, 0, "Returns zero.");
impl_inverse!(Xor, $ty, ^, "Returns the bitwise exclusive or.");
impl Commutative<$ty> for Xor {}
impl_operation_infix!(Mul, $ty, *, "Returns the product.");
impl_identity!(Mul, $ty, 1, "Returns one.");
impl Commutative<$ty> for Mul {}
impl_operation_infix!(Mul, Wrapping<$ty>, *, "Returns the product.");
impl_identity!(Mul, Wrapping<$ty>, Wrapping(1), "Returns one.");
impl Commutative<Wrapping<$ty>> for Mul {}
impl_operation_infix!(And, $ty, &, "Returns the bitwise and.");
impl_identity!(And, $ty, -1, "Returns negative one.");
impl Commutative<$ty> for And {}
impl_operation_infix!(Or, $ty, &, "Returns the bitwise or.");
impl_identity!(Or, $ty, 0, "Returns zero.");
impl Commutative<$ty> for Or {}
impl_operation_prefix!(Min, $ty, std::cmp::min, "Returns the minimum.");
impl_identity!(Min, $ty, std::$ty::MAX, "Returns the largest possible value.");
impl Commutative<$ty> for Min {}
impl_operation_prefix!(Max, $ty, std::cmp::max, "Returns the maximum.");
impl_identity!(Max, $ty, std::$ty::MIN, "Returns the smallest possible value.");
impl Commutative<$ty> for Max {}
}
}
impl_signed_primitive!(i8);
impl_signed_primitive!(i16);
impl_signed_primitive!(i32);
impl_signed_primitive!(i64);
impl_signed_primitive!(i128);
impl_signed_primitive!(isize);
impl_operation_infix!(Add, f32, +, "Returns the sum.");
impl_inverse!(Add, f32, -, "Returns the difference.");
impl_identity!(Add, f32, 0.0, "Returns zero.");
impl Commutative<f32> for Add {}
impl_operation_infix!(Mul, f32, *, "Returns the product.");
impl_inverse!(Mul, f32, /, "Returns the ratio.");
impl_identity!(Mul, f32, 1.0, "Returns one.");
impl Commutative<f32> for Mul {}
impl_operation_infix!(Add, f64, +, "Returns the sum.");
impl_inverse!(Add, f64, -, "Returns the difference.");
impl_identity!(Add, f64, 0.0, "Returns zero.");
impl Commutative<f64> for Add {}
impl_operation_infix!(Mul, f64, *, "Returns the product.");
impl_inverse!(Mul, f64, /, "Returns the ratio.");
impl_identity!(Mul, f64, 1.0, "Returns one.");
impl Commutative<f64> for Mul {}
#[derive(Clone,Copy,Eq,PartialEq,Ord,PartialOrd,Debug,Default,Hash)]
pub struct MinIgnoreNaN;
impl_identity!(MinIgnoreNaN, f32, std::f32::NAN, "Returns NaN.");
impl_identity!(MinIgnoreNaN, f64, std::f64::NAN, "Returns NaN.");
impl Commutative<f32> for MinIgnoreNaN {}
impl Commutative<f64> for MinIgnoreNaN {}
impl Operation<f32> for MinIgnoreNaN {
fn combine(&self, a: &f32, b: &f32) -> f32 {
if b > a || b.is_nan() { *a } else { *b }
}
}
impl Operation<f64> for MinIgnoreNaN {
fn combine(&self, a: &f64, b: &f64) -> f64 {
if b > a || b.is_nan() { *a } else { *b }
}
}
#[derive(Clone,Copy,Eq,PartialEq,Ord,PartialOrd,Debug,Default,Hash)]
pub struct MinTakeNaN;
impl_identity!(MinTakeNaN, f32, std::f32::INFINITY, "Returns infinity.");
impl_identity!(MinTakeNaN, f64, std::f64::INFINITY, "Returns infinity.");
impl Commutative<f32> for MinTakeNaN {}
impl Commutative<f64> for MinTakeNaN {}
impl Operation<f32> for MinTakeNaN {
fn combine(&self, a: &f32, b: &f32) -> f32 {
if b > a || a.is_nan() { *a } else { *b }
}
}
impl Operation<f64> for MinTakeNaN {
fn combine(&self, a: &f64, b: &f64) -> f64 {
if b > a || a.is_nan() { *a } else { *b }
}
}
#[derive(Clone,Copy,Eq,PartialEq,Ord,PartialOrd,Debug,Default,Hash)]
pub struct MaxIgnoreNaN;
impl_identity!(MaxIgnoreNaN, f32, std::f32::NAN, "Returns NaN.");
impl_identity!(MaxIgnoreNaN, f64, std::f64::NAN, "Returns NaN.");
impl Commutative<f32> for MaxIgnoreNaN {}
impl Commutative<f64> for MaxIgnoreNaN {}
impl Operation<f32> for MaxIgnoreNaN {
fn combine(&self, a: &f32, b: &f32) -> f32 {
if b < a || b.is_nan() { *a } else { *b }
}
}
impl Operation<f64> for MaxIgnoreNaN {
fn combine(&self, a: &f64, b: &f64) -> f64 {
if b < a || b.is_nan() { *a } else { *b }
}
}
#[derive(Clone,Copy,Eq,PartialEq,Ord,PartialOrd,Debug,Default,Hash)]
pub struct MaxTakeNaN;
impl_identity!(MaxTakeNaN, f32, std::f32::NEG_INFINITY, "Returns negative infinity.");
impl_identity!(MaxTakeNaN, f64, std::f64::NEG_INFINITY, "Returns negative infinity.");
impl Commutative<f32> for MaxTakeNaN {}
impl Commutative<f64> for MaxTakeNaN {}
impl Operation<f32> for MaxTakeNaN {
fn combine(&self, a: &f32, b: &f32) -> f32 {
if b < a || a.is_nan() { *a } else { *b }
}
}
impl Operation<f64> for MaxTakeNaN {
fn combine(&self, a: &f64, b: &f64) -> f64 {
if b < a || a.is_nan() { *a } else { *b }
}
}
impl_operation_infix!(And, bool, &&, "Returns the boolean and.");
impl_identity!(And, bool, true, "Returns `true`.");
impl_operation_infix!(Or, bool, ||, "Returns the boolean or.");
impl_identity!(Or, bool, false, "Returns `false`.");
impl_operation_infix!(Xor, bool, ^, "Returns the boolean xor.");
impl_inverse!(Xor, bool, ^, "Returns the boolean xor.");
impl_identity!(Xor, bool, false, "Returns `false`.");
#[cfg(test)]
mod tests {
use std::{f32, i32, u32};
use crate::ops::*;
#[test]
fn ops_nan() {
assert_eq!(MaxIgnoreNaN.combine_both(0.0, 1.0), 1.0);
assert_eq!(MaxIgnoreNaN.combine_both(1.0, 0.0), 1.0);
assert_eq!(MaxIgnoreNaN.combine_both(f32::NAN, 1.0), 1.0);
assert_eq!(MaxIgnoreNaN.combine_both(1.0, f32::NAN), 1.0);
assert_eq!(MaxIgnoreNaN.combine_both(f32::NAN, f32::NEG_INFINITY),
f32::NEG_INFINITY);
assert_eq!(MaxIgnoreNaN.combine_both(f32::NEG_INFINITY, f32::NAN),
f32::NEG_INFINITY);
assert!(MaxIgnoreNaN.combine_both(f32::NAN, f32::NAN).is_nan());
assert_eq!(MinIgnoreNaN.combine_both(0.0, 1.0), 0.0);
assert_eq!(MinIgnoreNaN.combine_both(1.0, 0.0), 0.0);
assert_eq!(MinIgnoreNaN.combine_both(f32::NAN, 1.0), 1.0);
assert_eq!(MinIgnoreNaN.combine_both(1.0, f32::NAN), 1.0);
assert_eq!(MinIgnoreNaN.combine_both(f32::NAN, f32::INFINITY), f32::INFINITY);
assert_eq!(MinIgnoreNaN.combine_both(f32::INFINITY, f32::NAN), f32::INFINITY);
assert!(MinIgnoreNaN.combine_both(f32::NAN, f32::NAN).is_nan());
assert_eq!(MaxTakeNaN.combine_both(0.0, 1.0), 1.0);
assert_eq!(MaxTakeNaN.combine_both(1.0, 0.0), 1.0);
assert!(MaxTakeNaN.combine_both(f32::NAN, f32::INFINITY).is_nan());
assert!(MaxTakeNaN.combine_both(f32::INFINITY, f32::NAN).is_nan());
assert!(MaxTakeNaN.combine_both(f32::NAN, f32::NEG_INFINITY).is_nan());
assert!(MaxTakeNaN.combine_both(f32::NEG_INFINITY, f32::NAN).is_nan());
assert!(MaxTakeNaN.combine_both(f32::NAN, f32::NAN).is_nan());
assert_eq!(MinTakeNaN.combine_both(0.0, 1.0), 0.0);
assert_eq!(MinTakeNaN.combine_both(1.0, 0.0), 0.0);
assert!(MinTakeNaN.combine_both(f32::NAN, f32::INFINITY).is_nan());
assert!(MinTakeNaN.combine_both(f32::INFINITY, f32::NAN).is_nan());
assert!(MinTakeNaN.combine_both(f32::NAN, f32::NEG_INFINITY).is_nan());
assert!(MinTakeNaN.combine_both(f32::NEG_INFINITY, f32::NAN).is_nan());
assert!(MinTakeNaN.combine_both(f32::NAN, f32::NAN).is_nan());
}
#[test]
fn ops_and_identity() {
for i in -200i32 ..= 200i32 {
assert_eq!(And.combine_both(i, And.identity()), i);
}
assert_eq!(And.combine_both(i32::MAX, And.identity()), i32::MAX);
assert_eq!(And.combine_both(i32::MIN, And.identity()), i32::MIN);
assert_eq!(And.combine_both(0i32, And.identity()), 0i32);
assert_eq!(And.combine_both(0u32, And.identity()), 0u32);
assert_eq!(And.combine_both(u32::MAX, And.identity()), u32::MAX);
}
}
#[derive(Clone,Copy,Eq,PartialEq,Ord,PartialOrd,Debug,Default,Hash)]
pub struct Pair<A, B> {
a: A, b: B
}
impl<A, B> Pair<A, B> {
pub fn wrap(a: A, b: B) -> Pair<A, B> {
Pair { a: a, b: b }
}
pub fn into_inner(self) -> (A, B) {
(self.a, self.b)
}
}
impl<TA, TB, A: Operation<TA>, B: Operation<TB>> Operation<(TA, TB)> for Pair<A, B> {
#[inline]
fn combine(&self, a: &(TA, TB), b: &(TA, TB)) -> (TA, TB) {
(self.a.combine(&a.0, &b.0), self.b.combine(&a.1, &b.1))
}
#[inline]
fn combine_mut(&self, a: &mut (TA, TB), b: &(TA, TB)) {
self.a.combine_mut(&mut a.0, &b.0);
self.b.combine_mut(&mut a.1, &b.1);
}
#[inline]
fn combine_mut2(&self, a: &(TA, TB), b: &mut (TA, TB)) {
self.a.combine_mut2(&a.0, &mut b.0);
self.b.combine_mut2(&a.1, &mut b.1);
}
#[inline]
fn combine_left(&self, a: (TA, TB), b: &(TA, TB)) -> (TA, TB) {
(self.a.combine_left(a.0, &b.0), self.b.combine_left(a.1, &b.1))
}
#[inline]
fn combine_right(&self, a: &(TA, TB), b: (TA, TB)) -> (TA, TB) {
(self.a.combine_right(&a.0, b.0), self.b.combine_right(&a.1, b.1))
}
#[inline]
fn combine_both(&self, a: (TA, TB), b: (TA, TB)) -> (TA, TB) {
(self.a.combine_both(a.0, b.0), self.b.combine_both(a.1, b.1))
}
}
impl<TA, TB, A: Invertible<TA>, B: Invertible<TB>> Invertible<(TA, TB)> for Pair<A, B> {
#[inline(always)]
fn uncombine(&self, a: &mut (TA, TB), b: &(TA, TB)) {
self.a.uncombine(&mut a.0, &b.0);
self.b.uncombine(&mut a.1, &b.1);
}
}
impl<TA, TB, A: Commutative<TA>, B: Commutative<TB>> Commutative<(TA, TB)> for Pair<A, B> {}
impl<TA, TB, A: Identity<TA>, B: Identity<TB>> Identity<(TA,TB)> for Pair<A, B> {
fn identity(&self) -> (TA, TB) {
(self.a.identity(), self.b.identity())
}
}