use core::convert::Infallible;
use core::{fmt, ops};
#[cfg(feature = "arbitrary")]
use arbitrary::{Arbitrary, Unstructured};
use NumOpResult as R;
use crate::{Amount, FeeRate, SignedAmount, Weight};
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[must_use]
pub enum NumOpResult<T> {
Valid(T),
Error(NumOpError),
}
impl<T> NumOpResult<T> {
#[inline]
pub fn map<U, F: FnOnce(T) -> U>(self, op: F) -> NumOpResult<U> {
match self {
Self::Valid(t) => NumOpResult::Valid(op(t)),
Self::Error(e) => NumOpResult::Error(e),
}
}
}
impl<T: fmt::Debug> NumOpResult<T> {
#[inline]
#[track_caller]
pub fn expect(self, msg: &str) -> T {
match self {
Self::Valid(x) => x,
Self::Error(_) => panic!("{}", msg),
}
}
#[inline]
#[track_caller]
pub fn unwrap(self) -> T {
match self {
Self::Valid(x) => x,
Self::Error(e) => panic!("tried to unwrap an invalid numeric result: {:?}", e),
}
}
#[inline]
#[track_caller]
pub fn unwrap_err(self) -> NumOpError {
match self {
Self::Error(e) => e,
Self::Valid(a) => panic!("tried to unwrap a valid numeric result: {:?}", a),
}
}
#[inline]
#[track_caller]
pub fn unwrap_or(self, default: T) -> T {
match self {
Self::Valid(x) => x,
Self::Error(_) => default,
}
}
#[inline]
#[track_caller]
pub fn unwrap_or_else<F>(self, f: F) -> T
where
F: FnOnce() -> T,
{
match self {
Self::Valid(x) => x,
Self::Error(_) => f(),
}
}
#[inline]
pub fn ok(self) -> Option<T> {
match self {
Self::Valid(x) => Some(x),
Self::Error(_) => None,
}
}
#[inline]
#[allow(clippy::missing_errors_doc)]
pub fn into_result(self) -> Result<T, NumOpError> {
match self {
Self::Valid(x) => Ok(x),
Self::Error(e) => Err(e),
}
}
#[inline]
pub fn and_then<F>(self, op: F) -> Self
where
F: FnOnce(T) -> Self,
{
match self {
Self::Valid(x) => op(x),
Self::Error(e) => Self::Error(e),
}
}
#[inline]
pub fn is_valid(&self) -> bool {
match self {
Self::Valid(_) => true,
Self::Error(_) => false,
}
}
#[inline]
pub fn is_error(&self) -> bool { !self.is_valid() }
}
crate::internal_macros::impl_op_for_references! {
impl<T> ops::Add<NumOpResult<T>> for NumOpResult<T>
where
(T: Copy + ops::Add<Output = NumOpResult<T>>)
{
type Output = NumOpResult<T>;
fn add(self, rhs: Self) -> Self::Output {
match (self, rhs) {
(R::Valid(lhs), R::Valid(rhs)) => lhs + rhs,
(_, _) => R::Error(NumOpError::while_doing(MathOp::Add)),
}
}
}
impl<T> ops::Add<T> for NumOpResult<T>
where
(T: Copy + ops::Add<NumOpResult<T>, Output = NumOpResult<T>>)
{
type Output = NumOpResult<T>;
fn add(self, rhs: T) -> Self::Output { rhs + self }
}
impl<T> ops::Sub<NumOpResult<T>> for NumOpResult<T>
where
(T: Copy + ops::Sub<Output = NumOpResult<T>>)
{
type Output = NumOpResult<T>;
fn sub(self, rhs: Self) -> Self::Output {
match (self, rhs) {
(R::Valid(lhs), R::Valid(rhs)) => lhs - rhs,
(_, _) => R::Error(NumOpError::while_doing(MathOp::Sub)),
}
}
}
impl<T> ops::Sub<T> for NumOpResult<T>
where
(T: Copy + ops::Sub<Output = NumOpResult<T>>)
{
type Output = NumOpResult<T>;
fn sub(self, rhs: T) -> Self::Output {
match self {
R::Valid(amount) => amount - rhs,
R::Error(_) => self,
}
}
}
}
impl<T: ops::AddAssign> ops::AddAssign<T> for NumOpResult<T> {
fn add_assign(&mut self, rhs: T) {
if let Self::Valid(ref mut lhs) = self {
*lhs += rhs;
}
}
}
impl<T: ops::AddAssign + Copy> ops::AddAssign<Self> for NumOpResult<T> {
fn add_assign(&mut self, rhs: Self) {
match (&self, rhs) {
(Self::Valid(_), Self::Valid(rhs)) => *self += rhs,
(_, _) => *self = Self::Error(NumOpError::while_doing(MathOp::Add)),
}
}
}
impl<T: ops::SubAssign> ops::SubAssign<T> for NumOpResult<T> {
fn sub_assign(&mut self, rhs: T) {
if let Self::Valid(ref mut lhs) = self {
*lhs -= rhs;
}
}
}
impl<T: ops::SubAssign + Copy> ops::SubAssign<Self> for NumOpResult<T> {
fn sub_assign(&mut self, rhs: Self) {
match (&self, rhs) {
(Self::Valid(_), Self::Valid(rhs)) => *self -= rhs,
(_, _) => *self = Self::Error(NumOpError::while_doing(MathOp::Sub)),
}
}
}
pub(crate) trait OptionExt<T> {
fn valid_or_error(self, op: MathOp) -> NumOpResult<T>;
}
macro_rules! impl_opt_ext {
($($ty:ident),* $(,)?) => {
$(
impl OptionExt<$ty> for Option<$ty> {
#[inline]
fn valid_or_error(self, op: MathOp) -> NumOpResult<$ty> {
match self {
Some(amount) => R::Valid(amount),
None => R::Error(NumOpError(op)),
}
}
}
)*
}
}
impl_opt_ext!(Amount, SignedAmount, u64, i64, FeeRate, Weight);
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub struct NumOpError(MathOp);
impl NumOpError {
pub(crate) const fn while_doing(op: MathOp) -> Self { Self(op) }
pub fn is_overflow(self) -> bool { self.0.is_overflow() }
pub fn is_div_by_zero(self) -> bool { self.0.is_div_by_zero() }
pub fn operation(self) -> MathOp { self.0 }
}
impl fmt::Display for NumOpError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "math operation '{}' gave an invalid numeric result", self.operation())
}
}
#[cfg(feature = "std")]
impl std::error::Error for NumOpError {}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum MathOp {
Add,
Sub,
Mul,
Div,
Rem,
Neg,
#[doc(hidden)]
_DoNotUse(Infallible),
}
impl MathOp {
pub fn is_overflow(self) -> bool {
matches!(self, Self::Add | Self::Sub | Self::Mul | Self::Neg)
}
pub fn is_div_by_zero(self) -> bool { !self.is_overflow() }
pub fn is_addition(self) -> bool { self == Self::Add }
pub fn is_subtraction(self) -> bool { self == Self::Sub }
pub fn is_multiplication(self) -> bool { self == Self::Mul }
pub fn is_negation(self) -> bool { self == Self::Neg }
}
impl fmt::Display for MathOp {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Self::Add => write!(f, "add"),
Self::Sub => write!(f, "sub"),
Self::Mul => write!(f, "mul"),
Self::Div => write!(f, "div"),
Self::Rem => write!(f, "rem"),
Self::Neg => write!(f, "neg"),
Self::_DoNotUse(infallible) => match infallible {},
}
}
}
#[cfg(feature = "arbitrary")]
impl<'a, T: Arbitrary<'a>> Arbitrary<'a> for NumOpResult<T> {
fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result<Self> {
match bool::arbitrary(u)? {
true => Ok(Self::Valid(T::arbitrary(u)?)),
false => Ok(Self::Error(NumOpError(MathOp::arbitrary(u)?))),
}
}
}
#[cfg(feature = "arbitrary")]
impl<'a> Arbitrary<'a> for MathOp {
fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result<Self> {
let choice = u.int_in_range(0..=5)?;
match choice {
0 => Ok(Self::Add),
1 => Ok(Self::Sub),
2 => Ok(Self::Mul),
3 => Ok(Self::Div),
4 => Ok(Self::Rem),
_ => Ok(Self::Neg),
}
}
}
#[cfg(test)]
mod tests {
use super::{MathOp, NumOpError, NumOpResult};
use crate::{Amount, FeeRate, Weight};
#[test]
fn mathop_predicates() {
assert!(MathOp::Add.is_overflow());
assert!(MathOp::Sub.is_overflow());
assert!(MathOp::Mul.is_overflow());
assert!(MathOp::Neg.is_overflow());
assert!(!MathOp::Div.is_overflow());
assert!(!MathOp::Rem.is_overflow());
assert!(MathOp::Div.is_div_by_zero());
assert!(MathOp::Rem.is_div_by_zero());
assert!(!MathOp::Add.is_div_by_zero());
assert!(MathOp::Add.is_addition());
assert!(!MathOp::Sub.is_addition());
assert!(MathOp::Sub.is_subtraction());
assert!(!MathOp::Add.is_subtraction());
assert!(MathOp::Mul.is_multiplication());
assert!(!MathOp::Div.is_multiplication());
assert!(MathOp::Neg.is_negation());
assert!(!MathOp::Add.is_negation());
}
#[test]
fn mathop_map() {
let res = NumOpResult::Valid(Amount::from_sat_u32(100));
let new_value = res.map(|val| (val / FeeRate::from_sat_per_kwu(10)).unwrap());
assert_eq!(new_value, NumOpResult::Valid(Weight::from_wu(10_000)));
let res = NumOpResult::<Weight>::Error(NumOpError::while_doing(MathOp::Add));
let res_err = res.map(|_| {
panic!("map should not evaluate for wrapped error values");
});
assert_eq!(res_err, res);
}
#[test]
fn mathop_expect() {
let amounts = [
Amount::from_sat_u32(0),
Amount::from_sat_u32(10_000_000),
Amount::from_sat_u32(u32::MAX),
];
for amount in amounts {
assert_eq!(
NumOpResult::Valid(amount).expect("unreachable"),
NumOpResult::Valid(amount).unwrap(),
);
assert_eq!(NumOpResult::Valid(amount).expect("unreachable"), amount);
}
}
#[test]
#[should_panic(expected = "test error message")]
fn mathop_expect_panics_on_error() {
NumOpResult::<Amount>::Error(NumOpError::while_doing(MathOp::Add))
.expect("test error message");
}
#[test]
fn mathop_unwrap() {
let amounts = [
Amount::from_sat_u32(0),
Amount::from_sat_u32(10_000_000),
Amount::from_sat_u32(u32::MAX),
];
for amount in amounts {
assert_eq!(NumOpResult::Valid(amount).unwrap(), amount);
}
let weights = [Weight::from_wu(0), Weight::from_wu(16_384_000), Weight::from_wu(u64::MAX)];
for weight in weights {
assert_eq!(NumOpResult::Valid(weight).unwrap(), weight);
}
}
#[test]
#[should_panic(expected = "")]
fn mathop_unwrap_panics_on_err() {
NumOpResult::<Amount>::Error(NumOpError::while_doing(MathOp::Add)).unwrap();
}
#[test]
fn mathop_unwrap_err() {
let errs = [
NumOpError::while_doing(MathOp::Add),
NumOpError::while_doing(MathOp::Sub),
NumOpError::while_doing(MathOp::Mul),
NumOpError::while_doing(MathOp::Div),
NumOpError::while_doing(MathOp::Neg),
NumOpError::while_doing(MathOp::Rem),
];
for err in errs {
assert_eq!(NumOpResult::<Amount>::Error(err).unwrap_err(), err);
}
}
#[test]
#[should_panic(expected = "")]
fn mathop_unwrap_err_panics_on_valid() {
let value = Amount::from_sat_u32(150);
NumOpResult::<Amount>::Valid(value).unwrap_err();
}
#[test]
fn mathop_unwrap_or() {
let base_amount = Amount::from_sat_u32(100);
let res = NumOpResult::<Amount>::Error(NumOpError::while_doing(MathOp::Add));
let res_default = res.unwrap_or(base_amount);
assert_eq!(res_default, base_amount);
let res = NumOpResult::Valid(base_amount);
let new_amount = res.unwrap_or(Amount::from_sat_u32(50));
assert_eq!(new_amount, base_amount);
}
#[test]
fn mathop_unwrap_or_else() {
let base_amount = Amount::from_sat_u32(100);
let res = NumOpResult::<Amount>::Error(NumOpError::while_doing(MathOp::Add));
let res_default = res.unwrap_or_else(|| base_amount);
assert_eq!(res_default, base_amount);
let res = NumOpResult::<Amount>::Valid(base_amount);
let new_amount = res.unwrap_or_else(|| {
panic!("unwrap_or_else should not evaluate for wrapped valid values");
});
assert_eq!(new_amount, base_amount);
}
#[test]
fn mathop_ok() {
let amt = Amount::from_sat_u32(150);
assert_eq!(NumOpResult::Valid(amt).ok(), Some(amt));
let err = NumOpError::while_doing(MathOp::Add);
assert_eq!(NumOpResult::<Amount>::Error(err).ok(), None);
}
#[test]
fn mathop_and_then() {
let res = NumOpResult::Valid(Amount::from_sat_u32(100));
let new_value = res.and_then(|val| val + Amount::from_sat_u32(50));
assert_eq!(new_value, NumOpResult::Valid(Amount::from_sat_u32(150)));
let res = NumOpResult::<Amount>::Error(NumOpError::while_doing(MathOp::Add));
let res_err = res.and_then(|_| {
panic!("and_then should not evaluate for wrapped error values");
});
assert_eq!(res_err, res);
}
}