use crate::error::{SpecialError, SpecialResult};
use crate::validation::check_finite;
use scirs2_core::ndarray::{Array1, ArrayView1};
use scirs2_core::numeric::{Float, FromPrimitive, Zero};
use std::fmt::{Debug, Display};
#[allow(dead_code)]
pub fn cbrt<T>(x: T) -> T
where
T: Float + FromPrimitive,
{
if x >= T::zero() {
x.powf(T::from_f64(1.0 / 3.0).expect("Operation failed"))
} else {
-(-x).powf(T::from_f64(1.0 / 3.0).expect("Operation failed"))
}
}
#[allow(dead_code)]
pub fn exp10<T>(x: T) -> T
where
T: Float + FromPrimitive,
{
T::from_f64(10.0).expect("Operation failed").powf(x)
}
#[allow(dead_code)]
pub fn exp2<T>(x: T) -> T
where
T: Float,
{
x.exp2()
}
#[allow(dead_code)]
pub fn radian<T>(degrees: T) -> T
where
T: Float + FromPrimitive,
{
let pi = T::from_f64(std::f64::consts::PI).expect("Operation failed");
degrees * pi / T::from_f64(180.0).expect("Operation failed")
}
#[allow(dead_code)]
pub fn cosdg<T>(x: T) -> T
where
T: Float + FromPrimitive,
{
radian(x).cos()
}
#[allow(dead_code)]
pub fn sindg<T>(x: T) -> T
where
T: Float + FromPrimitive,
{
radian(x).sin()
}
#[allow(dead_code)]
pub fn tandg<T>(x: T) -> T
where
T: Float + FromPrimitive,
{
radian(x).tan()
}
#[allow(dead_code)]
pub fn cotdg<T>(x: T) -> T
where
T: Float + FromPrimitive,
{
T::from_f64(1.0).expect("Operation failed") / tandg(x)
}
#[allow(dead_code)]
pub fn cosm1<T>(x: T) -> T
where
T: Float + FromPrimitive,
{
if x.abs() < T::from_f64(0.1).expect("Operation failed") {
let x2 = x * x;
let mut sum = -x2 / T::from_f64(2.0).expect("Operation failed");
let mut term = sum;
let mut n = T::from_f64(4.0).expect("Operation failed");
while term.abs() > T::epsilon() * sum.abs() {
term = term * (-x2) / (n * (n - T::from_f64(1.0).expect("Operation failed")));
sum = sum + term;
n = n + T::from_f64(2.0).expect("Operation failed");
}
sum
} else {
x.cos() - T::from_f64(1.0).expect("Operation failed")
}
}
#[allow(dead_code)]
pub fn powm1<T>(x: T, y: T) -> SpecialResult<T>
where
T: Float + FromPrimitive + Display,
{
check_finite(x, "x value")?;
check_finite(y, "y value")?;
if x.abs() < T::from_f64(0.1).expect("Operation failed")
&& y.abs() < T::from_f64(10.0).expect("Operation failed")
{
Ok((y * x.ln_1p()).exp_m1())
} else {
Ok((T::from_f64(1.0).expect("Operation failed") + x).powf(y)
- T::from_f64(1.0).expect("Operation failed"))
}
}
#[allow(dead_code)]
pub fn xlogy<T>(x: T, y: T) -> T
where
T: Float + Zero,
{
if x.is_zero() {
T::zero()
} else if y <= T::zero() {
T::nan()
} else {
x * y.ln()
}
}
#[allow(dead_code)]
pub fn xlog1py<T>(x: T, y: T) -> T
where
T: Float + Zero,
{
if x.is_zero() {
T::zero()
} else {
x * y.ln_1p()
}
}
#[allow(dead_code)]
pub fn exprel<T>(x: T) -> T
where
T: Float + FromPrimitive,
{
if x.abs() < T::from_f64(1e-5).expect("Operation failed") {
let mut sum = T::from_f64(1.0).expect("Operation failed");
let mut term = x / T::from_f64(2.0).expect("Operation failed");
let mut n = T::from_f64(2.0).expect("Operation failed");
sum = sum + term;
while term.abs() > T::epsilon() * sum.abs() {
term = term * x / (n + T::from_f64(1.0).expect("Operation failed"));
sum = sum + term;
n = n + T::from_f64(1.0).expect("Operation failed");
}
sum
} else {
x.exp_m1() / x
}
}
#[allow(dead_code)]
pub fn round<T>(x: T) -> T
where
T: Float,
{
x.round()
}
#[allow(dead_code)]
pub fn diric<T>(x: T, n: i32) -> T
where
T: Float + FromPrimitive,
{
if n == 0 {
return T::zero();
}
let n_f = T::from_i32(n).expect("Operation failed");
let half = T::from_f64(0.5).expect("Operation failed");
let x_half = x * half;
let sin_x_half = x_half.sin();
if sin_x_half.abs() < T::epsilon() {
T::from_i32(n).expect("Operation failed")
} else {
(n_f * x_half).sin() / (n_f * sin_x_half)
}
}
#[allow(dead_code)]
pub fn agm<T>(a: T, b: T) -> SpecialResult<T>
where
T: Float + FromPrimitive + Display,
{
check_finite(a, "a value")?;
check_finite(b, "b value")?;
if a <= T::zero() || b <= T::zero() {
return Err(SpecialError::DomainError(
"agm: arguments must be positive".to_string(),
));
}
let mut a_n = a;
let mut b_n = b;
let tol = T::epsilon() * a.max(b);
while (a_n - b_n).abs() > tol {
let a_next = (a_n + b_n) / T::from_f64(2.0).expect("Operation failed");
let b_next = (a_n * b_n).sqrt();
a_n = a_next;
b_n = b_next;
}
Ok(a_n)
}
#[allow(dead_code)]
pub fn log_expit<T>(x: T) -> T
where
T: Float,
{
if x >= T::zero() {
-(-x).exp().ln_1p()
} else {
x - x.exp().ln_1p()
}
}
#[allow(dead_code)]
pub fn softplus<T>(x: T) -> T
where
T: Float + FromPrimitive,
{
if x > T::from_f64(20.0).expect("Operation failed") {
x
} else if x < T::from_f64(-20.0).expect("Operation failed") {
x.exp()
} else {
x.exp().ln_1p()
}
}
#[allow(dead_code)]
pub fn owens_t<T>(h: T, a: T) -> SpecialResult<T>
where
T: Float + FromPrimitive + Display + Debug,
{
check_finite(h, "h value")?;
check_finite(a, "a value")?;
let zero = T::zero();
let one = T::one();
let two = T::from_f64(2.0).expect("Operation failed");
let pi = T::from_f64(std::f64::consts::PI).expect("Operation failed");
if a.is_zero() {
return Ok(zero);
}
if h.is_zero() {
return Ok(a.atan() / (two * pi));
}
let abs_h = h.abs();
let abs_a = a.abs();
let sign = if (h >= zero && a >= zero) || (h < zero && a < zero) {
one
} else {
-one
};
let result = if abs_h < T::from_f64(0.1).expect("Operation failed") {
owens_t_series(abs_h, abs_a)?
} else if abs_h > T::from_f64(10.0).expect("Operation failed") {
owens_t_asymptotic(abs_h, abs_a)?
} else {
owens_t_numerical(abs_h, abs_a)?
};
Ok(sign * result)
}
#[allow(dead_code)]
fn owens_t_series<T>(h: T, a: T) -> SpecialResult<T>
where
T: Float + FromPrimitive,
{
let zero = T::zero();
let one = T::one();
let two = T::from_f64(2.0).expect("Operation failed");
let pi = T::from_f64(std::f64::consts::PI).expect("Operation failed");
let h2 = h * h;
let a2 = a * a;
let atan_a = a.atan();
let mut sum = zero;
let mut h_power = one;
for n in 0..20 {
let integral = if n == 0 {
atan_a
} else {
if n == 1 {
(a2.ln_1p()) / two
} else {
a.powi(2 * n as i32 - 1) / T::from_usize(2 * n - 1).expect("Operation failed")
}
};
let term = if n % 2 == 0 {
h_power * integral
} else {
-h_power * integral
};
sum = sum + term;
if term.abs() < T::from_f64(1e-15).expect("Operation failed") {
break;
}
h_power = h_power * h2;
}
Ok(atan_a / (two * pi) - h * sum / (two * pi))
}
#[allow(dead_code)]
fn owens_t_asymptotic<T>(h: T, a: T) -> SpecialResult<T>
where
T: Float + FromPrimitive,
{
let one = T::one();
let two = T::from_f64(2.0).expect("Operation failed");
let pi = T::from_f64(std::f64::consts::PI).expect("Operation failed");
let h2 = h * h;
let a2 = a * a;
let exp_factor = (-h2 * (one + a2) / two).exp();
let denominator = h2 * (one + a2);
let result = exp_factor * a / (two * pi * denominator);
let correction = one - (T::from_f64(3.0).expect("Operation failed") * a2) / (one + a2).powi(2);
let corrected_result = result * correction;
Ok(corrected_result)
}
#[allow(dead_code)]
fn owens_t_numerical<T>(h: T, a: T) -> SpecialResult<T>
where
T: Float + FromPrimitive,
{
let zero = T::zero();
let one = T::one();
let two = T::from_f64(2.0).expect("Operation failed");
let pi = T::from_f64(std::f64::consts::PI).expect("Operation failed");
let h2 = h * h;
let n = 1000; let dx = a / T::from_usize(n).expect("Operation failed");
let mut sum = zero;
for i in 0..=n {
let x = T::from_usize(i).expect("Operation failed") * dx;
let integrand = (-h2 * (one + x * x) / two).exp() / (one + x * x);
let weight = if i == 0 || i == n {
one
} else if i % 2 == 1 {
T::from_f64(4.0).expect("Operation failed")
} else {
two
};
sum = sum + weight * integrand;
}
let result = sum * dx / (T::from_f64(3.0).expect("Operation failed") * two * pi);
Ok(result)
}
#[allow(dead_code)]
pub fn cbrt_array<T>(x: &ArrayView1<T>) -> Array1<T>
where
T: Float + FromPrimitive + Send + Sync,
{
x.mapv(cbrt)
}
#[allow(dead_code)]
pub fn exp10_array<T>(x: &ArrayView1<T>) -> Array1<T>
where
T: Float + FromPrimitive + Send + Sync,
{
x.mapv(exp10)
}
#[allow(dead_code)]
pub fn round_array<T>(x: &ArrayView1<T>) -> Array1<T>
where
T: Float + Send + Sync,
{
x.mapv(round)
}
#[allow(dead_code)]
pub fn expit<T>(x: T) -> T
where
T: Float + FromPrimitive + Copy,
{
let one = T::one();
let neg_x = -x;
one / (one + neg_x.exp())
}
#[allow(dead_code)]
pub fn logit<T>(p: T) -> SpecialResult<T>
where
T: Float + FromPrimitive + Copy + Debug,
{
let zero = T::zero();
let one = T::one();
if p <= zero || p >= one {
return Err(SpecialError::ValueError(format!(
"logit requires p in (0, 1), got {p:?}"
)));
}
Ok((p / (one - p)).ln())
}
#[allow(dead_code)]
pub fn expit_array<T>(x: &ArrayView1<T>) -> Array1<T>
where
T: Float + FromPrimitive + Copy,
{
x.mapv(|val| expit(val))
}
#[allow(dead_code)]
pub fn logit_array<T>(x: &ArrayView1<T>) -> Array1<T>
where
T: Float + FromPrimitive + Copy + Debug,
{
x.mapv(|val| logit(val).unwrap_or(T::nan()))
}
#[allow(dead_code)]
pub fn xlog1py_scalar<T>(x: T, y: T) -> T
where
T: Float + Zero,
{
xlog1py(x, y)
}
#[allow(dead_code)]
pub fn log1p_array_utility<T>(x: &ArrayView1<T>) -> Array1<T>
where
T: Float + Copy,
{
x.mapv(|val| val.ln_1p())
}
#[allow(dead_code)]
pub fn expm1_array_utility<T>(x: &ArrayView1<T>) -> Array1<T>
where
T: Float + Copy,
{
x.mapv(|val| val.exp_m1())
}
#[allow(dead_code)]
pub fn spherical_distance<T>(lat1: T, lon1: T, lat2: T, lon2: T) -> SpecialResult<T>
where
T: Float + FromPrimitive + Display + Copy,
{
check_finite(lat1, "lat1 value")?;
check_finite(lon1, "lon1 value")?;
check_finite(lat2, "lat2 value")?;
check_finite(lon2, "lon2 value")?;
let two = T::from_f64(2.0).expect("Operation failed");
let dlat = (lat2 - lat1) / two;
let dlon = (lon2 - lon1) / two;
let a = dlat.sin().powi(2) + lat1.cos() * lat2.cos() * dlon.sin().powi(2);
Ok(two * a.sqrt().asin())
}
#[allow(dead_code)]
pub fn gradient<T>(y: &ArrayView1<T>, x: Option<&ArrayView1<T>>) -> SpecialResult<Array1<T>>
where
T: Float + FromPrimitive + Copy,
{
if y.len() < 2 {
return Err(SpecialError::DomainError(
"Need at least 2 points for gradient".to_string(),
));
}
let n = y.len();
let mut grad = Array1::zeros(n);
let _one = T::one(); let two = T::from_f64(2.0).expect("Operation failed");
if let Some(x_vals) = x {
if x_vals.len() != n {
return Err(SpecialError::DomainError(
"x and y arrays must have same length".to_string(),
));
}
grad[0] = (y[1] - y[0]) / (x_vals[1] - x_vals[0]);
for i in 1..n - 1 {
grad[i] = (y[i + 1] - y[i - 1]) / (x_vals[i + 1] - x_vals[i - 1]);
}
grad[n - 1] = (y[n - 1] - y[n - 2]) / (x_vals[n - 1] - x_vals[n - 2]);
} else {
grad[0] = y[1] - y[0];
for i in 1..n - 1 {
grad[i] = (y[i + 1] - y[i - 1]) / two;
}
grad[n - 1] = y[n - 1] - y[n - 2];
}
Ok(grad)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_cbrt() {
assert_relative_eq!(cbrt(8.0), 2.0, epsilon = 1e-10);
assert_relative_eq!(cbrt(-8.0), -2.0, epsilon = 1e-10);
assert_relative_eq!(cbrt(27.0), 3.0, epsilon = 1e-10);
assert_eq!(cbrt(0.0), 0.0);
}
#[test]
fn test_exp10() {
assert_relative_eq!(exp10(0.0), 1.0, epsilon = 1e-10);
assert_relative_eq!(exp10(1.0), 10.0, epsilon = 1e-10);
assert_relative_eq!(exp10(2.0), 100.0, epsilon = 1e-10);
assert_relative_eq!(exp10(-1.0), 0.1, epsilon = 1e-10);
}
#[test]
fn test_exp2() {
assert_eq!(exp2(0.0), 1.0);
assert_eq!(exp2(1.0), 2.0);
assert_eq!(exp2(3.0), 8.0);
assert_eq!(exp2(-1.0), 0.5);
}
#[test]
fn test_spherical_angle_functions() {
let theta = std::f64::consts::PI / 4.0; let phi = std::f64::consts::PI / 6.0;
assert!(theta.cos() > 0.0);
assert!(phi.sin() > 0.0);
}
#[test]
fn test_hyp2f1_edge_cases() {
let a = 1.0;
let b = 2.0;
let c = 3.0;
let z = 0.5;
let _result = a + b + c + z; }
#[test]
fn test_trig_degrees() {
assert_relative_eq!(cosdg(0.0), 1.0, epsilon = 1e-10);
assert_relative_eq!(cosdg(90.0), 0.0, epsilon = 1e-10);
assert_relative_eq!(sindg(90.0), 1.0, epsilon = 1e-10);
assert_relative_eq!(tandg(45.0), 1.0, epsilon = 1e-10);
}
#[test]
fn test_cosm1() {
let x = 1e-8;
let result = cosm1(x);
assert!(result < 0.0);
assert!(result.abs() < 1e-15);
}
#[test]
fn test_xlogy() {
assert_eq!(xlogy(0.0, 2.0), 0.0);
assert_eq!(xlogy(0.0, 0.0), 0.0);
assert!(xlogy(1.0, 0.0).is_nan());
assert_relative_eq!(xlogy(2.0, 3.0), 2.0 * 3.0_f64.ln(), epsilon = 1e-10);
}
#[test]
fn test_exprel() {
assert_relative_eq!(exprel(0.0), 1.0, epsilon = 1e-10);
let x = 1e-10;
assert_relative_eq!(exprel(x), 1.0, epsilon = 1e-8);
}
#[test]
fn test_agm() {
let result = agm(1.0, 2.0).expect("Operation failed");
assert_relative_eq!(result, 1.4567910310469068, epsilon = 1e-10);
assert_relative_eq!(
agm(2.0, 1.0).expect("Operation failed"),
result,
epsilon = 1e-10
);
}
#[test]
fn test_diric() {
assert_relative_eq!(diric(0.0, 5), 5.0, epsilon = 1e-10);
assert_eq!(diric(0.0, 0), 0.0);
}
#[test]
fn test_expit() {
assert_relative_eq!(expit(0.0), 0.5, epsilon = 1e-10);
assert!(expit(10.0) > 0.99);
assert!(expit(-10.0) < 0.01);
assert!(!expit(1000.0).is_infinite());
assert!(!expit(-1000.0).is_nan());
}
#[test]
fn test_logit() {
assert_relative_eq!(logit(0.5).expect("Operation failed"), 0.0, epsilon = 1e-10);
assert!(logit(0.9).expect("Operation failed") > 0.0);
assert!(logit(0.1).expect("Operation failed") < 0.0);
assert!(logit(0.0).is_err());
assert!(logit(1.0).is_err());
assert!(logit(-0.1).is_err());
assert!(logit(1.1).is_err());
}
#[test]
fn test_expit_logit_inverse() {
let values = [0.1, 0.3, 0.5, 0.7, 0.9];
for &val in &values {
let logit_val = logit(val).expect("Operation failed");
let back = expit(logit_val);
assert_relative_eq!(back, val, epsilon = 1e-10);
}
}
#[test]
fn test_array_functions() {
use scirs2_core::ndarray::array;
let input = array![0.0, 1.0, -1.0];
let result = expit_array(&input.view());
assert_relative_eq!(result[0], 0.5, epsilon = 1e-10);
assert!(result[1] > 0.7);
assert!(result[2] < 0.3);
let probinput = array![0.1, 0.5, 0.9];
let logit_result = logit_array(&probinput.view());
assert_relative_eq!(logit_result[1], 0.0, epsilon = 1e-10);
assert!(logit_result[0] < 0.0);
assert!(logit_result[2] > 0.0);
}
}