use num_traits::Float;
use crate::errors::LogProbSubtractionError;
use super::LogProb;
use core::ops::{Add, AddAssign, Mul, Sub, SubAssign};
impl<T: Add> Add for LogProb<T> {
type Output = LogProb<T::Output>;
#[inline]
fn add(self, other: Self) -> Self::Output {
LogProb((self.0).add(other.0))
}
}
impl<'a, T> Add<&'a Self> for LogProb<T>
where
T: Add<&'a T>,
{
type Output = LogProb<<T as Add<&'a T>>::Output>;
#[inline]
fn add(self, other: &'a Self) -> Self::Output {
LogProb((self.0).add(&other.0))
}
}
impl<'a, T> Add<LogProb<T>> for &'a LogProb<T>
where
&'a T: Add<T>,
{
type Output = LogProb<<&'a T as Add<T>>::Output>;
#[inline]
fn add(self, other: LogProb<T>) -> Self::Output {
LogProb((self.0).add(other.0))
}
}
impl<'a, 'b, T> Add<&'b LogProb<T>> for &'a LogProb<T>
where
&'a T: Add<T>,
T: Copy,
{
type Output = LogProb<<&'a T as Add<T>>::Output>;
#[inline]
fn add(self, other: &'b LogProb<T>) -> Self::Output {
LogProb((self.0).add(other.0))
}
}
impl<T: AddAssign> AddAssign for LogProb<T> {
#[inline]
fn add_assign(&mut self, other: Self) {
(self.0).add_assign(other.0);
}
}
impl<'a, T: AddAssign<&'a T>> AddAssign<&'a Self> for LogProb<T> {
#[inline]
fn add_assign(&mut self, other: &'a Self) {
(self.0).add_assign(&other.0);
}
}
impl<T: Sub + Float + SubAssign> SubAssign for LogProb<T> {
#[inline]
fn sub_assign(&mut self, other: Self) {
debug_assert!(*self <= other, "Numerator is greater than denominator");
debug_assert!(other.0.is_finite(), "Division by zero in prob space");
*self = self.saturating_sub(other);
}
}
impl<'a, T: Sub + Float + SubAssign<T>> SubAssign<&'a Self> for LogProb<T> {
#[inline]
fn sub_assign(&mut self, other: &'a Self) {
debug_assert!(*self <= *other, "Numerator is greater than denominator");
debug_assert!(other.0.is_finite(), "Division by zero in prob space");
*self = self.saturating_sub(*other);
}
}
impl<T: Sub + Float> Sub for LogProb<T> {
type Output = LogProb<<T as Sub>::Output>;
#[inline]
fn sub(self, rhs: Self) -> Self::Output {
debug_assert!(self <= rhs, "Numerator is greater than denominator");
debug_assert!(rhs.0.is_finite(), "Division by zero in prob space");
self.saturating_sub(rhs)
}
}
impl<'a, T> Sub<&'a Self> for LogProb<T>
where
T: Sub<&'a T> + Float,
{
type Output = LogProb<T>;
#[inline]
fn sub(self, rhs: &'a Self) -> Self::Output {
debug_assert!(self <= *rhs, "Numerator is greater than denominator");
debug_assert!(rhs.0.is_finite(), "Division by zero in prob space");
self.saturating_sub(*rhs)
}
}
impl<'a, T> Sub<LogProb<T>> for &'a LogProb<T>
where
&'a T: Sub<T>,
T: Float,
{
type Output = LogProb<T>;
#[inline]
fn sub(self, rhs: LogProb<T>) -> Self::Output {
debug_assert!(*self <= rhs, "Numerator is greater than denominator");
debug_assert!(rhs.0.is_finite(), "Division by zero in prob space");
self.saturating_sub(rhs)
}
}
impl<'a, 'b, T> Sub<&'b LogProb<T>> for &'a LogProb<T>
where
&'a T: Sub<T>,
T: Copy + Float,
{
type Output = LogProb<T>;
#[inline]
fn sub(self, rhs: &'b LogProb<T>) -> Self::Output {
debug_assert!(self <= rhs, "Numerator is greater than denominator");
debug_assert!(rhs.0.is_finite(), "Division by zero in prob space");
self.saturating_sub(*rhs)
}
}
impl<T: Float> LogProb<T> {
#[inline]
pub fn try_sub(&self, rhs: LogProb<T>) -> Result<LogProb<T>, LogProbSubtractionError> {
if *self > rhs {
Err(LogProbSubtractionError::NumeratorBiggerThanDenominator)
} else if !rhs.0.is_finite() {
Err(LogProbSubtractionError::DivideByZero)
} else {
Ok(LogProb(self.0.sub(rhs.0)))
}
}
#[inline]
pub fn checked_sub(&self, rhs: LogProb<T>) -> Option<LogProb<T>> {
if *self > rhs || !rhs.0.is_finite() {
None
} else {
Some(LogProb(self.0.sub(rhs.0)))
}
}
#[must_use]
#[inline]
pub fn saturating_sub(&self, rhs: LogProb<T>) -> LogProb<T> {
LogProb((self.0 - rhs.0).min(T::zero()))
}
#[must_use]
#[inline]
pub unsafe fn unchecked_sub(&self, rhs: LogProb<T>) -> LogProb<T> {
LogProb(self.0 - rhs.0)
}
}
macro_rules! impl_mul {
($unsigned: ty, $float: ty) => {
impl Mul<LogProb<$float>> for $unsigned {
type Output = LogProb<$float>;
fn mul(self, rhs: LogProb<$float>) -> Self::Output {
let s: $float = self.into();
LogProb(s * rhs.0)
}
}
impl Mul<$unsigned> for LogProb<$float> {
type Output = LogProb<$float>;
fn mul(self, rhs: $unsigned) -> Self::Output {
let s: $float = rhs.into();
LogProb(s * self.0)
}
}
impl<'a> Mul<&'a $unsigned> for LogProb<$float> {
type Output = LogProb<$float>;
fn mul(self, rhs: &'a $unsigned) -> Self::Output {
let s: $float = (*rhs).into();
LogProb(s * self.0)
}
}
impl<'a> Mul<&'a LogProb<$float>> for $unsigned {
type Output = LogProb<$float>;
fn mul(self, rhs: &'a LogProb<$float>) -> Self::Output {
let s: $float = self.into();
LogProb(s * rhs.0)
}
}
impl Mul<LogProb<$float>> for &$unsigned {
type Output = LogProb<$float>;
fn mul(self, rhs: LogProb<$float>) -> Self::Output {
let s: $float = (*self).into();
LogProb(s * rhs.0)
}
}
impl Mul<$unsigned> for &LogProb<$float> {
type Output = LogProb<$float>;
fn mul(self, rhs: $unsigned) -> Self::Output {
let s: $float = rhs.into();
LogProb(s * self.0)
}
}
impl Mul<&LogProb<$float>> for &$unsigned {
type Output = LogProb<$float>;
fn mul(self, rhs: &LogProb<$float>) -> Self::Output {
let s: $float = (*self).into();
LogProb(s * rhs.0)
}
}
impl Mul<&$unsigned> for &LogProb<$float> {
type Output = LogProb<$float>;
fn mul(self, rhs: &$unsigned) -> Self::Output {
let s: $float = (*rhs).into();
LogProb(s * self.0)
}
}
};
}
impl_mul!(u8, f32);
impl_mul!(u8, f64);
impl_mul!(u16, f32);
impl_mul!(u16, f64);
impl_mul!(u32, f64);
macro_rules! impl_mul_lossy {
($unsigned: ty, $float: ty) => {
#[expect(clippy::cast_precision_loss)]
impl Mul<LogProb<$float>> for $unsigned {
type Output = LogProb<$float>;
fn mul(self, rhs: LogProb<$float>) -> Self::Output {
let s: $float = self as $float;
LogProb(s * rhs.0)
}
}
#[expect(clippy::cast_precision_loss)]
impl Mul<$unsigned> for LogProb<$float> {
type Output = LogProb<$float>;
fn mul(self, rhs: $unsigned) -> Self::Output {
let s: $float = rhs as $float;
LogProb(s * self.0)
}
}
#[expect(clippy::cast_precision_loss)]
impl Mul<&$unsigned> for LogProb<$float> {
type Output = LogProb<$float>;
fn mul(self, rhs: &$unsigned) -> Self::Output {
let s: $float = (*rhs) as $float;
LogProb(s * self.0)
}
}
#[expect(clippy::cast_precision_loss)]
impl Mul<&LogProb<$float>> for $unsigned {
type Output = LogProb<$float>;
fn mul(self, rhs: &LogProb<$float>) -> Self::Output {
let s: $float = self as $float;
LogProb(s * rhs.0)
}
}
#[expect(clippy::cast_precision_loss)]
impl Mul<LogProb<$float>> for &$unsigned {
type Output = LogProb<$float>;
fn mul(self, rhs: LogProb<$float>) -> Self::Output {
let s: $float = (*self) as $float;
LogProb(s * rhs.0)
}
}
#[expect(clippy::cast_precision_loss)]
impl Mul<$unsigned> for &LogProb<$float> {
type Output = LogProb<$float>;
fn mul(self, rhs: $unsigned) -> Self::Output {
let s: $float = rhs as $float;
LogProb(s * self.0)
}
}
#[expect(clippy::cast_precision_loss)]
impl Mul<&LogProb<$float>> for &$unsigned {
type Output = LogProb<$float>;
fn mul(self, rhs: &LogProb<$float>) -> Self::Output {
let s: $float = (*self) as $float;
LogProb(s * rhs.0)
}
}
#[expect(clippy::cast_precision_loss)]
impl Mul<&$unsigned> for &LogProb<$float> {
type Output = LogProb<$float>;
fn mul(self, rhs: &$unsigned) -> Self::Output {
let s: $float = (*rhs) as $float;
LogProb(s * self.0)
}
}
};
}
impl_mul_lossy!(usize, f64);
impl_mul_lossy!(usize, f32);
impl_mul_lossy!(u64, f64);
impl_mul_lossy!(u64, f32);
impl_mul_lossy!(u32, f32);
impl_mul_lossy!(u128, f32);
impl_mul_lossy!(u128, f64);