#[derive(Debug, Clone, PartialEq)]
pub enum PositionalEncoding {
RoPE {
freq_base: f32,
},
AliBi,
None,
}
#[derive(Debug, Clone)]
pub struct LayerAttentionConfig {
pub layer_idx: usize,
pub num_q_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub positional_encoding: PositionalEncoding,
pub sliding_window: Option<usize>,
pub sink_tokens: usize,
pub scale: f32,
}
impl LayerAttentionConfig {
pub fn new(layer_idx: usize, num_q_heads: usize, num_kv_heads: usize, head_dim: usize) -> Self {
let scale = if head_dim > 0 {
1.0_f32 / (head_dim as f32).sqrt()
} else {
1.0_f32
};
Self {
layer_idx,
num_q_heads,
num_kv_heads,
head_dim,
positional_encoding: PositionalEncoding::None,
sliding_window: None,
sink_tokens: 0,
scale,
}
}
#[must_use]
pub fn with_rope(mut self, freq_base: f32) -> Self {
self.positional_encoding = PositionalEncoding::RoPE { freq_base };
self
}
#[must_use]
pub fn with_alibi(mut self) -> Self {
self.positional_encoding = PositionalEncoding::AliBi;
self
}
#[must_use]
pub fn with_sliding_window(mut self, window: usize) -> Self {
self.sliding_window = Some(window);
self
}
#[must_use]
pub fn with_sink_tokens(mut self, n: usize) -> Self {
self.sink_tokens = n;
self
}
#[inline]
pub fn is_full_attention(&self) -> bool {
self.sliding_window.is_none()
}
pub fn effective_kv_len(&self, total_len: usize, _q_pos: usize) -> usize {
match self.sliding_window {
None => total_len,
Some(window) => (window + self.sink_tokens).min(total_len),
}
}
}
pub struct ModelAttentionConfig {
pub layers: Vec<LayerAttentionConfig>,
}
impl ModelAttentionConfig {
pub fn new(layers: Vec<LayerAttentionConfig>) -> Self {
Self { layers }
}
pub fn bonsai_8b() -> Self {
let layers = (0..36)
.map(|i| LayerAttentionConfig::new(i, 32, 8, 128).with_rope(1_000_000.0))
.collect();
Self { layers }
}
pub fn mixed_window_config(
num_layers: usize,
hidden_size: usize,
num_q_heads: usize,
num_kv_heads: usize,
) -> Self {
let head_dim = hidden_size.checked_div(num_q_heads).unwrap_or(hidden_size);
let layers = (0..num_layers)
.map(|i| {
let base = LayerAttentionConfig::new(i, num_q_heads, num_kv_heads, head_dim)
.with_rope(1_000_000.0);
if i % 2 == 0 {
base
} else {
base.with_sliding_window(1024).with_sink_tokens(4)
}
})
.collect();
Self { layers }
}
pub fn get(&self, layer_idx: usize) -> Option<&LayerAttentionConfig> {
self.layers.get(layer_idx)
}
pub fn num_layers(&self) -> usize {
self.layers.len()
}
pub fn full_attention_layers(&self) -> usize {
self.layers.iter().filter(|l| l.is_full_attention()).count()
}
pub fn sliding_window_layers(&self) -> usize {
self.layers
.iter()
.filter(|l| l.sliding_window.is_some())
.count()
}
pub fn memory_estimate_kv_cache(&self, context_len: usize, num_batches: usize) -> usize {
let q_pos = context_len.saturating_sub(1);
self.layers
.iter()
.map(|l| {
let eff = l.effective_kv_len(context_len, q_pos);
l.num_kv_heads * l.head_dim * eff * 4 * 2
})
.sum::<usize>()
* num_batches
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_layer_attention_config_defaults() {
let cfg = LayerAttentionConfig::new(3, 16, 4, 64);
assert_eq!(cfg.layer_idx, 3);
assert_eq!(cfg.num_q_heads, 16);
assert_eq!(cfg.num_kv_heads, 4);
assert_eq!(cfg.head_dim, 64);
assert_eq!(cfg.positional_encoding, PositionalEncoding::None);
assert!(cfg.sliding_window.is_none());
assert_eq!(cfg.sink_tokens, 0);
let expected_scale = 1.0_f32 / 64.0_f32.sqrt();
assert!((cfg.scale - expected_scale).abs() < 1e-7, "scale mismatch");
}
#[test]
fn test_layer_attention_config_with_rope() {
let cfg = LayerAttentionConfig::new(0, 32, 8, 128).with_rope(1_000_000.0);
assert_eq!(
cfg.positional_encoding,
PositionalEncoding::RoPE {
freq_base: 1_000_000.0
}
);
}
#[test]
fn test_layer_attention_config_with_alibi() {
let cfg = LayerAttentionConfig::new(0, 8, 8, 64).with_alibi();
assert_eq!(cfg.positional_encoding, PositionalEncoding::AliBi);
}
#[test]
fn test_layer_attention_config_sliding_window() {
let cfg = LayerAttentionConfig::new(1, 8, 2, 64)
.with_sliding_window(512)
.with_sink_tokens(4);
assert_eq!(cfg.sliding_window, Some(512));
assert_eq!(cfg.sink_tokens, 4);
assert!(!cfg.is_full_attention());
}
#[test]
fn test_effective_kv_len_full() {
let cfg = LayerAttentionConfig::new(0, 8, 2, 64);
assert_eq!(cfg.effective_kv_len(100, 99), 100);
assert_eq!(cfg.effective_kv_len(1, 0), 1);
assert_eq!(cfg.effective_kv_len(0, 0), 0);
}
#[test]
fn test_effective_kv_len_sliding() {
let cfg = LayerAttentionConfig::new(0, 8, 2, 64)
.with_sliding_window(8)
.with_sink_tokens(2);
assert_eq!(cfg.effective_kv_len(100, 99), 10);
assert_eq!(cfg.effective_kv_len(5, 4), 5);
assert_eq!(cfg.effective_kv_len(10, 9), 10);
}
#[test]
fn test_model_attention_config_bonsai_8b() {
let cfg = ModelAttentionConfig::bonsai_8b();
assert_eq!(cfg.num_layers(), 36);
assert_eq!(cfg.full_attention_layers(), 36);
assert_eq!(cfg.sliding_window_layers(), 0);
let l0 = cfg.get(0).expect("layer 0 must exist");
assert_eq!(l0.num_q_heads, 32);
assert_eq!(l0.num_kv_heads, 8);
assert_eq!(l0.head_dim, 128);
assert_eq!(
l0.positional_encoding,
PositionalEncoding::RoPE {
freq_base: 1_000_000.0
}
);
}
#[test]
fn test_model_attention_config_mixed() {
let cfg = ModelAttentionConfig::mixed_window_config(8, 1024, 8, 2);
assert_eq!(cfg.num_layers(), 8);
assert_eq!(cfg.full_attention_layers(), 4);
assert_eq!(cfg.sliding_window_layers(), 4);
let even = cfg.get(0).expect("layer 0 must exist");
assert!(even.is_full_attention());
let odd = cfg.get(1).expect("layer 1 must exist");
assert_eq!(odd.sliding_window, Some(1024));
assert_eq!(odd.sink_tokens, 4);
}
#[test]
fn test_memory_estimate_kv_cache() {
let cfg = ModelAttentionConfig::new(vec![LayerAttentionConfig::new(0, 4, 2, 4)]);
let mem = cfg.memory_estimate_kv_cache(10, 1);
assert_eq!(mem, 640, "expected 640 bytes, got {mem}");
let cfg2 = ModelAttentionConfig::new(vec![LayerAttentionConfig::new(0, 4, 2, 4)
.with_sliding_window(5)
.with_sink_tokens(1)]);
let mem2 = cfg2.memory_estimate_kv_cache(10, 1);
assert_eq!(mem2, 384, "expected 384 bytes, got {mem2}");
}
}