use super::SurvivalLocationScaleError;
use ndarray::Array1;
#[inline]
pub(super) fn softplus(x: f64) -> f64 {
if x.is_nan() {
f64::NAN
} else if x == f64::INFINITY {
f64::INFINITY
} else if x == f64::NEG_INFINITY {
0.0
} else if x >= 0.0 {
x + (-x).exp().ln_1p()
} else {
x.exp().ln_1p()
}
}
#[inline]
pub(super) fn safe_product(lhs: f64, rhs: f64) -> f64 {
if lhs == 0.0 || rhs == 0.0 {
0.0
} else {
let v = lhs * rhs;
if v == f64::INFINITY {
f64::MAX
} else if v == f64::NEG_INFINITY {
f64::MIN
} else {
v
}
}
}
#[inline]
pub(super) fn safe_sum2(a: f64, b: f64) -> f64 {
let sum = a + b;
if sum.is_nan() {
if a == 0.0 {
return b;
} else if b == 0.0 {
return a;
}
if (a == f64::INFINITY && b == f64::NEG_INFINITY)
|| (a == f64::NEG_INFINITY && b == f64::INFINITY)
{
return 0.0;
}
sum
} else {
sum
}
}
#[inline]
pub(super) fn safe_sum3(a: f64, b: f64, c: f64) -> f64 {
safe_sum2(safe_sum2(a, b), c)
}
#[inline]
pub(super) fn safe_product3(a: f64, b: f64, c: f64) -> f64 {
let mut factors = [a, b, c];
factors.sort_by(|lhs, rhs| lhs.abs().total_cmp(&rhs.abs()));
safe_product(safe_product(factors[0], factors[1]), factors[2])
}
pub(super) fn safe_hadamard_product(
lhs: &Array1<f64>,
rhs: &Array1<f64>,
) -> Result<Array1<f64>, SurvivalLocationScaleError> {
if lhs.len() != rhs.len() {
crate::bail_dim_sls!(
"safe_hadamard_product length mismatch: lhs has {}, rhs has {}",
lhs.len(),
rhs.len()
);
}
let out = Array1::from_shape_fn(lhs.len(), |i| safe_product(lhs[i], rhs[i]));
if out.iter().any(|value| value.is_nan()) {
return Err(SurvivalLocationScaleError::NumericalFailure {
reason: "safe_hadamard_product produced NaN values".to_string(),
});
}
Ok(out)
}
pub(super) fn safe_linear_combo2_arrays(
a: &Array1<f64>,
b: &Array1<f64>,
c: &Array1<f64>,
d: &Array1<f64>,
) -> Result<Array1<f64>, SurvivalLocationScaleError> {
if a.len() != b.len() || a.len() != c.len() || a.len() != d.len() {
crate::bail_dim_sls!(
"safe_linear_combo2_arrays length mismatch: a={}, b={}, c={}, d={}",
a.len(),
b.len(),
c.len(),
d.len()
);
}
let out = Array1::from_shape_fn(a.len(), |i| {
safe_sum2(safe_product(a[i], b[i]), safe_product(c[i], d[i]))
});
if out.iter().any(|value| value.is_nan()) {
return Err(SurvivalLocationScaleError::NumericalFailure {
reason: "safe_linear_combo2_arrays produced NaN values".to_string(),
});
}
Ok(out)
}
pub(super) fn sanitize_survival_weight_vector(weights: &Array1<f64>) -> Array1<f64> {
Array1::from_shape_fn(weights.len(), |i| {
let value = weights[i];
if value.is_finite() {
value
} else if value == f64::INFINITY {
f64::MAX
} else if value == f64::NEG_INFINITY {
f64::MIN
} else {
0.0
}
})
}
#[derive(Clone, Copy)]
pub(super) struct StableDifference {
pub(super) value: f64,
pub(super) roundoff_slack: f64,
pub(super) operand_scale: f64,
}
#[inline]
fn two_diff(lhs: f64, rhs: f64) -> (f64, f64) {
let high = lhs - rhs;
let z = high - lhs;
let low = (lhs - (high - z)) - (rhs + z);
(high, low)
}
#[inline]
pub(super) fn compensated_difference(lhs: f64, rhs: f64) -> StableDifference {
let operand_scale = lhs.abs().max(rhs.abs());
if lhs.is_nan() || rhs.is_nan() {
return StableDifference {
value: f64::NAN,
roundoff_slack: 0.0,
operand_scale,
};
}
if !lhs.is_finite() || !rhs.is_finite() {
let diff = safe_sum2(lhs, -rhs);
let slack = if diff == 0.0 && operand_scale > 0.0 {
operand_scale
} else {
0.0
};
return StableDifference {
value: diff,
roundoff_slack: slack,
operand_scale,
};
}
let (high, low) = two_diff(lhs, rhs);
if !high.is_finite() {
return StableDifference {
value: high,
roundoff_slack: 0.0,
operand_scale,
};
}
let value = high + low;
let roundoff_slack = low.abs() + 128.0 * f64::EPSILON * operand_scale.max(value.abs());
StableDifference {
value,
roundoff_slack,
operand_scale,
}
}