use alloc::vec;
use super::state::AttentionState;
use crate::math;
pub fn additive_update(state: &mut AttentionState, k: &[f64], v: &[f64], decay: f64) {
state.scale(decay);
state.add_outer_product(k, v);
}
pub fn additive_update_vec(state: &mut AttentionState, k: &[f64], v: &[f64], alpha: &[f64]) {
state.scale_per_row(alpha);
state.add_outer_product(k, v);
}
pub fn delta_update(state: &mut AttentionState, k: &[f64], v: &[f64]) {
let pred = state.query(k);
let d_v = v.len();
let mut error = vec![0.0; d_v];
for j in 0..d_v {
error[j] = v[j] - pred[j];
}
state.add_outer_product(k, &error);
}
pub fn gated_delta_update(
state: &mut AttentionState,
k: &[f64],
v: &[f64],
decay: f64,
beta_scale: f64,
) {
state.scale(decay);
let d_k = k.len();
let norm_sq: f64 = k.iter().map(|&x| x * x).sum();
let norm = math::sqrt(norm_sq);
let k_norm: alloc::vec::Vec<f64> = if norm < 1e-12 {
vec![0.0; d_k]
} else {
let inv = 1.0 / norm;
k.iter().map(|&x| x * inv).collect()
};
let pred = state.query(&k_norm);
let d_v = v.len();
let mut error = vec![0.0; d_v];
for j in 0..d_v {
error[j] = beta_scale * (v[j] - pred[j]);
}
state.add_outer_product(&k_norm, &error);
}
pub fn exponential_update(state: &mut AttentionState, k: &[f64], v: &[f64], w: f64) {
let decay = math::exp(-w);
state.scale(decay);
let d_k = k.len();
let mut exp_k = vec![0.0; d_k];
for i in 0..d_k {
exp_k[i] = math::exp(k[i]);
}
state.add_outer_product(&exp_k, v);
}
pub fn hawk_update(state: &mut AttentionState, x: &[f64], alpha: &[f64], beta: &[f64]) {
match state {
AttentionState::Vector(h) => {
debug_assert_eq!(h.len(), x.len(), "state and input must have same length");
debug_assert_eq!(
h.len(),
alpha.len(),
"state and alpha must have same length"
);
debug_assert_eq!(h.len(), beta.len(), "state and beta must have same length");
for i in 0..h.len() {
h[i] = alpha[i] * h[i] + beta[i] * x[i];
}
}
AttentionState::Matrix { .. } => panic!("hawk_update requires Vector state"),
}
}
pub fn mlstm_update(state: &mut AttentionState, k: &[f64], v: &[f64], forget: f64, input: f64) {
state.scale(forget);
let _d_k = k.len();
let d_v = v.len();
let mut scaled_v = vec![0.0; d_v];
for (j, sv) in scaled_v.iter_mut().enumerate() {
*sv = input * v[j];
}
state.add_outer_product(k, &scaled_v);
}
pub fn delta_product_update(
state: &mut AttentionState,
keys: &[&[f64]],
values: &[&[f64]],
betas: &[f64],
gate: f64,
) {
let n = betas.len();
debug_assert_eq!(keys.len(), n, "keys length must match n_compositions");
debug_assert_eq!(values.len(), n, "values length must match n_compositions");
state.scale(gate);
for j in 0..n {
let pred = state.query(keys[j]);
let d_v = values[j].len();
let mut error = vec![0.0; d_v];
for idx in 0..d_v {
error[idx] = betas[j] * (values[j][idx] - pred[idx]);
}
state.add_outer_product(keys[j], &error);
}
}
pub fn rwkv7_update(
state: &mut AttentionState,
w: &[f64],
kappa_hat: &[f64],
a: &[f64],
k_tilde: &[f64],
v: &[f64],
) {
state.scale_per_row(w);
let proj = state.query(kappa_hat);
let d_k = kappa_hat.len();
let mut a_kappa = vec![0.0; d_k];
for i in 0..d_k {
a_kappa[i] = -(a[i] * kappa_hat[i]); }
state.add_outer_product(&a_kappa, &proj);
state.add_outer_product(k_tilde, v);
}
pub fn hgrn2_update(state: &mut AttentionState, k: &[f64], v: &[f64], alpha: &[f64]) {
state.scale_per_row(alpha);
state.add_outer_product(k, v);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn additive_update_from_zero_state() {
let mut state = AttentionState::new_matrix(2, 3);
let k = [1.0, 2.0];
let v = [3.0, 4.0, 5.0];
additive_update(&mut state, &k, &v, 0.9);
assert!(
(state.get_matrix(0, 0) - 3.0).abs() < 1e-12,
"S[0][0] should be 1*3=3, got {}",
state.get_matrix(0, 0)
);
assert!(
(state.get_matrix(1, 2) - 10.0).abs() < 1e-12,
"S[1][2] should be 2*5=10, got {}",
state.get_matrix(1, 2)
);
}
#[test]
fn additive_update_decay_applied() {
let mut state = AttentionState::new_matrix(2, 2);
state.set_matrix(0, 0, 10.0);
state.set_matrix(1, 1, 20.0);
let k = [0.0, 0.0];
let v = [0.0, 0.0];
additive_update(&mut state, &k, &v, 0.5);
assert!(
(state.get_matrix(0, 0) - 5.0).abs() < 1e-12,
"decayed S[0][0] should be 10*0.5=5, got {}",
state.get_matrix(0, 0)
);
assert!(
(state.get_matrix(1, 1) - 10.0).abs() < 1e-12,
"decayed S[1][1] should be 20*0.5=10, got {}",
state.get_matrix(1, 1)
);
}
#[test]
fn delta_update_error_corrective() {
let mut state = AttentionState::new_matrix(2, 2);
let k = [1.0, 0.0];
let v = [5.0, 3.0];
delta_update(&mut state, &k, &v);
let out = state.query(&k);
assert!(
(out[0] - 5.0).abs() < 1e-12,
"after delta write, read-back should be ~5.0, got {}",
out[0]
);
assert!(
(out[1] - 3.0).abs() < 1e-12,
"after delta write, read-back should be ~3.0, got {}",
out[1]
);
}
#[test]
fn delta_update_corrects_existing() {
let mut state = AttentionState::new_matrix(2, 2);
let k = [1.0, 0.0];
let v1 = [5.0, 3.0];
delta_update(&mut state, &k, &v1);
let v2 = [10.0, 7.0];
delta_update(&mut state, &k, &v2);
let out = state.query(&k);
assert!(
(out[0] - 10.0).abs() < 1e-12,
"after second delta write, should read 10.0, got {}",
out[0]
);
assert!(
(out[1] - 7.0).abs() < 1e-12,
"after second delta write, should read 7.0, got {}",
out[1]
);
}
#[test]
fn gated_delta_update_combines_decay_and_correction() {
let mut state = AttentionState::new_matrix(2, 2);
state.set_matrix(0, 0, 100.0);
let k = [1.0, 0.0];
let v = [5.0, 3.0];
gated_delta_update(&mut state, &k, &v, 0.0, 1.0);
let out = state.query(&k);
assert!(
(out[0] - 5.0).abs() < 1e-12,
"with decay=0, should read fresh value 5.0, got {}",
out[0]
);
}
#[test]
fn exponential_update_changes_state() {
let mut state = AttentionState::new_matrix(2, 3);
let k = [0.1, -0.1];
let v = [1.0, 2.0, 3.0];
exponential_update(&mut state, &k, &v, 0.5);
let s = state.as_slice();
let sum: f64 = s.iter().map(|&x| if x < 0.0 { -x } else { x }).sum();
assert!(
sum > 0.0,
"state should be non-zero after exponential update"
);
}
#[test]
fn exponential_update_exp_k_applied() {
let mut state = AttentionState::new_matrix(1, 1);
let k = [0.0]; let v = [7.0];
exponential_update(&mut state, &k, &v, 0.0);
assert!(
(state.get_matrix(0, 0) - 7.0).abs() < 1e-12,
"with w=0 and k=0, state should be exp(0)*7=7, got {}",
state.get_matrix(0, 0)
);
}
#[test]
fn hawk_update_vector_recurrence() {
let mut state = AttentionState::new_vector(3);
let x = [1.0, 2.0, 3.0];
let alpha = [0.9, 0.8, 0.7];
let beta = [0.1, 0.2, 0.3];
hawk_update(&mut state, &x, &alpha, &beta);
let s = state.as_slice();
assert!(
(s[0] - 0.1).abs() < 1e-12,
"h[0] should be 0.1*1=0.1, got {}",
s[0]
);
assert!(
(s[1] - 0.4).abs() < 1e-12,
"h[1] should be 0.2*2=0.4, got {}",
s[1]
);
assert!(
(s[2] - 0.9).abs() < 1e-12,
"h[2] should be 0.3*3=0.9, got {}",
s[2]
);
}
#[test]
fn hawk_update_accumulates() {
let mut state = AttentionState::new_vector(2);
let alpha = [0.5, 0.5];
let beta = [1.0, 1.0];
hawk_update(&mut state, &[2.0, 4.0], &alpha, &beta);
hawk_update(&mut state, &[1.0, 1.0], &alpha, &beta);
let s = state.as_slice();
assert!(
(s[0] - 2.0).abs() < 1e-12,
"h[0] should be 2.0, got {}",
s[0]
);
assert!(
(s[1] - 3.0).abs() < 1e-12,
"h[1] should be 3.0, got {}",
s[1]
);
}
#[test]
fn mlstm_update_from_zero() {
let mut state = AttentionState::new_matrix(2, 2);
let k = [1.0, 0.0];
let v = [5.0, 3.0];
mlstm_update(&mut state, &k, &v, 0.9, 0.8);
assert!(
(state.get_matrix(0, 0) - 4.0).abs() < 1e-12,
"S[0][0] should be 0.8*5*1=4.0, got {}",
state.get_matrix(0, 0)
);
assert!(
(state.get_matrix(0, 1) - 2.4).abs() < 1e-12,
"S[0][1] should be 0.8*3*1=2.4, got {}",
state.get_matrix(0, 1)
);
assert!(
state.get_matrix(1, 0).abs() < 1e-12,
"S[1][0] should be 0.8*5*0=0, got {}",
state.get_matrix(1, 0)
);
}
#[test]
fn mlstm_forget_gate_decays_state() {
let mut state = AttentionState::new_matrix(2, 2);
state.set_matrix(0, 0, 10.0);
state.set_matrix(1, 1, 20.0);
let k = [0.0, 0.0];
let v = [0.0, 0.0];
mlstm_update(&mut state, &k, &v, 0.5, 1.0);
assert!(
(state.get_matrix(0, 0) - 5.0).abs() < 1e-12,
"forget gate 0.5 should halve state, got {}",
state.get_matrix(0, 0)
);
}
#[test]
fn delta_product_single_step_matches_delta() {
let mut state1 = AttentionState::new_matrix(2, 2);
let mut state2 = AttentionState::new_matrix(2, 2);
let k = [0.6, 0.8]; let v = [5.0, 3.0];
delta_update(&mut state1, &k, &v);
delta_product_update(&mut state2, &[&k[..]], &[&v[..]], &[1.0], 1.0);
let s1 = state1.as_slice();
let s2 = state2.as_slice();
for i in 0..s1.len() {
assert!(
(s1[i] - s2[i]).abs() < 1e-12,
"single-step DeltaProduct should match DeltaNet at {}: {} vs {}",
i,
s1[i],
s2[i]
);
}
}
#[test]
fn delta_product_multi_step_changes_state() {
let mut state = AttentionState::new_matrix(2, 2);
let k1 = [1.0, 0.0];
let k2 = [0.0, 1.0];
let v1 = [3.0, 4.0];
let v2 = [5.0, 6.0];
delta_product_update(
&mut state,
&[&k1[..], &k2[..]],
&[&v1[..], &v2[..]],
&[1.0, 1.0],
1.0,
);
let s = state.as_slice();
let sum: f64 = s.iter().map(|x| math::abs(*x)).sum();
assert!(sum > 0.0, "multi-step should produce non-zero state");
}
#[test]
fn delta_product_gate_decays_state() {
let mut state = AttentionState::new_matrix(2, 2);
state.set_matrix(0, 0, 10.0);
state.set_matrix(1, 1, 20.0);
let k = [1.0, 0.0];
let v = [0.0, 0.0];
delta_product_update(&mut state, &[&k[..]], &[&v[..]], &[1.0], 0.5);
assert!(
state.get_matrix(0, 0).abs() < 1e-12,
"gated delta should correct to target value 0"
);
}
#[test]
fn delta_product_beta_two_reflects() {
let mut state = AttentionState::new_matrix(2, 2);
state.set_matrix(0, 0, 10.0);
let k = [1.0, 0.0];
let v = [0.0, 0.0];
delta_product_update(&mut state, &[&k[..]], &[&v[..]], &[2.0], 1.0);
assert!(
(state.get_matrix(0, 0) - (-10.0)).abs() < 1e-12,
"beta=2 should reflect: got {}",
state.get_matrix(0, 0)
);
}
#[test]
fn rwkv7_update_from_zero() {
let mut state = AttentionState::new_matrix(2, 2);
let w = [0.9, 0.8];
let kappa_hat = [1.0, 0.0]; let a = [0.5, 0.5];
let k_tilde = [0.6, 0.8];
let v = [3.0, 7.0];
rwkv7_update(&mut state, &w, &kappa_hat, &a, &k_tilde, &v);
assert!(
(state.get_matrix(0, 0) - 1.8).abs() < 1e-12,
"from zero, S[0][0] = 0.6*3 = 1.8, got {}",
state.get_matrix(0, 0)
);
assert!(
(state.get_matrix(1, 1) - 5.6).abs() < 1e-12,
"from zero, S[1][1] = 0.8*7 = 5.6, got {}",
state.get_matrix(1, 1)
);
}
#[test]
fn rwkv7_decay_per_dimension() {
let mut state = AttentionState::new_matrix(2, 2);
state.set_matrix(0, 0, 10.0);
state.set_matrix(1, 1, 10.0);
let w = [0.5, 0.9]; let kappa_hat = [0.0, 0.0]; let a = [0.0, 0.0]; let k_tilde = [0.0, 0.0]; let v = [0.0, 0.0];
rwkv7_update(&mut state, &w, &kappa_hat, &a, &k_tilde, &v);
assert!(
(state.get_matrix(0, 0) - 5.0).abs() < 1e-12,
"row 0 decayed by 0.5: 10*0.5=5, got {}",
state.get_matrix(0, 0)
);
assert!(
(state.get_matrix(1, 1) - 9.0).abs() < 1e-12,
"row 1 decayed by 0.9: 10*0.9=9, got {}",
state.get_matrix(1, 1)
);
}
#[test]
fn rwkv7_delta_removal() {
let mut state = AttentionState::new_matrix(2, 2);
state.set_matrix(0, 0, 5.0);
state.set_matrix(0, 1, 3.0);
let w = [1.0, 1.0]; let kappa_hat = [1.0, 0.0]; let a = [1.0, 1.0]; let k_tilde = [0.0, 0.0]; let v = [0.0, 0.0];
rwkv7_update(&mut state, &w, &kappa_hat, &a, &k_tilde, &v);
assert!(
state.get_matrix(0, 0).abs() < 1e-12,
"full removal should clear row 0, got {}",
state.get_matrix(0, 0)
);
assert!(
state.get_matrix(0, 1).abs() < 1e-12,
"full removal should clear row 0, got {}",
state.get_matrix(0, 1)
);
}
#[test]
fn rwkv7_combined_remove_and_write() {
let mut state = AttentionState::new_matrix(2, 2);
state.set_matrix(0, 0, 10.0);
state.set_matrix(0, 1, 20.0);
let w = [1.0, 1.0]; let kappa_hat = [1.0, 0.0]; let a = [1.0, 1.0]; let k_tilde = [0.0, 1.0]; let v = [5.0, 3.0];
rwkv7_update(&mut state, &w, &kappa_hat, &a, &k_tilde, &v);
assert!(
state.get_matrix(0, 0).abs() < 1e-12,
"removed association should be cleared"
);
assert!(
(state.get_matrix(1, 0) - 5.0).abs() < 1e-12,
"new association written at [0,1] -> [5,3]"
);
}
#[test]
fn all_updates_change_state_from_zero() {
let k = [1.0, 0.5];
let v = [2.0, 3.0];
let x = [1.0, 2.0];
let alpha = [0.9, 0.8];
let beta = [0.1, 0.2];
let mut s1 = AttentionState::new_matrix(2, 2);
additive_update(&mut s1, &k, &v, 0.9);
let sum1: f64 = s1.as_slice().iter().map(|x| math::abs(*x)).sum();
assert!(sum1 > 0.0, "additive_update should change state");
let mut s2 = AttentionState::new_matrix(2, 2);
delta_update(&mut s2, &k, &v);
let sum2: f64 = s2.as_slice().iter().map(|x| math::abs(*x)).sum();
assert!(sum2 > 0.0, "delta_update should change state");
let mut s3 = AttentionState::new_matrix(2, 2);
gated_delta_update(&mut s3, &k, &v, 0.9, 1.0);
let sum3: f64 = s3.as_slice().iter().map(|x| math::abs(*x)).sum();
assert!(sum3 > 0.0, "gated_delta_update should change state");
let mut s4 = AttentionState::new_matrix(2, 2);
exponential_update(&mut s4, &k, &v, 0.5);
let sum4: f64 = s4.as_slice().iter().map(|x| math::abs(*x)).sum();
assert!(sum4 > 0.0, "exponential_update should change state");
let mut s5 = AttentionState::new_vector(2);
hawk_update(&mut s5, &x, &alpha, &beta);
let sum5: f64 = s5.as_slice().iter().map(|x| math::abs(*x)).sum();
assert!(sum5 > 0.0, "hawk_update should change state");
let mut s6 = AttentionState::new_matrix(2, 2);
mlstm_update(&mut s6, &k, &v, 0.9, 0.8);
let sum6: f64 = s6.as_slice().iter().map(|x| math::abs(*x)).sum();
assert!(sum6 > 0.0, "mlstm_update should change state");
let mut s7 = AttentionState::new_matrix(2, 2);
delta_product_update(&mut s7, &[&k[..]], &[&v[..]], &[1.0], 1.0);
let sum7: f64 = s7.as_slice().iter().map(|x| math::abs(*x)).sum();
assert!(sum7 > 0.0, "delta_product_update should change state");
let mut s8 = AttentionState::new_matrix(2, 2);
rwkv7_update(&mut s8, &[0.9, 0.8], &k, &[0.5, 0.5], &k, &v);
let sum8: f64 = s8.as_slice().iter().map(|x| math::abs(*x)).sum();
assert!(sum8 > 0.0, "rwkv7_update should change state");
let mut s9 = AttentionState::new_matrix(2, 2);
hgrn2_update(&mut s9, &k, &v, &[0.95, 0.9]);
let sum9: f64 = s9.as_slice().iter().map(|x| math::abs(*x)).sum();
assert!(sum9 > 0.0, "hgrn2_update should change state");
}
#[test]
fn gated_delta_net_beta_scale_default_matches_original() {
let mut state1 = AttentionState::new_matrix(2, 2);
let mut state2 = AttentionState::new_matrix(2, 2);
let k = [0.6, 0.8]; let v = [5.0, 3.0];
let decay = 0.9;
state1.scale(decay);
let pred = state1.query(&k);
let mut error = vec![0.0; 2];
for j in 0..2 {
error[j] = v[j] - pred[j];
}
state1.add_outer_product(&k, &error);
gated_delta_update(&mut state2, &k, &v, decay, 1.0);
let s1 = state1.as_slice();
let s2 = state2.as_slice();
for i in 0..s1.len() {
assert!(
(s1[i] - s2[i]).abs() < 1e-12,
"beta_scale=1.0 should match original at index {}: {} vs {}",
i,
s1[i],
s2[i]
);
}
}
#[test]
fn gated_delta_net_key_normalization_bounded_state() {
let mut state = AttentionState::new_matrix(2, 2);
let v = [1.0, 1.0];
let decay = 0.95;
for i in 0..100 {
let scale = (i + 1) as f64 * 10.0;
let k = [scale, scale];
gated_delta_update(&mut state, &k, &v, decay, 1.0);
}
let state_norm_sq: f64 = state.as_slice().iter().map(|&x| x * x).sum();
let state_norm = math::sqrt(state_norm_sq);
assert!(
state_norm < 100.0,
"state norm should be bounded with normalized keys, got {}",
state_norm
);
}
#[test]
fn gated_delta_net_beta_scale_zero_freezes_state() {
let mut state = AttentionState::new_matrix(2, 2);
state.set_matrix(0, 0, 10.0);
state.set_matrix(1, 1, 20.0);
let k = [1.0, 0.0];
let v = [999.0, 888.0]; let decay = 0.5;
gated_delta_update(&mut state, &k, &v, decay, 0.0);
assert!(
(state.get_matrix(0, 0) - 5.0).abs() < 1e-12,
"with beta=0, S[0][0] should be 10*0.5=5.0, got {}",
state.get_matrix(0, 0)
);
assert!(
(state.get_matrix(1, 1) - 10.0).abs() < 1e-12,
"with beta=0, S[1][1] should be 20*0.5=10.0, got {}",
state.get_matrix(1, 1)
);
assert!(
state.get_matrix(0, 1).abs() < 1e-12,
"with beta=0, S[0][1] should remain 0, got {}",
state.get_matrix(0, 1)
);
assert!(
state.get_matrix(1, 0).abs() < 1e-12,
"with beta=0, S[1][0] should remain 0, got {}",
state.get_matrix(1, 0)
);
}
#[test]
fn hgrn2_update_basic() {
let mut state = AttentionState::new_matrix(2, 3);
let k = [1.0, 2.0];
let v = [3.0, 4.0, 5.0];
let alpha = [0.9, 0.8]; hgrn2_update(&mut state, &k, &v, &alpha);
assert!(
(state.get_matrix(0, 0) - 3.0).abs() < 1e-12,
"S[0][0] should be 1*3=3, got {}",
state.get_matrix(0, 0)
);
assert!(
(state.get_matrix(1, 2) - 10.0).abs() < 1e-12,
"S[1][2] should be 2*5=10, got {}",
state.get_matrix(1, 2)
);
}
#[test]
fn hgrn2_lower_bound_ensures_retention() {
let mut state = AttentionState::new_matrix(2, 2);
state.set_matrix(0, 0, 100.0);
state.set_matrix(1, 1, 200.0);
let k = [0.0, 0.0]; let v = [0.0, 0.0];
let alpha = [0.99, 0.99]; hgrn2_update(&mut state, &k, &v, &alpha);
assert!(
(state.get_matrix(0, 0) - 99.0).abs() < 1e-12,
"with alpha=0.99, S[0][0] should be 100*0.99=99, got {}",
state.get_matrix(0, 0)
);
assert!(
(state.get_matrix(1, 1) - 198.0).abs() < 1e-12,
"with alpha=0.99, S[1][1] should be 200*0.99=198, got {}",
state.get_matrix(1, 1)
);
}
#[test]
fn hgrn2_lower_bound_zero_matches_gla() {
let mut state_hgrn2 = AttentionState::new_matrix(2, 2);
let mut state_gla = AttentionState::new_matrix(2, 2);
let k = [1.0, 0.5];
let v = [2.0, 3.0];
let decay = 0.7;
let alpha = [decay, decay];
hgrn2_update(&mut state_hgrn2, &k, &v, &alpha);
additive_update(&mut state_gla, &k, &v, decay);
let s1 = state_hgrn2.as_slice();
let s2 = state_gla.as_slice();
for i in 0..s1.len() {
assert!(
(s1[i] - s2[i]).abs() < 1e-12,
"HGRN2 with uniform alpha should match GLA at {}: {} vs {}",
i,
s1[i],
s2[i]
);
}
}
#[test]
fn hgrn2_per_dimension_decay() {
let mut state = AttentionState::new_matrix(2, 2);
state.set_matrix(0, 0, 10.0);
state.set_matrix(1, 1, 10.0);
let k = [0.0, 0.0]; let v = [0.0, 0.0];
let alpha = [0.5, 0.9]; hgrn2_update(&mut state, &k, &v, &alpha);
assert!(
(state.get_matrix(0, 0) - 5.0).abs() < 1e-12,
"row 0 decayed by 0.5: 10*0.5=5, got {}",
state.get_matrix(0, 0)
);
assert!(
(state.get_matrix(1, 1) - 9.0).abs() < 1e-12,
"row 1 decayed by 0.9: 10*0.9=9, got {}",
state.get_matrix(1, 1)
);
}
#[test]
fn additive_update_vec_from_zero_state() {
let mut state = AttentionState::new_matrix(2, 2);
let k = [1.0, 2.0];
let v = [3.0, 4.0];
let alpha = [0.9, 0.8]; additive_update_vec(&mut state, &k, &v, &alpha);
assert!(
(state.get_matrix(0, 0) - 3.0).abs() < 1e-12,
"S[0][0] should be 1*3=3, got {}",
state.get_matrix(0, 0)
);
assert!(
(state.get_matrix(1, 1) - 8.0).abs() < 1e-12,
"S[1][1] should be 2*4=8, got {}",
state.get_matrix(1, 1)
);
}
#[test]
fn additive_update_vec_per_row_decay() {
let mut state = AttentionState::new_matrix(2, 2);
state.set_matrix(0, 0, 10.0);
state.set_matrix(1, 1, 20.0);
let k = [0.0, 0.0]; let v = [0.0, 0.0];
let alpha = [0.5, 0.9]; additive_update_vec(&mut state, &k, &v, &alpha);
assert!(
(state.get_matrix(0, 0) - 5.0).abs() < 1e-12,
"row 0 decayed by 0.5: 10*0.5=5, got {}",
state.get_matrix(0, 0)
);
assert!(
(state.get_matrix(1, 1) - 18.0).abs() < 1e-12,
"row 1 decayed by 0.9: 20*0.9=18, got {}",
state.get_matrix(1, 1)
);
}
#[test]
fn additive_update_vec_uniform_alpha_matches_scalar() {
let mut state1 = AttentionState::new_matrix(2, 2);
let mut state2 = AttentionState::new_matrix(2, 2);
let k = [1.0, 0.5];
let v = [2.0, 3.0];
let decay = 0.7;
let alpha = [decay, decay];
additive_update_vec(&mut state1, &k, &v, &alpha);
additive_update(&mut state2, &k, &v, decay);
let s1 = state1.as_slice();
let s2 = state2.as_slice();
for i in 0..s1.len() {
assert!(
(s1[i] - s2[i]).abs() < 1e-12,
"uniform alpha vec should match scalar at {}: {} vs {}",
i,
s1[i],
s2[i]
);
}
}
}