logprob/
lib.rs

1//! This crate defines a basic [`LogProb`] wrapper for floats. The struct is designed so
2//! that only values that are coherent for a log-probability are acceptable. This means that
3//! [`LogProb`] can store:
4//!     - Any finite negative float value (e.g. -0.23, -32535.05, -66.0).
5//!     - Negative infinity (corresponding to 0.0 probability)
6//!     - 0.0 *and* -0.0.
7//!
8//! If any other value is passed, [`LogProb::new`] returns a [`FloatIsNanOrPositive`] error.
9//! You can also construct new [`LogProb`] from values in \[0,1\] by using
10//! [`LogProb::from_raw_prob`]
11//!
12//! The crate also includes the ability to add log probabilities (equivalent take the product of
13//! their corresponding raw probabilities):
14//!
15//! ```
16//! use logprob::LogProb;
17//! let x = LogProb::from_raw_prob(0.5).unwrap();
18//! let y = LogProb::from_raw_prob(0.5).unwrap();
19//! let z = x + y;
20//! assert_eq!(z, LogProb::from_raw_prob(0.25).unwrap());
21//! ```
22//!
23//! It is also possible to take product of a [`LogProb`] and an unsigned integer, which
24//! corresponds to taking the exponent of the log-probability to the power of the integer.
25//! ```
26//! # use logprob::LogProb;
27//! let x = LogProb::from_raw_prob(0.5_f64).unwrap();
28//! let y: u8 = 2;
29//! let z = x * y;
30//! assert_eq!(z, LogProb::from_raw_prob(0.25).unwrap());
31//! ```
32//!
33//!Finally, the crate also includes reasonably efficient implementations of
34//![LogSumExp](https://en.wikipedia.org/wiki/LogSumExp) so that one can take the sum of
35//!raw-probabilities directly with [`LogProb`].
36//!
37//! ```
38//! # use logprob::LogProb;
39//! let x = LogProb::from_raw_prob(0.5_f64).unwrap();
40//! let y = LogProb::from_raw_prob(0.25).unwrap();
41//! let z = x.add_log_prob(y).unwrap();
42//! assert_eq!(z, LogProb::from_raw_prob(0.75).unwrap());
43//! ```
44//!
45//! This can also work for slices or iterators (by importing [`log_sum_exp`] or the trait,
46//! [`LogSumExp`] respectively. Note that for empty vectors or iterators, the
47//! functions return a [`LogProb`] with negative infinity, corresponding to 0 probability.
48//! ```
49//! # use logprob::LogProb;
50//! use logprob::{LogSumExp, log_sum_exp};
51//! let x = LogProb::from_raw_prob(0.5_f64).unwrap();
52//! let y = LogProb::from_raw_prob(0.25).unwrap();
53//! let z = [x,y].iter().log_sum_exp().unwrap();
54//! assert_eq!(z, LogProb::from_raw_prob(0.75).unwrap());
55//! let v = log_sum_exp(&[x,y]).unwrap();
56//! assert_eq!(z, LogProb::from_raw_prob(0.75).unwrap());
57//! ```
58//!
59//! By default, the both [`log_sum_exp`] and [`LogProb::add_log_prob`] return a
60//! [`ProbabilitiesSumToGreaterThanOne`] error if the sum is overflows what is a possible
61//! [`LogProb`] value. However, one can use either the `clamped` or `float` versions of these
62//! functions to return either a value clamped at 0.0 or the underlying float value which may be
63//! greater than 0.0.
64//! ```
65//! # use logprob::LogProb;
66//! # use logprob::{LogSumExp, log_sum_exp};
67//! let x = LogProb::from_raw_prob(0.5_f64).unwrap();
68//! let y = LogProb::from_raw_prob(0.75).unwrap();
69//! let z = [x,y].into_iter().log_sum_exp_clamped();
70//! assert_eq!(z, LogProb::new(0.0).unwrap());
71//! let z = [x,y].into_iter().log_sum_exp_float();
72//! approx::assert_relative_eq!(z, (1.25_f64).ln());
73//!
74//! ```
75//!
76
77#![warn(
78    anonymous_parameters,
79    missing_copy_implementations,
80    missing_debug_implementations,
81    missing_docs,
82    rust_2018_idioms,
83    nonstandard_style,
84    single_use_lifetimes,
85    rustdoc::broken_intra_doc_links,
86    trivial_casts,
87    trivial_numeric_casts,
88    unreachable_pub,
89    unused_extern_crates,
90    unused_qualifications,
91    variant_size_differences
92)]
93
94use std::borrow::Borrow;
95
96use num_traits::Float;
97mod errors;
98pub use errors::{
99    FloatIsNanOrPositive, FloatIsNanOrPositiveInfinity, ProbabilitiesSumToGreaterThanOne,
100};
101use serde::{Deserialize, Serialize};
102mod adding;
103mod math;
104mod softmax;
105pub use softmax::{softmax, Softmax};
106
107#[derive(Copy, Clone, PartialEq, PartialOrd, Debug, Default, Serialize, Deserialize)]
108
109///Struct that can only hold float values that correspond to negative log
110///probabilities.
111#[repr(transparent)]
112pub struct LogProb<T>(T);
113pub use adding::{log_sum_exp, log_sum_exp_clamped, log_sum_exp_float, LogSumExp};
114
115impl<T: Float> LogProb<T> {
116    ///Construct a new [`LogProb`] that is guaranteed to be negative (or +0.0).
117    pub fn new(val: T) -> Result<Self, FloatIsNanOrPositive> {
118        if val.is_nan() || (!val.is_zero() && val.is_sign_positive()) {
119            Err(FloatIsNanOrPositive)
120        } else {
121            Ok(LogProb(val))
122        }
123    }
124
125    ///Construct a new [`LogProb`] that is guaranteed to be negative (or +0.0) from a value in [0.0, 1.0].
126    pub fn from_raw_prob(val: T) -> Result<Self, FloatIsNanOrPositive> {
127        let val = val.ln();
128        if val.is_nan() || (!val.is_zero() && val.is_sign_positive()) {
129            Err(FloatIsNanOrPositive)
130        } else {
131            Ok(LogProb(val))
132        }
133    }
134
135    ///Constructs a new LogProb which corresponds to a probability of zero (e.g. neg infinity)
136    pub fn prob_of_zero() -> Self {
137        LogProb(T::neg_infinity())
138    }
139
140    ///Constructs a new LogProb which corresponds to a probability of one (e.g. the log prob is
141    ///equal to 0)
142    pub fn prob_of_one() -> Self {
143        LogProb(T::zero())
144    }
145
146    /// Gets out the value.
147    #[inline]
148    pub fn into_inner(self) -> T {
149        self.0
150    }
151
152    /// Get the equivalent non-log probability
153    /// ```
154    /// # use logprob::LogProb;
155    /// let x = LogProb::from_raw_prob(0.25).unwrap();
156    /// assert_eq!(x.raw_prob(), 0.25);
157    /// ```
158    #[inline]
159    pub fn raw_prob(&self) -> T {
160        self.0.exp()
161    }
162
163    /// Calculates the probability of the complement of this log-probability
164    /// ```
165    /// # use logprob::LogProb;
166    /// let x = LogProb::from_raw_prob(0.25).unwrap();
167    /// let y = LogProb::from_raw_prob(0.75).unwrap();
168    /// assert_eq!(x.opposite_prob(), y);
169    /// ```
170    pub fn opposite_prob(&self) -> Self {
171        LogProb((-self.0.exp()).ln_1p())
172    }
173}
174
175impl<T: Float + std::fmt::Display> std::fmt::Display for LogProb<T> {
176    #[inline]
177    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
178        self.0.fmt(f)
179    }
180}
181
182impl Borrow<f32> for LogProb<f32> {
183    #[inline]
184    fn borrow(&self) -> &f32 {
185        &self.0
186    }
187}
188
189impl Borrow<f64> for LogProb<f64> {
190    #[inline]
191    fn borrow(&self) -> &f64 {
192        &self.0
193    }
194}
195
196impl<T: Float> Eq for LogProb<T> {}
197
198#[allow(clippy::derive_ord_xor_partial_ord)]
199impl<T: Float> Ord for LogProb<T> {
200    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
201        self.0.partial_cmp(&other.0).unwrap()
202    }
203}
204
205impl From<LogProb<f32>> for f32 {
206    #[inline]
207    fn from(f: LogProb<f32>) -> f32 {
208        f.0
209    }
210}
211
212impl From<LogProb<f64>> for f64 {
213    #[inline]
214    fn from(f: LogProb<f64>) -> f64 {
215        f.0
216    }
217}
218impl TryFrom<f64> for LogProb<f64> {
219    type Error = FloatIsNanOrPositive;
220
221    fn try_from(value: f64) -> Result<Self, Self::Error> {
222        LogProb::new(value)
223    }
224}
225
226impl TryFrom<f32> for LogProb<f32> {
227    type Error = FloatIsNanOrPositive;
228
229    fn try_from(value: f32) -> Result<Self, Self::Error> {
230        LogProb::new(value)
231    }
232}