irithyll_core/attention/config.rs
1//! Configuration types for the unified streaming linear attention engine.
2//!
3//! [`AttentionMode`] selects which architecture variant to use, while
4//! [`AttentionConfig`] holds the full configuration including dimensions,
5//! head count, and PRNG seed.
6
7/// Gate dimensionality for GLA-family variants.
8///
9/// The original GLA paper (Yang et al., 2024) specifies a vector gate
10/// `α_t ∈ (0,1)^{d_k}` — one decay scalar per key dimension. The simplified
11/// scalar gate (`Scalar`) uses a single sigmoid output shared across all key
12/// dimensions, which reduces parameter count at the cost of expressivity.
13///
14/// `Vector` is the mathematically canonical form from the paper.
15/// `Scalar` is retained for backward compatibility with models trained before
16/// the vector-gate upgrade.
17#[derive(Clone, Debug, PartialEq)]
18pub enum GateMode {
19 /// Single scalar gate shared across all key dimensions (legacy default).
20 Scalar,
21 /// Per-key-dimension gate vector `α_t ∈ (0,1)^{d_k}` (paper-canonical).
22 Vector,
23}
24
25/// Beta parameterization mode for GatedDeltaNet.
26///
27/// Controls whether the delta-rule mixing strength `β_t` is a static
28/// scalar or a data-dependent value computed per token.
29///
30/// # Paper reference
31///
32/// Yang et al. ICLR 2025 (arXiv:2412.06464) specifies the canonical form
33/// as per-token `β_t = sigmoid(W_β · x_t + b_β)`, making the delta-rule
34/// mixing strength data-dependent. The static scalar `β_scale` is a
35/// degenerate case retained for backward compatibility.
36#[derive(Clone, Debug, PartialEq)]
37pub enum GatedDeltaMode {
38 /// Static scalar `β_scale` shared across all tokens (legacy default).
39 ///
40 /// Equivalent to the original `beta_scale: f64` parameter. Back-compatible
41 /// with all checkpoints trained before this field was introduced.
42 Static,
43 /// Per-token `β_t = sigmoid(W_β · x_t)` — paper-canonical form.
44 ///
45 /// Adds a learned `W_β ∈ R^{d_model}` projection per head. The mixing
46 /// strength becomes data-dependent, allowing the model to selectively
47 /// apply strong or weak delta-rule updates depending on the input.
48 PerToken,
49}
50
51/// Selects the attention architecture variant.
52///
53/// Each variant corresponds to a different parameterization of the unified
54/// structured linear attention recurrence `S_t = decay_t * S_{t-1} + update_t`.
55///
56/// # Variants
57///
58/// - **RetNet** -- Fixed exponential decay (Sun et al., 2023)
59/// - **Hawk** -- Gated scalar recurrence with vector state (De et al., 2024)
60/// - **GLA** -- Gated Linear Attention with scalar gate (Yang et al., 2024, legacy)
61/// - **GLAVector** -- GLA with per-key-dimension vector gate (paper-canonical form)
62/// - **DeltaNet** -- Delta rule error-corrective update (Schlag et al., 2021)
63/// - **GatedDeltaNet** -- GLA + delta rule combined (Yang et al., 2024, NVIDIA)
64/// - **RWKV** -- Exponential decay with dynamic w (Peng et al., 2024)
65/// - **MLSTM** -- xLSTM matrix memory with forget+input gates (Beck et al., 2024)
66/// - **DeltaProduct** -- Product of n_h Householder delta rules (Siems et al., NeurIPS 2025)
67/// - **RWKV7** -- Vector-gated delta rule with DPLR transitions (Peng et al., 2025)
68/// - **HGRN2** -- Lower-bounded gated linear RNN with state expansion (Qin et al., ICML 2024)
69/// - **LogLinear** -- O(log T) hierarchical Fenwick state (Han Guo et al., ICLR 2026)
70#[derive(Clone, Debug)]
71#[non_exhaustive]
72pub enum AttentionMode {
73 /// Fixed exponential decay: `S = gamma * S + k * v^T`.
74 RetNet {
75 /// Decay factor in (0, 1). Typical: 0.9 to 0.999.
76 gamma: f64,
77 },
78 /// Gated scalar recurrence with vector state (Griffin/Hawk).
79 Hawk,
80 /// Gated Linear Attention with scalar data-dependent sigmoid gate (legacy).
81 ///
82 /// Uses a single sigmoid scalar `σ(w^T x) ∈ (0,1)` shared across all key
83 /// dimensions. Preserved for backward compatibility with pre-v10 checkpoints.
84 /// For new code, prefer [`GLAVector`](AttentionMode::GLAVector).
85 GLA,
86 /// Gated Linear Attention with per-key-dimension vector gate (paper-canonical).
87 ///
88 /// Implements the gate from Yang et al. 2024 exactly: `α_t ∈ (0,1)^{d_k}`
89 /// with one independent sigmoid per key dimension, giving the model
90 /// per-dimension memory control. The gate projection is `W_α ∈ R^{d_k × d_model}`.
91 ///
92 /// This is the recommended GLA variant for new models. Produces different
93 /// (and generally richer) dynamics than the scalar-gate `GLA` variant.
94 GLAVector,
95 /// Delta rule: error-corrective associative memory update.
96 DeltaNet,
97 /// Gated delta rule: GLA gate + delta error correction (Yang et al., ICLR 2025).
98 ///
99 /// Combines GLA's data-dependent gating with DeltaNet's error-corrective
100 /// delta rule, plus learnable beta scaling and L2-normalized keys.
101 ///
102 /// # Beta parameterization
103 ///
104 /// - [`GatedDeltaMode::Static`] (default): `beta_scale` scalar applies uniformly.
105 /// - [`GatedDeltaMode::PerToken`]: paper-canonical form; `β_t = sigmoid(W_β · x_t)`
106 /// is computed per token per head via a learned projection. `beta_scale` is
107 /// ignored in this mode.
108 GatedDeltaNet {
109 /// Learnable beta scaling factor (default: 1.0).
110 ///
111 /// Used only when `gate_mode_delta` is [`GatedDeltaMode::Static`].
112 /// Controls how aggressively the error-corrective update is applied.
113 beta_scale: f64,
114 /// Beta parameterization mode (default: [`GatedDeltaMode::Static`]).
115 ///
116 /// Set to [`GatedDeltaMode::PerToken`] for the paper-canonical
117 /// data-dependent mixing strength (Yang et al. ICLR 2025).
118 gate_mode_delta: GatedDeltaMode,
119 },
120 /// RWKV-style exponential decay with learned dynamic w.
121 RWKV {
122 /// Base decay rate before input-dependent modulation.
123 initial_decay: f64,
124 },
125 /// xLSTM matrix memory with separate forget and input gates.
126 MLSTM,
127 /// Product of n_h Householder delta rules (Siems et al., NeurIPS 2025).
128 ///
129 /// Applies `n_compositions` sequential delta rule steps per token, each with
130 /// its own key, value, and beta. The product of generalized Householder
131 /// transformations gives a spectrally bounded (norm ≤ 1) transition matrix.
132 /// Includes a scalar forget gate.
133 ///
134 /// # Reflections flag
135 ///
136 /// When `reflections` is `false` (default), each `β_{t,j} ∈ (0, 1)` via
137 /// a plain sigmoid. When `true`, `β_{t,j} ∈ (0, 2)` via `2 · sigmoid(·)`,
138 /// which enables full Householder reflections (negative eigenvalues) for
139 /// stronger state-tracking capability (Siems et al. NeurIPS 2025,
140 /// arXiv:2502.10297, §4).
141 DeltaProduct {
142 /// Number of Householder compositions per token (typically 2-4).
143 n_compositions: usize,
144 /// Enable full Householder reflections (default: `false`).
145 ///
146 /// When `false`: `β ∈ (0, 1)` — standard delta rule range.
147 /// When `true`: `β ∈ (0, 2)` — full reflections, enables negative
148 /// eigenvalues for parity / state-tracking tasks.
149 reflections: bool,
150 },
151 /// RWKV-7 vector-gated delta rule with DPLR transitions (Peng et al., 2025).
152 ///
153 /// Uses per-dimension vector decay, vector in-context learning rate (ICLR),
154 /// and decoupled removal/replacement keys. The transition matrix is
155 /// diagonal-plus-low-rank (DPLR), enabling state tracking beyond TC^0.
156 RWKV7,
157 /// HGRN2: gated linear RNN with outer-product state expansion (Qin et al., ICML 2024).
158 ///
159 /// Uses a lower-bounded forget gate for minimum memory retention.
160 /// The state update is outer-product based (like GLA) but with the
161 /// bounded gate ensuring the model never completely forgets.
162 /// `alpha_t = lower_bound + (1 - lower_bound) * sigmoid(raw_t)` per dimension.
163 HGRN2 {
164 /// Lower bound for forget gate (default: 0.9, range 0.0..1.0).
165 /// Higher values ensure stronger memory retention.
166 lower_bound: f64,
167 },
168 /// Log-Linear Attention (Han Guo et al., ICLR 2026, arXiv:2506.04761).
169 ///
170 /// Replaces the single fixed-size recurrent state of the inner
171 /// linear-attention rule with an O(log T) hierarchy of states
172 /// organized by a Fenwick-tree decomposition of the prefix.
173 /// Compute per token grows as `log T`; total compute is
174 /// `O(T log T)`. State memory is fixed at
175 /// `max_levels * d_k * d_v * n_heads`.
176 ///
177 /// # Inner mode (paper §3.2)
178 ///
179 /// Wraps any non-recursive `AttentionMode` as the leaf update
180 /// rule. Recommended: `GatedDeltaNet` for strongest associative
181 /// recall, `GLA` for stability. The wrapper handles per-token
182 /// key normalization automatically for delta-rule families.
183 ///
184 /// # `max_levels` (paper §3, R1 §3.5)
185 ///
186 /// Hard cap on hierarchy depth.
187 /// `max_levels = ⌊log₂(T_max)⌋ + 1`, where `T_max` is the
188 /// expected maximum stream length. Default 32 covers streams
189 /// up to ~4 G tokens. State is padded to `max_levels`
190 /// — NOT popcount-sized — for shape stability across stream
191 /// length (paper §3.4 Option B / R1 §3.4). This is the paper-
192 /// mandated stability choice: streaming consumers require
193 /// constant-shape state vectors.
194 ///
195 /// # `lambda_init` (paper §3.3, R1 §5.3)
196 ///
197 /// Static bias added to per-level λ logits before
198 /// softplus-softmax mixing. Default `1/max_levels` produces
199 /// a uniform mixture across all levels — the principled choice
200 /// when no information about which levels are useful is
201 /// available (no backprop in the streaming setting).
202 LogLinear {
203 /// Inner update rule (must NOT be `LogLinear`).
204 ///
205 /// `Box`-ed to keep the enum size small and enable
206 /// recursive-shape-checking in factory code.
207 inner: alloc::boxed::Box<AttentionMode>,
208 /// Hard cap on Fenwick depth. Memory is
209 /// `max_levels × d_k × d_v` per head. Default: 32.
210 max_levels: usize,
211 /// Initial λ bias before softplus-softmax mixing.
212 /// Default: `1/max_levels` (uniform mixture).
213 lambda_init: f64,
214 },
215}
216
217/// Full configuration for a multi-head streaming attention layer.
218///
219/// # Defaults
220///
221/// - `d_model`: 16
222/// - `n_heads`: 4
223/// - `d_key`: `d_model / n_heads` (4)
224/// - `d_value`: `d_model / n_heads` (4)
225/// - `mode`: `RetNet { gamma: 0.95 }`
226/// - `seed`: 42
227#[derive(Clone, Debug)]
228pub struct AttentionConfig {
229 /// Input/output dimension.
230 pub d_model: usize,
231 /// Number of attention heads.
232 pub n_heads: usize,
233 /// Per-head key dimension.
234 pub d_key: usize,
235 /// Per-head value dimension.
236 pub d_value: usize,
237 /// Architecture variant.
238 pub mode: AttentionMode,
239 /// PRNG seed for deterministic weight initialization.
240 pub seed: u64,
241}
242
243impl Default for AttentionConfig {
244 fn default() -> Self {
245 Self {
246 d_model: 16,
247 n_heads: 4,
248 d_key: 4,
249 d_value: 4,
250 mode: AttentionMode::RetNet { gamma: 0.95 },
251 seed: 42,
252 }
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259
260 #[test]
261 fn default_config_dimensions_consistent() {
262 let cfg = AttentionConfig::default();
263 assert_eq!(
264 cfg.d_key,
265 cfg.d_model / cfg.n_heads,
266 "d_key should equal d_model / n_heads"
267 );
268 assert_eq!(
269 cfg.d_value,
270 cfg.d_model / cfg.n_heads,
271 "d_value should equal d_model / n_heads"
272 );
273 }
274
275 #[test]
276 fn default_mode_is_retnet() {
277 let cfg = AttentionConfig::default();
278 match cfg.mode {
279 AttentionMode::RetNet { gamma } => {
280 assert!(
281 (gamma - 0.95).abs() < 1e-12,
282 "default gamma should be 0.95, got {}",
283 gamma
284 );
285 }
286 _ => panic!("default mode should be RetNet"),
287 }
288 }
289
290 #[test]
291 fn config_clone_is_independent() {
292 let cfg1 = AttentionConfig::default();
293 let mut cfg2 = cfg1.clone();
294 cfg2.d_model = 32;
295 assert_eq!(cfg1.d_model, 16, "clone should be independent");
296 assert_eq!(cfg2.d_model, 32, "cloned config should have new value");
297 }
298
299 #[test]
300 fn all_modes_constructible() {
301 let modes = [
302 AttentionMode::RetNet { gamma: 0.9 },
303 AttentionMode::Hawk,
304 AttentionMode::GLA,
305 AttentionMode::GLAVector,
306 AttentionMode::DeltaNet,
307 AttentionMode::GatedDeltaNet {
308 beta_scale: 1.0,
309 gate_mode_delta: GatedDeltaMode::Static,
310 },
311 AttentionMode::RWKV { initial_decay: 0.5 },
312 AttentionMode::MLSTM,
313 AttentionMode::DeltaProduct {
314 n_compositions: 3,
315 reflections: false,
316 },
317 AttentionMode::RWKV7,
318 AttentionMode::HGRN2 { lower_bound: 0.9 },
319 AttentionMode::LogLinear {
320 inner: alloc::boxed::Box::new(AttentionMode::GLA),
321 max_levels: 32,
322 lambda_init: 1.0 / 32.0,
323 },
324 ];
325 assert_eq!(modes.len(), 12, "should have exactly 12 modes");
326 }
327
328 /// `AttentionMode::LogLinear` is constructible and Box-ed inner is
329 /// preserved through Clone (required for builder paths).
330 #[test]
331 fn log_linear_attention_mode_variant_constructible() {
332 let mode = AttentionMode::LogLinear {
333 inner: alloc::boxed::Box::new(AttentionMode::GatedDeltaNet {
334 beta_scale: 1.0,
335 gate_mode_delta: GatedDeltaMode::Static,
336 }),
337 max_levels: 32,
338 lambda_init: 1.0 / 32.0,
339 };
340 let cloned = mode.clone();
341 let dbg = alloc::format!("{:?}", cloned);
342 assert!(
343 dbg.contains("LogLinear"),
344 "Debug output must name LogLinear, got {dbg}"
345 );
346 assert!(
347 dbg.contains("GatedDeltaNet"),
348 "Debug output must include inner mode name, got {dbg}"
349 );
350 }
351
352 #[test]
353 fn config_debug_format_contains_mode() {
354 let cfg = AttentionConfig {
355 mode: AttentionMode::Hawk,
356 ..AttentionConfig::default()
357 };
358 let debug = alloc::format!("{:?}", cfg);
359 assert!(
360 debug.contains("Hawk"),
361 "debug format should contain mode name, got: {}",
362 debug
363 );
364 }
365
366 #[test]
367 fn custom_config_preserves_values() {
368 let cfg = AttentionConfig {
369 d_model: 64,
370 n_heads: 8,
371 d_key: 8,
372 d_value: 8,
373 mode: AttentionMode::GatedDeltaNet {
374 beta_scale: 1.0,
375 gate_mode_delta: GatedDeltaMode::Static,
376 },
377 seed: 1234,
378 };
379 assert_eq!(cfg.d_model, 64, "d_model should be 64");
380 assert_eq!(cfg.n_heads, 8, "n_heads should be 8");
381 assert_eq!(cfg.seed, 1234, "seed should be 1234");
382 }
383
384 /// Per-token beta mode is distinct from static mode (enum distinguishable).
385 #[test]
386 fn gated_delta_mode_variants_distinguishable() {
387 assert_ne!(
388 GatedDeltaMode::Static,
389 GatedDeltaMode::PerToken,
390 "Static and PerToken must be distinct variants"
391 );
392 }
393
394 /// DeltaProduct with reflections=true is constructible.
395 #[test]
396 fn delta_product_reflections_flag_constructible() {
397 let mode = AttentionMode::DeltaProduct {
398 n_compositions: 2,
399 reflections: true,
400 };
401 let debug = alloc::format!("{:?}", mode);
402 assert!(
403 debug.contains("DeltaProduct"),
404 "debug should contain DeltaProduct, got: {}",
405 debug
406 );
407 }
408}