use crate::error::{AttentionError, AttentionResult};
use crate::traits::Attention;
#[derive(Clone, Debug)]
pub enum KernelType {
Softmax,
ReLU,
ELU,
}
pub struct LinearAttention {
dim: usize,
num_features: usize,
kernel: KernelType,
random_features: Vec<f32>,
}
impl LinearAttention {
pub fn new(dim: usize, num_features: usize) -> Self {
Self::with_kernel(dim, num_features, KernelType::Softmax)
}
pub fn with_kernel(dim: usize, num_features: usize, kernel: KernelType) -> Self {
let random_features = Self::generate_random_features(dim, num_features);
Self {
dim,
num_features,
kernel,
random_features,
}
}
fn generate_random_features(dim: usize, num_features: usize) -> Vec<f32> {
use std::f32::consts::PI;
let mut features = Vec::with_capacity(num_features * dim);
let mut seed = 42u64;
for _ in 0..((num_features * dim + 1) / 2) {
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
let u1 = (seed as f32) / (u64::MAX as f32);
seed = seed.wrapping_mul(6364136223846793005).wrapping_add(1);
let u2 = (seed as f32) / (u64::MAX as f32);
let r = (-2.0 * u1.max(1e-10).ln()).sqrt();
let theta = 2.0 * PI * u2;
features.push(r * theta.cos());
if features.len() < num_features * dim {
features.push(r * theta.sin());
}
}
features.truncate(num_features * dim);
let scale = 1.0 / (dim as f32).sqrt();
features.iter_mut().for_each(|x| *x *= scale);
features
}
fn feature_map(&self, x: &[f32]) -> Vec<f32> {
let mut phi = vec![0.0f32; self.num_features];
for (i, phi_i) in phi.iter_mut().enumerate() {
let projection: f32 = x
.iter()
.enumerate()
.map(|(j, &xj)| xj * self.random_features[i * self.dim + j])
.sum();
*phi_i = match self.kernel {
KernelType::Softmax => {
let norm_sq: f32 = x.iter().map(|xi| xi * xi).sum();
(projection - norm_sq / 2.0).exp() / (self.num_features as f32).sqrt()
}
KernelType::ReLU => projection.max(0.0),
KernelType::ELU => {
if projection >= 0.0 {
projection
} else {
projection.exp() - 1.0
}
}
};
}
phi
}
}
impl Attention for LinearAttention {
fn compute(
&self,
query: &[f32],
keys: &[&[f32]],
values: &[&[f32]],
) -> AttentionResult<Vec<f32>> {
if keys.is_empty() {
return Err(AttentionError::InvalidConfig("Empty keys".to_string()));
}
if keys.len() != values.len() {
return Err(AttentionError::DimensionMismatch {
expected: keys.len(),
actual: values.len(),
});
}
if query.len() != self.dim {
return Err(AttentionError::DimensionMismatch {
expected: self.dim,
actual: query.len(),
});
}
let phi_q = self.feature_map(query);
let value_dim = values[0].len();
let mut kv_sum = vec![0.0f32; self.num_features * value_dim]; let mut k_sum = vec![0.0f32; self.num_features];
for (key, value) in keys.iter().zip(values.iter()) {
let phi_k = self.feature_map(key);
for (i, &phi_ki) in phi_k.iter().enumerate() {
for (j, &vj) in value.iter().enumerate() {
kv_sum[i * value_dim + j] += phi_ki * vj;
}
k_sum[i] += phi_ki;
}
}
let mut output = vec![0.0f32; value_dim];
let mut normalizer = 0.0f32;
for (i, &phi_qi) in phi_q.iter().enumerate() {
for (j, out_j) in output.iter_mut().enumerate() {
*out_j += phi_qi * kv_sum[i * value_dim + j];
}
normalizer += phi_qi * k_sum[i];
}
if normalizer.abs() > 1e-8 {
output.iter_mut().for_each(|x| *x /= normalizer);
}
Ok(output)
}
fn compute_with_mask(
&self,
query: &[f32],
keys: &[&[f32]],
values: &[&[f32]],
mask: Option<&[bool]>,
) -> AttentionResult<Vec<f32>> {
if let Some(m) = mask {
let filtered: Vec<(usize, bool)> = m
.iter()
.copied()
.enumerate()
.filter(|(_, keep)| *keep)
.collect();
let filtered_keys: Vec<&[f32]> = filtered.iter().map(|(i, _)| keys[*i]).collect();
let filtered_values: Vec<&[f32]> = filtered.iter().map(|(i, _)| values[*i]).collect();
self.compute(query, &filtered_keys, &filtered_values)
} else {
self.compute(query, keys, values)
}
}
fn dim(&self) -> usize {
self.dim
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_linear_attention() {
let attention = LinearAttention::new(64, 32);
let query = vec![0.5; 64];
let keys: Vec<Vec<f32>> = (0..100).map(|_| vec![0.3; 64]).collect();
let values: Vec<Vec<f32>> = (0..100).map(|_| vec![1.0; 64]).collect();
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
let result = attention.compute(&query, &keys_refs, &values_refs).unwrap();
assert_eq!(result.len(), 64);
}
#[test]
fn test_kernel_types() {
for kernel in [KernelType::Softmax, KernelType::ReLU, KernelType::ELU] {
let attention = LinearAttention::with_kernel(32, 16, kernel);
let query = vec![1.0; 32];
let keys = vec![vec![0.5; 32]; 10];
let values = vec![vec![1.0; 32]; 10];
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
let result = attention.compute(&query, &keys_refs, &values_refs).unwrap();
assert_eq!(result.len(), 32);
}
}
}