use serde::{Deserialize, Serialize};
use super::error::{SparseAttentionError, SparseAttentionResult};
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct SparseAttentionConfig {
pub window_size: usize,
pub global_token_indices: Vec<usize>,
pub num_heads: usize,
pub head_dim: usize,
pub causal: bool,
pub dropout: f64,
}
impl SparseAttentionConfig {
pub fn new(
window_size: usize,
num_heads: usize,
head_dim: usize,
) -> SparseAttentionResult<Self> {
let cfg = Self {
window_size,
global_token_indices: Vec::new(),
num_heads,
head_dim,
causal: false,
dropout: 0.0,
};
cfg.validate()?;
Ok(cfg)
}
#[must_use]
pub fn with_global_tokens(mut self, indices: Vec<usize>) -> Self {
self.global_token_indices = indices;
self
}
#[must_use]
pub fn with_causal(mut self, causal: bool) -> Self {
self.causal = causal;
self
}
#[must_use]
pub fn with_dropout(mut self, dropout: f64) -> Self {
self.dropout = dropout;
self
}
pub fn validate(&self) -> SparseAttentionResult<()> {
if self.window_size == 0 {
return Err(SparseAttentionError::InvalidWindowSize(0));
}
if self.num_heads == 0 {
return Err(SparseAttentionError::DimensionMismatch {
context: "num_heads".into(),
expected: 1,
got: 0,
});
}
if self.head_dim == 0 {
return Err(SparseAttentionError::DimensionMismatch {
context: "head_dim".into(),
expected: 1,
got: 0,
});
}
Ok(())
}
pub fn d_model(&self) -> usize {
self.num_heads * self.head_dim
}
pub fn validate_globals_for_seq_len(&self, seq_len: usize) -> SparseAttentionResult<()> {
for &idx in &self.global_token_indices {
if idx >= seq_len {
return Err(SparseAttentionError::InvalidGlobalIndices {
index: idx,
seq_len,
});
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn valid_config() {
let cfg = SparseAttentionConfig::new(4, 2, 8);
assert!(cfg.is_ok());
let cfg = cfg.ok();
assert!(cfg.is_some());
let cfg = cfg.as_ref();
assert!(cfg.is_some());
let cfg_ref = cfg.map(|c| c.d_model());
assert_eq!(cfg_ref, Some(16));
}
#[test]
fn zero_window_rejected() {
let result = SparseAttentionConfig::new(0, 2, 8);
assert!(matches!(
result,
Err(SparseAttentionError::InvalidWindowSize(0))
));
}
#[test]
fn zero_heads_rejected() {
let result = SparseAttentionConfig::new(4, 0, 8);
assert!(result.is_err());
}
#[test]
fn global_oob_detected() {
let cfg = SparseAttentionConfig::new(4, 2, 8).map(|c| c.with_global_tokens(vec![0, 99]));
assert!(cfg.is_ok());
let cfg = cfg.as_ref().ok();
assert!(cfg.is_some());
let result = cfg.map(|c| c.validate_globals_for_seq_len(16));
assert!(matches!(result, Some(Err(_))));
}
#[test]
fn builder_chain() {
let cfg = SparseAttentionConfig::new(4, 2, 8).map(|c| {
c.with_global_tokens(vec![0])
.with_causal(true)
.with_dropout(0.1)
});
assert!(cfg.is_ok());
let cfg = cfg.as_ref().ok();
assert!(cfg.is_some());
let cfg = cfg.map(|c| (c.causal, c.dropout, c.global_token_indices.len()));
assert_eq!(cfg, Some((true, 0.1, 1)));
}
}