1use crate::{CyaneaError, Result};
8
9#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
14pub struct LogProb(pub f64);
15
16impl LogProb {
17 pub fn from_prob(p: f64) -> Result<Self> {
23 if p <= 0.0 || p > 1.0 {
24 return Err(CyaneaError::InvalidInput(
25 "LogProb::from_prob: p must be in (0, 1]".into(),
26 ));
27 }
28 Ok(Self(p.ln()))
29 }
30
31 pub fn to_prob(self) -> f64 {
33 self.0.exp()
34 }
35
36 pub fn ln_add(self, other: Self) -> Self {
40 if self.0 == f64::NEG_INFINITY {
41 return other;
42 }
43 if other.0 == f64::NEG_INFINITY {
44 return self;
45 }
46 let (max, min) = if self.0 >= other.0 {
47 (self.0, other.0)
48 } else {
49 (other.0, self.0)
50 };
51 Self(max + (min - max).exp().ln_1p())
52 }
53
54 pub fn ln_mul(self, other: Self) -> Self {
56 Self(self.0 + other.0)
57 }
58
59 pub const fn certain() -> Self {
61 Self(0.0)
62 }
63
64 pub const fn impossible() -> Self {
66 Self(f64::NEG_INFINITY)
67 }
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
74pub struct PhredProb(pub f64);
75
76const PHRED_SCALE: f64 = 10.0 / core::f64::consts::LN_10;
78
79impl PhredProb {
80 pub fn from_phred(q: f64) -> Result<Self> {
86 if q < 0.0 {
87 return Err(CyaneaError::InvalidInput(
88 "PhredProb::from_phred: q must be non-negative".into(),
89 ));
90 }
91 Ok(Self(q))
92 }
93
94 pub fn from_prob(p: f64) -> Result<Self> {
100 if p <= 0.0 || p > 1.0 {
101 return Err(CyaneaError::InvalidInput(
102 "PhredProb::from_prob: p must be in (0, 1]".into(),
103 ));
104 }
105 Ok(Self(-10.0 * p.log10()))
106 }
107
108 pub fn to_phred(self) -> f64 {
110 self.0
111 }
112
113 pub fn to_prob(self) -> f64 {
115 10.0_f64.powf(-self.0 / 10.0)
116 }
117}
118
119impl From<PhredProb> for LogProb {
120 fn from(phred: PhredProb) -> Self {
122 Self(-phred.0 / PHRED_SCALE)
123 }
124}
125
126impl From<LogProb> for PhredProb {
127 fn from(lp: LogProb) -> Self {
129 Self(-lp.0 * PHRED_SCALE)
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136
137 const TOL: f64 = 1e-10;
138
139 #[test]
140 fn logprob_from_prob_one() {
141 let lp = LogProb::from_prob(1.0).unwrap();
142 assert!((lp.0 - 0.0).abs() < TOL);
143 }
144
145 #[test]
146 fn logprob_from_prob_half() {
147 let lp = LogProb::from_prob(0.5).unwrap();
148 assert!((lp.0 - 0.5_f64.ln()).abs() < TOL);
149 }
150
151 #[test]
152 fn logprob_roundtrip() {
153 let p = 0.001;
154 let lp = LogProb::from_prob(p).unwrap();
155 assert!((lp.to_prob() - p).abs() < TOL);
156 }
157
158 #[test]
159 fn logprob_invalid() {
160 assert!(LogProb::from_prob(0.0).is_err());
161 assert!(LogProb::from_prob(-0.5).is_err());
162 assert!(LogProb::from_prob(1.5).is_err());
163 }
164
165 #[test]
166 fn logprob_certain_impossible() {
167 assert_eq!(LogProb::certain().0, 0.0);
168 assert_eq!(LogProb::certain().to_prob(), 1.0);
169 assert_eq!(LogProb::impossible().0, f64::NEG_INFINITY);
170 assert_eq!(LogProb::impossible().to_prob(), 0.0);
171 }
172
173 #[test]
174 fn logprob_ln_mul() {
175 let a = LogProb::from_prob(0.5).unwrap();
176 let b = LogProb::from_prob(0.5).unwrap();
177 let product = a.ln_mul(b);
178 assert!((product.to_prob() - 0.25).abs() < TOL);
179 }
180
181 #[test]
182 fn logprob_ln_add() {
183 let a = LogProb::from_prob(0.3).unwrap();
184 let b = LogProb::from_prob(0.2).unwrap();
185 let sum = a.ln_add(b);
186 assert!((sum.to_prob() - 0.5).abs() < TOL);
187 }
188
189 #[test]
190 fn logprob_ln_add_identity() {
191 let a = LogProb::from_prob(0.7).unwrap();
192 let sum = a.ln_add(LogProb::impossible());
193 assert!((sum.to_prob() - 0.7).abs() < TOL);
194
195 let sum2 = LogProb::impossible().ln_add(a);
196 assert!((sum2.to_prob() - 0.7).abs() < TOL);
197 }
198
199 #[test]
200 fn phredprob_from_phred() {
201 let q = PhredProb::from_phred(30.0).unwrap();
202 assert!((q.to_prob() - 0.001).abs() < 1e-10);
203 }
204
205 #[test]
206 fn phredprob_from_prob() {
207 let q = PhredProb::from_prob(0.001).unwrap();
208 assert!((q.to_phred() - 30.0).abs() < 1e-8);
209 }
210
211 #[test]
212 fn phredprob_roundtrip() {
213 let q = PhredProb::from_phred(20.0).unwrap();
214 let p = q.to_prob();
215 let q2 = PhredProb::from_prob(p).unwrap();
216 assert!((q2.to_phred() - 20.0).abs() < 1e-8);
217 }
218
219 #[test]
220 fn phredprob_invalid() {
221 assert!(PhredProb::from_phred(-1.0).is_err());
222 assert!(PhredProb::from_prob(0.0).is_err());
223 assert!(PhredProb::from_prob(1.5).is_err());
224 }
225
226 #[test]
227 fn convert_phred_to_logprob() {
228 let phred = PhredProb::from_phred(30.0).unwrap();
229 let lp: LogProb = phred.into();
230 assert!((lp.to_prob() - 0.001).abs() < 1e-10);
231 }
232
233 #[test]
234 fn convert_logprob_to_phred() {
235 let lp = LogProb::from_prob(0.001).unwrap();
236 let phred: PhredProb = lp.into();
237 assert!((phred.to_phred() - 30.0).abs() < 1e-6);
238 }
239
240 #[test]
241 fn phred_logprob_roundtrip_conversion() {
242 let original = PhredProb::from_phred(25.0).unwrap();
243 let lp: LogProb = original.into();
244 let back: PhredProb = lp.into();
245 assert!((back.to_phred() - 25.0).abs() < 1e-10);
246 }
247}