use alloc::boxed::Box;
use alloc::vec;
use alloc::vec::Vec;
use super::config::AttentionMode;
#[cfg(test)]
use super::config::GatedDeltaMode;
use super::gating::{init_weights, mat_vec, Xorshift64};
use super::log_linear_state::LogLinearState;
use super::AttentionLayer;
use crate::math;
use crate::streaming_primitives::{softplus_softmax_mix, tanh_inplace};
pub const DEFAULT_MAX_LEVELS: usize = 32;
pub fn default_lambda_init(max_levels: usize) -> f64 {
1.0 / (max_levels as f64).max(1.0)
}
pub const DEFAULT_TAU: f64 = 1.0;
pub const DEFAULT_LEARNING_RATE: f64 = 0.05;
pub struct LogLinearAttention {
inner_mode: Box<AttentionMode>,
state: LogLinearState,
w_key: Vec<f64>,
w_value: Vec<f64>,
w_query: Vec<f64>,
w_lambda: Vec<f64>,
lambda_bias: f64,
d_model: usize,
d_key: usize,
d_value: usize,
max_levels: usize,
tau: f64,
learning_rate: f64,
train_step_count: u64,
scratch_lambda_raw: Vec<f64>,
scratch_lambda: Vec<f64>,
scratch_k: Vec<f64>,
scratch_v: Vec<f64>,
scratch_q: Vec<f64>,
}
impl LogLinearAttention {
pub fn new(
inner_mode: AttentionMode,
d_model: usize,
d_key: usize,
d_value: usize,
max_levels: usize,
lambda_init: f64,
seed: u64,
) -> Self {
debug_assert!(d_model > 0, "d_model must be positive");
debug_assert!(d_key > 0, "d_key must be positive");
debug_assert!(d_value > 0, "d_value must be positive");
debug_assert!(max_levels > 0, "max_levels must be positive");
debug_assert!(
!matches!(inner_mode, AttentionMode::LogLinear { .. }),
"log-linear cannot wrap log-linear (no recursive nesting)"
);
let mut rng = Xorshift64(seed);
let w_key = init_weights(&mut rng, d_key * d_model);
let w_value = init_weights(&mut rng, d_value * d_model);
let w_query = init_weights(&mut rng, d_key * d_model);
let w_lambda = init_weights(&mut rng, max_levels * d_model);
let state = LogLinearState::new(max_levels, d_key, d_value);
Self {
inner_mode: Box::new(inner_mode),
state,
w_key,
w_value,
w_query,
w_lambda,
lambda_bias: lambda_init,
d_model,
d_key,
d_value,
max_levels,
tau: DEFAULT_TAU,
learning_rate: DEFAULT_LEARNING_RATE,
train_step_count: 0,
scratch_lambda_raw: vec![0.0; max_levels],
scratch_lambda: vec![0.0; max_levels],
scratch_k: vec![0.0; d_key],
scratch_v: vec![0.0; d_value],
scratch_q: vec![0.0; d_key],
}
}
#[inline]
pub fn learning_rate(&self) -> f64 {
self.learning_rate
}
pub fn set_learning_rate(&mut self, lr: f64) {
debug_assert!(
lr.is_finite() && lr > 0.0,
"learning_rate must be a finite positive number, got {lr}"
);
self.learning_rate = lr;
}
#[inline]
pub fn train_step_count(&self) -> u64 {
self.train_step_count
}
pub fn reset_train_step_count(&mut self) {
self.train_step_count = 0;
}
pub fn inner_mode(&self) -> &AttentionMode {
&self.inner_mode
}
pub fn log_linear_state(&self) -> &LogLinearState {
&self.state
}
fn compute_lambda(&mut self, input: &[f64]) {
mat_vec(
&self.w_lambda,
input,
self.max_levels,
self.d_model,
&mut self.scratch_lambda_raw,
);
for r in self.scratch_lambda_raw.iter_mut() {
*r += self.lambda_bias;
}
softplus_softmax_mix(&self.scratch_lambda_raw, self.tau, &mut self.scratch_lambda);
}
pub fn query_readonly(&mut self, input: &[f64]) -> Vec<f64> {
debug_assert_eq!(
input.len(),
self.d_model,
"input must have d_model elements"
);
for x in self.scratch_q.iter_mut() {
*x = 0.0;
}
mat_vec(
&self.w_query,
input,
self.d_key,
self.d_model,
&mut self.scratch_q,
);
self.compute_lambda(input);
let mut out = vec![0.0; self.d_value];
self.state
.query_mixed(&self.scratch_q, &self.scratch_lambda, &mut out);
tanh_inplace(&mut out);
out
}
#[allow(clippy::needless_range_loop)]
pub fn train_one(&mut self, input: &[f64], target: &[f64]) -> Vec<f64> {
debug_assert_eq!(
input.len(),
self.d_model,
"input must have d_model elements"
);
debug_assert_eq!(
target.len(),
self.d_value,
"target must have d_value elements"
);
for x in self.scratch_k.iter_mut() {
*x = 0.0;
}
for x in self.scratch_v.iter_mut() {
*x = 0.0;
}
for x in self.scratch_q.iter_mut() {
*x = 0.0;
}
mat_vec(
&self.w_key,
input,
self.d_key,
self.d_model,
&mut self.scratch_k,
);
mat_vec(
&self.w_value,
input,
self.d_value,
self.d_model,
&mut self.scratch_v,
);
mat_vec(
&self.w_query,
input,
self.d_key,
self.d_model,
&mut self.scratch_q,
);
let is_delta_family = matches!(
self.inner_mode.as_ref(),
AttentionMode::DeltaNet
| AttentionMode::GatedDeltaNet { .. }
| AttentionMode::DeltaProduct { .. }
| AttentionMode::RWKV7
);
let k_raw_norm: f64 = if is_delta_family {
let n_sq: f64 = self.scratch_k.iter().map(|&x| x * x).sum();
math::sqrt(n_sq)
} else {
0.0 };
let k_for_leaf: Vec<f64> = if is_delta_family {
l2_normalize(&self.scratch_k)
} else {
self.scratch_k.clone()
};
mat_vec(
&self.w_lambda,
input,
self.max_levels,
self.d_model,
&mut self.scratch_lambda_raw,
);
for r in self.scratch_lambda_raw.iter_mut() {
*r += self.lambda_bias;
}
let inv_tau = 1.0 / self.tau;
let mut softplus_sum = 0.0;
for (i, &xi) in self.scratch_lambda_raw.iter().enumerate() {
let sp = math::softplus(xi * inv_tau);
self.scratch_lambda[i] = sp;
softplus_sum += sp;
}
if softplus_sum > 0.0 {
for s in self.scratch_lambda.iter_mut() {
*s /= softplus_sum;
}
}
let pre_push_size = self.state.size();
let landed_level = (pre_push_size.trailing_ones() as usize).min(self.max_levels - 1);
self.state.push_leaf(&k_for_leaf, &self.scratch_v);
let mut o_pre = vec![0.0; self.d_value];
self.state
.query_mixed(&self.scratch_q, &self.scratch_lambda, &mut o_pre);
let mut o = o_pre.clone();
tanh_inplace(&mut o);
let mut delta = vec![0.0; self.d_value];
for d in 0..self.d_value {
let err = o[d] - target[d];
delta[d] = err * (1.0 - o[d] * o[d]);
}
let mut dl_dlambda = vec![0.0; self.max_levels];
for ell in 0..self.max_levels {
if !self.state.is_active(ell) {
continue;
}
let z_l = self.state.level(ell).query(&self.scratch_q);
let mut dot = 0.0;
for d in 0..self.d_value {
dot += delta[d] * z_l[d];
}
dl_dlambda[ell] = dot;
}
let mut dl_dq = vec![0.0; self.d_key];
for ell in 0..self.max_levels {
if !self.state.is_active(ell) || self.scratch_lambda[ell] == 0.0 {
continue;
}
let lam = self.scratch_lambda[ell];
let s_l = self.state.level(ell).as_slice();
for i in 0..self.d_key {
let row_start = i * self.d_value;
let mut acc = 0.0;
for d in 0..self.d_value {
acc += s_l[row_start + d] * delta[d];
}
dl_dq[i] += lam * acc;
}
}
let mut weighted_sum = 0.0;
for ell in 0..self.max_levels {
weighted_sum += self.scratch_lambda[ell] * dl_dlambda[ell];
}
let mut dl_draw = vec![0.0; self.max_levels];
if softplus_sum > 0.0 {
for j in 0..self.max_levels {
let sigma = math::sigmoid(self.scratch_lambda_raw[j] * inv_tau);
dl_draw[j] = (sigma * inv_tau / softplus_sum) * (dl_dlambda[j] - weighted_sum);
}
}
let lam_l = if landed_level < self.max_levels {
self.scratch_lambda[landed_level]
} else {
0.0
};
let kq_dot: f64 = {
let mut acc = 0.0;
for i in 0..self.d_key {
acc += k_for_leaf[i] * self.scratch_q[i];
}
acc
};
let v_delta_dot: f64 = {
let mut acc = 0.0;
for d in 0..self.d_value {
acc += self.scratch_v[d] * delta[d];
}
acc
};
let mut dl_dv = vec![0.0; self.d_value];
for d in 0..self.d_value {
dl_dv[d] = lam_l * kq_dot * delta[d];
}
let mut dl_dk_for_leaf = vec![0.0; self.d_key];
for i in 0..self.d_key {
dl_dk_for_leaf[i] = lam_l * v_delta_dot * self.scratch_q[i];
}
let dl_dk: Vec<f64> = if is_delta_family && k_raw_norm > 1e-12 {
let kn_dot_grad: f64 = {
let mut acc = 0.0;
for i in 0..self.d_key {
acc += k_for_leaf[i] * dl_dk_for_leaf[i];
}
acc
};
let inv_norm = 1.0 / k_raw_norm;
let mut grad_raw = vec![0.0; self.d_key];
for i in 0..self.d_key {
grad_raw[i] = inv_norm * (dl_dk_for_leaf[i] - k_for_leaf[i] * kn_dot_grad);
}
grad_raw
} else {
dl_dk_for_leaf
};
let lr = self.learning_rate;
sgd_outer_descent(
&mut self.w_query,
&dl_dq,
input,
self.d_key,
self.d_model,
lr,
);
sgd_outer_descent(&mut self.w_key, &dl_dk, input, self.d_key, self.d_model, lr);
sgd_outer_descent(
&mut self.w_value,
&dl_dv,
input,
self.d_value,
self.d_model,
lr,
);
sgd_outer_descent(
&mut self.w_lambda,
&dl_draw,
input,
self.max_levels,
self.d_model,
lr,
);
self.train_step_count = self.train_step_count.saturating_add(1);
o
}
}
impl AttentionLayer for LogLinearAttention {
fn forward(&mut self, input: &[f64]) -> Vec<f64> {
debug_assert_eq!(
input.len(),
self.d_model,
"input must have d_model elements"
);
for x in self.scratch_k.iter_mut() {
*x = 0.0;
}
for x in self.scratch_v.iter_mut() {
*x = 0.0;
}
for x in self.scratch_q.iter_mut() {
*x = 0.0;
}
mat_vec(
&self.w_key,
input,
self.d_key,
self.d_model,
&mut self.scratch_k,
);
mat_vec(
&self.w_value,
input,
self.d_value,
self.d_model,
&mut self.scratch_v,
);
mat_vec(
&self.w_query,
input,
self.d_key,
self.d_model,
&mut self.scratch_q,
);
let k_for_leaf: Vec<f64> = match self.inner_mode.as_ref() {
AttentionMode::DeltaNet
| AttentionMode::GatedDeltaNet { .. }
| AttentionMode::DeltaProduct { .. }
| AttentionMode::RWKV7 => l2_normalize(&self.scratch_k),
_ => self.scratch_k.clone(),
};
self.compute_lambda(input);
let mut out = vec![0.0; self.d_value];
self.state
.query_mixed(&self.scratch_q, &self.scratch_lambda, &mut out);
self.state.push_leaf(&k_for_leaf, &self.scratch_v);
tanh_inplace(&mut out);
out
}
fn state(&self) -> &[f64] {
self.state.flat_state()
}
fn output_dim(&self) -> usize {
self.d_value
}
fn reset(&mut self) {
self.state.reset();
}
}
fn l2_normalize(v: &[f64]) -> Vec<f64> {
let norm_sq: f64 = v.iter().map(|&x| x * x).sum();
let norm = math::sqrt(norm_sq);
if norm < 1e-12 {
vec![0.0; v.len()]
} else {
let inv = 1.0 / norm;
v.iter().map(|&x| x * inv).collect()
}
}
#[inline]
fn sgd_outer_descent(
w: &mut [f64],
grad_y: &[f64],
input: &[f64],
rows: usize,
cols: usize,
lr: f64,
) {
debug_assert_eq!(w.len(), rows * cols, "W shape mismatch");
debug_assert_eq!(grad_y.len(), rows, "grad_y must have rows elements");
debug_assert_eq!(input.len(), cols, "input must have cols elements");
if lr == 0.0 {
return;
}
for (i, &gi) in grad_y.iter().enumerate() {
if gi == 0.0 {
continue;
}
let lr_gi = lr * gi;
let row_start = i * cols;
for (j, &xj) in input.iter().enumerate() {
w[row_start + j] -= lr_gi * xj;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn xs(t: usize) -> Vec<f64> {
let n = 8usize;
(0..n).map(|i| ((t * 7 + i * 3) as f64).sin()).collect()
}
#[test]
fn log_linear_wraps_arbitrary_inner_update_rule() {
let inner_modes: Vec<AttentionMode> = vec![
AttentionMode::RetNet { gamma: 0.95 },
AttentionMode::GLA,
AttentionMode::GLAVector,
AttentionMode::DeltaNet,
AttentionMode::GatedDeltaNet {
beta_scale: 1.0,
gate_mode_delta: GatedDeltaMode::Static,
},
AttentionMode::DeltaProduct {
n_compositions: 2,
reflections: false,
},
AttentionMode::RWKV7,
AttentionMode::HGRN2 { lower_bound: 0.9 },
AttentionMode::MLSTM,
AttentionMode::Hawk,
AttentionMode::RWKV { initial_decay: 0.5 },
];
for inner in inner_modes {
let mode_dbg = alloc::format!("{:?}", inner);
let mut lla = LogLinearAttention::new(inner, 8, 4, 4, 8, default_lambda_init(8), 42);
let x = xs(0);
let out = lla.forward(&x);
assert_eq!(
out.len(),
4,
"inner={mode_dbg}: output dim must equal d_value=4"
);
assert!(
out.iter().all(|v| v.is_finite()),
"inner={mode_dbg}: output must be finite"
);
assert!(
out.iter().all(|v| v.abs() <= 1.0),
"inner={mode_dbg}: tanh-bounded output must be in [-1, 1]"
);
}
}
#[test]
fn forward_advances_size_by_one() {
let mut lla =
LogLinearAttention::new(AttentionMode::GLA, 8, 4, 4, 8, default_lambda_init(8), 42);
assert_eq!(lla.log_linear_state().size(), 0);
for t in 1..=5u64 {
let _ = lla.forward(&xs(t as usize));
assert_eq!(
lla.log_linear_state().size(),
t,
"size must increment by 1 per forward"
);
}
}
#[test]
fn reset_returns_to_fresh_state() {
let mut lla =
LogLinearAttention::new(AttentionMode::GLA, 8, 4, 4, 8, default_lambda_init(8), 42);
for t in 0..50 {
let _ = lla.forward(&xs(t));
}
assert!(lla.log_linear_state().size() > 0);
assert!(lla.state().iter().any(|&v| v != 0.0));
lla.reset();
assert_eq!(lla.log_linear_state().size(), 0);
assert!(lla.state().iter().all(|&v| v == 0.0));
}
#[test]
fn output_bounded_by_tanh() {
let mut lla = LogLinearAttention::new(
AttentionMode::DeltaNet,
8,
4,
4,
8,
default_lambda_init(8),
17,
);
for t in 0..100 {
let out = lla.forward(&xs(t));
for &v in &out {
assert!(
v.is_finite() && v.abs() <= 1.0,
"tanh-bounded output must be in [-1, 1] at t={}, got {}",
t,
v
);
}
}
}
#[test]
fn deterministic_with_same_seed() {
let mut lla1 =
LogLinearAttention::new(AttentionMode::GLA, 8, 4, 4, 8, default_lambda_init(8), 42);
let mut lla2 =
LogLinearAttention::new(AttentionMode::GLA, 8, 4, 4, 8, default_lambda_init(8), 42);
for t in 0..30 {
let x = xs(t);
let o1 = lla1.forward(&x);
let o2 = lla2.forward(&x);
for (a, b) in o1.iter().zip(o2.iter()) {
assert!(
(a - b).abs() < 1e-15,
"same seed must produce same output (t={})",
t
);
}
}
}
#[test]
fn state_padded_to_max_levels() {
let max_levels = 12;
let d_key = 4;
let d_value = 4;
let mut lla = LogLinearAttention::new(
AttentionMode::GLA,
8,
d_key,
d_value,
max_levels,
default_lambda_init(max_levels),
42,
);
let expected = max_levels * d_key * d_value;
assert_eq!(
lla.state().len(),
expected,
"state() must be max_levels * d_k * d_v (constant shape)"
);
for t in 1..=20 {
let _ = lla.forward(&xs(t));
assert_eq!(
lla.state().len(),
expected,
"state shape must be constant after forward t={}",
t
);
}
}
#[test]
fn lambda_sums_bounded_after_softplus_softmax() {
let mut lla =
LogLinearAttention::new(AttentionMode::GLA, 8, 4, 4, 8, default_lambda_init(8), 42);
for t in 0..30 {
let x = xs(t);
lla.compute_lambda(&x);
let sum: f64 = lla.scratch_lambda.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-9,
"softplus_softmax_mix must produce a probability distribution (sum=1), got {sum}"
);
for &lam in &lla.scratch_lambda {
assert!(
(0.0..=1.0).contains(&lam),
"λ entry must be in [0, 1], got {lam}"
);
}
}
}
#[test]
fn query_readonly_does_not_mutate_state() {
let mut lla =
LogLinearAttention::new(AttentionMode::GLA, 8, 4, 4, 8, default_lambda_init(8), 42);
for t in 0..10 {
let _ = lla.forward(&xs(t));
}
let size_before = lla.log_linear_state().size();
let state_before: Vec<f64> = lla.state().to_vec();
let _ = lla.query_readonly(&xs(99));
let size_after = lla.log_linear_state().size();
let state_after: Vec<f64> = lla.state().to_vec();
assert_eq!(
size_before, size_after,
"query_readonly must not advance size"
);
assert_eq!(
state_before, state_after,
"query_readonly must not mutate state cache"
);
}
#[test]
fn default_lambda_init_uniform_at_max_levels() {
for ml in [1, 4, 16, 32] {
let lam = default_lambda_init(ml);
assert!(
(lam - 1.0 / ml as f64).abs() < 1e-15,
"default_lambda_init({ml}) should be 1/{ml}"
);
}
}
#[test]
fn log_linear_default_learning_rate_is_finite_positive() {
let lla =
LogLinearAttention::new(AttentionMode::GLA, 8, 4, 4, 8, default_lambda_init(8), 7);
let lr = lla.learning_rate();
assert!(
lr.is_finite() && lr > 0.0,
"default learning_rate must be positive finite, got {lr}"
);
assert!(
(lr - DEFAULT_LEARNING_RATE).abs() < 1e-15,
"default learning_rate should equal DEFAULT_LEARNING_RATE, got {lr}"
);
}
#[test]
fn log_linear_set_learning_rate_overrides_default() {
let mut lla =
LogLinearAttention::new(AttentionMode::GLA, 8, 4, 4, 8, default_lambda_init(8), 7);
lla.set_learning_rate(0.123);
assert!(
(lla.learning_rate() - 0.123).abs() < 1e-15,
"set_learning_rate should override default"
);
}
#[test]
fn log_linear_train_one_returns_d_value_output() {
let mut lla =
LogLinearAttention::new(AttentionMode::GLA, 8, 4, 4, 8, default_lambda_init(8), 42);
let target = vec![0.1, -0.2, 0.3, -0.4];
let out = lla.train_one(&xs(0), &target);
assert_eq!(out.len(), 4, "train_one output must equal d_value");
for &v in &out {
assert!(
v.is_finite() && v.abs() <= 1.0,
"tanh-bounded train_one output must be in [-1, 1], got {v}"
);
}
}
#[test]
fn log_linear_train_one_advances_train_step_count() {
let mut lla =
LogLinearAttention::new(AttentionMode::GLA, 8, 4, 4, 8, default_lambda_init(8), 42);
let target = vec![0.0; 4];
assert_eq!(lla.train_step_count(), 0);
for t in 1..=5 {
let _ = lla.train_one(&xs(t), &target);
assert_eq!(
lla.train_step_count(),
t as u64,
"train_step_count should increment by 1 per call"
);
}
lla.reset_train_step_count();
assert_eq!(
lla.train_step_count(),
0,
"reset_train_step_count should clear the counter"
);
}
#[test]
fn log_linear_train_one_advances_state_size() {
let mut lla =
LogLinearAttention::new(AttentionMode::GLA, 8, 4, 4, 8, default_lambda_init(8), 42);
let target = vec![0.0; 4];
assert_eq!(lla.log_linear_state().size(), 0);
for t in 1..=5u64 {
let _ = lla.train_one(&xs(t as usize), &target);
assert_eq!(
lla.log_linear_state().size(),
t,
"size must increment by 1 per train_one"
);
}
}
#[test]
fn log_linear_train_one_modifies_q_k_v_lambda_weights() {
let mut lla =
LogLinearAttention::new(AttentionMode::GLA, 8, 4, 4, 8, default_lambda_init(8), 42);
let w_q_before = lla.w_query.clone();
let w_k_before = lla.w_key.clone();
let w_v_before = lla.w_value.clone();
let w_l_before = lla.w_lambda.clone();
let target = vec![0.7, -0.5, 0.3, 0.2];
for t in 0..30 {
let _ = lla.train_one(&xs(t), &target);
}
let any_q_changed = w_q_before
.iter()
.zip(lla.w_query.iter())
.any(|(a, b)| (a - b).abs() > 1e-12);
let any_k_changed = w_k_before
.iter()
.zip(lla.w_key.iter())
.any(|(a, b)| (a - b).abs() > 1e-12);
let any_v_changed = w_v_before
.iter()
.zip(lla.w_value.iter())
.any(|(a, b)| (a - b).abs() > 1e-12);
let any_l_changed = w_l_before
.iter()
.zip(lla.w_lambda.iter())
.any(|(a, b)| (a - b).abs() > 1e-12);
assert!(any_q_changed, "W_q must be updated by train_one");
assert!(any_k_changed, "W_k must be updated by train_one");
assert!(any_v_changed, "W_v must be updated by train_one");
assert!(any_l_changed, "W_lambda must be updated by train_one");
}
#[test]
fn log_linear_qkv_projections_update_via_streaming_gradient() {
let mut lla =
LogLinearAttention::new(AttentionMode::GLA, 8, 4, 4, 8, default_lambda_init(8), 42);
let probe_input = xs(99);
let target = vec![0.4_f64, -0.3, 0.2, -0.1];
lla.reset();
let o0 = lla.train_one(&probe_input, &target);
let initial_loss: f64 = o0
.iter()
.zip(target.iter())
.map(|(p, t)| (p - t).powi(2))
.sum();
for _ in 0..300 {
lla.reset();
let _ = lla.train_one(&probe_input, &target);
}
lla.reset();
let o_final = lla.train_one(&probe_input, &target);
let final_loss: f64 = o_final
.iter()
.zip(target.iter())
.map(|(p, t)| (p - t).powi(2))
.sum();
assert!(
final_loss < initial_loss,
"Gradient must descend on a single-pair fresh-state task: \
initial_loss={initial_loss:.6}, final_loss={final_loss:.6}"
);
assert!(
final_loss.is_finite() && initial_loss.is_finite(),
"loss must remain finite throughout"
);
}
#[test]
fn log_linear_online_training_reduces_mqar_loss() {
let n_pairs = 2usize;
let d_model = 8usize;
let d_k = 4usize;
let d_v = 4usize;
let max_levels = 8usize;
let lr = 0.1_f64;
let n_epochs = 200usize;
let mut lla = LogLinearAttention::new(
AttentionMode::GatedDeltaNet {
beta_scale: 1.0,
gate_mode_delta: GatedDeltaMode::Static,
},
d_model,
d_k,
d_v,
max_levels,
default_lambda_init(max_levels),
0xABCD,
);
lla.set_learning_rate(lr);
let pairs: alloc::vec::Vec<(alloc::vec::Vec<f64>, alloc::vec::Vec<f64>)> = (0..n_pairs)
.map(|i| {
let k: alloc::vec::Vec<f64> = (0..d_model)
.map(|j| ((i * 13 + j * 7) as f64).sin())
.collect();
let v: alloc::vec::Vec<f64> = (0..d_v)
.map(|j| ((i * 17 + j * 11) as f64).cos() * 0.5)
.collect();
(k, v)
})
.collect();
let recall_loss = |lla: &mut LogLinearAttention,
pairs: &[(alloc::vec::Vec<f64>, alloc::vec::Vec<f64>)]|
-> f64 {
lla.reset();
for (k, target) in pairs {
let _ = lla.train_one(k, target);
}
let mut total = 0.0;
for (k, target) in pairs {
let o = lla.query_readonly(k);
total += o
.iter()
.zip(target.iter())
.map(|(p, t)| (p - t).powi(2))
.sum::<f64>()
/ o.len() as f64;
}
total / pairs.len() as f64
};
let initial_loss = recall_loss(&mut lla, &pairs);
let mut min_loss = initial_loss;
for _ in 0..n_epochs {
let l = recall_loss(&mut lla, &pairs);
if l < min_loss {
min_loss = l;
}
assert!(
l.is_finite(),
"recall loss must stay finite during training"
);
}
assert!(
min_loss < 0.7 * initial_loss,
"Online streaming SGD must reduce MQAR recall MSE by ≥ 30%: \
initial_loss={initial_loss:.6}, min_loss={min_loss:.6}, \
ratio={:.4} (must be < 0.70)",
min_loss / initial_loss
);
assert!(
initial_loss.is_finite() && min_loss.is_finite(),
"loss must stay finite — initial={initial_loss}, min={min_loss}"
);
}
#[test]
fn log_linear_train_one_zero_lr_is_no_op_on_weights() {
let mut lla =
LogLinearAttention::new(AttentionMode::GLA, 8, 4, 4, 8, default_lambda_init(8), 7);
for t in 0..5 {
let _ = lla.forward(&xs(t));
}
lla.set_learning_rate(1e-30);
let mut lla_zero =
LogLinearAttention::new(AttentionMode::GLA, 8, 4, 4, 8, default_lambda_init(8), 7);
lla_zero.learning_rate = 0.0;
let w_q_before = lla_zero.w_query.clone();
let target = vec![0.1, -0.1, 0.05, -0.05];
for t in 0..10 {
let _ = lla_zero.train_one(&xs(t), &target);
}
assert_eq!(
lla_zero.w_query, w_q_before,
"lr=0 SGD must leave W_q unchanged"
);
}
}