use crate::{ModelError, ModelResult};
use scirs2_core::random::{rng, RngExt};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum FeatureMap {
#[default]
EluPlus1,
Relu,
Identity,
}
impl FeatureMap {
#[inline]
pub fn apply(&self, x: &[f32]) -> Vec<f32> {
match self {
FeatureMap::EluPlus1 => x
.iter()
.map(|&v| if v >= 0.0 { v + 1.0 } else { v.exp() })
.collect(),
FeatureMap::Relu => x.iter().map(|&v| v.max(0.0)).collect(),
FeatureMap::Identity => x.to_vec(),
}
}
}
#[derive(Debug, Clone)]
pub struct FlashLinearAttnConfig {
pub d_model: usize,
pub n_heads: usize,
pub d_head: usize,
pub chunk_size: usize,
pub feature_map: FeatureMap,
pub eps: f32,
pub causal: bool,
}
impl FlashLinearAttnConfig {
pub fn new(d_model: usize, n_heads: usize) -> ModelResult<Self> {
if !d_model.is_multiple_of(n_heads) {
return Err(ModelError::invalid_config(format!(
"d_model={d_model} not divisible by n_heads={n_heads}"
)));
}
Ok(Self {
d_model,
n_heads,
d_head: d_model / n_heads,
chunk_size: 64,
feature_map: FeatureMap::EluPlus1,
eps: 1e-6,
causal: true,
})
}
}
#[derive(Debug, Clone)]
pub struct LinearAttnState {
pub s: Vec<f32>,
pub z: Vec<f32>,
pub d_head: usize,
}
impl LinearAttnState {
pub fn new(d_head: usize) -> Self {
Self {
s: vec![0.0_f32; d_head * d_head],
z: vec![0.0_f32; d_head],
d_head,
}
}
#[inline]
pub fn reset(&mut self) {
self.s.fill(0.0);
self.z.fill(0.0);
}
#[inline]
pub fn update(&mut self, phi_k: &[f32], v: &[f32]) {
let d = self.d_head;
debug_assert_eq!(phi_k.len(), d);
debug_assert_eq!(v.len(), d);
for (i, &pk_i) in phi_k.iter().enumerate() {
let row_start = i * d;
for (j, &vj) in v.iter().enumerate() {
self.s[row_start + j] += pk_i * vj;
}
}
for (zi, &pk_i) in self.z.iter_mut().zip(phi_k.iter()) {
*zi += pk_i;
}
}
#[inline]
pub fn query(&self, phi_q: &[f32], eps: f32) -> Vec<f32> {
let d = self.d_head;
debug_assert_eq!(phi_q.len(), d);
let mut num = vec![0.0_f32; d];
for (i, &pq_i) in phi_q.iter().enumerate() {
if pq_i == 0.0 {
continue;
}
let row_start = i * d;
for (nj, &sij) in num.iter_mut().zip(self.s[row_start..row_start + d].iter()) {
*nj += pq_i * sij;
}
}
let denom: f32 = phi_q
.iter()
.zip(self.z.iter())
.map(|(&a, &b)| a * b)
.sum::<f32>()
+ eps;
num.iter().map(|&v| v / denom).collect()
}
}
pub struct FlashLinearAttention {
config: FlashLinearAttnConfig,
w_q: Vec<f32>,
w_k: Vec<f32>,
w_v: Vec<f32>,
w_o: Vec<f32>,
states: Vec<LinearAttnState>,
}
impl FlashLinearAttention {
pub fn new(config: FlashLinearAttnConfig) -> ModelResult<Self> {
let d = config.d_model;
let limit = (3.0_f32 / d as f32).sqrt();
let mut rng = rng();
let mut xavier_init = |size: usize| -> Vec<f32> {
(0..size)
.map(|_| (rng.random::<f32>() * 2.0 - 1.0) * limit)
.collect()
};
let states: Vec<LinearAttnState> = (0..config.n_heads)
.map(|_| LinearAttnState::new(config.d_head))
.collect();
Ok(Self {
w_q: xavier_init(d * d),
w_k: xavier_init(d * d),
w_v: xavier_init(d * d),
w_o: xavier_init(d * d),
states,
config,
})
}
pub fn forward(&mut self, input: &[f32], seq_len: usize) -> ModelResult<Vec<f32>> {
let d = self.config.d_model;
if input.len() != seq_len * d {
return Err(ModelError::invalid_config(format!(
"forward: input length {} != seq_len*d_model={}",
input.len(),
seq_len * d
)));
}
let q = matmul(input, seq_len, d, &self.w_q, d)?;
let k = matmul(input, seq_len, d, &self.w_k, d)?;
let v = matmul(input, seq_len, d, &self.w_v, d)?;
let mut pre_out = vec![0.0_f32; seq_len * d];
let n_heads = self.config.n_heads;
let d_head = self.config.d_head;
let eps = self.config.eps;
let feature_map = self.config.feature_map;
let causal = self.config.causal;
for h in 0..n_heads {
self.states[h].reset();
let head_start = h * d_head;
let head_end = head_start + d_head;
for t in 0..seq_len {
let tok_off = t * d;
let q_t = &q[tok_off + head_start..tok_off + head_end];
let k_t = &k[tok_off + head_start..tok_off + head_end];
let v_t = &v[tok_off + head_start..tok_off + head_end];
let phi_k = feature_map.apply(k_t);
let phi_q = feature_map.apply(q_t);
if causal {
let out_h = self.states[h].query(&phi_q, eps);
self.states[h].update(&phi_k, v_t);
let dst = tok_off + head_start;
for (j, &val) in out_h.iter().enumerate() {
pre_out[dst + j] = val;
}
} else {
self.states[h].update(&phi_k, v_t);
let out_h = self.states[h].query(&phi_q, eps);
let dst = tok_off + head_start;
for (j, &val) in out_h.iter().enumerate() {
pre_out[dst + j] = val;
}
}
}
}
matmul(&pre_out, seq_len, d, &self.w_o, d)
}
pub fn step(&mut self, x: &[f32]) -> ModelResult<Vec<f32>> {
let d = self.config.d_model;
if x.len() != d {
return Err(ModelError::invalid_config(format!(
"step: input length {} != d_model={d}",
x.len()
)));
}
let q = matmul(x, 1, d, &self.w_q, d)?;
let k = matmul(x, 1, d, &self.w_k, d)?;
let v = matmul(x, 1, d, &self.w_v, d)?;
let n_heads = self.config.n_heads;
let d_head = self.config.d_head;
let eps = self.config.eps;
let feature_map = self.config.feature_map;
let mut pre_out = vec![0.0_f32; d];
for h in 0..n_heads {
let head_start = h * d_head;
let head_end = head_start + d_head;
let q_h = &q[head_start..head_end];
let k_h = &k[head_start..head_end];
let v_h = &v[head_start..head_end];
let phi_q = feature_map.apply(q_h);
let phi_k = feature_map.apply(k_h);
let out_h = self.states[h].query(&phi_q, eps);
self.states[h].update(&phi_k, v_h);
for (j, &val) in out_h.iter().enumerate() {
pre_out[head_start + j] = val;
}
}
matmul(&pre_out, 1, d, &self.w_o, d)
}
pub fn reset_states(&mut self) {
for s in &mut self.states {
s.reset();
}
}
pub fn load_weights(
&mut self,
weights: &HashMap<String, Vec<f32>>,
prefix: &str,
) -> ModelResult<()> {
let d2 = self.config.d_model * self.config.d_model;
let fetch = |key: &str| -> ModelResult<Vec<f32>> {
let full_key = format!("{prefix}{key}");
let src = weights.get(&full_key).ok_or_else(|| {
ModelError::invalid_config(format!("load_weights: missing key '{full_key}'"))
})?;
if src.len() != d2 {
return Err(ModelError::invalid_config(format!(
"load_weights: key '{full_key}' has length {} but expected {d2}",
src.len()
)));
}
Ok(src.clone())
};
self.w_q = fetch("w_q")?;
self.w_k = fetch("w_k")?;
self.w_v = fetch("w_v")?;
self.w_o = fetch("w_o")?;
Ok(())
}
#[inline]
pub fn config(&self) -> &FlashLinearAttnConfig {
&self.config
}
#[inline]
pub fn states(&self) -> &[LinearAttnState] {
&self.states
}
#[inline]
pub fn states_mut(&mut self) -> &mut [LinearAttnState] {
&mut self.states
}
}
pub fn chunked_forward(
attn: &mut FlashLinearAttention,
input: &[f32],
seq_len: usize,
) -> ModelResult<Vec<f32>> {
let d = attn.config().d_model;
let chunk_size = attn.config().chunk_size;
if input.len() != seq_len * d {
return Err(ModelError::invalid_config(format!(
"chunked_forward: input length {} != seq_len*d_model={}",
input.len(),
seq_len * d
)));
}
attn.reset_states();
let mut output = vec![0.0_f32; seq_len * d];
let mut t = 0usize;
while t < seq_len {
let end = (t + chunk_size).min(seq_len);
let chunk_len = end - t;
let chunk_input = &input[t * d..end * d];
for i in 0..chunk_len {
let token = &chunk_input[i * d..(i + 1) * d];
let tok_out = attn.step(token)?;
let dst = (t + i) * d;
output[dst..dst + d].copy_from_slice(&tok_out);
}
t = end;
}
Ok(output)
}
fn matmul(a: &[f32], m: usize, k: usize, b: &[f32], n: usize) -> ModelResult<Vec<f32>> {
if a.len() != m * k {
return Err(ModelError::invalid_config(format!(
"matmul: A length {} != m*k={}",
a.len(),
m * k
)));
}
if b.len() != k * n {
return Err(ModelError::invalid_config(format!(
"matmul: B length {} != k*n={}",
b.len(),
k * n
)));
}
let mut out = vec![0.0_f32; m * n];
for i in 0..m {
for l in 0..k {
let a_il = a[i * k + l];
if a_il == 0.0 {
continue;
}
let out_row = i * n;
let b_row = l * n;
for j in 0..n {
out[out_row + j] += a_il * b[b_row + j];
}
}
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_feature_map_elu_plus1_nonnegative() {
let phi = FeatureMap::EluPlus1;
let vals = vec![-2.0_f32, -1.0, 0.0, 1.0, 2.0];
let mapped = phi.apply(&vals);
for v in &mapped {
assert!(*v > 0.0, "ELU+1 should be positive, got {v}");
}
}
#[test]
fn test_feature_map_relu_nonnegative() {
let phi = FeatureMap::Relu;
let vals = vec![-1.0_f32, 0.0, 1.0];
let mapped = phi.apply(&vals);
assert_eq!(mapped, vec![0.0, 0.0, 1.0]);
}
#[test]
fn test_feature_map_identity() {
let phi = FeatureMap::Identity;
let vals = vec![-1.0_f32, 0.0, 1.0];
let mapped = phi.apply(&vals);
assert_eq!(mapped, vals);
}
#[test]
fn test_feature_map_elu_plus1_positive_inputs() {
let phi = FeatureMap::EluPlus1;
let vals = vec![0.0_f32, 1.0, 2.0];
let mapped = phi.apply(&vals);
assert!((mapped[0] - 1.0).abs() < 1e-6);
assert!((mapped[1] - 2.0).abs() < 1e-6);
assert!((mapped[2] - 3.0).abs() < 1e-6);
}
#[test]
fn test_feature_map_elu_plus1_negative_inputs() {
let phi = FeatureMap::EluPlus1;
let vals = vec![-1.0_f32];
let mapped = phi.apply(&vals);
let expected = (-1.0_f32).exp();
assert!((mapped[0] - expected).abs() < 1e-6);
}
#[test]
fn test_linear_attn_state_update_query() {
let mut state = LinearAttnState::new(4);
state.update(&[1.0_f32, 0.0, 0.0, 0.0], &[1.0, 2.0, 3.0, 4.0]);
let out = state.query(&[1.0_f32, 0.0, 0.0, 0.0], 1e-6);
assert!((out[0] - 1.0).abs() < 0.01, "out[0]={}", out[0]);
assert!((out[1] - 2.0).abs() < 0.01, "out[1]={}", out[1]);
assert!((out[2] - 3.0).abs() < 0.01, "out[2]={}", out[2]);
assert!((out[3] - 4.0).abs() < 0.01, "out[3]={}", out[3]);
}
#[test]
fn test_linear_attn_state_reset() {
let mut state = LinearAttnState::new(2);
state.update(&[1.0_f32, 1.0], &[1.0, 1.0]);
state.reset();
let out = state.query(&[1.0_f32, 1.0], 1e-6);
assert!(out[0].is_finite(), "output should be finite after reset");
assert!(out[1].is_finite(), "output should be finite after reset");
}
#[test]
fn test_linear_attn_state_accumulates() {
let mut state = LinearAttnState::new(2);
state.update(&[1.0_f32, 0.0], &[2.0, 3.0]);
state.update(&[0.0_f32, 1.0], &[4.0, 5.0]);
let out = state.query(&[1.0_f32, 0.0], 1e-6);
assert!((out[0] - 2.0).abs() < 0.01, "out[0]={}", out[0]);
assert!((out[1] - 3.0).abs() < 0.01, "out[1]={}", out[1]);
}
#[test]
fn test_flash_linear_attn_config_new() {
let cfg = FlashLinearAttnConfig::new(64, 4).unwrap();
assert_eq!(cfg.d_head, 16);
assert_eq!(cfg.chunk_size, 64);
assert_eq!(cfg.n_heads, 4);
assert_eq!(cfg.d_model, 64);
assert_eq!(cfg.eps, 1e-6);
assert!(cfg.causal);
}
#[test]
fn test_flash_linear_attn_config_invalid() {
let result = FlashLinearAttnConfig::new(65, 4);
assert!(result.is_err(), "65 not divisible by 4 should error");
}
#[test]
fn test_flash_linear_attn_config_single_head() {
let cfg = FlashLinearAttnConfig::new(16, 1).unwrap();
assert_eq!(cfg.d_head, 16);
}
#[test]
fn test_flash_linear_attention_forward_shape() {
let config = FlashLinearAttnConfig::new(8, 2).unwrap();
let mut attn = FlashLinearAttention::new(config).unwrap();
let seq_len = 4;
let d_model = 8;
let input = vec![0.1_f32; seq_len * d_model];
let output = attn.forward(&input, seq_len).unwrap();
assert_eq!(output.len(), seq_len * d_model);
}
#[test]
fn test_flash_linear_attention_finite_output() {
let config = FlashLinearAttnConfig::new(8, 2).unwrap();
let mut attn = FlashLinearAttention::new(config).unwrap();
let input: Vec<f32> = (0..32).map(|i| i as f32 * 0.01).collect();
let output = attn.forward(&input, 4).unwrap();
for &v in &output {
assert!(v.is_finite(), "output should be finite, got {v}");
}
}
#[test]
fn test_flash_linear_attention_step_matches_forward() {
let config = FlashLinearAttnConfig {
causal: true,
..FlashLinearAttnConfig::new(8, 2).unwrap()
};
let mut attn = FlashLinearAttention::new(config).unwrap();
let d = 8;
let x = vec![0.5_f32; d];
let out = attn.step(&x).unwrap();
assert_eq!(out.len(), d);
for &v in &out {
assert!(v.is_finite(), "step output should be finite, got {v}");
}
}
#[test]
fn test_flash_linear_attention_invalid_input_length() {
let config = FlashLinearAttnConfig::new(8, 2).unwrap();
let mut attn = FlashLinearAttention::new(config).unwrap();
let input = vec![0.1_f32; 24];
let result = attn.forward(&input, 4);
assert!(result.is_err());
}
#[test]
fn test_flash_linear_attention_step_invalid_length() {
let config = FlashLinearAttnConfig::new(8, 2).unwrap();
let mut attn = FlashLinearAttention::new(config).unwrap();
let result = attn.step(&[0.1_f32; 5]); assert!(result.is_err());
}
#[test]
fn test_flash_linear_attention_reset_states() {
let config = FlashLinearAttnConfig::new(8, 2).unwrap();
let mut attn = FlashLinearAttention::new(config).unwrap();
let _ = attn.step(&[1.0_f32; 8]).unwrap();
attn.reset_states();
for state in attn.states() {
assert!(state.s.iter().all(|&v| v == 0.0));
assert!(state.z.iter().all(|&v| v == 0.0));
}
}
#[test]
fn test_flash_linear_attention_load_weights_missing_key() {
let config = FlashLinearAttnConfig::new(8, 2).unwrap();
let mut attn = FlashLinearAttention::new(config).unwrap();
let weights = HashMap::new();
let result = attn.load_weights(&weights, "layer.");
assert!(result.is_err());
}
#[test]
fn test_flash_linear_attention_load_weights_wrong_size() {
let config = FlashLinearAttnConfig::new(8, 2).unwrap();
let mut attn = FlashLinearAttention::new(config).unwrap();
let mut weights = HashMap::new();
weights.insert("layer.w_q".to_string(), vec![0.0_f32; 10]);
weights.insert("layer.w_k".to_string(), vec![0.0_f32; 64]);
weights.insert("layer.w_v".to_string(), vec![0.0_f32; 64]);
weights.insert("layer.w_o".to_string(), vec![0.0_f32; 64]);
let result = attn.load_weights(&weights, "layer.");
assert!(result.is_err());
}
#[test]
fn test_flash_linear_attention_load_weights_ok() {
let config = FlashLinearAttnConfig::new(8, 2).unwrap();
let mut attn = FlashLinearAttention::new(config).unwrap();
let d2 = 64_usize;
let mut weights = HashMap::new();
weights.insert("attn.w_q".to_string(), vec![0.1_f32; d2]);
weights.insert("attn.w_k".to_string(), vec![0.2_f32; d2]);
weights.insert("attn.w_v".to_string(), vec![0.3_f32; d2]);
weights.insert("attn.w_o".to_string(), vec![0.4_f32; d2]);
attn.load_weights(&weights, "attn.").unwrap();
assert_eq!(attn.w_q[0], 0.1);
assert_eq!(attn.w_k[0], 0.2);
assert_eq!(attn.w_v[0], 0.3);
assert_eq!(attn.w_o[0], 0.4);
}
#[test]
fn test_chunked_forward_shape() {
let config = FlashLinearAttnConfig::new(8, 2).unwrap();
let mut attn = FlashLinearAttention::new(config).unwrap();
let seq_len = 16;
let d = 8;
let input: Vec<f32> = (0..seq_len * d).map(|i| i as f32 * 0.001).collect();
let out = chunked_forward(&mut attn, &input, seq_len).unwrap();
assert_eq!(out.len(), seq_len * d);
for &v in &out {
assert!(v.is_finite(), "chunked output must be finite");
}
}
#[test]
fn test_matmul_identity() {
let identity = vec![1.0_f32, 0.0, 0.0, 1.0];
let a = vec![1.0_f32, 2.0, 3.0, 4.0];
let out = matmul(&identity, 2, 2, &a, 2).unwrap();
assert!((out[0] - 1.0).abs() < 1e-6);
assert!((out[1] - 2.0).abs() < 1e-6);
assert!((out[2] - 3.0).abs() < 1e-6);
assert!((out[3] - 4.0).abs() < 1e-6);
}
#[test]
fn test_matmul_size_mismatch() {
let a = vec![1.0_f32; 4];
let b = vec![1.0_f32; 9];
let result = matmul(&a, 2, 2, &b, 3);
assert!(result.is_err());
}
}