1use crate::{UtilsError, UtilsResult};
4use scirs2_core::ndarray::Array1;
5use scirs2_core::numeric::{Float, FromPrimitive};
6use std::cmp::Ordering;
7
8pub mod constants {
10 pub const PI: f64 = std::f64::consts::PI;
11 pub const E: f64 = std::f64::consts::E;
12 pub const LN_2: f64 = std::f64::consts::LN_2;
13 pub const LN_10: f64 = std::f64::consts::LN_10;
14 pub const SQRT_2: f64 = std::f64::consts::SQRT_2;
15 pub const SQRT_PI: f64 = 1.772_453_850_905_516;
16 pub const EPS_F32: f32 = f32::EPSILON;
17 pub const EPS_F64: f64 = f64::EPSILON;
18 pub const TINY_F32: f32 = 1e-30;
19 pub const TINY_F64: f64 = 1e-100;
20 pub const HUGE_F32: f32 = 1e30;
21 pub const HUGE_F64: f64 = 1e100;
22}
23
24pub struct NumericalPrecision;
26
27impl NumericalPrecision {
28 pub fn epsilon<T: Float>() -> T {
30 T::epsilon()
31 }
32
33 pub fn tiny<T: Float>() -> T {
35 T::from(1e-30).unwrap_or_else(|| T::epsilon())
36 }
37
38 pub fn huge<T: Float>() -> T {
40 T::from(1e30).unwrap_or_else(|| T::max_value())
41 }
42
43 pub fn is_zero<T: Float>(value: T, eps: Option<T>) -> bool {
45 let tolerance = eps.unwrap_or_else(|| T::epsilon() * T::from(10).unwrap());
46 value.abs() < tolerance
47 }
48
49 pub fn approx_eq<T: Float>(a: T, b: T, eps: Option<T>) -> bool {
51 let tolerance = eps.unwrap_or_else(|| T::epsilon() * T::from(10).unwrap());
52 (a - b).abs() < tolerance
53 }
54
55 pub fn rel_eq<T: Float>(a: T, b: T, rel_tol: Option<T>) -> bool {
57 let tolerance = rel_tol.unwrap_or_else(|| T::from(1e-9).unwrap());
58 let max_val = a.abs().max(b.abs());
59 if max_val < T::epsilon() {
60 return true; }
62 (a - b).abs() / max_val < tolerance
63 }
64
65 pub fn safe_cmp<T: Float>(a: T, b: T, eps: Option<T>) -> Ordering {
67 if Self::approx_eq(a, b, eps) {
68 Ordering::Equal
69 } else if a < b {
70 Ordering::Less
71 } else {
72 Ordering::Greater
73 }
74 }
75}
76
77pub struct OverflowDetection;
79
80impl OverflowDetection {
81 pub fn near_overflow<T: Float>(value: T) -> bool {
83 let max_val = T::max_value();
84 value.abs() > max_val / T::from(1000).unwrap()
85 }
86
87 pub fn near_underflow<T: Float>(value: T) -> bool {
89 let min_val = T::min_positive_value();
90 value.abs() < min_val * T::from(10).unwrap() && !value.is_zero()
91 }
92
93 pub fn safe_add<T: Float>(a: T, b: T) -> UtilsResult<T> {
95 if Self::near_overflow(a) || Self::near_overflow(b) {
96 return Err(UtilsError::InvalidParameter(
97 "Addition would cause overflow".to_string(),
98 ));
99 }
100 let result = a + b;
101 if !result.is_finite() {
102 return Err(UtilsError::InvalidParameter(
103 "Addition resulted in non-finite value".to_string(),
104 ));
105 }
106 Ok(result)
107 }
108
109 pub fn safe_mul<T: Float>(a: T, b: T) -> UtilsResult<T> {
111 if Self::near_overflow(a) && !NumericalPrecision::is_zero(b, None) {
112 return Err(UtilsError::InvalidParameter(
113 "Multiplication would cause overflow".to_string(),
114 ));
115 }
116 let result = a * b;
117 if !result.is_finite() {
118 return Err(UtilsError::InvalidParameter(
119 "Multiplication resulted in non-finite value".to_string(),
120 ));
121 }
122 Ok(result)
123 }
124
125 pub fn safe_div<T: Float>(a: T, b: T) -> UtilsResult<T> {
127 if NumericalPrecision::is_zero(b, None) {
128 return Err(UtilsError::InvalidParameter("Division by zero".to_string()));
129 }
130 if Self::near_underflow(b) && !NumericalPrecision::is_zero(a, None) {
131 return Err(UtilsError::InvalidParameter(
132 "Division would cause overflow".to_string(),
133 ));
134 }
135 let result = a / b;
136 if !result.is_finite() {
137 return Err(UtilsError::InvalidParameter(
138 "Division resulted in non-finite value".to_string(),
139 ));
140 }
141 Ok(result)
142 }
143}
144
145pub struct SpecialFunctions;
147
148impl SpecialFunctions {
149 pub fn logistic<T: Float>(x: T) -> T {
151 let one = T::one();
152 one / (one + (-x).exp())
153 }
154
155 pub fn logsumexp<T: Float>(x: &[T]) -> T {
157 if x.is_empty() {
158 return T::neg_infinity();
159 }
160
161 let max_val = x.iter().copied().fold(T::neg_infinity(), T::max);
162 if !max_val.is_finite() {
163 return max_val;
164 }
165
166 let sum_exp: T = x
167 .iter()
168 .map(|&val| (val - max_val).exp())
169 .fold(T::zero(), |acc, val| acc + val);
170
171 max_val + sum_exp.ln()
172 }
173
174 pub fn softmax<T: Float>(x: &[T]) -> Vec<T> {
176 if x.is_empty() {
177 return Vec::new();
178 }
179
180 let max_val = x.iter().copied().fold(T::neg_infinity(), T::max);
181 let exp_vals: Vec<T> = x.iter().map(|&val| (val - max_val).exp()).collect();
182
183 let sum_exp: T = exp_vals
184 .iter()
185 .copied()
186 .fold(T::zero(), |acc, val| acc + val);
187
188 exp_vals.into_iter().map(|val| val / sum_exp).collect()
189 }
190
191 pub fn log_softmax<T: Float>(x: &[T]) -> Vec<T> {
193 let log_sum_exp = Self::logsumexp(x);
194 x.iter().map(|&val| val - log_sum_exp).collect()
195 }
196
197 pub fn gamma(x: f64) -> f64 {
199 if x == 1.0 || x == 2.0 {
201 1.0
202 } else if x == 3.0 {
203 2.0
204 } else if x == 4.0 {
205 6.0
206 } else if x > 1.0 {
207 (x - 1.0) * Self::gamma(x - 1.0)
209 } else {
210 1.0 / x }
213 }
214
215 pub fn lgamma(x: f64) -> f64 {
217 Self::gamma(x).ln()
218 }
219
220 pub fn gamma_inc(a: f64, x: f64) -> f64 {
222 if x < 0.0 || a <= 0.0 {
223 return 0.0;
224 }
225
226 if x < a + 1.0 {
228 let mut sum = 1.0;
229 let mut term = 1.0;
230 let mut n = 1.0;
231
232 for _ in 0..100 {
233 term *= x / (a + n - 1.0);
234 sum += term;
235 if term.abs() < 1e-15 {
236 break;
237 }
238 n += 1.0;
239 }
240
241 sum * x.powf(a) * (-x).exp() / Self::gamma(a)
242 } else {
243 Self::gamma(a) * (1.0 - Self::gamma_inc_cf(a, x))
245 }
246 }
247
248 fn gamma_inc_cf(a: f64, x: f64) -> f64 {
250 let mut b = x + 1.0 - a;
251 let mut c = 1e30;
252 let mut d = 1.0 / b;
253 let mut h = d;
254
255 for i in 1..=100 {
256 let an = -i as f64 * (i as f64 - a);
257 b += 2.0;
258 d = an * d + b;
259 if d.abs() < 1e-30 {
260 d = 1e-30;
261 }
262 c = b + an / c;
263 if c.abs() < 1e-30 {
264 c = 1e-30;
265 }
266 d = 1.0 / d;
267 let del = d * c;
268 h *= del;
269 if (del - 1.0).abs() < 1e-15 {
270 break;
271 }
272 }
273
274 h * x.powf(a) * (-x).exp()
275 }
276
277 pub fn beta(a: f64, b: f64) -> f64 {
279 (Self::gamma(a) * Self::gamma(b)) / Self::gamma(a + b)
280 }
281
282 pub fn erf(x: f64) -> f64 {
284 const A1: f64 = 0.254829592;
286 const A2: f64 = -0.284496736;
287 const A3: f64 = 1.421413741;
288 const A4: f64 = -1.453152027;
289 const A5: f64 = 1.061405429;
290 const P: f64 = 0.3275911;
291
292 let sign = if x >= 0.0 { 1.0 } else { -1.0 };
293 let x = x.abs();
294
295 let t = 1.0 / (1.0 + P * x);
296 let y = 1.0 - (((((A5 * t + A4) * t) + A3) * t + A2) * t + A1) * t * (-x * x).exp();
297
298 sign * y
299 }
300
301 pub fn erfc(x: f64) -> f64 {
303 1.0 - Self::erf(x)
304 }
305}
306
307pub struct RobustArrayOps;
309
310impl RobustArrayOps {
311 pub fn robust_sum<T: Float + FromPrimitive>(arr: &Array1<T>) -> T {
313 let mut sum = T::zero();
315 let mut c = T::zero(); for &value in arr.iter() {
318 let y = value - c;
319 let t = sum + y;
320 c = (t - sum) - y;
321 sum = t;
322 }
323
324 sum
325 }
326
327 pub fn robust_mean<T: Float + FromPrimitive>(arr: &Array1<T>) -> UtilsResult<T> {
329 if arr.is_empty() {
330 return Err(UtilsError::EmptyInput);
331 }
332
333 let sum = Self::robust_sum(arr);
334 let n = T::from(arr.len()).unwrap();
335
336 OverflowDetection::safe_div(sum, n)
337 }
338
339 pub fn robust_variance<T: Float + FromPrimitive>(
341 arr: &Array1<T>,
342 ddof: usize,
343 ) -> UtilsResult<T> {
344 if arr.len() <= ddof {
345 return Err(UtilsError::InsufficientData {
346 min: ddof + 1,
347 actual: arr.len(),
348 });
349 }
350
351 let mean = Self::robust_mean(arr)?;
352 let mut sum_sq = T::zero();
353 let mut c = T::zero(); for &value in arr.iter() {
356 let diff = value - mean;
357 let sq_diff = diff * diff;
358 let y = sq_diff - c;
359 let t = sum_sq + y;
360 c = (t - sum_sq) - y;
361 sum_sq = t;
362 }
363
364 let n = T::from(arr.len() - ddof).unwrap();
365 OverflowDetection::safe_div(sum_sq, n)
366 }
367
368 pub fn robust_std<T: Float + FromPrimitive>(arr: &Array1<T>, ddof: usize) -> UtilsResult<T> {
370 let variance = Self::robust_variance(arr, ddof)?;
371 Ok(variance.sqrt())
372 }
373
374 pub fn robust_dot<T: Float + FromPrimitive>(a: &Array1<T>, b: &Array1<T>) -> UtilsResult<T> {
376 if a.len() != b.len() {
377 return Err(UtilsError::ShapeMismatch {
378 expected: vec![a.len()],
379 actual: vec![b.len()],
380 });
381 }
382
383 let mut sum = T::zero();
384 let mut c = T::zero(); for (&x, &y) in a.iter().zip(b.iter()) {
387 let product = OverflowDetection::safe_mul(x, y)?;
388 let corrected = product - c;
389 let temp = sum + corrected;
390 c = (temp - sum) - corrected;
391 sum = temp;
392 }
393
394 Ok(sum)
395 }
396
397 pub fn robust_norm<T: Float + FromPrimitive>(arr: &Array1<T>) -> UtilsResult<T> {
399 if arr.is_empty() {
400 return Ok(T::zero());
401 }
402
403 let max_abs = arr.iter().map(|&x| x.abs()).fold(T::zero(), T::max);
405
406 if NumericalPrecision::is_zero(max_abs, None) {
407 return Ok(T::zero());
408 }
409
410 let mut sum_sq = T::zero();
411 let mut c = T::zero(); for &value in arr.iter() {
414 let scaled = OverflowDetection::safe_div(value, max_abs)?;
415 let sq = OverflowDetection::safe_mul(scaled, scaled)?;
416 let y = sq - c;
417 let t = sum_sq + y;
418 c = (t - sum_sq) - y;
419 sum_sq = t;
420 }
421
422 let norm_scaled = sum_sq.sqrt();
423 OverflowDetection::safe_mul(norm_scaled, max_abs)
424 }
425}
426
427#[allow(non_snake_case)]
428#[cfg(test)]
429mod tests {
430 use super::*;
431 use approx::assert_relative_eq;
432 use scirs2_core::ndarray::array;
433
434 #[test]
435 fn test_numerical_precision() {
436 assert!(NumericalPrecision::is_zero(1e-16, None));
437 assert!(!NumericalPrecision::is_zero(1e-6, None));
438
439 assert!(NumericalPrecision::approx_eq(1.0, 1.0 + 1e-15, None));
440 assert!(!NumericalPrecision::approx_eq(1.0, 1.1, None));
441
442 assert!(NumericalPrecision::rel_eq(1000.0, 1000.0001, Some(1e-6)));
443 assert!(!NumericalPrecision::rel_eq(1000.0, 1001.0, Some(1e-6)));
444 }
445
446 #[test]
447 fn test_overflow_detection() {
448 assert!(OverflowDetection::safe_add(f64::MAX / 2.0, f64::MAX / 2.0).is_err());
450 assert!(OverflowDetection::safe_add(1.0, 2.0).is_ok());
451
452 assert!(OverflowDetection::safe_mul(f64::MAX / 2.0, 2.0).is_err());
453 assert!(OverflowDetection::safe_mul(2.0, 3.0).is_ok());
454
455 assert!(OverflowDetection::safe_div(1.0, 0.0).is_err());
456 assert!(OverflowDetection::safe_div(1.0, f64::MIN_POSITIVE).is_err());
457 assert_relative_eq!(OverflowDetection::safe_div(6.0, 2.0).unwrap(), 3.0);
458 }
459
460 #[test]
461 fn test_special_functions() {
462 assert_relative_eq!(SpecialFunctions::logistic(0.0), 0.5, epsilon = 1e-10);
464 assert!(SpecialFunctions::logistic(10.0) > 0.99);
465 assert!(SpecialFunctions::logistic(-10.0) < 0.01);
466
467 let x = [1.0, 2.0, 3.0];
469 let result = SpecialFunctions::logsumexp(&x);
470 let expected = (1.0_f64.exp() + 2.0_f64.exp() + 3.0_f64.exp()).ln();
471 assert_relative_eq!(result, expected, epsilon = 1e-10);
472
473 let softmax_result = SpecialFunctions::softmax(&x);
475 let sum: f64 = softmax_result.iter().sum();
476 assert_relative_eq!(sum, 1.0, epsilon = 1e-10);
477
478 assert_relative_eq!(SpecialFunctions::gamma(1.0), 1.0, epsilon = 1e-8);
480 assert_relative_eq!(SpecialFunctions::gamma(2.0), 1.0, epsilon = 1e-8);
481 assert_relative_eq!(SpecialFunctions::gamma(3.0), 2.0, epsilon = 1e-8);
482 assert_relative_eq!(SpecialFunctions::gamma(4.0), 6.0, epsilon = 1e-8);
483
484 assert_relative_eq!(SpecialFunctions::erf(0.0), 0.0, epsilon = 1e-9);
486 assert!(SpecialFunctions::erf(1.0) > 0.8);
487 assert!(SpecialFunctions::erf(-1.0) < -0.8);
488 }
489
490 #[test]
491 fn test_robust_array_ops() {
492 let arr = array![1.0, 2.0, 3.0, 4.0, 5.0];
493
494 let sum = RobustArrayOps::robust_sum(&arr);
496 assert_relative_eq!(sum, 15.0, epsilon = 1e-10);
497
498 let mean = RobustArrayOps::robust_mean(&arr).unwrap();
500 assert_relative_eq!(mean, 3.0, epsilon = 1e-10);
501
502 let var = RobustArrayOps::robust_variance(&arr, 1).unwrap();
504 assert_relative_eq!(var, 2.5, epsilon = 1e-10);
505
506 let std = RobustArrayOps::robust_std(&arr, 1).unwrap();
508 assert_relative_eq!(std, 2.5_f64.sqrt(), epsilon = 1e-10);
509
510 let a = array![1.0, 2.0, 3.0];
512 let b = array![4.0, 5.0, 6.0];
513 let dot = RobustArrayOps::robust_dot(&a, &b).unwrap();
514 assert_relative_eq!(dot, 32.0, epsilon = 1e-10); let norm = RobustArrayOps::robust_norm(&a).unwrap();
518 let expected_norm = (1.0 + 4.0 + 9.0_f64).sqrt(); assert_relative_eq!(norm, expected_norm, epsilon = 1e-10);
520 }
521}