use scirs2_core::ndarray::{Array1, Array2};
pub const F32_EPS: f32 = 1.192_092_9e-7;
pub const F32_MIN_POSITIVE: f32 = 1.175_494_4e-38;
pub const F32_MAX: f32 = 3.402_823_5e38;
pub const LOG_MIN: f32 = -87.0;
pub const LOG_MAX: f32 = 88.0;
pub const EPS: f32 = 1e-8;
pub fn log_sum_exp(x: &[f32]) -> f32 {
if x.is_empty() {
return f32::NEG_INFINITY;
}
let max_val = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
if max_val.is_infinite() && max_val < 0.0 {
return f32::NEG_INFINITY;
}
let sum: f32 = x.iter().map(|&v| (v - max_val).exp()).sum();
max_val + sum.ln()
}
#[inline]
pub fn log_add_exp(a: f32, b: f32) -> f32 {
if a > b {
a + (1.0 + (b - a).exp()).ln()
} else {
b + (1.0 + (a - b).exp()).ln()
}
}
pub fn log_softmax_stable(x: &Array1<f32>) -> Array1<f32> {
let max_val = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let shifted = x.mapv(|v| v - max_val);
let log_sum: f32 = shifted.mapv(|v| v.exp()).sum().ln();
shifted.mapv(|v| v - log_sum)
}
pub fn softmax_stable(x: &Array1<f32>) -> Array1<f32> {
let max_val = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_x = x.mapv(|v| (v - max_val).exp());
let sum: f32 = exp_x.sum();
if sum > 0.0 {
exp_x / sum
} else {
Array1::from_elem(x.len(), 1.0 / x.len() as f32)
}
}
#[inline]
pub fn safe_exp(x: f32) -> f32 {
if x > LOG_MAX {
F32_MAX
} else if x < LOG_MIN {
0.0
} else {
x.exp()
}
}
#[inline]
pub fn safe_ln(x: f32) -> f32 {
if x <= 0.0 {
LOG_MIN
} else {
x.ln().max(LOG_MIN)
}
}
#[inline]
pub fn safe_log10(x: f32) -> f32 {
if x <= 0.0 {
LOG_MIN / std::f32::consts::LN_10
} else {
x.log10()
}
}
#[inline]
pub fn clamp_for_exp(x: f32) -> f32 {
x.clamp(LOG_MIN, LOG_MAX)
}
#[inline]
pub fn safe_div(num: f32, denom: f32) -> f32 {
if denom.abs() < EPS {
if num >= 0.0 {
F32_MAX
} else {
-F32_MAX
}
} else {
num / denom
}
}
pub fn safe_normalize(x: &Array1<f32>) -> Array1<f32> {
let norm: f32 = x.iter().map(|v| v * v).sum::<f32>().sqrt();
if norm < EPS {
Array1::zeros(x.len())
} else {
x / norm
}
}
pub fn l2_normalize(x: &Array1<f32>, min_norm: f32) -> Array1<f32> {
let norm: f32 = x.iter().map(|v| v * v).sum::<f32>().sqrt();
let denom = norm.max(min_norm);
x / denom
}
#[derive(Debug, Clone, Default)]
pub struct KahanSum {
sum: f32,
compensation: f32,
}
impl KahanSum {
pub fn new() -> Self {
Self {
sum: 0.0,
compensation: 0.0,
}
}
pub fn with_value(initial: f32) -> Self {
Self {
sum: initial,
compensation: 0.0,
}
}
#[inline]
pub fn add(&mut self, value: f32) {
let y = value - self.compensation;
let t = self.sum + y;
self.compensation = (t - self.sum) - y;
self.sum = t;
}
pub fn sum(&self) -> f32 {
self.sum
}
pub fn reset(&mut self) {
self.sum = 0.0;
self.compensation = 0.0;
}
}
pub fn kahan_sum(values: &[f32]) -> f32 {
let mut acc = KahanSum::new();
for &v in values {
acc.add(v);
}
acc.sum()
}
pub fn kahan_mean(values: &[f32]) -> f32 {
if values.is_empty() {
return 0.0;
}
kahan_sum(values) / values.len() as f32
}
#[derive(Debug, Clone, Default)]
pub struct WelfordVariance {
count: usize,
mean: f32,
m2: f32,
}
impl WelfordVariance {
pub fn new() -> Self {
Self {
count: 0,
mean: 0.0,
m2: 0.0,
}
}
pub fn add(&mut self, value: f32) {
self.count += 1;
let delta = value - self.mean;
self.mean += delta / self.count as f32;
let delta2 = value - self.mean;
self.m2 += delta * delta2;
}
pub fn mean(&self) -> f32 {
self.mean
}
pub fn variance(&self) -> f32 {
if self.count < 2 {
0.0
} else {
self.m2 / (self.count - 1) as f32
}
}
pub fn variance_population(&self) -> f32 {
if self.count == 0 {
0.0
} else {
self.m2 / self.count as f32
}
}
pub fn std(&self) -> f32 {
self.variance().sqrt()
}
pub fn count(&self) -> usize {
self.count
}
pub fn reset(&mut self) {
self.count = 0;
self.mean = 0.0;
self.m2 = 0.0;
}
}
pub fn matrix_exp_pade(a: &Array2<f32>, order: usize) -> Array2<f32> {
let n = a.shape()[0];
assert_eq!(a.shape()[1], n, "Matrix must be square");
let norm: f32 = a.iter().map(|x| x.abs()).sum::<f32>();
let s = if norm > 0.0 {
(norm.log2() as i32).max(0) as u32
} else {
0
};
let scale = 2.0f32.powi(-(s as i32));
let a_scaled = a.mapv(|x| x * scale);
let mut u: Array2<f32> = Array2::eye(n);
let mut v: Array2<f32> = Array2::eye(n);
let (c_u, c_v) = pade_coefficients(order);
let mut a_power = Array2::eye(n);
for k in 1..=order {
a_power = a_power.dot(&a_scaled);
if k % 2 == 1 {
u = &u + &a_power.mapv(|x| x * c_u[k]);
} else {
v = &v + &a_power.mapv(|x| x * c_v[k]);
}
}
let result = solve_linear(&(&v - &u), &(&v + &u));
let mut exp_a = result;
for _ in 0..s {
exp_a = exp_a.dot(&exp_a);
}
exp_a
}
fn pade_coefficients(order: usize) -> (Vec<f32>, Vec<f32>) {
let order = order.min(6);
let mut c_u = vec![0.0f32; order + 1];
let mut c_v = vec![0.0f32; order + 1];
c_v[0] = 1.0;
if order >= 1 {
c_u[1] = 0.5;
}
if order >= 2 {
c_v[2] = 1.0 / 12.0;
}
if order >= 3 {
c_u[3] = 1.0 / 120.0;
}
if order >= 4 {
c_v[4] = 1.0 / 30240.0;
}
if order >= 5 {
c_u[5] = 1.0 / 1209600.0;
}
if order >= 6 {
c_v[6] = 1.0 / 17297280.0;
}
(c_u, c_v)
}
fn solve_linear(a: &Array2<f32>, b: &Array2<f32>) -> Array2<f32> {
let n = a.shape()[0];
let mut x = b.clone();
let identity = Array2::eye(n);
for _ in 0..10 {
let residual = b - &a.dot(&x);
let correction = residual.mapv(|v| v * 0.5);
x = &x + &identity.dot(&correction);
}
x
}
pub fn zoh_discretize(a: &Array2<f32>, b: &Array2<f32>, dt: f32) -> (Array2<f32>, Array2<f32>) {
let n = a.shape()[0];
let a_dt = a.mapv(|x| x * dt);
let a_d = taylor_exp(&a_dt, 8);
let identity: Array2<f32> = Array2::eye(n);
let half_a_dt = a_dt.mapv(|x| x * 0.5);
let approx_factor = &identity + &half_a_dt;
let b_d = approx_factor.dot(b).mapv(|x| x * dt);
(a_d, b_d)
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum DiscretizationMethod {
Zoh,
Bilinear,
ForwardEuler,
}
pub fn zoh_discretize_diagonal(
a: &Array1<f32>,
b: &Array2<f32>,
dt: f32,
) -> (Array1<f32>, Array2<f32>) {
let a_bar = a.mapv(|ai| (dt * ai).exp());
let n_in = b.ncols();
let state_dim = a.len();
let b_bar = Array2::from_shape_fn((state_dim, n_in), |(i, j)| {
let ai = a[i];
let scale = {
let y = dt * ai;
if y.abs() < 1e-6 {
dt
} else {
y.exp_m1() / ai
}
};
scale * b[[i, j]]
});
(a_bar, b_bar)
}
pub fn bilinear_discretize(
a: &Array1<f32>,
b: &Array2<f32>,
dt: f32,
) -> (Array1<f32>, Array2<f32>) {
let half_dt = dt * 0.5;
let a_bar: Array1<f32> = a.mapv(|ai| {
let num = 1.0 + half_dt * ai;
let den = 1.0 - half_dt * ai;
num / den
});
let n_in = b.ncols();
let state_dim = a.len();
let b_bar = Array2::from_shape_fn((state_dim, n_in), |(i, j)| {
half_dt * (1.0 + a_bar[i]) * b[[i, j]]
});
(a_bar, b_bar)
}
pub fn forward_euler_discretize(
a: &Array1<f32>,
b: &Array2<f32>,
dt: f32,
) -> (Array1<f32>, Array2<f32>) {
let a_bar = a.mapv(|ai| 1.0 + dt * ai);
let b_bar = b.mapv(|bij| dt * bij);
(a_bar, b_bar)
}
pub fn discretize(
method: DiscretizationMethod,
a: &Array1<f32>,
b: &Array2<f32>,
dt: f32,
) -> (Array1<f32>, Array2<f32>) {
match method {
DiscretizationMethod::Zoh => zoh_discretize_diagonal(a, b, dt),
DiscretizationMethod::Bilinear => bilinear_discretize(a, b, dt),
DiscretizationMethod::ForwardEuler => forward_euler_discretize(a, b, dt),
}
}
fn taylor_exp(a: &Array2<f32>, terms: usize) -> Array2<f32> {
let n = a.shape()[0];
let mut result = Array2::eye(n);
let mut a_power = Array2::eye(n);
let mut factorial = 1.0f32;
for k in 1..=terms {
factorial *= k as f32;
a_power = a_power.dot(a);
result = &result + &a_power.mapv(|x| x / factorial);
}
result
}
pub fn clip_grad_norm(gradients: &mut [Array1<f32>], max_norm: f32) -> f32 {
let total_norm: f32 = gradients
.iter()
.map(|g| g.iter().map(|x| x * x).sum::<f32>())
.sum::<f32>()
.sqrt();
let clip_coef = max_norm / (total_norm + EPS);
if clip_coef < 1.0 {
for grad in gradients.iter_mut() {
grad.mapv_inplace(|x| x * clip_coef);
}
}
total_norm
}
pub fn clip_grad_value(gradient: &mut Array1<f32>, max_value: f32) {
gradient.mapv_inplace(|x| x.clamp(-max_value, max_value));
}
pub fn has_nan_inf(x: &Array1<f32>) -> bool {
x.iter().any(|&v| v.is_nan() || v.is_infinite())
}
pub fn replace_nan(x: &Array1<f32>, default: f32) -> Array1<f32> {
x.mapv(|v| if v.is_nan() { default } else { v })
}
pub fn sanitize(x: &Array1<f32>, nan_value: f32, inf_value: f32) -> Array1<f32> {
x.mapv(|v| {
if v.is_nan() {
nan_value
} else if v.is_infinite() {
if v > 0.0 {
inf_value
} else {
-inf_value
}
} else {
v
}
})
}
pub fn clamp_to_valid(x: &Array1<f32>, min: f32, max: f32) -> Array1<f32> {
x.mapv(|v| {
if v.is_nan() {
(min + max) / 2.0
} else {
v.clamp(min, max)
}
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_log_sum_exp() {
let x = vec![1.0, 2.0, 3.0];
let result = log_sum_exp(&x);
assert!((result - 3.408).abs() < 0.01);
}
#[test]
fn test_log_sum_exp_large() {
let x = vec![1000.0, 1001.0, 1002.0];
let result = log_sum_exp(&x);
assert!((result - 1002.408).abs() < 0.01);
}
#[test]
fn test_log_add_exp() {
let a = 2.0f32;
let b = 3.0f32;
let result = log_add_exp(a, b);
let expected = (a.exp() + b.exp()).ln();
assert!((result - expected).abs() < 0.001);
}
#[test]
fn test_softmax_stable() {
let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let result = softmax_stable(&x);
assert!((result.sum() - 1.0).abs() < 0.001);
assert!(result[2] > result[1] && result[1] > result[0]);
}
#[test]
fn test_softmax_large_values() {
let x = Array1::from_vec(vec![1000.0, 1001.0, 1002.0]);
let result = softmax_stable(&x);
assert!((result.sum() - 1.0).abs() < 0.001);
}
#[test]
fn test_safe_exp() {
assert!(safe_exp(100.0) < f32::INFINITY);
assert!(safe_exp(100.0) > 0.0);
assert!(safe_exp(-100.0) >= 0.0); assert!((safe_exp(0.0) - 1.0).abs() < 0.001);
assert!((safe_exp(1.0) - std::f32::consts::E).abs() < 0.001);
}
#[test]
fn test_safe_ln() {
assert!(safe_ln(0.0).is_finite());
assert!(safe_ln(-1.0).is_finite());
assert!((safe_ln(1.0) - 0.0).abs() < 0.001);
}
#[test]
fn test_kahan_sum() {
let values: Vec<f32> = (0..1000).map(|_| 0.1).collect();
let result = kahan_sum(&values);
assert!((result - 100.0).abs() < 0.001);
}
#[test]
fn test_welford_variance() {
let mut acc = WelfordVariance::new();
for i in 1..=5 {
acc.add(i as f32);
}
assert!((acc.mean() - 3.0).abs() < 0.001);
assert!((acc.variance() - 2.5).abs() < 0.001); }
#[test]
fn test_safe_normalize() {
let x = Array1::from_vec(vec![0.0, 0.0, 0.0]);
let result = safe_normalize(&x);
assert!(!has_nan_inf(&result));
let x = Array1::from_vec(vec![3.0, 4.0]);
let result = safe_normalize(&x);
let norm: f32 = result.iter().map(|v| v * v).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 0.001);
}
#[test]
fn test_clip_grad_norm() {
let mut grads = vec![
Array1::from_vec(vec![3.0, 4.0]),
Array1::from_vec(vec![5.0, 12.0]),
];
let norm = clip_grad_norm(&mut grads, 5.0);
assert!((norm - 13.93).abs() < 0.1);
let new_norm: f32 = grads
.iter()
.map(|g| g.iter().map(|x| x * x).sum::<f32>())
.sum::<f32>()
.sqrt();
assert!((new_norm - 5.0).abs() < 0.1);
}
#[test]
fn test_sanitize() {
let x = Array1::from_vec(vec![1.0, f32::NAN, f32::INFINITY, -f32::INFINITY, 2.0]);
let result = sanitize(&x, 0.0, 1e6);
assert!(!has_nan_inf(&result));
assert_eq!(result[0], 1.0);
assert_eq!(result[1], 0.0);
assert_eq!(result[4], 2.0);
}
#[test]
fn test_taylor_exp_identity() {
let n = 3;
let a = Array2::zeros((n, n));
let result = taylor_exp(&a, 6);
for i in 0..n {
for j in 0..n {
let expected = if i == j { 1.0 } else { 0.0 };
assert!((result[[i, j]] - expected).abs() < 0.001);
}
}
}
#[test]
fn test_zoh_discretize() {
let a = Array2::from_diag(&Array1::from_vec(vec![-1.0, -2.0]));
let b: Array2<f32> = Array2::eye(2);
let dt = 0.1;
let (a_d, b_d) = zoh_discretize(&a, &b, dt);
assert!(
(a_d[[0, 0]] - 0.905).abs() < 0.1,
"a_d[0,0] = {}",
a_d[[0, 0]]
);
assert!(
(a_d[[1, 1]] - 0.82).abs() < 0.1,
"a_d[1,1] = {}",
a_d[[1, 1]]
);
assert!(b_d[[0, 0]].abs() > 0.0, "b_d[0,0] = {}", b_d[[0, 0]]);
}
#[test]
fn bilinear_stable_negative_eigenvalue() {
let a = Array1::from_vec(vec![-1.0f32, -0.5, -2.0]);
let b = Array2::<f32>::ones((3, 1));
for &dt in &[0.01f32, 0.1, 0.5, 1.0, 2.0] {
let (a_bar, _) = bilinear_discretize(&a, &b, dt);
for &x in a_bar.iter() {
assert!(x.abs() < 1.0, "stability violated: a_bar={x} at dt={dt}");
}
}
}
#[test]
fn bilinear_exact_at_zero_eigenvalue() {
let a = Array1::<f32>::zeros(2);
let b = Array2::from_shape_vec((2, 1), vec![2.0f32, 3.0]).unwrap();
let dt = 0.1;
let (a_bar, b_bar) = bilinear_discretize(&a, &b, dt);
for &x in a_bar.iter() {
assert!((x - 1.0).abs() < 1e-6, "a_bar should be 1 for a=0, got {x}");
}
let expected_b = b.mapv(|x| dt * x);
for (got, exp) in b_bar.iter().zip(expected_b.iter()) {
assert!((got - exp).abs() < 1e-6, "b_bar mismatch: {got} vs {exp}");
}
}
#[test]
fn forward_euler_close_to_zoh_small_dt() {
let a = Array1::from_vec(vec![-1.0f32]);
let b = Array2::<f32>::ones((1, 1));
let dt = 1e-4_f32;
let (a_fe, _) = forward_euler_discretize(&a, &b, dt);
let (a_zoh, _) = zoh_discretize_diagonal(&a, &b, dt);
let err = (a_fe[0] - a_zoh[0]).abs();
assert!(
err < 1e-7,
"FE vs ZOH-diagonal error {err} too large for dt={dt}"
);
}
#[test]
fn discretize_zoh_matches_zoh_discretize_diagonal() {
let a = Array1::from_vec(vec![-0.5f32, -1.0, -2.0]);
let b = Array2::<f32>::ones((3, 2));
let dt = 0.05;
let (a1, b1) = zoh_discretize_diagonal(&a, &b, dt);
let (a2, b2) = discretize(DiscretizationMethod::Zoh, &a, &b, dt);
for (x, y) in a1.iter().zip(a2.iter()) {
assert!((x - y).abs() < 1e-7, "ZOH mismatch: {x} vs {y}");
}
for (x, y) in b1.iter().zip(b2.iter()) {
assert!((x - y).abs() < 1e-7, "ZOH B mismatch: {x} vs {y}");
}
}
#[test]
fn zoh_diagonal_exact_expm() {
let a = Array1::from_vec(vec![-1.0f32, -2.0, -0.5]);
let b = Array2::<f32>::ones((3, 2));
let dt = 0.1;
let (a_bar, _) = zoh_discretize_diagonal(&a, &b, dt);
for (i, (&ab, &ai)) in a_bar.iter().zip(a.iter()).enumerate() {
let expected = (dt * ai).exp();
assert!(
(ab - expected).abs() < 1e-6,
"ZOH-diagonal a_bar[{i}]={ab} expected {expected}"
);
}
}
}