1use thiserror::Error;
2
3use crate::error::SearchError;
4use crate::search::{search_bounded_zero, search_monotone, SEARCH_BOUND};
5use crate::special::beta_inc;
6use crate::special::{beta_log, psi};
7use crate::traits::{Continuous, ContinuousCdf, Entropy, Mean, Variance};
8
9#[derive(Debug, Clone, Copy, PartialEq)]
29pub struct Beta {
30 a: f64,
31 b: f64,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Error)]
38pub enum BetaError {
39 #[error("shape parameter `a` must be positive, got {0}")]
41 ANotPositive(f64),
42 #[error("shape parameter `a` must be finite, got {0}")]
44 ANotFinite(f64),
45 #[error("shape parameter `b` must be positive, got {0}")]
47 BNotPositive(f64),
48 #[error("shape parameter `b` must be finite, got {0}")]
50 BNotFinite(f64),
51 #[error("argument x must be in [0..1], got {0}")]
53 XOutOfRange(f64),
54 #[error("probability {0} outside [0..1]")]
56 PNotInRange(f64),
57 #[error("probability {0} outside [0..1]")]
59 QNotInRange(f64),
60 #[error("p ({p}) and q ({q}) are not complementary: |p + q - 1| > 3ε")]
63 PQSumNotOne { p: f64, q: f64 },
64 #[error(transparent)]
68 Search(#[from] SearchError),
69}
70
71impl Beta {
72 #[inline]
81 pub fn new(a: f64, b: f64) -> Self {
82 Self::try_new(a, b).unwrap()
83 }
84
85 #[inline]
96 pub fn try_new(a: f64, b: f64) -> Result<Self, BetaError> {
97 if !a.is_finite() {
98 return Err(BetaError::ANotFinite(a));
99 }
100 if a <= 0.0 {
101 return Err(BetaError::ANotPositive(a));
102 }
103 if !b.is_finite() {
104 return Err(BetaError::BNotFinite(b));
105 }
106 if b <= 0.0 {
107 return Err(BetaError::BNotPositive(b));
108 }
109 Ok(Self { a, b })
110 }
111
112 #[inline]
114 pub const fn a(&self) -> f64 {
115 self.a
116 }
117
118 #[inline]
120 pub const fn b(&self) -> f64 {
121 self.b
122 }
123
124 #[inline]
129 pub fn search_a(p: f64, q: f64, x: f64, b: f64) -> Result<f64, BetaError> {
130 check_pq(p, q)?;
131 if !(0.0..=1.0).contains(&x) {
132 return Err(BetaError::XOutOfRange(x));
133 }
134 if !b.is_finite() {
135 return Err(BetaError::BNotFinite(b));
136 }
137 if b <= 0.0 {
138 return Err(BetaError::BNotPositive(b));
139 }
140 let f = |a: f64| {
141 let (cum, ccum) = beta_inc(a, b, x, 1.0 - x);
142 if p <= q {
143 cum - p
144 } else {
145 ccum - q
146 }
147 };
148 Ok(search_monotone(
152 0.0,
153 SEARCH_BOUND,
154 5.0,
155 0.0,
156 SEARCH_BOUND,
157 f,
158 )?)
159 }
160
161 #[inline]
166 pub fn search_b(p: f64, q: f64, x: f64, a: f64) -> Result<f64, BetaError> {
167 check_pq(p, q)?;
168 if !(0.0..=1.0).contains(&x) {
169 return Err(BetaError::XOutOfRange(x));
170 }
171 if !a.is_finite() {
172 return Err(BetaError::ANotFinite(a));
173 }
174 if a <= 0.0 {
175 return Err(BetaError::ANotPositive(a));
176 }
177 let f = |b: f64| {
178 let (cum, ccum) = beta_inc(a, b, x, 1.0 - x);
179 if p <= q {
180 cum - p
181 } else {
182 ccum - q
183 }
184 };
185 Ok(search_monotone(
188 0.0,
189 SEARCH_BOUND,
190 5.0,
191 0.0,
192 SEARCH_BOUND,
193 f,
194 )?)
195 }
196}
197
198#[inline]
199fn check_p(p: f64) -> Result<(), BetaError> {
200 if !(0.0..=1.0).contains(&p) || !p.is_finite() {
201 Err(BetaError::PNotInRange(p))
202 } else {
203 Ok(())
204 }
205}
206
207#[inline]
208fn check_q(q: f64) -> Result<(), BetaError> {
209 if !(0.0..=1.0).contains(&q) || !q.is_finite() {
210 Err(BetaError::QNotInRange(q))
211 } else {
212 Ok(())
213 }
214}
215
216#[inline]
217fn check_pq(p: f64, q: f64) -> Result<(), BetaError> {
218 check_p(p)?;
219 check_q(q)?;
220 if (p + q - 1.0).abs() > 3.0 * f64::EPSILON {
221 return Err(BetaError::PQSumNotOne { p, q });
222 }
223 Ok(())
224}
225
226impl ContinuousCdf for Beta {
227 type Error = BetaError;
228
229 #[inline]
230 fn cdf(&self, x: f64) -> f64 {
231 if x <= 0.0 {
232 return 0.0;
233 }
234 if x >= 1.0 {
235 return 1.0;
236 }
237 let (cum, _) = beta_inc(self.a, self.b, x, 1.0 - x);
238 cum
239 }
240
241 #[inline]
242 fn ccdf(&self, x: f64) -> f64 {
243 if x <= 0.0 {
244 return 1.0;
245 }
246 if x >= 1.0 {
247 return 0.0;
248 }
249 let (_, ccum) = beta_inc(self.a, self.b, x, 1.0 - x);
250 ccum
251 }
252
253 #[inline]
254 fn inverse_cdf(&self, p: f64) -> Result<f64, BetaError> {
255 check_p(p)?;
256 if p == 0.0 {
257 return Ok(0.0);
258 }
259 if p == 1.0 {
260 return Ok(1.0);
261 }
262 let a = self.a;
263 let b = self.b;
264 let q = 1.0 - p;
265 if p <= q {
268 let f = |x: f64| {
269 let (cum, _) = beta_inc(a, b, x, 1.0 - x);
270 cum - p
271 };
272 Ok(search_bounded_zero(0.0, 1.0, f)?)
273 } else {
274 let f = |y: f64| {
275 let (_, ccum) = beta_inc(a, b, 1.0 - y, y);
276 ccum - q
277 };
278 let y = search_bounded_zero(0.0, 1.0, f)?;
279 Ok(1.0 - y)
280 }
281 }
282}
283
284impl Beta {
285 #[inline]
292 pub fn inverse_ccdf(&self, q: f64) -> Result<f64, BetaError> {
293 check_q(q)?;
294 if q == 1.0 {
295 return Ok(0.0);
296 }
297 if q == 0.0 {
298 return Ok(1.0);
299 }
300 let a = self.a;
301 let b = self.b;
302 let p = 1.0 - q;
303 if p <= q {
309 let f = |x: f64| {
310 let (cum, _) = beta_inc(a, b, x, 1.0 - x);
311 cum - p
312 };
313 Ok(search_bounded_zero(0.0, 1.0, f)?)
314 } else {
315 let f = |y: f64| {
316 let (_, ccum) = beta_inc(a, b, 1.0 - y, y);
317 ccum - q
318 };
319 let y = search_bounded_zero(0.0, 1.0, f)?;
320 Ok(1.0 - y)
321 }
322 }
323}
324
325impl Continuous for Beta {
326 #[inline]
327 fn pdf(&self, x: f64) -> f64 {
328 if x <= 0.0 || x >= 1.0 {
329 return 0.0;
330 }
331 self.ln_pdf(x).exp()
332 }
333 #[inline]
334 fn ln_pdf(&self, x: f64) -> f64 {
335 if x <= 0.0 || x >= 1.0 {
336 return f64::NEG_INFINITY;
337 }
338 (self.a - 1.0) * x.ln() + (self.b - 1.0) * (1.0 - x).ln() - beta_log(self.a, self.b)
339 }
340}
341
342impl Mean for Beta {
343 #[inline]
344 fn mean(&self) -> f64 {
345 self.a / (self.a + self.b)
346 }
347}
348
349impl Variance for Beta {
350 #[inline]
351 fn variance(&self) -> f64 {
352 let s = self.a + self.b;
353 self.a * self.b / (s * s * (s + 1.0))
354 }
355}
356
357impl Entropy for Beta {
358 #[inline]
359 fn entropy(&self) -> f64 {
360 beta_log(self.a, self.b) - (self.a - 1.0) * psi(self.a) - (self.b - 1.0) * psi(self.b)
362 + (self.a + self.b - 2.0) * psi(self.a + self.b)
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369
370 #[test]
371 fn rejects_invalid_parameters() {
372 assert!(matches!(
373 Beta::try_new(0.0, 1.0),
374 Err(BetaError::ANotPositive(0.0))
375 ));
376 assert!(matches!(
377 Beta::try_new(1.0, 0.0),
378 Err(BetaError::BNotPositive(0.0))
379 ));
380 assert!(matches!(
381 Beta::try_new(f64::NAN, 1.0),
382 Err(BetaError::ANotFinite(_))
383 ));
384 assert!(matches!(
385 Beta::try_new(1.0, f64::INFINITY),
386 Err(BetaError::BNotFinite(_))
387 ));
388 }
389
390 #[test]
391 fn inverse_boundaries_and_density_edges() {
392 let d = Beta::new(2.0, 3.0);
393 assert_eq!(d.cdf(0.0), 0.0);
394 assert_eq!(d.cdf(1.0), 1.0);
395 assert_eq!(d.ccdf(0.0), 1.0);
396 assert_eq!(d.ccdf(1.0), 0.0);
397 assert_eq!(d.inverse_cdf(0.0).unwrap(), 0.0);
398 assert_eq!(d.inverse_cdf(1.0).unwrap(), 1.0);
399 assert_eq!(d.inverse_ccdf(1.0).unwrap(), 0.0);
400 assert_eq!(d.inverse_ccdf(0.0).unwrap(), 1.0);
401 assert_eq!(d.pdf(0.0), 0.0);
402 assert_eq!(d.pdf(1.0), 0.0);
403 assert_eq!(d.ln_pdf(0.0), f64::NEG_INFINITY);
404 assert_eq!(d.ln_pdf(1.0), f64::NEG_INFINITY);
405 assert!(d.pdf(0.4).is_finite());
406 assert!(d.ln_pdf(0.4).is_finite());
407 assert!(d.inverse_ccdf(0.4).unwrap().is_finite());
408 assert!(d.mean().is_finite());
409 assert!(d.variance().is_finite());
410 assert!(d.entropy().is_finite());
411 }
412
413 #[test]
414 fn search_parameter_rejects_invalid_inputs() {
415 assert!(matches!(
416 Beta::search_a(-0.1, 1.1, 0.5, 2.0),
417 Err(BetaError::PNotInRange(-0.1))
418 ));
419 assert!(matches!(
420 Beta::search_a(0.5, 0.5, 0.5, 0.0),
421 Err(BetaError::BNotPositive(0.0))
422 ));
423 assert!(matches!(
424 Beta::search_b(0.5, 0.5, 0.5, 0.0),
425 Err(BetaError::ANotPositive(0.0))
426 ));
427 assert!(matches!(
428 Beta::search_a(0.5, 0.5, 1.5, 2.0),
429 Err(BetaError::XOutOfRange(1.5))
430 ));
431 assert!(matches!(
432 Beta::search_b(0.5, 0.5, -0.1, 2.0),
433 Err(BetaError::XOutOfRange(x)) if x == -0.1
434 ));
435 }
436}