scirs2_stats/distributions/
stable.rs1use crate::error::{StatsError, StatsResult};
41use crate::sampling::SampleableDistribution;
42use scirs2_core::numeric::{Float, NumCast};
43use scirs2_core::random::prelude::*;
44use scirs2_core::random::rand_distributions::Distribution as _;
45use scirs2_core::random::Uniform as RandUniform;
46
47pub struct StableDistribution<F: Float> {
61 pub alpha: F,
63 pub beta: F,
65 pub gamma: F,
67 pub delta: F,
69 uniform_distr: RandUniform<f64>,
70}
71
72impl<F: Float + NumCast + std::fmt::Display> StableDistribution<F> {
73 pub fn new(alpha: F, beta: F, gamma: F, delta: F) -> StatsResult<Self> {
82 let alpha_f64: f64 = NumCast::from(alpha).unwrap_or(0.0);
83 let beta_f64: f64 = NumCast::from(beta).unwrap_or(0.0);
84 let gamma_f64: f64 = NumCast::from(gamma).unwrap_or(0.0);
85
86 if alpha_f64 <= 0.0 || alpha_f64 > 2.0 {
87 return Err(StatsError::DomainError(
88 "Stability index alpha must be in (0, 2]".to_string(),
89 ));
90 }
91 if beta_f64 < -1.0 || beta_f64 > 1.0 {
92 return Err(StatsError::DomainError(
93 "Skewness beta must be in [-1, 1]".to_string(),
94 ));
95 }
96 if gamma_f64 <= 0.0 {
97 return Err(StatsError::DomainError(
98 "Scale gamma must be positive".to_string(),
99 ));
100 }
101
102 let uniform_distr = RandUniform::new(0.0_f64, 1.0_f64).map_err(|_| {
103 StatsError::ComputationError("Failed to create uniform distribution".to_string())
104 })?;
105
106 Ok(Self {
107 alpha,
108 beta,
109 gamma,
110 delta,
111 uniform_distr,
112 })
113 }
114
115 pub fn standard(alpha: F, beta: F) -> StatsResult<Self> {
117 Self::new(alpha, beta, F::one(), F::zero())
118 }
119
120 fn cms_sample_standard<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
123 let alpha_f64: f64 = NumCast::from(self.alpha).unwrap_or(2.0);
124 let beta_f64: f64 = NumCast::from(self.beta).unwrap_or(0.0);
125 let pi = std::f64::consts::PI;
126
127 let u_raw: f64 = self.uniform_distr.sample(rng);
129 let w_raw: f64 = self.uniform_distr.sample(rng);
130
131 let u = pi * (u_raw - 0.5); let w = -w_raw.ln().max(-1e300); if (alpha_f64 - 1.0).abs() < 1e-10 {
135 let b_term = if beta_f64.abs() < 1e-15 {
137 0.0
138 } else {
139 let zeta = beta_f64;
140 zeta * (pi / 2.0 + beta_f64 * u).tan()
141 };
142 let term1 = (pi / 2.0 + beta_f64 * u) * u.tan();
143 let term2 = beta_f64 * (w * u.cos() / (pi / 2.0 + beta_f64 * u)).ln();
144 (term1 + term2) * (2.0 / pi) + b_term
145 } else {
146 let zeta = -beta_f64 * (alpha_f64 * pi / 2.0).tan();
148 let xi = (1.0_f64 / alpha_f64) * (-zeta).atan();
149
150 let sin_a_u_xi = (alpha_f64 * (u + xi)).sin();
151 let cos_u = u.cos();
152
153 if cos_u.abs() < 1e-300 {
154 return zeta; }
156
157 let a = sin_a_u_xi / cos_u.powf(1.0 / alpha_f64);
158 let cos_diff = ((1.0 - alpha_f64) * u - alpha_f64 * xi).cos();
159
160 let b_arg = cos_diff / w;
161 if b_arg <= 0.0 {
162 return zeta + a;
163 }
164 let b = b_arg.powf((1.0 - alpha_f64) / alpha_f64);
165
166 a * b - zeta
167 }
168 }
169
170 fn sample_one<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
172 let alpha_f64: f64 = NumCast::from(self.alpha).unwrap_or(2.0);
173 let gamma_f64: f64 = NumCast::from(self.gamma).unwrap_or(1.0);
174 let delta_f64: f64 = NumCast::from(self.delta).unwrap_or(0.0);
175 let beta_f64: f64 = NumCast::from(self.beta).unwrap_or(0.0);
176
177 let z = self.cms_sample_standard(rng);
178
179 if (alpha_f64 - 1.0).abs() < 1e-10 {
181 gamma_f64 * z
182 + delta_f64
183 + (2.0 / std::f64::consts::PI) * beta_f64 * gamma_f64 * gamma_f64.ln()
184 } else {
185 gamma_f64 * z + delta_f64
186 }
187 }
188
189 fn log_char_fn(&self, t: f64) -> (f64, f64) {
191 let alpha_f64: f64 = NumCast::from(self.alpha).unwrap_or(2.0);
192 let beta_f64: f64 = NumCast::from(self.beta).unwrap_or(0.0);
193 let gamma_f64: f64 = NumCast::from(self.gamma).unwrap_or(1.0);
194 let delta_f64: f64 = NumCast::from(self.delta).unwrap_or(0.0);
195 let pi = std::f64::consts::PI;
196
197 let sign_t = if t > 0.0 {
198 1.0
199 } else if t < 0.0 {
200 -1.0
201 } else {
202 0.0
203 };
204 let abs_t = t.abs();
205 let g_t = gamma_f64.powf(alpha_f64) * abs_t.powf(alpha_f64);
206
207 let (re_log_phi, im_log_phi) = if (alpha_f64 - 1.0).abs() < 1e-10 {
208 let re = -gamma_f64 * abs_t;
209 let im =
210 delta_f64 * t - (2.0 / pi) * beta_f64 * gamma_f64 * sign_t * abs_t.ln().max(-700.0);
211 (re, im)
212 } else {
213 let tan_term = (alpha_f64 * pi / 2.0).tan();
214 let re = -g_t;
215 let im = delta_f64 * t + g_t * beta_f64 * sign_t * tan_term;
216 (re, im)
217 };
218
219 (re_log_phi, im_log_phi)
220 }
221
222 fn pdf_by_inversion(&self, x: F) -> F {
226 let x_f64: f64 = NumCast::from(x).unwrap_or(0.0);
227 let pi = std::f64::consts::PI;
228
229 let n_points = 4096_usize;
231 let t_max = 50.0_f64;
232 let dt = t_max / n_points as f64;
233
234 let mut integral = 0.0_f64;
235
236 for k in 1..n_points {
238 let t = k as f64 * dt;
239 let (re_log, im_log) = self.log_char_fn(t);
240 let amp = re_log.exp();
241 let phase = im_log - t * x_f64;
242 let re_integrand = amp * phase.cos();
243
244 let weight = if k == 0 || k == n_points - 1 {
245 0.5
246 } else {
247 1.0
248 };
249 integral += weight * re_integrand * dt;
250 }
251
252 let pdf_val = integral / pi;
253 F::from(pdf_val.max(0.0)).unwrap_or(F::zero())
254 }
255
256 fn cdf_by_integration(&self, x: F) -> F {
258 let x_f64: f64 = NumCast::from(x).unwrap_or(0.0);
259
260 let delta_f64: f64 = NumCast::from(self.delta).unwrap_or(0.0);
262 let gamma_f64: f64 = NumCast::from(self.gamma).unwrap_or(1.0);
263
264 let x_low = delta_f64 - 100.0 * gamma_f64;
265 let n_steps = 2000_usize;
266 let h = (x_f64 - x_low) / n_steps as f64;
267
268 if h <= 0.0 {
269 return F::zero();
270 }
271
272 let mut sum = 0.0_f64;
273 for k in 0..=n_steps {
274 let xi = x_low + k as f64 * h;
275 let xi_f = F::from(xi).unwrap_or(F::zero());
276 let pdf_val: f64 = NumCast::from(self.pdf(xi_f)).unwrap_or(0.0);
277 let weight = if k == 0 || k == n_steps { 0.5 } else { 1.0 };
278 sum += weight * pdf_val * h;
279 }
280
281 F::from(sum.clamp(0.0, 1.0)).unwrap_or(F::zero())
282 }
283
284 pub fn pdf(&self, x: F) -> F {
289 let alpha_f64: f64 = NumCast::from(self.alpha).unwrap_or(2.0);
290
291 if (alpha_f64 - 2.0).abs() < 1e-10 {
292 return self.pdf_normal(x);
294 }
295
296 if (alpha_f64 - 1.0).abs() < 1e-10 {
297 let beta_f64: f64 = NumCast::from(self.beta).unwrap_or(0.0);
298 if beta_f64.abs() < 1e-10 {
299 return self.pdf_cauchy(x);
300 }
301 }
302
303 self.pdf_by_inversion(x)
304 }
305
306 pub fn cdf(&self, x: F) -> F {
308 let alpha_f64: f64 = NumCast::from(self.alpha).unwrap_or(2.0);
309
310 if (alpha_f64 - 2.0).abs() < 1e-10 {
311 return self.cdf_normal(x);
312 }
313 if (alpha_f64 - 1.0).abs() < 1e-10 {
314 let beta_f64: f64 = NumCast::from(self.beta).unwrap_or(0.0);
315 if beta_f64.abs() < 1e-10 {
316 return self.cdf_cauchy(x);
317 }
318 }
319
320 self.cdf_by_integration(x)
321 }
322
323 fn pdf_normal(&self, x: F) -> F {
326 let delta_f64: f64 = NumCast::from(self.delta).unwrap_or(0.0);
327 let gamma_f64: f64 = NumCast::from(self.gamma).unwrap_or(1.0);
328 let x_f64: f64 = NumCast::from(x).unwrap_or(0.0);
329 let sigma = 2.0_f64.sqrt() * gamma_f64;
330 let z = (x_f64 - delta_f64) / sigma;
331 let pdf = (-0.5 * z * z).exp() / (sigma * (2.0 * std::f64::consts::PI).sqrt());
332 F::from(pdf).unwrap_or(F::zero())
333 }
334
335 fn cdf_normal(&self, x: F) -> F {
336 let delta_f64: f64 = NumCast::from(self.delta).unwrap_or(0.0);
337 let gamma_f64: f64 = NumCast::from(self.gamma).unwrap_or(1.0);
338 let x_f64: f64 = NumCast::from(x).unwrap_or(0.0);
339 let sigma = 2.0_f64.sqrt() * gamma_f64;
340 let z = (x_f64 - delta_f64) / sigma;
341 let cdf = 0.5 * (1.0 + erf_approx(z / 2.0_f64.sqrt()));
342 F::from(cdf).unwrap_or(F::zero())
343 }
344
345 fn pdf_cauchy(&self, x: F) -> F {
346 let delta_f64: f64 = NumCast::from(self.delta).unwrap_or(0.0);
347 let gamma_f64: f64 = NumCast::from(self.gamma).unwrap_or(1.0);
348 let x_f64: f64 = NumCast::from(x).unwrap_or(0.0);
349 let z = (x_f64 - delta_f64) / gamma_f64;
350 let pdf = 1.0 / (std::f64::consts::PI * gamma_f64 * (1.0 + z * z));
351 F::from(pdf).unwrap_or(F::zero())
352 }
353
354 fn cdf_cauchy(&self, x: F) -> F {
355 let delta_f64: f64 = NumCast::from(self.delta).unwrap_or(0.0);
356 let gamma_f64: f64 = NumCast::from(self.gamma).unwrap_or(1.0);
357 let x_f64: f64 = NumCast::from(x).unwrap_or(0.0);
358 let z = (x_f64 - delta_f64) / gamma_f64;
359 let cdf = 0.5 + z.atan() / std::f64::consts::PI;
360 F::from(cdf).unwrap_or(F::zero())
361 }
362
363 pub fn rvs<R: Rng + ?Sized>(&self, n: usize, rng: &mut R) -> StatsResult<Vec<F>> {
365 let mut samples = Vec::with_capacity(n);
366 for _ in 0..n {
367 let s = self.sample_one(rng);
368 let f_s = F::from(s).ok_or_else(|| {
369 StatsError::ComputationError("Failed to convert sample to F".to_string())
370 })?;
371 samples.push(f_s);
372 }
373 Ok(samples)
374 }
375
376 pub fn mean(&self) -> Option<F> {
380 let alpha_f64: f64 = NumCast::from(self.alpha).unwrap_or(0.0);
381 if alpha_f64 > 1.0 {
382 Some(self.delta)
383 } else {
384 None
385 }
386 }
387
388 pub fn variance(&self) -> Option<F> {
392 let alpha_f64: f64 = NumCast::from(self.alpha).unwrap_or(0.0);
393 if (alpha_f64 - 2.0).abs() < 1e-10 {
394 let two = F::from(2.0).unwrap_or(F::one() + F::one());
395 Some(two * self.gamma * self.gamma)
396 } else {
397 None
398 }
399 }
400}
401
402fn erf_approx(x: f64) -> f64 {
404 let t = 1.0 / (1.0 + 0.3275911 * x.abs());
405 let poly = t
406 * (0.254829592
407 + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
408 let sign = if x >= 0.0 { 1.0 } else { -1.0 };
409 sign * (1.0 - poly * (-x * x).exp())
410}
411
412impl<F: Float + NumCast + std::fmt::Display> SampleableDistribution<F> for StableDistribution<F> {
413 fn rvs(&self, size: usize) -> StatsResult<Vec<F>> {
414 use scirs2_core::random::rngs::SmallRng;
415 use scirs2_core::random::SeedableRng;
416 let seed = std::time::SystemTime::now()
417 .duration_since(std::time::UNIX_EPOCH)
418 .map(|d| d.as_nanos() as u64)
419 .unwrap_or(0x9e3779b97f4a7c15);
420 let mut rng = SmallRng::seed_from_u64(seed);
421 self.rvs(size, &mut rng)
422 }
423}
424
425#[cfg(test)]
426mod tests {
427 use super::*;
428 use scirs2_core::random::{rngs::SmallRng, SeedableRng};
429
430 #[test]
431 fn test_normal_special_case() {
432 let stable = StableDistribution::new(2.0f64, 0.0, 1.0, 0.0).expect("valid params");
434 let pdf_0 = stable.pdf(0.0);
436 let expected = 1.0 / (2.0 * std::f64::consts::PI.sqrt());
437 assert!((pdf_0 - expected).abs() < 1e-6, "pdf_0={}", pdf_0);
438 }
439
440 #[test]
441 fn test_cauchy_special_case() {
442 let stable = StableDistribution::new(1.0f64, 0.0, 1.0, 0.0).expect("valid params");
444 let pdf_0 = stable.pdf(0.0);
445 let expected = std::f64::consts::FRAC_1_PI;
446 assert!((pdf_0 - expected).abs() < 1e-6, "pdf_0={}", pdf_0);
447 }
448
449 #[test]
450 fn test_sampling() {
451 let mut rng = SmallRng::seed_from_u64(42);
452 let stable = StableDistribution::new(1.5f64, 0.0, 1.0, 0.0).expect("valid params");
453 let samples = stable.rvs(500, &mut rng).expect("sampling should succeed");
454 assert_eq!(samples.len(), 500);
455 let mut s: Vec<f64> = samples;
457 s.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
458 let median = s[250];
459 assert!(
460 median.abs() < 2.0,
461 "median {} far from 0 for symmetric stable",
462 median
463 );
464 }
465
466 #[test]
467 fn test_mean_variance() {
468 let stable_normal = StableDistribution::new(2.0f64, 0.0, 1.0, 0.0).expect("valid");
469 assert!(stable_normal.mean().is_some());
470 assert!(stable_normal.variance().is_some());
471
472 let stable_cauchy = StableDistribution::new(1.0f64, 0.0, 1.0, 0.0).expect("valid");
473 assert!(stable_cauchy.mean().is_none()); assert!(stable_cauchy.variance().is_none());
475
476 let stable_15 = StableDistribution::new(1.5f64, 0.0, 1.0, 5.0).expect("valid");
477 assert_eq!(stable_15.mean().expect("mean should exist"), 5.0_f64);
478 assert!(stable_15.variance().is_none());
479 }
480
481 #[test]
482 fn test_invalid_params() {
483 assert!(StableDistribution::new(0.0f64, 0.0, 1.0, 0.0).is_err()); assert!(StableDistribution::new(2.5f64, 0.0, 1.0, 0.0).is_err()); assert!(StableDistribution::new(1.5f64, 1.5, 1.0, 0.0).is_err()); assert!(StableDistribution::new(1.5f64, 0.0, -1.0, 0.0).is_err()); }
488}