#![no_std]
use num_traits::{int::PrimInt, sign::Signed, identities::{ConstOne, ConstZero}, CheckedAdd};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum Rounding {
#[default]
Round,
Floor,
Ceiling,
TowardsZero,
AwayFromZero,
}
use core::ops::{AddAssign, SubAssign};
pub trait PrimSignedInt: PrimInt + Signed + ConstOne + ConstZero + AddAssign + SubAssign {
fn unsigned_abs(self) -> impl PrimInt;
}
impl PrimSignedInt for i8 {
fn unsigned_abs(self) -> impl PrimInt {
i8::unsigned_abs(self)
}
}
impl PrimSignedInt for i16 {
fn unsigned_abs(self) -> impl PrimInt {
i16::unsigned_abs(self)
}
}
impl PrimSignedInt for i32 {
fn unsigned_abs(self) -> impl PrimInt {
i32::unsigned_abs(self)
}
}
impl PrimSignedInt for i64 {
fn unsigned_abs(self) -> impl PrimInt {
i64::unsigned_abs(self)
}
}
impl PrimSignedInt for i128 {
fn unsigned_abs(self) -> impl PrimInt {
i128::unsigned_abs(self)
}
}
fn cmp_abs_ge<I>(a: I, b: I) -> bool
where I: PrimSignedInt
{
a.unsigned_abs() >= b.unsigned_abs()
}
fn cmp_abs_half_ge<I>(a: I, b: I) -> bool
where I: PrimSignedInt
{
let a = a.unsigned_abs();
match a.checked_add(&a) {
Some(a2) => a2 >= b.unsigned_abs(),
None => true,
}
}
fn add_cmp_abs_lt<I>(a: I, b: I) -> bool
where I: PrimSignedInt
{
match a.checked_add(&b) {
None => false,
Some(n) => !cmp_abs_ge(n, b),
}
}
fn sub_cmp_abs_lt<I>(a: I, b: I) -> bool
where I: PrimSignedInt
{
match a.checked_sub(&b) {
None => false,
Some(n) => !cmp_abs_ge(n, b),
}
}
fn same_sign<I>(a: I, b: I) -> bool
where I: PrimSignedInt
{
(a ^ b).is_positive()
}
pub fn checked_divide_with_rounding<I>(
left: I,
right: I,
rounding: Rounding,
) -> Option<I>
where I: PrimSignedInt
{
let Some(q) = left.checked_div(&right) else {
return None;
};
let remain = left % right;
if remain.is_zero() {
return Some(q);
}
Some(match rounding {
Rounding::Floor => {
if same_sign(left, right) {
q
} else {
q - I::ONE
}
}
Rounding::Ceiling => {
if same_sign(left, right) {
q + I::ONE
} else {
q
}
}
Rounding::Round => {
if cmp_abs_half_ge(remain, right) {
if same_sign(left, right) {
q + I::ONE
} else {
q - I::ONE
}
} else {
q
}
}
Rounding::TowardsZero => {
q
}
Rounding::AwayFromZero => {
if same_sign(left, right) {
q + I::ONE
} else {
q - I::ONE
}
}
})
}
pub fn checked_divide_with_cum_error<I>(
left: I,
right: I,
rounding: Rounding,
cum_error: &mut I,
) -> Option<I>
where I: PrimSignedInt
{
let Some(mut q) = left.checked_div(&right) else {
return None;
};
let remain = left % right;
if remain.is_zero() {
return Some(q);
}
let Some(tmpsum) = cum_error.checked_add(&remain) else {
if same_sign(left, right) {
*cum_error += remain - right;
return Some(q + I::ONE);
} else {
*cum_error += remain + right;
return Some(q - I::ONE);
}
};
*cum_error = tmpsum;
match rounding {
Rounding::Floor => {
if same_sign(left, right) {
if cmp_abs_ge(*cum_error, right) {
*cum_error -= right;
q += I::ONE;
}
} else {
if add_cmp_abs_lt(*cum_error, right) {
*cum_error += right;
q -= I::ONE;
}
}
}
Rounding::Ceiling => {
if same_sign(left, right) {
if sub_cmp_abs_lt(*cum_error, right) {
*cum_error -= right;
q += I::ONE;
}
} else {
if cmp_abs_ge(*cum_error, right) {
*cum_error += right;
q -= I::ONE;
}
}
}
Rounding::Round => {
if cmp_abs_half_ge(*cum_error, right) {
if same_sign(left, right) {
*cum_error -= right;
q += I::ONE;
} else {
*cum_error += right;
q -= I::ONE;
}
}
}
Rounding::TowardsZero => {
if cmp_abs_ge(*cum_error, right) {
if same_sign(left, right) {
*cum_error -= right;
q += I::ONE;
} else {
*cum_error += right;
q -= I::ONE;
}
}
}
Rounding::AwayFromZero => {
if same_sign(left, right) {
if sub_cmp_abs_lt(*cum_error, right) {
*cum_error -= right;
q += I::ONE;
}
} else {
if add_cmp_abs_lt(*cum_error, right) {
*cum_error += right;
q -= I::ONE;
}
}
}
}
Some(q)
}
pub fn checked_divide<I>(
left: I,
right: I,
rounding: Rounding,
cum_error: Option<&mut I>,
) -> Option<I>
where I: PrimSignedInt
{
match cum_error {
Some(cum_error) => checked_divide_with_cum_error(left, right, rounding, cum_error),
None => checked_divide_with_rounding(left, right, rounding),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_rounding_fib(a: i32, b: i32, rounding: Rounding) {
let mut cum_error = 0_i32;
let mut i0 = a;
let mut i1 = a;
let mut isum = 0_i32;
let mut ret = 0_i32;
while isum.abs() < 1_000_000_000 {
let ix = i0 + i1;
ret += checked_divide_with_cum_error(ix, b, rounding, &mut cum_error).unwrap();
i0 = i1;
i1 = ix;
isum += ix;
let q = checked_divide_with_rounding(isum, b, rounding).unwrap();
let r = isum - q * b;
assert_eq!(q, ret);
assert_eq!(r, cum_error);
}
}
fn do_test(a: i32, b: i32) {
test_rounding_fib(a, b, Rounding::Floor);
test_rounding_fib(a, b, Rounding::Ceiling);
test_rounding_fib(a, b, Rounding::Round);
test_rounding_fib(a, b, Rounding::AwayFromZero);
test_rounding_fib(a, b, Rounding::TowardsZero);
}
fn test(b: i32) {
do_test(1, b);
do_test(-1, b);
do_test(1, -b);
do_test(-1, -b);
}
#[test]
fn many_test() {
for b in 1..100 {
test(b*3+6);
test(b*13+7);
test(b*113+8);
test(b*1113+9);
test(b*11113+9);
test(b*111113+9);
test(b*1111113+9);
}
println!("done");
}
fn do_test_overflow(a: i32, b: i32) {
let mut cum_error = 0_i32;
checked_divide_with_cum_error(a, b, Rounding::Floor, &mut cum_error).unwrap();
checked_divide_with_cum_error(a, b, Rounding::Floor, &mut cum_error).unwrap();
checked_divide_with_cum_error(a, b, Rounding::Floor, &mut cum_error).unwrap();
checked_divide_with_cum_error(a, b, Rounding::Floor, &mut cum_error).unwrap();
let mut cum_error = 0_i32;
checked_divide_with_cum_error(a, b, Rounding::Ceiling, &mut cum_error).unwrap();
checked_divide_with_cum_error(a, b, Rounding::Ceiling, &mut cum_error).unwrap();
checked_divide_with_cum_error(a, b, Rounding::Ceiling, &mut cum_error).unwrap();
checked_divide_with_cum_error(a, b, Rounding::Ceiling, &mut cum_error).unwrap();
let mut cum_error = 0_i32;
checked_divide_with_cum_error(a, b, Rounding::Round, &mut cum_error).unwrap();
checked_divide_with_cum_error(a, b, Rounding::Round, &mut cum_error).unwrap();
checked_divide_with_cum_error(a, b, Rounding::Round, &mut cum_error).unwrap();
checked_divide_with_cum_error(a, b, Rounding::Round, &mut cum_error).unwrap();
let mut cum_error = 0_i32;
checked_divide_with_cum_error(a, b, Rounding::TowardsZero, &mut cum_error).unwrap();
checked_divide_with_cum_error(a, b, Rounding::TowardsZero, &mut cum_error).unwrap();
checked_divide_with_cum_error(a, b, Rounding::TowardsZero, &mut cum_error).unwrap();
checked_divide_with_cum_error(a, b, Rounding::TowardsZero, &mut cum_error).unwrap();
let mut cum_error = 0_i32;
checked_divide_with_cum_error(a, b, Rounding::AwayFromZero, &mut cum_error).unwrap();
checked_divide_with_cum_error(a, b, Rounding::AwayFromZero, &mut cum_error).unwrap();
checked_divide_with_cum_error(a, b, Rounding::AwayFromZero, &mut cum_error).unwrap();
checked_divide_with_cum_error(a, b, Rounding::AwayFromZero, &mut cum_error).unwrap();
}
#[test]
fn test_overflow() {
do_test_overflow(i32::MAX - 100, i32::MAX);
do_test_overflow(i32::MIN + 100, i32::MIN);
do_test_overflow(i32::MAX - 100, i32::MIN);
do_test_overflow(i32::MIN + 100, i32::MAX);
do_test_overflow(i32::MAX / 2, i32::MAX);
do_test_overflow(i32::MIN / 2, i32::MIN);
do_test_overflow(i32::MAX / 2, i32::MIN);
do_test_overflow(i32::MIN / 2, i32::MAX);
do_test_overflow(i32::MAX / 3, i32::MAX);
do_test_overflow(i32::MIN / 3, i32::MIN);
do_test_overflow(i32::MAX / 3, i32::MIN);
do_test_overflow(i32::MIN / 3, i32::MAX);
do_test_overflow(i32::MAX / 3 * 2, i32::MAX);
do_test_overflow(i32::MIN / 3 * 2, i32::MIN);
do_test_overflow(i32::MAX / 3 * 2, i32::MIN);
do_test_overflow(i32::MIN / 3 * 2, i32::MAX);
}
}