Skip to main content

irithyll_core/attention/
log_linear_state.rs

1//! Hierarchical Fenwick-tree state for Log-Linear Attention.
2//!
3//! Implements the per-head state container for Log-Linear Attention
4//! (Han Guo et al., ICLR 2026, arXiv:2506.04761). Each head owns a stack
5//! of up to `max_levels` matrix states, organized as a binary-counter
6//! (Fenwick) decomposition of the prefix `[0, t)`. After `t` tokens the
7//! ACTIVE levels correspond exactly to the 1-bits of `t` (paper §2);
8//! storage is padded to `max_levels` so the public state vector is
9//! constant-shaped, satisfying the diagnostic-consumer invariant
10//! "state().len() is constant" (paper §3.4 in R1, "Option B —
11//! Recommended").
12//!
13//! # Carry-propagation algorithm
14//!
15//! On `push_leaf(s_leaf)`:
16//! 1. Place `s_leaf` at level 0.
17//! 2. While level ℓ has TWO buckets of equal size 2^ℓ, sum them into a
18//!    bucket of size 2^(ℓ+1) at level ℓ+1, freeing both children.
19//! 3. Continue until the carry stops or `max_levels` is exceeded.
20//!
21//! This is identical to incrementing a binary counter. After `t` pushes
22//! the active levels are precisely the 1-bits of `t`, so the maximum
23//! occupancy is `popcount(t) ≤ ⌊log₂(t)⌋ + 1` — the O(log T) state
24//! bound advertised by the paper.
25//!
26//! # Padding to `max_levels` (NOT popcount)
27//!
28//! The paper-mandated stability choice (R1 §3.4): pad to a constant
29//! `max_levels` so `state()` length is stable across stream length. A
30//! popcount-sized vector would change shape every token, breaking the
31//! `AttentionLayer::state()` contract that downstream diagnostics
32//! depend on. Inactive levels are zero matrices.
33//!
34//! # `max_levels` capacity
35//!
36//! `max_levels = ⌊log₂(T_max)⌋ + 1`. For T_max = 2^32 (4 billion
37//! tokens), `max_levels = 33`. The recommended default is **32**,
38//! matching R1 §3.5: covers streams up to ~4 G tokens with constant
39//! overhead `max_levels * d_k * d_v` per head.
40
41use alloc::vec;
42use alloc::vec::Vec;
43
44use super::state::AttentionState;
45
46/// Hierarchical stack of matrix states, one per active Fenwick level.
47///
48/// Storage is fixed at `max_levels` slots; each slot is a `d_k x d_v`
49/// matrix (zeros when inactive). The `active` mask records which slots
50/// currently hold a real bucket. The `size` field counts tokens pushed
51/// so far — equivalently, after `size = t` pushes, the bits of `t`
52/// indicate which levels are active.
53///
54/// # Paper reference
55///
56/// Han Guo, Songlin Yang, Tarushii Goel, Eric P. Xing, Tri Dao, Yoon
57/// Kim. *Log-Linear Attention*. ICLR 2026. arXiv:2506.04761, §2-§3.
58#[derive(Clone, Debug)]
59pub struct LogLinearState {
60    /// Per-level matrix states, length `max_levels`. Each entry is a
61    /// `d_k x d_v` matrix; inactive entries hold all-zero data.
62    levels: Vec<AttentionState>,
63    /// Active mask: `active[ℓ] == true` iff level ℓ holds a real bucket.
64    /// Length `max_levels`. Equivalent to bit ℓ of `size`, but kept as a
65    /// separate vector for branch-free read access in hot paths.
66    active: Vec<bool>,
67    /// Token count pushed so far. The bit pattern of `size` matches
68    /// `active` exactly after each successful `push_leaf`.
69    size: u64,
70    /// Hard cap on hierarchy depth. State storage is fixed at
71    /// `max_levels` regardless of `size`.
72    max_levels: usize,
73    /// Per-head key dimension.
74    d_k: usize,
75    /// Per-head value dimension.
76    d_v: usize,
77    /// Flat state cache: concatenated levels in row-major
78    /// `[L0 | L1 | … | L_{max_levels-1}]` form, length
79    /// `max_levels * d_k * d_v`. Zeroed slots remain zero.
80    state_cache: Vec<f64>,
81}
82
83impl LogLinearState {
84    /// Create a new state with all `max_levels` matrices zero-initialized.
85    ///
86    /// # Panics
87    ///
88    /// Panics in debug mode if `max_levels == 0`, `d_k == 0`, or
89    /// `d_v == 0`.
90    pub fn new(max_levels: usize, d_k: usize, d_v: usize) -> Self {
91        debug_assert!(max_levels > 0, "max_levels must be positive");
92        debug_assert!(d_k > 0, "d_k must be positive");
93        debug_assert!(d_v > 0, "d_v must be positive");
94
95        let levels: Vec<AttentionState> = (0..max_levels)
96            .map(|_| AttentionState::new_matrix(d_k, d_v))
97            .collect();
98        let active = vec![false; max_levels];
99        let state_cache = vec![0.0; max_levels * d_k * d_v];
100
101        Self {
102            levels,
103            active,
104            size: 0,
105            max_levels,
106            d_k,
107            d_v,
108            state_cache,
109        }
110    }
111
112    /// Hierarchy depth cap (`max_levels`). Storage is always padded to
113    /// this size.
114    #[inline]
115    pub fn max_levels(&self) -> usize {
116        self.max_levels
117    }
118
119    /// Per-head key dimension.
120    #[inline]
121    pub fn d_k(&self) -> usize {
122        self.d_k
123    }
124
125    /// Per-head value dimension.
126    #[inline]
127    pub fn d_v(&self) -> usize {
128        self.d_v
129    }
130
131    /// Number of tokens pushed so far. Equivalent to `t` in the paper.
132    #[inline]
133    pub fn size(&self) -> u64 {
134        self.size
135    }
136
137    /// Number of currently active levels = `popcount(size)`.
138    ///
139    /// Always `≤ max_levels`. After exhausting capacity (size ≥ 2^max_levels),
140    /// the highest level absorbs further carries (see `push_leaf`).
141    pub fn active_level_count(&self) -> usize {
142        self.active.iter().filter(|&&a| a).count()
143    }
144
145    /// Whether level `ℓ` currently holds a real bucket.
146    ///
147    /// # Panics
148    ///
149    /// Panics in debug mode if `level >= max_levels`.
150    #[inline]
151    pub fn is_active(&self, level: usize) -> bool {
152        debug_assert!(
153            level < self.max_levels,
154            "level {} out of range (max_levels={})",
155            level,
156            self.max_levels
157        );
158        self.active[level]
159    }
160
161    /// Borrow level `ℓ`'s matrix state (zero matrix if inactive).
162    ///
163    /// # Panics
164    ///
165    /// Panics in debug mode if `level >= max_levels`.
166    #[inline]
167    pub fn level(&self, level: usize) -> &AttentionState {
168        debug_assert!(
169            level < self.max_levels,
170            "level {} out of range (max_levels={})",
171            level,
172            self.max_levels
173        );
174        &self.levels[level]
175    }
176
177    /// Push a new leaf bucket holding the outer product `k * v^T`,
178    /// then run carry-propagation upward.
179    ///
180    /// Algorithm (paper §2.1):
181    /// 1. Set level 0 to `k * v^T`. If level 0 was already active, the
182    ///    new leaf would collide — but classical Fenwick increment
183    ///    means that case happens iff the previous push produced a
184    ///    carry that did NOT consume level 0. By construction the
185    ///    invariant holds: after every prior push, level 0 is active
186    ///    iff bit 0 of `size` is set (== `size` is odd). So before
187    ///    push: `level0_active iff size_was_odd`. We treat this with
188    ///    standard binary-increment: place the new bucket at level 0
189    ///    pre-emptively, then run the standard carry loop.
190    ///
191    /// In the paper this is the carry-propagation form of the Fenwick
192    /// scan; in irithyll terms it's an in-place rewrite of the level
193    /// stack, no allocation past `max_levels`.
194    ///
195    /// # Capacity overflow
196    ///
197    /// If a carry would propagate above level `max_levels - 1`, the
198    /// excess bucket is folded into the topmost level via matrix
199    /// addition. This preserves the invariant "total information
200    /// captured by the Fenwick tree" at the cost of resolution at
201    /// the very deepest scale — equivalent to the paper's note that
202    /// `max_levels = ⌊log₂(T_max)⌋ + 1` should be chosen so
203    /// `T_max` exceeds the expected stream length.
204    ///
205    /// # Arguments
206    ///
207    /// - `k` — key vector, length `d_k`.
208    /// - `v` — value vector, length `d_v`.
209    pub fn push_leaf(&mut self, k: &[f64], v: &[f64]) {
210        debug_assert_eq!(k.len(), self.d_k, "k length must match d_k");
211        debug_assert_eq!(v.len(), self.d_v, "v length must match d_v");
212
213        // Sanity: classical binary-counter increment makes level 0
214        // collisions impossible when invariants hold; assert in debug.
215        // Specifically, before this push, level 0 active <=> size is
216        // odd. After push, level 0 active <=> (size+1) is odd.
217        debug_assert_eq!(
218            self.active[0],
219            self.size & 1 == 1,
220            "Fenwick invariant: level 0 active iff size is odd"
221        );
222
223        // The new leaf must enter at level 0. If level 0 is active
224        // (i.e., size was odd), classical binary increment carries up
225        // — but in the matrix interpretation, the "carry" means the
226        // existing level-0 bucket sums with the new leaf and is then
227        // written to level 1, then potentially summing with level 1's
228        // existing bucket, and so on, until we hit an inactive level.
229
230        // Build the new bucket as outer product (k * v^T).
231        let mut carry = AttentionState::new_matrix(self.d_k, self.d_v);
232        carry.add_outer_product(k, v);
233
234        let mut ell = 0usize;
235        loop {
236            if ell >= self.max_levels {
237                // Capacity exhausted: fold the carry into the topmost
238                // level (max_levels - 1). This caps memory at the
239                // configured bound while still accumulating information.
240                let top = self.max_levels - 1;
241                add_matrix_in_place(&mut self.levels[top], &carry);
242                self.active[top] = true;
243                break;
244            }
245
246            if !self.active[ell] {
247                // Slot is free — write the carry here, halt.
248                replace_matrix(&mut self.levels[ell], carry);
249                self.active[ell] = true;
250                break;
251            }
252
253            // Slot ℓ is active: sum the existing bucket into carry
254            // and clear ℓ. Continue propagation upward.
255            let existing = take_matrix(&mut self.levels[ell], self.d_k, self.d_v);
256            self.active[ell] = false;
257            add_matrix_in_place(&mut carry, &existing);
258            ell += 1;
259        }
260
261        self.size = self.size.saturating_add(1);
262        self.refresh_cache();
263    }
264
265    /// Reset all levels to zero and clear `size`. After reset,
266    /// `state()` returns all zeros and `active_level_count() == 0`.
267    pub fn reset(&mut self) {
268        for state in self.levels.iter_mut() {
269            state.reset();
270        }
271        for a in self.active.iter_mut() {
272            *a = false;
273        }
274        self.size = 0;
275        for x in self.state_cache.iter_mut() {
276            *x = 0.0;
277        }
278    }
279
280    /// Flat view of the padded state — concatenation of all
281    /// `max_levels` levels in row-major order.
282    ///
283    /// Length is always `max_levels * d_k * d_v`, regardless of
284    /// `active_level_count()`. Inactive levels contribute all-zero
285    /// blocks. This is the constant-shape contract required by
286    /// `AttentionLayer::state()` consumers.
287    #[inline]
288    pub fn flat_state(&self) -> &[f64] {
289        &self.state_cache
290    }
291
292    /// Compute the λ-weighted readout `Σ_ℓ λ_ℓ · q^T · S^(ℓ)` over all
293    /// `max_levels` slots and write into `out` (length `d_v`).
294    ///
295    /// Inactive levels contribute zero (their `S^(ℓ)` is the zero
296    /// matrix). The caller supplies `lambdas` of length `max_levels`
297    /// (typically a softplus-softmax mix bounding `Σ λ ≤ 1`).
298    ///
299    /// # Arguments
300    ///
301    /// - `q` — query vector, length `d_k`.
302    /// - `lambdas` — per-level non-negative mix weights, length
303    ///   `max_levels`.
304    /// - `out` — output buffer, length `d_v`. Overwritten.
305    ///
306    /// # Panics
307    ///
308    /// Panics in debug mode if `q.len() != d_k`,
309    /// `lambdas.len() != max_levels`, or `out.len() != d_v`.
310    pub fn query_mixed(&self, q: &[f64], lambdas: &[f64], out: &mut [f64]) {
311        debug_assert_eq!(q.len(), self.d_k, "q length must match d_k");
312        debug_assert_eq!(
313            lambdas.len(),
314            self.max_levels,
315            "lambdas length must match max_levels"
316        );
317        debug_assert_eq!(out.len(), self.d_v, "out length must match d_v");
318
319        for o in out.iter_mut() {
320            *o = 0.0;
321        }
322        for (ell, &lam) in lambdas.iter().enumerate() {
323            if !self.active[ell] || lam == 0.0 {
324                continue;
325            }
326            // Per-level readout: o_ℓ = q^T · S^(ℓ) (length d_v).
327            let o_l = self.levels[ell].query(q);
328            for (oi, ol) in out.iter_mut().zip(o_l.iter()) {
329                *oi += lam * ol;
330            }
331        }
332    }
333
334    /// Refresh the flat cache from the level matrices. Cheap: total
335    /// work is `max_levels * d_k * d_v` per token, equal to the
336    /// log-linear state size already advertised.
337    fn refresh_cache(&mut self) {
338        let mut offset = 0;
339        for state in self.levels.iter() {
340            let slice = state.as_slice();
341            let len = slice.len();
342            self.state_cache[offset..offset + len].copy_from_slice(slice);
343            offset += len;
344        }
345    }
346}
347
348/// In-place matrix add: `dst += src` (both `d_k x d_v` row-major).
349fn add_matrix_in_place(dst: &mut AttentionState, src: &AttentionState) {
350    match (dst, src) {
351        (
352            AttentionState::Matrix { data: dst_data, .. },
353            AttentionState::Matrix { data: src_data, .. },
354        ) => {
355            debug_assert_eq!(
356                dst_data.len(),
357                src_data.len(),
358                "matrix addition shape mismatch"
359            );
360            for (d, s) in dst_data.iter_mut().zip(src_data.iter()) {
361                *d += *s;
362            }
363        }
364        _ => panic!("add_matrix_in_place: both states must be Matrix"),
365    }
366}
367
368/// Move `src` into `*dst`, leaving `dst` holding the new bucket.
369/// Equivalent to assignment but uses the existing buffer of `dst`
370/// when possible to avoid alloc churn — copies element-wise.
371fn replace_matrix(dst: &mut AttentionState, src: AttentionState) {
372    match (dst, src) {
373        (
374            AttentionState::Matrix { data: dst_data, .. },
375            AttentionState::Matrix { data: src_data, .. },
376        ) => {
377            debug_assert_eq!(
378                dst_data.len(),
379                src_data.len(),
380                "matrix replace shape mismatch"
381            );
382            dst_data.copy_from_slice(&src_data);
383        }
384        _ => panic!("replace_matrix: both states must be Matrix"),
385    }
386}
387
388/// Read out the existing matrix at `dst` into a new owned
389/// `AttentionState`, leaving `dst` zeroed in place. Avoids a swap by
390/// copying then zeroing — the old data is preserved in the returned
391/// state.
392fn take_matrix(dst: &mut AttentionState, d_k: usize, d_v: usize) -> AttentionState {
393    let mut taken = AttentionState::new_matrix(d_k, d_v);
394    if let (
395        AttentionState::Matrix { data: dst_data, .. },
396        AttentionState::Matrix {
397            data: taken_data, ..
398        },
399    ) = (dst, &mut taken)
400    {
401        taken_data.copy_from_slice(dst_data);
402        for d in dst_data.iter_mut() {
403            *d = 0.0;
404        }
405    } else {
406        panic!("take_matrix: state must be Matrix");
407    }
408    taken
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    #[test]
416    fn new_state_has_zero_size_and_no_active_levels() {
417        let s = LogLinearState::new(8, 4, 4);
418        assert_eq!(s.size(), 0, "fresh state has size 0");
419        assert_eq!(
420            s.active_level_count(),
421            0,
422            "fresh state has no active levels"
423        );
424        assert!(
425            s.flat_state().iter().all(|&x| x == 0.0),
426            "fresh state cache is all zeros"
427        );
428    }
429
430    #[test]
431    fn log_linear_state_padded_to_max_levels() {
432        // The flat state slice MUST equal max_levels * d_k * d_v
433        // regardless of how many tokens have been pushed. This is the
434        // paper-mandated stability choice (R1 §3.4 Option B).
435        let max_levels = 8;
436        let d_k = 4;
437        let d_v = 4;
438        let mut s = LogLinearState::new(max_levels, d_k, d_v);
439        let expected_len = max_levels * d_k * d_v;
440        assert_eq!(
441            s.flat_state().len(),
442            expected_len,
443            "flat state must be max_levels * d_k * d_v at t=0"
444        );
445
446        // Push one token: should add a leaf at level 0.
447        s.push_leaf(&[1.0, 2.0, 3.0, 4.0], &[0.5, -0.5, 0.25, -0.25]);
448        assert_eq!(
449            s.flat_state().len(),
450            expected_len,
451            "flat state must remain max_levels * d_k * d_v after t=1"
452        );
453        assert_eq!(s.size(), 1);
454        assert_eq!(s.active_level_count(), 1, "popcount(1) = 1");
455        assert!(s.is_active(0), "after 1 push, level 0 is active");
456
457        // Push three more tokens (size = 4 = 0b100), expect only
458        // level 2 active (popcount = 1).
459        for i in 0..3 {
460            let f = (i + 1) as f64;
461            s.push_leaf(&[f, f, f, f], &[f, f, f, f]);
462        }
463        assert_eq!(s.size(), 4);
464        assert_eq!(s.active_level_count(), 1, "popcount(4) = 1");
465        assert!(s.is_active(2), "size=4 -> level 2 active");
466        assert!(!s.is_active(0));
467        assert!(!s.is_active(1));
468        assert_eq!(
469            s.flat_state().len(),
470            expected_len,
471            "flat state still padded to max_levels"
472        );
473    }
474
475    #[test]
476    fn log_linear_state_reset_clears_all_levels() {
477        let max_levels = 8;
478        let mut s = LogLinearState::new(max_levels, 4, 4);
479        for i in 0..50u64 {
480            let f = i as f64 + 1.0;
481            s.push_leaf(&[f, f, f, f], &[f, f, f, f]);
482        }
483        assert!(s.size() > 0);
484        assert!(s.active_level_count() > 0);
485        assert!(
486            s.flat_state().iter().any(|&x| x != 0.0),
487            "after pushes, cache should have non-zero entries"
488        );
489
490        s.reset();
491
492        assert_eq!(s.size(), 0, "reset clears size");
493        assert_eq!(s.active_level_count(), 0, "reset deactivates all levels");
494        assert!(
495            s.flat_state().iter().all(|&x| x == 0.0),
496            "reset clears flat state"
497        );
498        for ell in 0..max_levels {
499            assert!(
500                !s.is_active(ell),
501                "level {} must be inactive after reset",
502                ell
503            );
504            assert!(
505                s.level(ell).as_slice().iter().all(|&x| x == 0.0),
506                "level {} matrix must be zero after reset",
507                ell
508            );
509        }
510    }
511
512    #[test]
513    fn fenwick_active_levels_match_popcount_of_size() {
514        // After t pushes, the active levels MUST equal the 1-bits of
515        // t (Han Guo et al., ICLR 2026 §2). Verify across t = 1..32.
516        let max_levels = 8;
517        let mut s = LogLinearState::new(max_levels, 4, 4);
518        let k = [0.5; 4];
519        let v = [0.5; 4];
520
521        for t in 1..=31u64 {
522            s.push_leaf(&k, &v);
523            for ell in 0..max_levels {
524                let bit_set = (t >> ell) & 1 == 1;
525                assert_eq!(
526                    s.is_active(ell),
527                    bit_set,
528                    "at size={}, level {} active should match bit {} of size",
529                    t,
530                    ell,
531                    ell
532                );
533            }
534            assert_eq!(
535                s.active_level_count() as u32,
536                t.count_ones(),
537                "active count must equal popcount of size"
538            );
539        }
540    }
541
542    #[test]
543    fn level_matrix_size_doubles_with_level() {
544        // After 2^k tokens with all-equal leaves, the merged bucket at
545        // level k is the SUM of 2^k identical outer products, i.e., the
546        // outer-product magnitude at level k is 2^k times the single
547        // leaf magnitude. This verifies the merge semantics
548        // (matrix addition of equal-size siblings, paper §2.1).
549        let max_levels = 8;
550        let mut s = LogLinearState::new(max_levels, 4, 4);
551        let k_vec = [1.0, 0.0, 0.0, 0.0];
552        let v_vec = [1.0, 0.0, 0.0, 0.0];
553
554        // Push exactly 4 = 2^2 tokens. Only level 2 should be active,
555        // and its (0,0) element should be 4 (outer product (k * v^T) at
556        // (0,0) = 1, summed 4 times).
557        for _ in 0..4 {
558            s.push_leaf(&k_vec, &v_vec);
559        }
560        assert_eq!(s.size(), 4);
561        assert_eq!(s.active_level_count(), 1);
562        assert!(s.is_active(2));
563        let entry = s.level(2).get_matrix(0, 0);
564        assert!(
565            (entry - 4.0).abs() < 1e-12,
566            "level 2 (0,0) should accumulate 4 leaves, got {}",
567            entry
568        );
569    }
570
571    #[test]
572    fn query_mixed_zero_lambdas_gives_zero_output() {
573        let max_levels = 8;
574        let mut s = LogLinearState::new(max_levels, 4, 4);
575        s.push_leaf(&[1.0, 2.0, 3.0, 4.0], &[0.5, 0.5, 0.5, 0.5]);
576
577        let q = [1.0; 4];
578        let lambdas = [0.0; 8];
579        let mut out = [42.0; 4];
580        s.query_mixed(&q, &lambdas, &mut out);
581        for &o in &out {
582            assert_eq!(o, 0.0, "zero λ produces zero output");
583        }
584    }
585
586    #[test]
587    fn query_mixed_uniform_lambdas_sums_active_levels() {
588        // With λ = 1.0 on all levels, output equals the unweighted
589        // sum of per-level queries (only active levels contribute).
590        let max_levels = 8;
591        let mut s = LogLinearState::new(max_levels, 4, 4);
592        let k = [1.0, 0.0, 0.0, 0.0];
593        let v = [1.0, 1.0, 1.0, 1.0];
594        s.push_leaf(&k, &v); // level 0: k * v^T
595
596        let q = [1.0, 0.0, 0.0, 0.0];
597        let lambdas = [1.0; 8];
598        let mut out = [0.0; 4];
599        s.query_mixed(&q, &lambdas, &mut out);
600        // S^(0) at (0,*) = v = [1,1,1,1]; S^T q at index j = sum_i S[i][j] * q[i] = S[0][j]*1 = v[j].
601        for &o in &out {
602            assert!(
603                (o - 1.0).abs() < 1e-12,
604                "uniform λ readout should equal v, got {}",
605                o
606            );
607        }
608    }
609
610    #[test]
611    fn query_mixed_inactive_levels_skipped() {
612        // After 2 pushes (size=2 = 0b10), only level 1 is active.
613        // λ on inactive levels must contribute exactly zero.
614        let max_levels = 4;
615        let mut s = LogLinearState::new(max_levels, 4, 4);
616        s.push_leaf(&[1.0, 0.0, 0.0, 0.0], &[1.0, 0.0, 0.0, 0.0]);
617        s.push_leaf(&[1.0, 0.0, 0.0, 0.0], &[1.0, 0.0, 0.0, 0.0]);
618        assert!(s.is_active(1));
619        assert!(!s.is_active(0));
620        assert!(!s.is_active(2));
621
622        let q = [1.0, 0.0, 0.0, 0.0];
623        // Compare:
624        //   - All λ=1: only level 1 contributes
625        //   - λ=1 only on level 0 (inactive): output should be zero.
626        let mut out_all = [0.0; 4];
627        s.query_mixed(&q, &[1.0; 4], &mut out_all);
628
629        let mut out_inactive = [0.0; 4];
630        s.query_mixed(&q, &[1.0, 0.0, 0.0, 0.0], &mut out_inactive);
631        for &o in &out_inactive {
632            assert_eq!(
633                o, 0.0,
634                "λ on inactive level 0 must contribute zero (level 0 is empty), got {}",
635                o
636            );
637        }
638
639        // The "all λ=1" output should be non-zero (level 1 has 2-leaf
640        // accumulated bucket).
641        assert!(
642            out_all.iter().any(|&o| o != 0.0),
643            "active level 1 with λ=1 must contribute non-zero output"
644        );
645    }
646
647    #[test]
648    fn capacity_overflow_folds_into_top_level() {
649        // With max_levels=2, after 4 pushes the carry would propagate
650        // to level 2 (out of range). Spec: fold into top level
651        // (max_levels - 1 = 1).
652        let max_levels = 2;
653        let mut s = LogLinearState::new(max_levels, 4, 4);
654        let k = [1.0, 0.0, 0.0, 0.0];
655        let v = [1.0, 0.0, 0.0, 0.0];
656        for _ in 0..4 {
657            s.push_leaf(&k, &v);
658        }
659        assert_eq!(s.size(), 4);
660        // Top level should hold the accumulated information.
661        assert!(s.is_active(1), "top level must be active after overflow");
662        let entry = s.level(1).get_matrix(0, 0);
663        assert!(
664            entry > 0.0,
665            "top level should accumulate folded carries, got {}",
666            entry
667        );
668    }
669
670    #[test]
671    fn flat_state_matches_concatenated_levels() {
672        let max_levels = 4;
673        let d_k = 3;
674        let d_v = 3;
675        let mut s = LogLinearState::new(max_levels, d_k, d_v);
676        for i in 0..7u64 {
677            let f = (i + 1) as f64 * 0.1;
678            s.push_leaf(&[f, f, f], &[f, f, f]);
679        }
680        // Size = 7 = 0b111: levels 0, 1, 2 active.
681        let flat = s.flat_state();
682        assert_eq!(flat.len(), max_levels * d_k * d_v);
683        let block = d_k * d_v;
684        for ell in 0..max_levels {
685            let level_slice = s.level(ell).as_slice();
686            let cache_slice = &flat[ell * block..(ell + 1) * block];
687            assert_eq!(
688                level_slice, cache_slice,
689                "flat cache for level {} must match level matrix",
690                ell
691            );
692        }
693    }
694
695    #[test]
696    fn deterministic_construction() {
697        let mut a = LogLinearState::new(8, 4, 4);
698        let mut b = LogLinearState::new(8, 4, 4);
699        for t in 1..=20u64 {
700            let f = t as f64 * 0.1;
701            a.push_leaf(&[f, f, f, f], &[f, -f, f, -f]);
702            b.push_leaf(&[f, f, f, f], &[f, -f, f, -f]);
703        }
704        for (x, y) in a.flat_state().iter().zip(b.flat_state().iter()) {
705            assert!(
706                (x - y).abs() < 1e-15,
707                "identical pushes produce identical state"
708            );
709        }
710    }
711}