irithyll_core/attention/log_linear_state.rs
1//! Hierarchical Fenwick-tree state for Log-Linear Attention.
2//!
3//! Implements the per-head state container for Log-Linear Attention
4//! (Han Guo et al., ICLR 2026, arXiv:2506.04761). Each head owns a stack
5//! of up to `max_levels` matrix states, organized as a binary-counter
6//! (Fenwick) decomposition of the prefix `[0, t)`. After `t` tokens the
7//! ACTIVE levels correspond exactly to the 1-bits of `t` (paper §2);
8//! storage is padded to `max_levels` so the public state vector is
9//! constant-shaped, satisfying the diagnostic-consumer invariant
10//! "state().len() is constant" (paper §3.4 in R1, "Option B —
11//! Recommended").
12//!
13//! # Carry-propagation algorithm
14//!
15//! On `push_leaf(s_leaf)`:
16//! 1. Place `s_leaf` at level 0.
17//! 2. While level ℓ has TWO buckets of equal size 2^ℓ, sum them into a
18//! bucket of size 2^(ℓ+1) at level ℓ+1, freeing both children.
19//! 3. Continue until the carry stops or `max_levels` is exceeded.
20//!
21//! This is identical to incrementing a binary counter. After `t` pushes
22//! the active levels are precisely the 1-bits of `t`, so the maximum
23//! occupancy is `popcount(t) ≤ ⌊log₂(t)⌋ + 1` — the O(log T) state
24//! bound advertised by the paper.
25//!
26//! # Padding to `max_levels` (NOT popcount)
27//!
28//! The paper-mandated stability choice (R1 §3.4): pad to a constant
29//! `max_levels` so `state()` length is stable across stream length. A
30//! popcount-sized vector would change shape every token, breaking the
31//! `AttentionLayer::state()` contract that downstream diagnostics
32//! depend on. Inactive levels are zero matrices.
33//!
34//! # `max_levels` capacity
35//!
36//! `max_levels = ⌊log₂(T_max)⌋ + 1`. For T_max = 2^32 (4 billion
37//! tokens), `max_levels = 33`. The recommended default is **32**,
38//! matching R1 §3.5: covers streams up to ~4 G tokens with constant
39//! overhead `max_levels * d_k * d_v` per head.
40
41use alloc::vec;
42use alloc::vec::Vec;
43
44use super::state::AttentionState;
45
46/// Hierarchical stack of matrix states, one per active Fenwick level.
47///
48/// Storage is fixed at `max_levels` slots; each slot is a `d_k x d_v`
49/// matrix (zeros when inactive). The `active` mask records which slots
50/// currently hold a real bucket. The `size` field counts tokens pushed
51/// so far — equivalently, after `size = t` pushes, the bits of `t`
52/// indicate which levels are active.
53///
54/// # Paper reference
55///
56/// Han Guo, Songlin Yang, Tarushii Goel, Eric P. Xing, Tri Dao, Yoon
57/// Kim. *Log-Linear Attention*. ICLR 2026. arXiv:2506.04761, §2-§3.
58#[derive(Clone, Debug)]
59pub struct LogLinearState {
60 /// Per-level matrix states, length `max_levels`. Each entry is a
61 /// `d_k x d_v` matrix; inactive entries hold all-zero data.
62 levels: Vec<AttentionState>,
63 /// Active mask: `active[ℓ] == true` iff level ℓ holds a real bucket.
64 /// Length `max_levels`. Equivalent to bit ℓ of `size`, but kept as a
65 /// separate vector for branch-free read access in hot paths.
66 active: Vec<bool>,
67 /// Token count pushed so far. The bit pattern of `size` matches
68 /// `active` exactly after each successful `push_leaf`.
69 size: u64,
70 /// Hard cap on hierarchy depth. State storage is fixed at
71 /// `max_levels` regardless of `size`.
72 max_levels: usize,
73 /// Per-head key dimension.
74 d_k: usize,
75 /// Per-head value dimension.
76 d_v: usize,
77 /// Flat state cache: concatenated levels in row-major
78 /// `[L0 | L1 | … | L_{max_levels-1}]` form, length
79 /// `max_levels * d_k * d_v`. Zeroed slots remain zero.
80 state_cache: Vec<f64>,
81}
82
83impl LogLinearState {
84 /// Create a new state with all `max_levels` matrices zero-initialized.
85 ///
86 /// # Panics
87 ///
88 /// Panics in debug mode if `max_levels == 0`, `d_k == 0`, or
89 /// `d_v == 0`.
90 pub fn new(max_levels: usize, d_k: usize, d_v: usize) -> Self {
91 debug_assert!(max_levels > 0, "max_levels must be positive");
92 debug_assert!(d_k > 0, "d_k must be positive");
93 debug_assert!(d_v > 0, "d_v must be positive");
94
95 let levels: Vec<AttentionState> = (0..max_levels)
96 .map(|_| AttentionState::new_matrix(d_k, d_v))
97 .collect();
98 let active = vec![false; max_levels];
99 let state_cache = vec![0.0; max_levels * d_k * d_v];
100
101 Self {
102 levels,
103 active,
104 size: 0,
105 max_levels,
106 d_k,
107 d_v,
108 state_cache,
109 }
110 }
111
112 /// Hierarchy depth cap (`max_levels`). Storage is always padded to
113 /// this size.
114 #[inline]
115 pub fn max_levels(&self) -> usize {
116 self.max_levels
117 }
118
119 /// Per-head key dimension.
120 #[inline]
121 pub fn d_k(&self) -> usize {
122 self.d_k
123 }
124
125 /// Per-head value dimension.
126 #[inline]
127 pub fn d_v(&self) -> usize {
128 self.d_v
129 }
130
131 /// Number of tokens pushed so far. Equivalent to `t` in the paper.
132 #[inline]
133 pub fn size(&self) -> u64 {
134 self.size
135 }
136
137 /// Number of currently active levels = `popcount(size)`.
138 ///
139 /// Always `≤ max_levels`. After exhausting capacity (size ≥ 2^max_levels),
140 /// the highest level absorbs further carries (see `push_leaf`).
141 pub fn active_level_count(&self) -> usize {
142 self.active.iter().filter(|&&a| a).count()
143 }
144
145 /// Whether level `ℓ` currently holds a real bucket.
146 ///
147 /// # Panics
148 ///
149 /// Panics in debug mode if `level >= max_levels`.
150 #[inline]
151 pub fn is_active(&self, level: usize) -> bool {
152 debug_assert!(
153 level < self.max_levels,
154 "level {} out of range (max_levels={})",
155 level,
156 self.max_levels
157 );
158 self.active[level]
159 }
160
161 /// Borrow level `ℓ`'s matrix state (zero matrix if inactive).
162 ///
163 /// # Panics
164 ///
165 /// Panics in debug mode if `level >= max_levels`.
166 #[inline]
167 pub fn level(&self, level: usize) -> &AttentionState {
168 debug_assert!(
169 level < self.max_levels,
170 "level {} out of range (max_levels={})",
171 level,
172 self.max_levels
173 );
174 &self.levels[level]
175 }
176
177 /// Push a new leaf bucket holding the outer product `k * v^T`,
178 /// then run carry-propagation upward.
179 ///
180 /// Algorithm (paper §2.1):
181 /// 1. Set level 0 to `k * v^T`. If level 0 was already active, the
182 /// new leaf would collide — but classical Fenwick increment
183 /// means that case happens iff the previous push produced a
184 /// carry that did NOT consume level 0. By construction the
185 /// invariant holds: after every prior push, level 0 is active
186 /// iff bit 0 of `size` is set (== `size` is odd). So before
187 /// push: `level0_active iff size_was_odd`. We treat this with
188 /// standard binary-increment: place the new bucket at level 0
189 /// pre-emptively, then run the standard carry loop.
190 ///
191 /// In the paper this is the carry-propagation form of the Fenwick
192 /// scan; in irithyll terms it's an in-place rewrite of the level
193 /// stack, no allocation past `max_levels`.
194 ///
195 /// # Capacity overflow
196 ///
197 /// If a carry would propagate above level `max_levels - 1`, the
198 /// excess bucket is folded into the topmost level via matrix
199 /// addition. This preserves the invariant "total information
200 /// captured by the Fenwick tree" at the cost of resolution at
201 /// the very deepest scale — equivalent to the paper's note that
202 /// `max_levels = ⌊log₂(T_max)⌋ + 1` should be chosen so
203 /// `T_max` exceeds the expected stream length.
204 ///
205 /// # Arguments
206 ///
207 /// - `k` — key vector, length `d_k`.
208 /// - `v` — value vector, length `d_v`.
209 pub fn push_leaf(&mut self, k: &[f64], v: &[f64]) {
210 debug_assert_eq!(k.len(), self.d_k, "k length must match d_k");
211 debug_assert_eq!(v.len(), self.d_v, "v length must match d_v");
212
213 // Sanity: classical binary-counter increment makes level 0
214 // collisions impossible when invariants hold; assert in debug.
215 // Specifically, before this push, level 0 active <=> size is
216 // odd. After push, level 0 active <=> (size+1) is odd.
217 debug_assert_eq!(
218 self.active[0],
219 self.size & 1 == 1,
220 "Fenwick invariant: level 0 active iff size is odd"
221 );
222
223 // The new leaf must enter at level 0. If level 0 is active
224 // (i.e., size was odd), classical binary increment carries up
225 // — but in the matrix interpretation, the "carry" means the
226 // existing level-0 bucket sums with the new leaf and is then
227 // written to level 1, then potentially summing with level 1's
228 // existing bucket, and so on, until we hit an inactive level.
229
230 // Build the new bucket as outer product (k * v^T).
231 let mut carry = AttentionState::new_matrix(self.d_k, self.d_v);
232 carry.add_outer_product(k, v);
233
234 let mut ell = 0usize;
235 loop {
236 if ell >= self.max_levels {
237 // Capacity exhausted: fold the carry into the topmost
238 // level (max_levels - 1). This caps memory at the
239 // configured bound while still accumulating information.
240 let top = self.max_levels - 1;
241 add_matrix_in_place(&mut self.levels[top], &carry);
242 self.active[top] = true;
243 break;
244 }
245
246 if !self.active[ell] {
247 // Slot is free — write the carry here, halt.
248 replace_matrix(&mut self.levels[ell], carry);
249 self.active[ell] = true;
250 break;
251 }
252
253 // Slot ℓ is active: sum the existing bucket into carry
254 // and clear ℓ. Continue propagation upward.
255 let existing = take_matrix(&mut self.levels[ell], self.d_k, self.d_v);
256 self.active[ell] = false;
257 add_matrix_in_place(&mut carry, &existing);
258 ell += 1;
259 }
260
261 self.size = self.size.saturating_add(1);
262 self.refresh_cache();
263 }
264
265 /// Reset all levels to zero and clear `size`. After reset,
266 /// `state()` returns all zeros and `active_level_count() == 0`.
267 pub fn reset(&mut self) {
268 for state in self.levels.iter_mut() {
269 state.reset();
270 }
271 for a in self.active.iter_mut() {
272 *a = false;
273 }
274 self.size = 0;
275 for x in self.state_cache.iter_mut() {
276 *x = 0.0;
277 }
278 }
279
280 /// Flat view of the padded state — concatenation of all
281 /// `max_levels` levels in row-major order.
282 ///
283 /// Length is always `max_levels * d_k * d_v`, regardless of
284 /// `active_level_count()`. Inactive levels contribute all-zero
285 /// blocks. This is the constant-shape contract required by
286 /// `AttentionLayer::state()` consumers.
287 #[inline]
288 pub fn flat_state(&self) -> &[f64] {
289 &self.state_cache
290 }
291
292 /// Compute the λ-weighted readout `Σ_ℓ λ_ℓ · q^T · S^(ℓ)` over all
293 /// `max_levels` slots and write into `out` (length `d_v`).
294 ///
295 /// Inactive levels contribute zero (their `S^(ℓ)` is the zero
296 /// matrix). The caller supplies `lambdas` of length `max_levels`
297 /// (typically a softplus-softmax mix bounding `Σ λ ≤ 1`).
298 ///
299 /// # Arguments
300 ///
301 /// - `q` — query vector, length `d_k`.
302 /// - `lambdas` — per-level non-negative mix weights, length
303 /// `max_levels`.
304 /// - `out` — output buffer, length `d_v`. Overwritten.
305 ///
306 /// # Panics
307 ///
308 /// Panics in debug mode if `q.len() != d_k`,
309 /// `lambdas.len() != max_levels`, or `out.len() != d_v`.
310 pub fn query_mixed(&self, q: &[f64], lambdas: &[f64], out: &mut [f64]) {
311 debug_assert_eq!(q.len(), self.d_k, "q length must match d_k");
312 debug_assert_eq!(
313 lambdas.len(),
314 self.max_levels,
315 "lambdas length must match max_levels"
316 );
317 debug_assert_eq!(out.len(), self.d_v, "out length must match d_v");
318
319 for o in out.iter_mut() {
320 *o = 0.0;
321 }
322 for (ell, &lam) in lambdas.iter().enumerate() {
323 if !self.active[ell] || lam == 0.0 {
324 continue;
325 }
326 // Per-level readout: o_ℓ = q^T · S^(ℓ) (length d_v).
327 let o_l = self.levels[ell].query(q);
328 for (oi, ol) in out.iter_mut().zip(o_l.iter()) {
329 *oi += lam * ol;
330 }
331 }
332 }
333
334 /// Refresh the flat cache from the level matrices. Cheap: total
335 /// work is `max_levels * d_k * d_v` per token, equal to the
336 /// log-linear state size already advertised.
337 fn refresh_cache(&mut self) {
338 let mut offset = 0;
339 for state in self.levels.iter() {
340 let slice = state.as_slice();
341 let len = slice.len();
342 self.state_cache[offset..offset + len].copy_from_slice(slice);
343 offset += len;
344 }
345 }
346}
347
348/// In-place matrix add: `dst += src` (both `d_k x d_v` row-major).
349fn add_matrix_in_place(dst: &mut AttentionState, src: &AttentionState) {
350 match (dst, src) {
351 (
352 AttentionState::Matrix { data: dst_data, .. },
353 AttentionState::Matrix { data: src_data, .. },
354 ) => {
355 debug_assert_eq!(
356 dst_data.len(),
357 src_data.len(),
358 "matrix addition shape mismatch"
359 );
360 for (d, s) in dst_data.iter_mut().zip(src_data.iter()) {
361 *d += *s;
362 }
363 }
364 _ => panic!("add_matrix_in_place: both states must be Matrix"),
365 }
366}
367
368/// Move `src` into `*dst`, leaving `dst` holding the new bucket.
369/// Equivalent to assignment but uses the existing buffer of `dst`
370/// when possible to avoid alloc churn — copies element-wise.
371fn replace_matrix(dst: &mut AttentionState, src: AttentionState) {
372 match (dst, src) {
373 (
374 AttentionState::Matrix { data: dst_data, .. },
375 AttentionState::Matrix { data: src_data, .. },
376 ) => {
377 debug_assert_eq!(
378 dst_data.len(),
379 src_data.len(),
380 "matrix replace shape mismatch"
381 );
382 dst_data.copy_from_slice(&src_data);
383 }
384 _ => panic!("replace_matrix: both states must be Matrix"),
385 }
386}
387
388/// Read out the existing matrix at `dst` into a new owned
389/// `AttentionState`, leaving `dst` zeroed in place. Avoids a swap by
390/// copying then zeroing — the old data is preserved in the returned
391/// state.
392fn take_matrix(dst: &mut AttentionState, d_k: usize, d_v: usize) -> AttentionState {
393 let mut taken = AttentionState::new_matrix(d_k, d_v);
394 if let (
395 AttentionState::Matrix { data: dst_data, .. },
396 AttentionState::Matrix {
397 data: taken_data, ..
398 },
399 ) = (dst, &mut taken)
400 {
401 taken_data.copy_from_slice(dst_data);
402 for d in dst_data.iter_mut() {
403 *d = 0.0;
404 }
405 } else {
406 panic!("take_matrix: state must be Matrix");
407 }
408 taken
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414
415 #[test]
416 fn new_state_has_zero_size_and_no_active_levels() {
417 let s = LogLinearState::new(8, 4, 4);
418 assert_eq!(s.size(), 0, "fresh state has size 0");
419 assert_eq!(
420 s.active_level_count(),
421 0,
422 "fresh state has no active levels"
423 );
424 assert!(
425 s.flat_state().iter().all(|&x| x == 0.0),
426 "fresh state cache is all zeros"
427 );
428 }
429
430 #[test]
431 fn log_linear_state_padded_to_max_levels() {
432 // The flat state slice MUST equal max_levels * d_k * d_v
433 // regardless of how many tokens have been pushed. This is the
434 // paper-mandated stability choice (R1 §3.4 Option B).
435 let max_levels = 8;
436 let d_k = 4;
437 let d_v = 4;
438 let mut s = LogLinearState::new(max_levels, d_k, d_v);
439 let expected_len = max_levels * d_k * d_v;
440 assert_eq!(
441 s.flat_state().len(),
442 expected_len,
443 "flat state must be max_levels * d_k * d_v at t=0"
444 );
445
446 // Push one token: should add a leaf at level 0.
447 s.push_leaf(&[1.0, 2.0, 3.0, 4.0], &[0.5, -0.5, 0.25, -0.25]);
448 assert_eq!(
449 s.flat_state().len(),
450 expected_len,
451 "flat state must remain max_levels * d_k * d_v after t=1"
452 );
453 assert_eq!(s.size(), 1);
454 assert_eq!(s.active_level_count(), 1, "popcount(1) = 1");
455 assert!(s.is_active(0), "after 1 push, level 0 is active");
456
457 // Push three more tokens (size = 4 = 0b100), expect only
458 // level 2 active (popcount = 1).
459 for i in 0..3 {
460 let f = (i + 1) as f64;
461 s.push_leaf(&[f, f, f, f], &[f, f, f, f]);
462 }
463 assert_eq!(s.size(), 4);
464 assert_eq!(s.active_level_count(), 1, "popcount(4) = 1");
465 assert!(s.is_active(2), "size=4 -> level 2 active");
466 assert!(!s.is_active(0));
467 assert!(!s.is_active(1));
468 assert_eq!(
469 s.flat_state().len(),
470 expected_len,
471 "flat state still padded to max_levels"
472 );
473 }
474
475 #[test]
476 fn log_linear_state_reset_clears_all_levels() {
477 let max_levels = 8;
478 let mut s = LogLinearState::new(max_levels, 4, 4);
479 for i in 0..50u64 {
480 let f = i as f64 + 1.0;
481 s.push_leaf(&[f, f, f, f], &[f, f, f, f]);
482 }
483 assert!(s.size() > 0);
484 assert!(s.active_level_count() > 0);
485 assert!(
486 s.flat_state().iter().any(|&x| x != 0.0),
487 "after pushes, cache should have non-zero entries"
488 );
489
490 s.reset();
491
492 assert_eq!(s.size(), 0, "reset clears size");
493 assert_eq!(s.active_level_count(), 0, "reset deactivates all levels");
494 assert!(
495 s.flat_state().iter().all(|&x| x == 0.0),
496 "reset clears flat state"
497 );
498 for ell in 0..max_levels {
499 assert!(
500 !s.is_active(ell),
501 "level {} must be inactive after reset",
502 ell
503 );
504 assert!(
505 s.level(ell).as_slice().iter().all(|&x| x == 0.0),
506 "level {} matrix must be zero after reset",
507 ell
508 );
509 }
510 }
511
512 #[test]
513 fn fenwick_active_levels_match_popcount_of_size() {
514 // After t pushes, the active levels MUST equal the 1-bits of
515 // t (Han Guo et al., ICLR 2026 §2). Verify across t = 1..32.
516 let max_levels = 8;
517 let mut s = LogLinearState::new(max_levels, 4, 4);
518 let k = [0.5; 4];
519 let v = [0.5; 4];
520
521 for t in 1..=31u64 {
522 s.push_leaf(&k, &v);
523 for ell in 0..max_levels {
524 let bit_set = (t >> ell) & 1 == 1;
525 assert_eq!(
526 s.is_active(ell),
527 bit_set,
528 "at size={}, level {} active should match bit {} of size",
529 t,
530 ell,
531 ell
532 );
533 }
534 assert_eq!(
535 s.active_level_count() as u32,
536 t.count_ones(),
537 "active count must equal popcount of size"
538 );
539 }
540 }
541
542 #[test]
543 fn level_matrix_size_doubles_with_level() {
544 // After 2^k tokens with all-equal leaves, the merged bucket at
545 // level k is the SUM of 2^k identical outer products, i.e., the
546 // outer-product magnitude at level k is 2^k times the single
547 // leaf magnitude. This verifies the merge semantics
548 // (matrix addition of equal-size siblings, paper §2.1).
549 let max_levels = 8;
550 let mut s = LogLinearState::new(max_levels, 4, 4);
551 let k_vec = [1.0, 0.0, 0.0, 0.0];
552 let v_vec = [1.0, 0.0, 0.0, 0.0];
553
554 // Push exactly 4 = 2^2 tokens. Only level 2 should be active,
555 // and its (0,0) element should be 4 (outer product (k * v^T) at
556 // (0,0) = 1, summed 4 times).
557 for _ in 0..4 {
558 s.push_leaf(&k_vec, &v_vec);
559 }
560 assert_eq!(s.size(), 4);
561 assert_eq!(s.active_level_count(), 1);
562 assert!(s.is_active(2));
563 let entry = s.level(2).get_matrix(0, 0);
564 assert!(
565 (entry - 4.0).abs() < 1e-12,
566 "level 2 (0,0) should accumulate 4 leaves, got {}",
567 entry
568 );
569 }
570
571 #[test]
572 fn query_mixed_zero_lambdas_gives_zero_output() {
573 let max_levels = 8;
574 let mut s = LogLinearState::new(max_levels, 4, 4);
575 s.push_leaf(&[1.0, 2.0, 3.0, 4.0], &[0.5, 0.5, 0.5, 0.5]);
576
577 let q = [1.0; 4];
578 let lambdas = [0.0; 8];
579 let mut out = [42.0; 4];
580 s.query_mixed(&q, &lambdas, &mut out);
581 for &o in &out {
582 assert_eq!(o, 0.0, "zero λ produces zero output");
583 }
584 }
585
586 #[test]
587 fn query_mixed_uniform_lambdas_sums_active_levels() {
588 // With λ = 1.0 on all levels, output equals the unweighted
589 // sum of per-level queries (only active levels contribute).
590 let max_levels = 8;
591 let mut s = LogLinearState::new(max_levels, 4, 4);
592 let k = [1.0, 0.0, 0.0, 0.0];
593 let v = [1.0, 1.0, 1.0, 1.0];
594 s.push_leaf(&k, &v); // level 0: k * v^T
595
596 let q = [1.0, 0.0, 0.0, 0.0];
597 let lambdas = [1.0; 8];
598 let mut out = [0.0; 4];
599 s.query_mixed(&q, &lambdas, &mut out);
600 // S^(0) at (0,*) = v = [1,1,1,1]; S^T q at index j = sum_i S[i][j] * q[i] = S[0][j]*1 = v[j].
601 for &o in &out {
602 assert!(
603 (o - 1.0).abs() < 1e-12,
604 "uniform λ readout should equal v, got {}",
605 o
606 );
607 }
608 }
609
610 #[test]
611 fn query_mixed_inactive_levels_skipped() {
612 // After 2 pushes (size=2 = 0b10), only level 1 is active.
613 // λ on inactive levels must contribute exactly zero.
614 let max_levels = 4;
615 let mut s = LogLinearState::new(max_levels, 4, 4);
616 s.push_leaf(&[1.0, 0.0, 0.0, 0.0], &[1.0, 0.0, 0.0, 0.0]);
617 s.push_leaf(&[1.0, 0.0, 0.0, 0.0], &[1.0, 0.0, 0.0, 0.0]);
618 assert!(s.is_active(1));
619 assert!(!s.is_active(0));
620 assert!(!s.is_active(2));
621
622 let q = [1.0, 0.0, 0.0, 0.0];
623 // Compare:
624 // - All λ=1: only level 1 contributes
625 // - λ=1 only on level 0 (inactive): output should be zero.
626 let mut out_all = [0.0; 4];
627 s.query_mixed(&q, &[1.0; 4], &mut out_all);
628
629 let mut out_inactive = [0.0; 4];
630 s.query_mixed(&q, &[1.0, 0.0, 0.0, 0.0], &mut out_inactive);
631 for &o in &out_inactive {
632 assert_eq!(
633 o, 0.0,
634 "λ on inactive level 0 must contribute zero (level 0 is empty), got {}",
635 o
636 );
637 }
638
639 // The "all λ=1" output should be non-zero (level 1 has 2-leaf
640 // accumulated bucket).
641 assert!(
642 out_all.iter().any(|&o| o != 0.0),
643 "active level 1 with λ=1 must contribute non-zero output"
644 );
645 }
646
647 #[test]
648 fn capacity_overflow_folds_into_top_level() {
649 // With max_levels=2, after 4 pushes the carry would propagate
650 // to level 2 (out of range). Spec: fold into top level
651 // (max_levels - 1 = 1).
652 let max_levels = 2;
653 let mut s = LogLinearState::new(max_levels, 4, 4);
654 let k = [1.0, 0.0, 0.0, 0.0];
655 let v = [1.0, 0.0, 0.0, 0.0];
656 for _ in 0..4 {
657 s.push_leaf(&k, &v);
658 }
659 assert_eq!(s.size(), 4);
660 // Top level should hold the accumulated information.
661 assert!(s.is_active(1), "top level must be active after overflow");
662 let entry = s.level(1).get_matrix(0, 0);
663 assert!(
664 entry > 0.0,
665 "top level should accumulate folded carries, got {}",
666 entry
667 );
668 }
669
670 #[test]
671 fn flat_state_matches_concatenated_levels() {
672 let max_levels = 4;
673 let d_k = 3;
674 let d_v = 3;
675 let mut s = LogLinearState::new(max_levels, d_k, d_v);
676 for i in 0..7u64 {
677 let f = (i + 1) as f64 * 0.1;
678 s.push_leaf(&[f, f, f], &[f, f, f]);
679 }
680 // Size = 7 = 0b111: levels 0, 1, 2 active.
681 let flat = s.flat_state();
682 assert_eq!(flat.len(), max_levels * d_k * d_v);
683 let block = d_k * d_v;
684 for ell in 0..max_levels {
685 let level_slice = s.level(ell).as_slice();
686 let cache_slice = &flat[ell * block..(ell + 1) * block];
687 assert_eq!(
688 level_slice, cache_slice,
689 "flat cache for level {} must match level matrix",
690 ell
691 );
692 }
693 }
694
695 #[test]
696 fn deterministic_construction() {
697 let mut a = LogLinearState::new(8, 4, 4);
698 let mut b = LogLinearState::new(8, 4, 4);
699 for t in 1..=20u64 {
700 let f = t as f64 * 0.1;
701 a.push_leaf(&[f, f, f, f], &[f, -f, f, -f]);
702 b.push_leaf(&[f, f, f, f], &[f, -f, f, -f]);
703 }
704 for (x, y) in a.flat_state().iter().zip(b.flat_state().iter()) {
705 assert!(
706 (x - y).abs() < 1e-15,
707 "identical pushes produce identical state"
708 );
709 }
710 }
711}