#[inline]
fn dot_product(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b).map(|(x, y)| x * y).sum()
}
pub fn softmax(values: &[f32], beta: f32) -> Vec<f32> {
if values.is_empty() {
return Vec::new();
}
let max_val = values.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_values: Vec<f32> = values
.iter()
.map(|&x| ((x - max_val) * beta).exp())
.collect();
let sum: f32 = exp_values.iter().sum();
if sum <= f32::EPSILON {
let n = exp_values.len() as f32;
return vec![1.0 / n; exp_values.len()];
}
exp_values.iter().map(|&x| x / sum).collect()
}
pub fn compute_attention(patterns: &[Vec<f32>], query: &[f32], beta: f32) -> (Vec<f32>, Vec<f32>) {
let similarities: Vec<f32> = patterns
.iter()
.map(|pattern| dot_product(pattern, query))
.collect();
let attention = softmax(&similarities, beta);
(attention, similarities)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_dot_product() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let result = dot_product(&a, &b);
assert_relative_eq!(result, 32.0, epsilon = 1e-6);
}
#[test]
fn test_dot_product_orthogonal() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
let result = dot_product(&a, &b);
assert_relative_eq!(result, 0.0, epsilon = 1e-6);
}
#[test]
fn test_softmax_uniform() {
let values = vec![1.0, 1.0, 1.0];
let probs = softmax(&values, 1.0);
for &p in &probs {
assert_relative_eq!(p, 1.0 / 3.0, epsilon = 1e-6);
}
}
#[test]
fn test_softmax_sums_to_one() {
let values = vec![0.5, 1.0, 1.5, 2.0];
let probs = softmax(&values, 1.0);
let sum: f32 = probs.iter().sum();
assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
}
#[test]
fn test_softmax_temperature_effect() {
let values = vec![1.0, 2.0];
let probs_low = softmax(&values, 0.5);
let probs_high = softmax(&values, 5.0);
assert!(probs_high[1] > probs_low[1]);
}
#[test]
fn test_softmax_empty() {
let values: Vec<f32> = Vec::new();
let probs = softmax(&values, 1.0);
assert!(probs.is_empty());
}
#[test]
fn test_softmax_numerical_stability() {
let values = vec![1000.0, 1001.0, 1002.0];
let probs = softmax(&values, 1.0);
let sum: f32 = probs.iter().sum();
assert_relative_eq!(sum, 1.0, epsilon = 1e-5);
}
#[test]
fn test_compute_attention_orthogonal_patterns() {
let patterns = vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
];
let query = vec![1.0, 0.0, 0.0];
let (attention, similarities) = compute_attention(&patterns, &query, 1.0);
assert_relative_eq!(similarities[0], 1.0, epsilon = 1e-6);
assert_relative_eq!(similarities[1], 0.0, epsilon = 1e-6);
assert_relative_eq!(similarities[2], 0.0, epsilon = 1e-6);
assert!(attention[0] > attention[1]);
assert!(attention[0] > attention[2]);
}
#[test]
fn test_compute_attention_identical_patterns() {
let patterns = vec![vec![1.0, 1.0, 1.0], vec![1.0, 1.0, 1.0]];
let query = vec![1.0, 1.0, 1.0];
let (attention, similarities) = compute_attention(&patterns, &query, 1.0);
assert_relative_eq!(similarities[0], 3.0, epsilon = 1e-6);
assert_relative_eq!(similarities[1], 3.0, epsilon = 1e-6);
assert_relative_eq!(attention[0], 0.5, epsilon = 1e-6);
assert_relative_eq!(attention[1], 0.5, epsilon = 1e-6);
}
#[test]
fn test_compute_attention_beta_effect() {
let patterns = vec![vec![1.0, 0.0], vec![0.5, 0.5]];
let query = vec![1.0, 0.0];
let (attn_low, _) = compute_attention(&patterns, &query, 0.5);
let (attn_high, _) = compute_attention(&patterns, &query, 5.0);
assert!(attn_high[0] > attn_low[0]);
assert!(attn_high[1] < attn_low[1]);
}
}