use super::FixedPoint;
use super::linalg::{ComputeStorage, upscale_to_compute, round_to_storage};
use crate::fixed_point::universal::fasc::stack_evaluator::compute::{
compute_add, compute_subtract, compute_multiply, compute_divide,
compute_negate, compute_is_zero,
sqrt_at_compute_tier, exp_at_compute_tier,
};
use crate::fixed_point::core_types::errors::OverflowDetected;
#[inline]
fn compute_zero() -> ComputeStorage {
upscale_to_compute(FixedPoint::ZERO.raw())
}
#[inline]
fn compute_one() -> ComputeStorage {
upscale_to_compute(FixedPoint::one().raw())
}
pub fn sqrt_sum_sq(values: &[FixedPoint]) -> FixedPoint {
let mut acc = compute_zero();
for v in values {
let vc = upscale_to_compute(v.raw());
acc = compute_add(acc, compute_multiply(vc, vc));
}
FixedPoint::from_raw(round_to_storage(sqrt_at_compute_tier(acc)))
}
pub fn euclidean_distance(a: &[FixedPoint], b: &[FixedPoint]) -> FixedPoint {
assert_eq!(a.len(), b.len(), "euclidean_distance: dimension mismatch");
let mut acc = compute_zero();
for i in 0..a.len() {
let da = upscale_to_compute(a[i].raw());
let db = upscale_to_compute(b[i].raw());
let diff = compute_subtract(da, db);
acc = compute_add(acc, compute_multiply(diff, diff));
}
FixedPoint::from_raw(round_to_storage(sqrt_at_compute_tier(acc)))
}
pub fn softmax(scores: &[FixedPoint]) -> Result<Vec<FixedPoint>, OverflowDetected> {
if scores.is_empty() {
return Ok(vec![]);
}
let mut max_raw = scores[0].raw();
for s in &scores[1..] {
if s.raw() > max_raw {
max_raw = s.raw();
}
}
let max_compute = upscale_to_compute(max_raw);
let mut exp_values: Vec<ComputeStorage> = Vec::with_capacity(scores.len());
let mut sum = compute_zero();
for s in scores {
let s_compute = upscale_to_compute(s.raw());
let shifted = compute_subtract(s_compute, max_compute);
let e = exp_at_compute_tier(shifted);
sum = compute_add(sum, e);
exp_values.push(e);
}
if compute_is_zero(&sum) {
return Err(OverflowDetected::DivisionByZero);
}
let mut result = Vec::with_capacity(scores.len());
for e in &exp_values {
let normalized = compute_divide(*e, sum)?;
result.push(FixedPoint::from_raw(round_to_storage(normalized)));
}
Ok(result)
}
pub fn rms_norm_factor(values: &[FixedPoint], eps: FixedPoint) -> Result<FixedPoint, OverflowDetected> {
if values.is_empty() {
return Err(OverflowDetected::DivisionByZero);
}
let mut sum_sq = compute_zero();
for v in values {
let vc = upscale_to_compute(v.raw());
sum_sq = compute_add(sum_sq, compute_multiply(vc, vc));
}
let n_compute = upscale_to_compute(FixedPoint::from_int(values.len() as i32).raw());
let mean = compute_divide(sum_sq, n_compute)?;
let eps_compute = upscale_to_compute(eps.raw());
let mean_eps = compute_add(mean, eps_compute);
let root = sqrt_at_compute_tier(mean_eps);
if compute_is_zero(&root) {
return Err(OverflowDetected::DivisionByZero);
}
let inv = compute_divide(compute_one(), root)?;
Ok(FixedPoint::from_raw(round_to_storage(inv)))
}
pub fn silu(x: FixedPoint) -> FixedPoint {
let x_compute = upscale_to_compute(x.raw());
let neg_x = compute_negate(x_compute);
let exp_neg = exp_at_compute_tier(neg_x);
let one_plus_exp = compute_add(compute_one(), exp_neg);
if compute_is_zero(&one_plus_exp) {
return FixedPoint::ZERO;
}
match compute_divide(x_compute, one_plus_exp) {
Ok(result) => FixedPoint::from_raw(round_to_storage(result)),
Err(_) => FixedPoint::ZERO,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn fp(s: &str) -> FixedPoint {
if s.starts_with('-') { -FixedPoint::from_str(&s[1..]) }
else { FixedPoint::from_str(s) }
}
fn tight() -> FixedPoint {
#[cfg(table_format = "q16_16")]
{ fp("0.001") }
#[cfg(table_format = "q32_32")]
{ fp("0.000000001") }
#[cfg(any(table_format = "q64_64", table_format = "q128_128", table_format = "q256_256"))]
{ fp("0.000000001") }
}
#[test]
fn test_sqrt_sum_sq_basic() {
let vals = [fp("3"), fp("4")];
let result = sqrt_sum_sq(&vals);
let diff = (result - fp("5")).abs();
assert!(diff < tight(), "sqrt(3²+4²) = {}, expected 5", result);
}
#[test]
fn test_sqrt_sum_sq_single() {
let vals = [fp("7")];
let result = sqrt_sum_sq(&vals);
let diff = (result - fp("7")).abs();
assert!(diff < tight(), "sqrt(7²) = {}, expected 7", result);
}
#[test]
fn test_euclidean_distance_basic() {
let a = [FixedPoint::ZERO, FixedPoint::ZERO];
let b = [fp("3"), fp("4")];
let dist = euclidean_distance(&a, &b);
let diff = (dist - fp("5")).abs();
assert!(diff < tight(), "dist([0,0],[3,4]) = {}, expected 5", dist);
}
#[test]
fn test_euclidean_distance_same_point() {
let a = [fp("1"), fp("2"), fp("3")];
let dist = euclidean_distance(&a, &a);
assert!(dist.is_zero() || dist.abs() < tight(),
"distance to self should be 0, got {}", dist);
}
#[test]
fn test_softmax_uniform() {
let scores = vec![fp("1"); 4];
let result = softmax(&scores).unwrap();
let expected = fp("0.25");
for (i, w) in result.iter().enumerate() {
let diff = (*w - expected).abs();
assert!(diff < fp("0.001"), "softmax[{}] = {}, expected 0.25", i, w);
}
}
#[test]
fn test_softmax_sums_to_one() {
let scores = vec![fp("1"), fp("2"), fp("3"), fp("4")];
let result = softmax(&scores).unwrap();
let sum: FixedPoint = result.iter().copied().fold(FixedPoint::ZERO, |a, b| a + b);
let diff = (sum - fp("1")).abs();
assert!(diff < tight(), "softmax sum = {}, expected 1.0", sum);
}
#[test]
fn test_softmax_monotone() {
let scores = vec![fp("1"), fp("2"), fp("3")];
let result = softmax(&scores).unwrap();
assert!(result[0] < result[1], "softmax not monotone: {} >= {}", result[0], result[1]);
assert!(result[1] < result[2], "softmax not monotone: {} >= {}", result[1], result[2]);
}
#[test]
fn test_rms_norm_factor_constant() {
let c = fp("2");
let eps = fp("0.000001");
let vals = vec![c; 4];
let factor = rms_norm_factor(&vals, eps).unwrap();
let diff = (factor - fp("0.5")).abs();
assert!(diff < fp("0.001"), "rms_norm_factor = {}, expected ~0.5", factor);
}
#[test]
fn test_silu_zero() {
let result = silu(FixedPoint::ZERO);
assert!(result.abs() < tight(), "silu(0) = {}, expected 0", result);
}
#[test]
fn test_silu_positive() {
let x = fp("10");
let result = silu(x);
let diff = (result - x).abs();
assert!(diff < fp("0.001"), "silu(10) = {}, expected ~10", result);
}
#[test]
fn test_silu_negative() {
let result = silu(fp("-10"));
assert!(result.abs() < fp("0.001"), "silu(-10) = {}, expected ~0", result);
}
}