#[derive(Clone, Debug, PartialEq)]
pub enum GateMode {
Scalar,
Vector,
}
#[derive(Clone, Debug, PartialEq)]
pub enum GatedDeltaMode {
Static,
PerToken,
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub enum AttentionMode {
RetNet {
gamma: f64,
},
Hawk,
GLA,
GLAVector,
DeltaNet,
GatedDeltaNet {
beta_scale: f64,
gate_mode_delta: GatedDeltaMode,
},
RWKV {
initial_decay: f64,
},
MLSTM,
DeltaProduct {
n_compositions: usize,
reflections: bool,
},
RWKV7,
HGRN2 {
lower_bound: f64,
},
LogLinear {
inner: alloc::boxed::Box<AttentionMode>,
max_levels: usize,
lambda_init: f64,
},
}
#[derive(Clone, Debug)]
pub struct AttentionConfig {
pub d_model: usize,
pub n_heads: usize,
pub d_key: usize,
pub d_value: usize,
pub mode: AttentionMode,
pub seed: u64,
}
impl Default for AttentionConfig {
fn default() -> Self {
Self {
d_model: 16,
n_heads: 4,
d_key: 4,
d_value: 4,
mode: AttentionMode::RetNet { gamma: 0.95 },
seed: 42,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config_dimensions_consistent() {
let cfg = AttentionConfig::default();
assert_eq!(
cfg.d_key,
cfg.d_model / cfg.n_heads,
"d_key should equal d_model / n_heads"
);
assert_eq!(
cfg.d_value,
cfg.d_model / cfg.n_heads,
"d_value should equal d_model / n_heads"
);
}
#[test]
fn default_mode_is_retnet() {
let cfg = AttentionConfig::default();
match cfg.mode {
AttentionMode::RetNet { gamma } => {
assert!(
(gamma - 0.95).abs() < 1e-12,
"default gamma should be 0.95, got {}",
gamma
);
}
_ => panic!("default mode should be RetNet"),
}
}
#[test]
fn config_clone_is_independent() {
let cfg1 = AttentionConfig::default();
let mut cfg2 = cfg1.clone();
cfg2.d_model = 32;
assert_eq!(cfg1.d_model, 16, "clone should be independent");
assert_eq!(cfg2.d_model, 32, "cloned config should have new value");
}
#[test]
fn all_modes_constructible() {
let modes = [
AttentionMode::RetNet { gamma: 0.9 },
AttentionMode::Hawk,
AttentionMode::GLA,
AttentionMode::GLAVector,
AttentionMode::DeltaNet,
AttentionMode::GatedDeltaNet {
beta_scale: 1.0,
gate_mode_delta: GatedDeltaMode::Static,
},
AttentionMode::RWKV { initial_decay: 0.5 },
AttentionMode::MLSTM,
AttentionMode::DeltaProduct {
n_compositions: 3,
reflections: false,
},
AttentionMode::RWKV7,
AttentionMode::HGRN2 { lower_bound: 0.9 },
AttentionMode::LogLinear {
inner: alloc::boxed::Box::new(AttentionMode::GLA),
max_levels: 32,
lambda_init: 1.0 / 32.0,
},
];
assert_eq!(modes.len(), 12, "should have exactly 12 modes");
}
#[test]
fn log_linear_attention_mode_variant_constructible() {
let mode = AttentionMode::LogLinear {
inner: alloc::boxed::Box::new(AttentionMode::GatedDeltaNet {
beta_scale: 1.0,
gate_mode_delta: GatedDeltaMode::Static,
}),
max_levels: 32,
lambda_init: 1.0 / 32.0,
};
let cloned = mode.clone();
let dbg = alloc::format!("{:?}", cloned);
assert!(
dbg.contains("LogLinear"),
"Debug output must name LogLinear, got {dbg}"
);
assert!(
dbg.contains("GatedDeltaNet"),
"Debug output must include inner mode name, got {dbg}"
);
}
#[test]
fn config_debug_format_contains_mode() {
let cfg = AttentionConfig {
mode: AttentionMode::Hawk,
..AttentionConfig::default()
};
let debug = alloc::format!("{:?}", cfg);
assert!(
debug.contains("Hawk"),
"debug format should contain mode name, got: {}",
debug
);
}
#[test]
fn custom_config_preserves_values() {
let cfg = AttentionConfig {
d_model: 64,
n_heads: 8,
d_key: 8,
d_value: 8,
mode: AttentionMode::GatedDeltaNet {
beta_scale: 1.0,
gate_mode_delta: GatedDeltaMode::Static,
},
seed: 1234,
};
assert_eq!(cfg.d_model, 64, "d_model should be 64");
assert_eq!(cfg.n_heads, 8, "n_heads should be 8");
assert_eq!(cfg.seed, 1234, "seed should be 1234");
}
#[test]
fn gated_delta_mode_variants_distinguishable() {
assert_ne!(
GatedDeltaMode::Static,
GatedDeltaMode::PerToken,
"Static and PerToken must be distinct variants"
);
}
#[test]
fn delta_product_reflections_flag_constructible() {
let mode = AttentionMode::DeltaProduct {
n_compositions: 2,
reflections: true,
};
let debug = alloc::format!("{:?}", mode);
assert!(
debug.contains("DeltaProduct"),
"debug should contain DeltaProduct, got: {}",
debug
);
}
}