use crate::error::{AttentionError, AttentionResult};
#[inline]
pub fn stable_softmax(values: &[f32]) -> Vec<f32> {
if values.is_empty() {
return vec![];
}
let max_val = values
.iter()
.copied()
.filter(|x| x.is_finite())
.fold(f32::NEG_INFINITY, f32::max);
if !max_val.is_finite() {
let n = values.len();
return vec![1.0 / n as f32; n];
}
let mut exp_values: Vec<f32> = values
.iter()
.map(|&x| {
if x.is_finite() {
(x - max_val).exp()
} else {
0.0
}
})
.collect();
let sum: f32 = exp_values.iter().sum();
if sum <= 1e-10 || !sum.is_finite() {
let n = values.len();
return vec![1.0 / n as f32; n];
}
let inv_sum = 1.0 / sum;
exp_values.iter_mut().for_each(|x| *x *= inv_sum);
exp_values
}
#[inline]
pub fn softmax(values: &[f32]) -> AttentionResult<Vec<f32>> {
if values.is_empty() {
return Err(AttentionError::EmptyInput(
"cannot compute softmax of empty slice".to_string(),
));
}
let max_val = values.iter().copied().fold(f32::NEG_INFINITY, f32::max);
if !max_val.is_finite() {
return Err(AttentionError::NumericalInstability(
"non-finite values in softmax input".to_string(),
));
}
let mut exp_values: Vec<f32> = values.iter().map(|&x| (x - max_val).exp()).collect();
let sum: f32 = exp_values.iter().sum();
if sum <= 0.0 || !sum.is_finite() {
return Err(AttentionError::NumericalInstability(
"invalid sum in softmax computation".to_string(),
));
}
let inv_sum = 1.0 / sum;
exp_values.iter_mut().for_each(|x| *x *= inv_sum);
Ok(exp_values)
}
#[inline]
pub fn masked_softmax(values: &[f32], mask: Option<&[bool]>) -> AttentionResult<Vec<f32>> {
if values.is_empty() {
return Err(AttentionError::EmptyInput(
"cannot compute softmax of empty slice".to_string(),
));
}
let masked_values = if let Some(m) = mask {
if m.len() != values.len() {
return Err(AttentionError::InvalidMask {
expected: format!("{}", values.len()),
actual: format!("{}", m.len()),
});
}
values
.iter()
.zip(m.iter())
.map(|(&v, &keep)| if keep { v } else { f32::NEG_INFINITY })
.collect::<Vec<_>>()
} else {
values.to_vec()
};
softmax(&masked_values)
}
pub fn apply_causal_mask(
scores: &mut [f32],
query_len: usize,
key_len: usize,
) -> AttentionResult<()> {
if scores.len() != query_len * key_len {
return Err(AttentionError::InvalidMask {
expected: format!("{}x{}", query_len, key_len),
actual: format!("{}", scores.len()),
});
}
for i in 0..query_len {
for j in (i + 1)..key_len {
scores[i * key_len + j] = f32::NEG_INFINITY;
}
}
Ok(())
}
#[inline]
pub fn dot_product(a: &[f32], b: &[f32]) -> AttentionResult<f32> {
if a.len() != b.len() {
return Err(AttentionError::DimensionMismatch {
expected: a.len(),
actual: b.len(),
});
}
Ok(a.iter().zip(b.iter()).map(|(x, y)| x * y).sum())
}
#[inline]
pub fn scale_vector(vector: &mut [f32], scale: f32) {
vector.iter_mut().for_each(|x| *x *= scale);
}
#[inline]
pub fn add_vectors(a: &[f32], b: &[f32]) -> AttentionResult<Vec<f32>> {
if a.len() != b.len() {
return Err(AttentionError::DimensionMismatch {
expected: a.len(),
actual: b.len(),
});
}
Ok(a.iter().zip(b.iter()).map(|(x, y)| x + y).collect())
}
#[inline]
pub fn l2_norm(vector: &[f32]) -> f32 {
vector.iter().map(|x| x * x).sum::<f32>().sqrt()
}
pub fn normalize_vector(vector: &mut [f32]) -> AttentionResult<f32> {
let norm = l2_norm(vector);
if norm <= 0.0 || !norm.is_finite() {
return Err(AttentionError::NumericalInstability(
"cannot normalize zero or non-finite vector".to_string(),
));
}
let inv_norm = 1.0 / norm;
vector.iter_mut().for_each(|x| *x *= inv_norm);
Ok(norm)
}
pub fn apply_dropout(
vector: &mut [f32],
dropout_prob: f32,
training: bool,
rng: &mut impl rand::Rng,
) {
if !training || dropout_prob == 0.0 {
return;
}
let scale = 1.0 / (1.0 - dropout_prob);
for x in vector.iter_mut() {
if rng.gen::<f32>() < dropout_prob {
*x = 0.0;
} else {
*x *= scale;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_softmax() {
let values = vec![1.0, 2.0, 3.0];
let result = softmax(&values).unwrap();
let sum: f32 = result.iter().sum();
assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
assert!(result[0] < result[1]);
assert!(result[1] < result[2]);
}
#[test]
fn test_softmax_numerical_stability() {
let values = vec![1000.0, 1001.0, 1002.0];
let result = softmax(&values).unwrap();
let sum: f32 = result.iter().sum();
assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
}
#[test]
fn test_masked_softmax() {
let values = vec![1.0, 2.0, 3.0, 4.0];
let mask = vec![true, true, false, false];
let result = masked_softmax(&values, Some(&mask)).unwrap();
assert_relative_eq!(result[2], 0.0, epsilon = 1e-6);
assert_relative_eq!(result[3], 0.0, epsilon = 1e-6);
let sum: f32 = result[0] + result[1];
assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
}
#[test]
fn test_dot_product() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let result = dot_product(&a, &b).unwrap();
assert_relative_eq!(result, 32.0, epsilon = 1e-6);
}
#[test]
fn test_scale_vector() {
let mut vector = vec![1.0, 2.0, 3.0];
scale_vector(&mut vector, 2.0);
assert_relative_eq!(vector[0], 2.0);
assert_relative_eq!(vector[1], 4.0);
assert_relative_eq!(vector[2], 6.0);
}
#[test]
fn test_normalize_vector() {
let mut vector = vec![3.0, 4.0];
let norm = normalize_vector(&mut vector).unwrap();
assert_relative_eq!(norm, 5.0, epsilon = 1e-6);
assert_relative_eq!(l2_norm(&vector), 1.0, epsilon = 1e-6);
}
#[test]
fn test_causal_mask() {
let mut scores = vec![0.0; 9]; apply_causal_mask(&mut scores, 3, 3).unwrap();
assert_eq!(scores[1], f32::NEG_INFINITY); assert_eq!(scores[2], f32::NEG_INFINITY); assert_eq!(scores[5], f32::NEG_INFINITY);
assert_eq!(scores[0], 0.0); assert_eq!(scores[4], 0.0); assert_eq!(scores[8], 0.0); }
}