use alloc::vec;
use alloc::vec::Vec;
use super::state::AttentionState;
#[derive(Clone, Debug)]
pub struct LogLinearState {
levels: Vec<AttentionState>,
active: Vec<bool>,
size: u64,
max_levels: usize,
d_k: usize,
d_v: usize,
state_cache: Vec<f64>,
}
impl LogLinearState {
pub fn new(max_levels: usize, d_k: usize, d_v: usize) -> Self {
debug_assert!(max_levels > 0, "max_levels must be positive");
debug_assert!(d_k > 0, "d_k must be positive");
debug_assert!(d_v > 0, "d_v must be positive");
let levels: Vec<AttentionState> = (0..max_levels)
.map(|_| AttentionState::new_matrix(d_k, d_v))
.collect();
let active = vec![false; max_levels];
let state_cache = vec![0.0; max_levels * d_k * d_v];
Self {
levels,
active,
size: 0,
max_levels,
d_k,
d_v,
state_cache,
}
}
#[inline]
pub fn max_levels(&self) -> usize {
self.max_levels
}
#[inline]
pub fn d_k(&self) -> usize {
self.d_k
}
#[inline]
pub fn d_v(&self) -> usize {
self.d_v
}
#[inline]
pub fn size(&self) -> u64 {
self.size
}
pub fn active_level_count(&self) -> usize {
self.active.iter().filter(|&&a| a).count()
}
#[inline]
pub fn is_active(&self, level: usize) -> bool {
debug_assert!(
level < self.max_levels,
"level {} out of range (max_levels={})",
level,
self.max_levels
);
self.active[level]
}
#[inline]
pub fn level(&self, level: usize) -> &AttentionState {
debug_assert!(
level < self.max_levels,
"level {} out of range (max_levels={})",
level,
self.max_levels
);
&self.levels[level]
}
pub fn push_leaf(&mut self, k: &[f64], v: &[f64]) {
debug_assert_eq!(k.len(), self.d_k, "k length must match d_k");
debug_assert_eq!(v.len(), self.d_v, "v length must match d_v");
debug_assert_eq!(
self.active[0],
self.size & 1 == 1,
"Fenwick invariant: level 0 active iff size is odd"
);
let mut carry = AttentionState::new_matrix(self.d_k, self.d_v);
carry.add_outer_product(k, v);
let mut ell = 0usize;
loop {
if ell >= self.max_levels {
let top = self.max_levels - 1;
add_matrix_in_place(&mut self.levels[top], &carry);
self.active[top] = true;
break;
}
if !self.active[ell] {
replace_matrix(&mut self.levels[ell], carry);
self.active[ell] = true;
break;
}
let existing = take_matrix(&mut self.levels[ell], self.d_k, self.d_v);
self.active[ell] = false;
add_matrix_in_place(&mut carry, &existing);
ell += 1;
}
self.size = self.size.saturating_add(1);
self.refresh_cache();
}
pub fn reset(&mut self) {
for state in self.levels.iter_mut() {
state.reset();
}
for a in self.active.iter_mut() {
*a = false;
}
self.size = 0;
for x in self.state_cache.iter_mut() {
*x = 0.0;
}
}
#[inline]
pub fn flat_state(&self) -> &[f64] {
&self.state_cache
}
pub fn query_mixed(&self, q: &[f64], lambdas: &[f64], out: &mut [f64]) {
debug_assert_eq!(q.len(), self.d_k, "q length must match d_k");
debug_assert_eq!(
lambdas.len(),
self.max_levels,
"lambdas length must match max_levels"
);
debug_assert_eq!(out.len(), self.d_v, "out length must match d_v");
for o in out.iter_mut() {
*o = 0.0;
}
for (ell, &lam) in lambdas.iter().enumerate() {
if !self.active[ell] || lam == 0.0 {
continue;
}
let o_l = self.levels[ell].query(q);
for (oi, ol) in out.iter_mut().zip(o_l.iter()) {
*oi += lam * ol;
}
}
}
fn refresh_cache(&mut self) {
let mut offset = 0;
for state in self.levels.iter() {
let slice = state.as_slice();
let len = slice.len();
self.state_cache[offset..offset + len].copy_from_slice(slice);
offset += len;
}
}
}
fn add_matrix_in_place(dst: &mut AttentionState, src: &AttentionState) {
match (dst, src) {
(
AttentionState::Matrix { data: dst_data, .. },
AttentionState::Matrix { data: src_data, .. },
) => {
debug_assert_eq!(
dst_data.len(),
src_data.len(),
"matrix addition shape mismatch"
);
for (d, s) in dst_data.iter_mut().zip(src_data.iter()) {
*d += *s;
}
}
_ => panic!("add_matrix_in_place: both states must be Matrix"),
}
}
fn replace_matrix(dst: &mut AttentionState, src: AttentionState) {
match (dst, src) {
(
AttentionState::Matrix { data: dst_data, .. },
AttentionState::Matrix { data: src_data, .. },
) => {
debug_assert_eq!(
dst_data.len(),
src_data.len(),
"matrix replace shape mismatch"
);
dst_data.copy_from_slice(&src_data);
}
_ => panic!("replace_matrix: both states must be Matrix"),
}
}
fn take_matrix(dst: &mut AttentionState, d_k: usize, d_v: usize) -> AttentionState {
let mut taken = AttentionState::new_matrix(d_k, d_v);
if let (
AttentionState::Matrix { data: dst_data, .. },
AttentionState::Matrix {
data: taken_data, ..
},
) = (dst, &mut taken)
{
taken_data.copy_from_slice(dst_data);
for d in dst_data.iter_mut() {
*d = 0.0;
}
} else {
panic!("take_matrix: state must be Matrix");
}
taken
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_state_has_zero_size_and_no_active_levels() {
let s = LogLinearState::new(8, 4, 4);
assert_eq!(s.size(), 0, "fresh state has size 0");
assert_eq!(
s.active_level_count(),
0,
"fresh state has no active levels"
);
assert!(
s.flat_state().iter().all(|&x| x == 0.0),
"fresh state cache is all zeros"
);
}
#[test]
fn log_linear_state_padded_to_max_levels() {
let max_levels = 8;
let d_k = 4;
let d_v = 4;
let mut s = LogLinearState::new(max_levels, d_k, d_v);
let expected_len = max_levels * d_k * d_v;
assert_eq!(
s.flat_state().len(),
expected_len,
"flat state must be max_levels * d_k * d_v at t=0"
);
s.push_leaf(&[1.0, 2.0, 3.0, 4.0], &[0.5, -0.5, 0.25, -0.25]);
assert_eq!(
s.flat_state().len(),
expected_len,
"flat state must remain max_levels * d_k * d_v after t=1"
);
assert_eq!(s.size(), 1);
assert_eq!(s.active_level_count(), 1, "popcount(1) = 1");
assert!(s.is_active(0), "after 1 push, level 0 is active");
for i in 0..3 {
let f = (i + 1) as f64;
s.push_leaf(&[f, f, f, f], &[f, f, f, f]);
}
assert_eq!(s.size(), 4);
assert_eq!(s.active_level_count(), 1, "popcount(4) = 1");
assert!(s.is_active(2), "size=4 -> level 2 active");
assert!(!s.is_active(0));
assert!(!s.is_active(1));
assert_eq!(
s.flat_state().len(),
expected_len,
"flat state still padded to max_levels"
);
}
#[test]
fn log_linear_state_reset_clears_all_levels() {
let max_levels = 8;
let mut s = LogLinearState::new(max_levels, 4, 4);
for i in 0..50u64 {
let f = i as f64 + 1.0;
s.push_leaf(&[f, f, f, f], &[f, f, f, f]);
}
assert!(s.size() > 0);
assert!(s.active_level_count() > 0);
assert!(
s.flat_state().iter().any(|&x| x != 0.0),
"after pushes, cache should have non-zero entries"
);
s.reset();
assert_eq!(s.size(), 0, "reset clears size");
assert_eq!(s.active_level_count(), 0, "reset deactivates all levels");
assert!(
s.flat_state().iter().all(|&x| x == 0.0),
"reset clears flat state"
);
for ell in 0..max_levels {
assert!(
!s.is_active(ell),
"level {} must be inactive after reset",
ell
);
assert!(
s.level(ell).as_slice().iter().all(|&x| x == 0.0),
"level {} matrix must be zero after reset",
ell
);
}
}
#[test]
fn fenwick_active_levels_match_popcount_of_size() {
let max_levels = 8;
let mut s = LogLinearState::new(max_levels, 4, 4);
let k = [0.5; 4];
let v = [0.5; 4];
for t in 1..=31u64 {
s.push_leaf(&k, &v);
for ell in 0..max_levels {
let bit_set = (t >> ell) & 1 == 1;
assert_eq!(
s.is_active(ell),
bit_set,
"at size={}, level {} active should match bit {} of size",
t,
ell,
ell
);
}
assert_eq!(
s.active_level_count() as u32,
t.count_ones(),
"active count must equal popcount of size"
);
}
}
#[test]
fn level_matrix_size_doubles_with_level() {
let max_levels = 8;
let mut s = LogLinearState::new(max_levels, 4, 4);
let k_vec = [1.0, 0.0, 0.0, 0.0];
let v_vec = [1.0, 0.0, 0.0, 0.0];
for _ in 0..4 {
s.push_leaf(&k_vec, &v_vec);
}
assert_eq!(s.size(), 4);
assert_eq!(s.active_level_count(), 1);
assert!(s.is_active(2));
let entry = s.level(2).get_matrix(0, 0);
assert!(
(entry - 4.0).abs() < 1e-12,
"level 2 (0,0) should accumulate 4 leaves, got {}",
entry
);
}
#[test]
fn query_mixed_zero_lambdas_gives_zero_output() {
let max_levels = 8;
let mut s = LogLinearState::new(max_levels, 4, 4);
s.push_leaf(&[1.0, 2.0, 3.0, 4.0], &[0.5, 0.5, 0.5, 0.5]);
let q = [1.0; 4];
let lambdas = [0.0; 8];
let mut out = [42.0; 4];
s.query_mixed(&q, &lambdas, &mut out);
for &o in &out {
assert_eq!(o, 0.0, "zero λ produces zero output");
}
}
#[test]
fn query_mixed_uniform_lambdas_sums_active_levels() {
let max_levels = 8;
let mut s = LogLinearState::new(max_levels, 4, 4);
let k = [1.0, 0.0, 0.0, 0.0];
let v = [1.0, 1.0, 1.0, 1.0];
s.push_leaf(&k, &v);
let q = [1.0, 0.0, 0.0, 0.0];
let lambdas = [1.0; 8];
let mut out = [0.0; 4];
s.query_mixed(&q, &lambdas, &mut out);
for &o in &out {
assert!(
(o - 1.0).abs() < 1e-12,
"uniform λ readout should equal v, got {}",
o
);
}
}
#[test]
fn query_mixed_inactive_levels_skipped() {
let max_levels = 4;
let mut s = LogLinearState::new(max_levels, 4, 4);
s.push_leaf(&[1.0, 0.0, 0.0, 0.0], &[1.0, 0.0, 0.0, 0.0]);
s.push_leaf(&[1.0, 0.0, 0.0, 0.0], &[1.0, 0.0, 0.0, 0.0]);
assert!(s.is_active(1));
assert!(!s.is_active(0));
assert!(!s.is_active(2));
let q = [1.0, 0.0, 0.0, 0.0];
let mut out_all = [0.0; 4];
s.query_mixed(&q, &[1.0; 4], &mut out_all);
let mut out_inactive = [0.0; 4];
s.query_mixed(&q, &[1.0, 0.0, 0.0, 0.0], &mut out_inactive);
for &o in &out_inactive {
assert_eq!(
o, 0.0,
"λ on inactive level 0 must contribute zero (level 0 is empty), got {}",
o
);
}
assert!(
out_all.iter().any(|&o| o != 0.0),
"active level 1 with λ=1 must contribute non-zero output"
);
}
#[test]
fn capacity_overflow_folds_into_top_level() {
let max_levels = 2;
let mut s = LogLinearState::new(max_levels, 4, 4);
let k = [1.0, 0.0, 0.0, 0.0];
let v = [1.0, 0.0, 0.0, 0.0];
for _ in 0..4 {
s.push_leaf(&k, &v);
}
assert_eq!(s.size(), 4);
assert!(s.is_active(1), "top level must be active after overflow");
let entry = s.level(1).get_matrix(0, 0);
assert!(
entry > 0.0,
"top level should accumulate folded carries, got {}",
entry
);
}
#[test]
fn flat_state_matches_concatenated_levels() {
let max_levels = 4;
let d_k = 3;
let d_v = 3;
let mut s = LogLinearState::new(max_levels, d_k, d_v);
for i in 0..7u64 {
let f = (i + 1) as f64 * 0.1;
s.push_leaf(&[f, f, f], &[f, f, f]);
}
let flat = s.flat_state();
assert_eq!(flat.len(), max_levels * d_k * d_v);
let block = d_k * d_v;
for ell in 0..max_levels {
let level_slice = s.level(ell).as_slice();
let cache_slice = &flat[ell * block..(ell + 1) * block];
assert_eq!(
level_slice, cache_slice,
"flat cache for level {} must match level matrix",
ell
);
}
}
#[test]
fn deterministic_construction() {
let mut a = LogLinearState::new(8, 4, 4);
let mut b = LogLinearState::new(8, 4, 4);
for t in 1..=20u64 {
let f = t as f64 * 0.1;
a.push_leaf(&[f, f, f, f], &[f, -f, f, -f]);
b.push_leaf(&[f, f, f, f], &[f, -f, f, -f]);
}
for (x, y) in a.flat_state().iter().zip(b.flat_state().iter()) {
assert!(
(x - y).abs() < 1e-15,
"identical pushes produce identical state"
);
}
}
}