#![warn(
anonymous_parameters,
missing_copy_implementations,
missing_debug_implementations,
missing_docs,
rust_2018_idioms,
nonstandard_style,
single_use_lifetimes,
rustdoc::broken_intra_doc_links,
trivial_casts,
trivial_numeric_casts,
unreachable_pub,
unused_extern_crates,
unused_qualifications,
variant_size_differences
)]
#![no_std]
#[cfg(feature = "std")]
extern crate std;
#[cfg(feature = "alloc")]
extern crate alloc;
use core::borrow::Borrow;
use core::hash::Hash;
use num_traits::Float;
mod errors;
pub use errors::{
FloatIsNanOrPositive, FloatIsNanOrPositiveInfinity, ProbabilitiesSumToGreaterThanOne,
};
mod adding;
mod math;
#[cfg(feature = "alloc")]
mod softmax;
#[cfg(feature = "alloc")]
pub use softmax::{softmax, Softmax};
#[derive(Copy, Clone, PartialEq, PartialOrd, Debug, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[repr(transparent)]
pub struct LogProb<T>(T);
pub use adding::{log_sum_exp, log_sum_exp_clamped, log_sum_exp_float, LogSumExp};
impl<T: Float> LogProb<T> {
pub fn new(val: T) -> Result<Self, FloatIsNanOrPositive> {
if val.is_nan() || (!val.is_zero() && val.is_sign_positive()) {
Err(FloatIsNanOrPositive)
} else {
Ok(LogProb(val))
}
}
pub fn from_raw_prob(val: T) -> Result<Self, FloatIsNanOrPositive> {
let val = val.ln();
if val.is_nan() || (!val.is_zero() && val.is_sign_positive()) {
Err(FloatIsNanOrPositive)
} else {
Ok(LogProb(val))
}
}
#[must_use]
pub fn prob_of_zero() -> Self {
LogProb(T::neg_infinity())
}
#[must_use]
pub fn prob_of_one() -> Self {
LogProb(T::zero())
}
#[inline]
#[must_use]
pub const fn into_inner(self) -> T {
self.0
}
#[inline]
#[must_use]
pub fn raw_prob(&self) -> T {
self.0.exp()
}
#[must_use]
pub fn opposite_prob(&self) -> Self {
LogProb((-self.0.exp()).ln_1p())
}
}
impl<T: Float + core::fmt::Display> core::fmt::Display for LogProb<T> {
#[inline]
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
self.0.fmt(f)
}
}
impl Borrow<f32> for LogProb<f32> {
#[inline]
fn borrow(&self) -> &f32 {
&self.0
}
}
impl Borrow<f64> for LogProb<f64> {
#[inline]
fn borrow(&self) -> &f64 {
&self.0
}
}
impl<T: Float> Eq for LogProb<T> {}
#[allow(clippy::derive_ord_xor_partial_ord)]
impl<T: Float> Ord for LogProb<T> {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
self.0.partial_cmp(&other.0).unwrap()
}
}
impl<T: Hash> Hash for LogProb<T> {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.0.hash(state);
}
}
impl From<LogProb<f32>> for f32 {
#[inline]
fn from(f: LogProb<f32>) -> f32 {
f.0
}
}
impl From<LogProb<f64>> for f64 {
#[inline]
fn from(f: LogProb<f64>) -> f64 {
f.0
}
}
impl TryFrom<f64> for LogProb<f64> {
type Error = FloatIsNanOrPositive;
fn try_from(value: f64) -> Result<Self, Self::Error> {
LogProb::new(value)
}
}
impl TryFrom<f32> for LogProb<f32> {
type Error = FloatIsNanOrPositive;
fn try_from(value: f32) -> Result<Self, Self::Error> {
LogProb::new(value)
}
}