Skip to main content

irithyll_core/attention/
config.rs

1//! Configuration types for the unified streaming linear attention engine.
2//!
3//! [`AttentionMode`] selects which architecture variant to use, while
4//! [`AttentionConfig`] holds the full configuration including dimensions,
5//! head count, and PRNG seed.
6
7/// Gate dimensionality for GLA-family variants.
8///
9/// The original GLA paper (Yang et al., 2024) specifies a vector gate
10/// `α_t ∈ (0,1)^{d_k}` — one decay scalar per key dimension. The simplified
11/// scalar gate (`Scalar`) uses a single sigmoid output shared across all key
12/// dimensions, which reduces parameter count at the cost of expressivity.
13///
14/// `Vector` is the mathematically canonical form from the paper.
15/// `Scalar` is retained for backward compatibility with models trained before
16/// the vector-gate upgrade.
17#[derive(Clone, Debug, PartialEq)]
18pub enum GateMode {
19    /// Single scalar gate shared across all key dimensions (legacy default).
20    Scalar,
21    /// Per-key-dimension gate vector `α_t ∈ (0,1)^{d_k}` (paper-canonical).
22    Vector,
23}
24
25/// Beta parameterization mode for GatedDeltaNet.
26///
27/// Controls whether the delta-rule mixing strength `β_t` is a static
28/// scalar or a data-dependent value computed per token.
29///
30/// # Paper reference
31///
32/// Yang et al. ICLR 2025 (arXiv:2412.06464) specifies the canonical form
33/// as per-token `β_t = sigmoid(W_β · x_t + b_β)`, making the delta-rule
34/// mixing strength data-dependent. The static scalar `β_scale` is a
35/// degenerate case retained for backward compatibility.
36#[derive(Clone, Debug, PartialEq)]
37pub enum GatedDeltaMode {
38    /// Static scalar `β_scale` shared across all tokens (legacy default).
39    ///
40    /// Equivalent to the original `beta_scale: f64` parameter. Back-compatible
41    /// with all checkpoints trained before this field was introduced.
42    Static,
43    /// Per-token `β_t = sigmoid(W_β · x_t)` — paper-canonical form.
44    ///
45    /// Adds a learned `W_β ∈ R^{d_model}` projection per head. The mixing
46    /// strength becomes data-dependent, allowing the model to selectively
47    /// apply strong or weak delta-rule updates depending on the input.
48    PerToken,
49}
50
51/// Selects the attention architecture variant.
52///
53/// Each variant corresponds to a different parameterization of the unified
54/// structured linear attention recurrence `S_t = decay_t * S_{t-1} + update_t`.
55///
56/// # Variants
57///
58/// - **RetNet** -- Fixed exponential decay (Sun et al., 2023)
59/// - **Hawk** -- Gated scalar recurrence with vector state (De et al., 2024)
60/// - **GLA** -- Gated Linear Attention with scalar gate (Yang et al., 2024, legacy)
61/// - **GLAVector** -- GLA with per-key-dimension vector gate (paper-canonical form)
62/// - **DeltaNet** -- Delta rule error-corrective update (Schlag et al., 2021)
63/// - **GatedDeltaNet** -- GLA + delta rule combined (Yang et al., 2024, NVIDIA)
64/// - **RWKV** -- Exponential decay with dynamic w (Peng et al., 2024)
65/// - **MLSTM** -- xLSTM matrix memory with forget+input gates (Beck et al., 2024)
66/// - **DeltaProduct** -- Product of n_h Householder delta rules (Siems et al., NeurIPS 2025)
67/// - **RWKV7** -- Vector-gated delta rule with DPLR transitions (Peng et al., 2025)
68/// - **HGRN2** -- Lower-bounded gated linear RNN with state expansion (Qin et al., ICML 2024)
69/// - **LogLinear** -- O(log T) hierarchical Fenwick state (Han Guo et al., ICLR 2026)
70#[derive(Clone, Debug)]
71#[non_exhaustive]
72pub enum AttentionMode {
73    /// Fixed exponential decay: `S = gamma * S + k * v^T`.
74    RetNet {
75        /// Decay factor in (0, 1). Typical: 0.9 to 0.999.
76        gamma: f64,
77    },
78    /// Gated scalar recurrence with vector state (Griffin/Hawk).
79    Hawk,
80    /// Gated Linear Attention with scalar data-dependent sigmoid gate (legacy).
81    ///
82    /// Uses a single sigmoid scalar `σ(w^T x) ∈ (0,1)` shared across all key
83    /// dimensions. Preserved for backward compatibility with pre-v10 checkpoints.
84    /// For new code, prefer [`GLAVector`](AttentionMode::GLAVector).
85    GLA,
86    /// Gated Linear Attention with per-key-dimension vector gate (paper-canonical).
87    ///
88    /// Implements the gate from Yang et al. 2024 exactly: `α_t ∈ (0,1)^{d_k}`
89    /// with one independent sigmoid per key dimension, giving the model
90    /// per-dimension memory control. The gate projection is `W_α ∈ R^{d_k × d_model}`.
91    ///
92    /// This is the recommended GLA variant for new models. Produces different
93    /// (and generally richer) dynamics than the scalar-gate `GLA` variant.
94    GLAVector,
95    /// Delta rule: error-corrective associative memory update.
96    DeltaNet,
97    /// Gated delta rule: GLA gate + delta error correction (Yang et al., ICLR 2025).
98    ///
99    /// Combines GLA's data-dependent gating with DeltaNet's error-corrective
100    /// delta rule, plus learnable beta scaling and L2-normalized keys.
101    ///
102    /// # Beta parameterization
103    ///
104    /// - [`GatedDeltaMode::Static`] (default): `beta_scale` scalar applies uniformly.
105    /// - [`GatedDeltaMode::PerToken`]: paper-canonical form; `β_t = sigmoid(W_β · x_t)`
106    ///   is computed per token per head via a learned projection. `beta_scale` is
107    ///   ignored in this mode.
108    GatedDeltaNet {
109        /// Learnable beta scaling factor (default: 1.0).
110        ///
111        /// Used only when `gate_mode_delta` is [`GatedDeltaMode::Static`].
112        /// Controls how aggressively the error-corrective update is applied.
113        beta_scale: f64,
114        /// Beta parameterization mode (default: [`GatedDeltaMode::Static`]).
115        ///
116        /// Set to [`GatedDeltaMode::PerToken`] for the paper-canonical
117        /// data-dependent mixing strength (Yang et al. ICLR 2025).
118        gate_mode_delta: GatedDeltaMode,
119    },
120    /// RWKV-style exponential decay with learned dynamic w.
121    RWKV {
122        /// Base decay rate before input-dependent modulation.
123        initial_decay: f64,
124    },
125    /// xLSTM matrix memory with separate forget and input gates.
126    MLSTM,
127    /// Product of n_h Householder delta rules (Siems et al., NeurIPS 2025).
128    ///
129    /// Applies `n_compositions` sequential delta rule steps per token, each with
130    /// its own key, value, and beta. The product of generalized Householder
131    /// transformations gives a spectrally bounded (norm ≤ 1) transition matrix.
132    /// Includes a scalar forget gate.
133    ///
134    /// # Reflections flag
135    ///
136    /// When `reflections` is `false` (default), each `β_{t,j} ∈ (0, 1)` via
137    /// a plain sigmoid. When `true`, `β_{t,j} ∈ (0, 2)` via `2 · sigmoid(·)`,
138    /// which enables full Householder reflections (negative eigenvalues) for
139    /// stronger state-tracking capability (Siems et al. NeurIPS 2025,
140    /// arXiv:2502.10297, §4).
141    DeltaProduct {
142        /// Number of Householder compositions per token (typically 2-4).
143        n_compositions: usize,
144        /// Enable full Householder reflections (default: `false`).
145        ///
146        /// When `false`: `β ∈ (0, 1)` — standard delta rule range.
147        /// When `true`: `β ∈ (0, 2)` — full reflections, enables negative
148        /// eigenvalues for parity / state-tracking tasks.
149        reflections: bool,
150    },
151    /// RWKV-7 vector-gated delta rule with DPLR transitions (Peng et al., 2025).
152    ///
153    /// Uses per-dimension vector decay, vector in-context learning rate (ICLR),
154    /// and decoupled removal/replacement keys. The transition matrix is
155    /// diagonal-plus-low-rank (DPLR), enabling state tracking beyond TC^0.
156    RWKV7,
157    /// HGRN2: gated linear RNN with outer-product state expansion (Qin et al., ICML 2024).
158    ///
159    /// Uses a lower-bounded forget gate for minimum memory retention.
160    /// The state update is outer-product based (like GLA) but with the
161    /// bounded gate ensuring the model never completely forgets.
162    /// `alpha_t = lower_bound + (1 - lower_bound) * sigmoid(raw_t)` per dimension.
163    HGRN2 {
164        /// Lower bound for forget gate (default: 0.9, range 0.0..1.0).
165        /// Higher values ensure stronger memory retention.
166        lower_bound: f64,
167    },
168    /// Log-Linear Attention (Han Guo et al., ICLR 2026, arXiv:2506.04761).
169    ///
170    /// Replaces the single fixed-size recurrent state of the inner
171    /// linear-attention rule with an O(log T) hierarchy of states
172    /// organized by a Fenwick-tree decomposition of the prefix.
173    /// Compute per token grows as `log T`; total compute is
174    /// `O(T log T)`. State memory is fixed at
175    /// `max_levels * d_k * d_v * n_heads`.
176    ///
177    /// # Inner mode (paper §3.2)
178    ///
179    /// Wraps any non-recursive `AttentionMode` as the leaf update
180    /// rule. Recommended: `GatedDeltaNet` for strongest associative
181    /// recall, `GLA` for stability. The wrapper handles per-token
182    /// key normalization automatically for delta-rule families.
183    ///
184    /// # `max_levels` (paper §3, R1 §3.5)
185    ///
186    /// Hard cap on hierarchy depth.
187    /// `max_levels = ⌊log₂(T_max)⌋ + 1`, where `T_max` is the
188    /// expected maximum stream length. Default 32 covers streams
189    /// up to ~4 G tokens. State is padded to `max_levels`
190    /// — NOT popcount-sized — for shape stability across stream
191    /// length (paper §3.4 Option B / R1 §3.4). This is the paper-
192    /// mandated stability choice: streaming consumers require
193    /// constant-shape state vectors.
194    ///
195    /// # `lambda_init` (paper §3.3, R1 §5.3)
196    ///
197    /// Static bias added to per-level λ logits before
198    /// softplus-softmax mixing. Default `1/max_levels` produces
199    /// a uniform mixture across all levels — the principled choice
200    /// when no information about which levels are useful is
201    /// available (no backprop in the streaming setting).
202    LogLinear {
203        /// Inner update rule (must NOT be `LogLinear`).
204        ///
205        /// `Box`-ed to keep the enum size small and enable
206        /// recursive-shape-checking in factory code.
207        inner: alloc::boxed::Box<AttentionMode>,
208        /// Hard cap on Fenwick depth. Memory is
209        /// `max_levels × d_k × d_v` per head. Default: 32.
210        max_levels: usize,
211        /// Initial λ bias before softplus-softmax mixing.
212        /// Default: `1/max_levels` (uniform mixture).
213        lambda_init: f64,
214    },
215}
216
217/// Full configuration for a multi-head streaming attention layer.
218///
219/// # Defaults
220///
221/// - `d_model`: 16
222/// - `n_heads`: 4
223/// - `d_key`: `d_model / n_heads` (4)
224/// - `d_value`: `d_model / n_heads` (4)
225/// - `mode`: `RetNet { gamma: 0.95 }`
226/// - `seed`: 42
227#[derive(Clone, Debug)]
228pub struct AttentionConfig {
229    /// Input/output dimension.
230    pub d_model: usize,
231    /// Number of attention heads.
232    pub n_heads: usize,
233    /// Per-head key dimension.
234    pub d_key: usize,
235    /// Per-head value dimension.
236    pub d_value: usize,
237    /// Architecture variant.
238    pub mode: AttentionMode,
239    /// PRNG seed for deterministic weight initialization.
240    pub seed: u64,
241}
242
243impl Default for AttentionConfig {
244    fn default() -> Self {
245        Self {
246            d_model: 16,
247            n_heads: 4,
248            d_key: 4,
249            d_value: 4,
250            mode: AttentionMode::RetNet { gamma: 0.95 },
251            seed: 42,
252        }
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    #[test]
261    fn default_config_dimensions_consistent() {
262        let cfg = AttentionConfig::default();
263        assert_eq!(
264            cfg.d_key,
265            cfg.d_model / cfg.n_heads,
266            "d_key should equal d_model / n_heads"
267        );
268        assert_eq!(
269            cfg.d_value,
270            cfg.d_model / cfg.n_heads,
271            "d_value should equal d_model / n_heads"
272        );
273    }
274
275    #[test]
276    fn default_mode_is_retnet() {
277        let cfg = AttentionConfig::default();
278        match cfg.mode {
279            AttentionMode::RetNet { gamma } => {
280                assert!(
281                    (gamma - 0.95).abs() < 1e-12,
282                    "default gamma should be 0.95, got {}",
283                    gamma
284                );
285            }
286            _ => panic!("default mode should be RetNet"),
287        }
288    }
289
290    #[test]
291    fn config_clone_is_independent() {
292        let cfg1 = AttentionConfig::default();
293        let mut cfg2 = cfg1.clone();
294        cfg2.d_model = 32;
295        assert_eq!(cfg1.d_model, 16, "clone should be independent");
296        assert_eq!(cfg2.d_model, 32, "cloned config should have new value");
297    }
298
299    #[test]
300    fn all_modes_constructible() {
301        let modes = [
302            AttentionMode::RetNet { gamma: 0.9 },
303            AttentionMode::Hawk,
304            AttentionMode::GLA,
305            AttentionMode::GLAVector,
306            AttentionMode::DeltaNet,
307            AttentionMode::GatedDeltaNet {
308                beta_scale: 1.0,
309                gate_mode_delta: GatedDeltaMode::Static,
310            },
311            AttentionMode::RWKV { initial_decay: 0.5 },
312            AttentionMode::MLSTM,
313            AttentionMode::DeltaProduct {
314                n_compositions: 3,
315                reflections: false,
316            },
317            AttentionMode::RWKV7,
318            AttentionMode::HGRN2 { lower_bound: 0.9 },
319            AttentionMode::LogLinear {
320                inner: alloc::boxed::Box::new(AttentionMode::GLA),
321                max_levels: 32,
322                lambda_init: 1.0 / 32.0,
323            },
324        ];
325        assert_eq!(modes.len(), 12, "should have exactly 12 modes");
326    }
327
328    /// `AttentionMode::LogLinear` is constructible and Box-ed inner is
329    /// preserved through Clone (required for builder paths).
330    #[test]
331    fn log_linear_attention_mode_variant_constructible() {
332        let mode = AttentionMode::LogLinear {
333            inner: alloc::boxed::Box::new(AttentionMode::GatedDeltaNet {
334                beta_scale: 1.0,
335                gate_mode_delta: GatedDeltaMode::Static,
336            }),
337            max_levels: 32,
338            lambda_init: 1.0 / 32.0,
339        };
340        let cloned = mode.clone();
341        let dbg = alloc::format!("{:?}", cloned);
342        assert!(
343            dbg.contains("LogLinear"),
344            "Debug output must name LogLinear, got {dbg}"
345        );
346        assert!(
347            dbg.contains("GatedDeltaNet"),
348            "Debug output must include inner mode name, got {dbg}"
349        );
350    }
351
352    #[test]
353    fn config_debug_format_contains_mode() {
354        let cfg = AttentionConfig {
355            mode: AttentionMode::Hawk,
356            ..AttentionConfig::default()
357        };
358        let debug = alloc::format!("{:?}", cfg);
359        assert!(
360            debug.contains("Hawk"),
361            "debug format should contain mode name, got: {}",
362            debug
363        );
364    }
365
366    #[test]
367    fn custom_config_preserves_values() {
368        let cfg = AttentionConfig {
369            d_model: 64,
370            n_heads: 8,
371            d_key: 8,
372            d_value: 8,
373            mode: AttentionMode::GatedDeltaNet {
374                beta_scale: 1.0,
375                gate_mode_delta: GatedDeltaMode::Static,
376            },
377            seed: 1234,
378        };
379        assert_eq!(cfg.d_model, 64, "d_model should be 64");
380        assert_eq!(cfg.n_heads, 8, "n_heads should be 8");
381        assert_eq!(cfg.seed, 1234, "seed should be 1234");
382    }
383
384    /// Per-token beta mode is distinct from static mode (enum distinguishable).
385    #[test]
386    fn gated_delta_mode_variants_distinguishable() {
387        assert_ne!(
388            GatedDeltaMode::Static,
389            GatedDeltaMode::PerToken,
390            "Static and PerToken must be distinct variants"
391        );
392    }
393
394    /// DeltaProduct with reflections=true is constructible.
395    #[test]
396    fn delta_product_reflections_flag_constructible() {
397        let mode = AttentionMode::DeltaProduct {
398            n_compositions: 2,
399            reflections: true,
400        };
401        let debug = alloc::format!("{:?}", mode);
402        assert!(
403            debug.contains("DeltaProduct"),
404            "debug should contain DeltaProduct, got: {}",
405            debug
406        );
407    }
408}