use scirs2_core::ndarray::{Array1, Array2, Zip};
#[inline]
pub fn fused_mul_add(a: &Array1<f32>, b: &Array1<f32>, c: &Array1<f32>) -> Array1<f32> {
debug_assert_eq!(a.len(), b.len());
debug_assert_eq!(a.len(), c.len());
let mut result = Array1::zeros(a.len());
Zip::from(&mut result)
.and(a)
.and(b)
.and(c)
.for_each(|r, &av, &bv, &cv| {
*r = av.mul_add(bv, cv);
});
result
}
#[inline]
pub fn fused_mul_scalar_add(a: &Array1<f32>, scalar: f32, b: &Array1<f32>) -> Array1<f32> {
debug_assert_eq!(a.len(), b.len());
let mut result = Array1::zeros(a.len());
Zip::from(&mut result)
.and(a)
.and(b)
.for_each(|r, &av, &bv| {
*r = scalar.mul_add(av, bv);
});
result
}
#[inline]
pub fn exp_array(a: &Array1<f32>) -> Array1<f32> {
a.mapv(f32::exp)
}
#[inline]
pub fn ln_array(a: &Array1<f32>) -> Array1<f32> {
a.mapv(f32::ln)
}
#[inline]
pub fn silu(x: &Array1<f32>) -> Array1<f32> {
x.mapv(|v| {
let sigmoid = 1.0 / (1.0 + (-v).exp());
v * sigmoid
})
}
#[inline]
pub fn sigmoid(x: &Array1<f32>) -> Array1<f32> {
x.mapv(|v| 1.0 / (1.0 + (-v).exp()))
}
#[inline]
pub fn tanh_array(x: &Array1<f32>) -> Array1<f32> {
x.mapv(f32::tanh)
}
#[inline]
pub fn relu(x: &Array1<f32>) -> Array1<f32> {
x.mapv(|v| v.max(0.0))
}
#[inline]
pub fn gelu(x: &Array1<f32>) -> Array1<f32> {
const SQRT_2_OVER_PI: f32 = 0.797_884_6; const COEFF: f32 = 0.044715;
x.mapv(|v| {
let inner = SQRT_2_OVER_PI * (v + COEFF * v * v * v);
0.5 * v * (1.0 + inner.tanh())
})
}
#[inline]
pub fn ssm_state_update(
h_prev: &Array1<f32>,
a: &Array1<f32>,
b: &Array1<f32>,
x: &Array1<f32>,
) -> Array1<f32> {
debug_assert_eq!(h_prev.len(), a.len());
debug_assert_eq!(h_prev.len(), b.len());
debug_assert_eq!(h_prev.len(), x.len());
let mut h_new = Array1::zeros(h_prev.len());
Zip::from(&mut h_new)
.and(h_prev)
.and(a)
.and(b)
.and(x)
.for_each(|h, &h_p, &av, &bv, &xv| {
*h = av.mul_add(h_p, bv * xv);
});
h_new
}
#[inline]
pub fn diagonal_ssm_update(
state: &Array1<f32>,
a_diag: &Array1<f32>,
b: &Array1<f32>,
x_scalar: f32,
) -> Array1<f32> {
debug_assert_eq!(state.len(), a_diag.len());
debug_assert_eq!(state.len(), b.len());
let mut new_state = Array1::zeros(state.len());
Zip::from(&mut new_state)
.and(state)
.and(a_diag)
.and(b)
.for_each(|s, &state_val, &a_val, &b_val| {
*s = a_val.mul_add(state_val, b_val * x_scalar);
});
new_state
}
#[inline]
pub fn matvec(matrix: &Array2<f32>, vector: &Array1<f32>) -> Array1<f32> {
let (m, n) = matrix.dim();
debug_assert_eq!(vector.len(), n);
let mut result = Array1::zeros(m);
for i in 0..m {
let row = matrix.row(i);
let mut sum = 0.0;
for j in 0..n {
sum = row[j].mul_add(vector[j], sum);
}
result[i] = sum;
}
result
}
#[inline]
pub fn softmax(x: &Array1<f32>) -> Array1<f32> {
let max_val = x.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_values = x.mapv(|v| (v - max_val).exp());
let sum: f32 = exp_values.sum();
exp_values.mapv(|v| v / sum)
}
#[inline]
pub fn layer_norm(
x: &Array1<f32>,
gamma: &Array1<f32>,
beta: &Array1<f32>,
eps: f32,
) -> Array1<f32> {
debug_assert_eq!(x.len(), gamma.len());
debug_assert_eq!(x.len(), beta.len());
let mean = x.mean().unwrap_or(0.0);
let variance = x.iter().map(|&v| (v - mean).powi(2)).sum::<f32>() / x.len() as f32;
let std_dev = (variance + eps).sqrt();
let mut result = Array1::zeros(x.len());
Zip::from(&mut result)
.and(x)
.and(gamma)
.and(beta)
.for_each(|r, &xv, &g, &b| {
*r = g * ((xv - mean) / std_dev) + b;
});
result
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f32, b: f32, epsilon: f32) -> bool {
(a - b).abs() < epsilon
}
fn array_approx_eq(a: &Array1<f32>, b: &Array1<f32>, epsilon: f32) -> bool {
a.len() == b.len()
&& a.iter()
.zip(b.iter())
.all(|(&av, &bv)| approx_eq(av, bv, epsilon))
}
#[test]
fn test_fused_mul_add() {
let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let b = Array1::from_vec(vec![2.0, 3.0, 4.0]);
let c = Array1::from_vec(vec![1.0, 1.0, 1.0]);
let result = fused_mul_add(&a, &b, &c);
let expected = Array1::from_vec(vec![3.0, 7.0, 13.0]);
assert!(array_approx_eq(&result, &expected, 1e-6));
}
#[test]
fn test_silu() {
let x = Array1::from_vec(vec![0.0, 1.0, -1.0]);
let result = silu(&x);
assert!(approx_eq(result[0], 0.0, 1e-5));
assert!(approx_eq(result[1], 0.731, 1e-2));
assert!(approx_eq(result[2], -0.269, 1e-2));
}
#[test]
fn test_relu() {
let x = Array1::from_vec(vec![-1.0, 0.0, 1.0, 2.0]);
let result = relu(&x);
let expected = Array1::from_vec(vec![0.0, 0.0, 1.0, 2.0]);
assert!(array_approx_eq(&result, &expected, 1e-6));
}
#[test]
fn test_softmax() {
let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let result = softmax(&x);
let sum: f32 = result.sum();
assert!(approx_eq(sum, 1.0, 1e-5));
assert!(result[0] < result[1]);
assert!(result[1] < result[2]);
}
#[test]
fn test_ssm_state_update() {
let h_prev = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let a = Array1::from_vec(vec![0.9, 0.8, 0.7]);
let b = Array1::from_vec(vec![0.1, 0.2, 0.3]);
let x = Array1::from_vec(vec![1.0, 1.0, 1.0]);
let result = ssm_state_update(&h_prev, &a, &b, &x);
let expected = Array1::from_vec(vec![1.0, 1.8, 2.4]);
assert!(array_approx_eq(&result, &expected, 1e-5));
}
#[test]
fn test_diagonal_ssm_update() {
let state = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let a_diag = Array1::from_vec(vec![0.9, 0.8, 0.7]);
let b = Array1::from_vec(vec![0.1, 0.2, 0.3]);
let x_scalar = 2.0;
let result = diagonal_ssm_update(&state, &a_diag, &b, x_scalar);
let expected = Array1::from_vec(vec![1.1, 2.0, 2.7]);
assert!(array_approx_eq(&result, &expected, 1e-5));
}
#[test]
fn test_matvec() {
let matrix = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.expect("Failed to create test matrix");
let vector = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let result = matvec(&matrix, &vector);
let expected = Array1::from_vec(vec![14.0, 32.0]);
assert!(array_approx_eq(&result, &expected, 1e-5));
}
#[test]
fn test_layer_norm() {
let x = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let gamma = Array1::from_vec(vec![1.0, 1.0, 1.0, 1.0]);
let beta = Array1::from_vec(vec![0.0, 0.0, 0.0, 0.0]);
let result = layer_norm(&x, &gamma, &beta, 1e-5);
let mean = result.mean().expect("Failed to compute mean");
assert!(approx_eq(mean, 0.0, 1e-5));
}
}