use super::super::Tensor;
#[allow(dead_code)] pub const STABILITY_EPSILON_F32: f32 = 1e-7;
#[allow(dead_code)] pub const STABILITY_EPSILON_F64: f64 = 1e-15;
#[allow(dead_code)] pub const MAX_SAFE_VALUE_F32: f32 = 1e30;
#[allow(dead_code)] pub const MAX_SAFE_VALUE_F64: f64 = 1e300;
#[allow(dead_code)] pub fn is_stable_f32(x: f32) -> bool {
x.is_finite() && x.abs() < MAX_SAFE_VALUE_F32 && (x.abs() > STABILITY_EPSILON_F32 || x == 0.0)
}
#[allow(dead_code)] pub fn is_stable_f64(x: f64) -> bool {
x.is_finite() && x.abs() < MAX_SAFE_VALUE_F64 && (x.abs() > STABILITY_EPSILON_F64 || x == 0.0)
}
#[allow(dead_code)] pub fn stabilize_f32(x: f32) -> f32 {
if !x.is_finite() {
return 0.0;
}
if x.abs() > MAX_SAFE_VALUE_F32 {
x.signum() * MAX_SAFE_VALUE_F32
} else if x.abs() < STABILITY_EPSILON_F32 && x != 0.0 {
x.signum() * STABILITY_EPSILON_F32
} else {
x
}
}
#[allow(dead_code)] pub fn stabilize_f64(x: f64) -> f64 {
if !x.is_finite() {
return 0.0;
}
if x.abs() > MAX_SAFE_VALUE_F64 {
x.signum() * MAX_SAFE_VALUE_F64
} else if x.abs() < STABILITY_EPSILON_F64 && x != 0.0 {
x.signum() * STABILITY_EPSILON_F64
} else {
x
}
}
impl Tensor {
pub fn shapes_are_broadcastable(shape1: &[usize], shape2: &[usize]) -> bool {
let max_len = shape1.len().max(shape2.len());
let mut s1 = vec![1; max_len];
let mut s2 = vec![1; max_len];
for (i, &dim) in shape1.iter().rev().enumerate() {
if i < max_len {
s1[max_len - 1 - i] = dim;
}
}
for (i, &dim) in shape2.iter().rev().enumerate() {
if i < max_len {
s2[max_len - 1 - i] = dim;
}
}
for (d1, d2) in s1.iter().zip(s2.iter()) {
if *d1 != *d2 && *d1 != 1 && *d2 != 1 {
return false; }
}
true }
}