#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct AttentionConfig {
pub num_heads: usize,
pub head_dim: usize,
pub dropout_prob: f64,
pub causal: bool,
pub use_flash: bool,
pub scale: Option<f64>,
}
impl Default for AttentionConfig {
fn default() -> Self {
Self {
num_heads: 8,
head_dim: 64,
dropout_prob: 0.0,
causal: false,
use_flash: true,
scale: None,
}
}
}
impl AttentionConfig {
pub fn effective_scale(&self) -> f64 {
self.scale
.unwrap_or_else(|| 1.0 / (self.head_dim as f64).sqrt())
}
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub enum AttentionMask {
None,
Causal,
Custom(Vec<Vec<bool>>),
PaddingMask(Vec<usize>),
}
#[derive(Debug, Clone)]
pub struct AttentionOutput {
pub output: Vec<Vec<f64>>,
pub attention_weights: Option<Vec<Vec<Vec<f64>>>>,
}
impl AttentionOutput {
pub fn new(output: Vec<Vec<f64>>) -> Self {
Self {
output,
attention_weights: None,
}
}
pub fn with_weights(output: Vec<Vec<f64>>, weights: Vec<Vec<Vec<f64>>>) -> Self {
Self {
output,
attention_weights: Some(weights),
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PositionEncoding {
Sinusoidal,
Learned,
RoPE,
ALiBi,
NoPE,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_attention_config_default() {
let cfg = AttentionConfig::default();
assert_eq!(cfg.num_heads, 8);
assert_eq!(cfg.head_dim, 64);
assert_eq!(cfg.dropout_prob, 0.0);
assert!(!cfg.causal);
assert!(cfg.use_flash);
assert!(cfg.scale.is_none());
}
#[test]
fn test_attention_config_effective_scale_default() {
let cfg = AttentionConfig::default();
let expected = 1.0 / (64.0_f64).sqrt();
assert!((cfg.effective_scale() - expected).abs() < 1e-12);
}
#[test]
fn test_attention_config_effective_scale_custom() {
let cfg = AttentionConfig {
scale: Some(0.5),
..Default::default()
};
assert!((cfg.effective_scale() - 0.5).abs() < 1e-12);
}
#[test]
fn test_attention_output_struct() {
let output = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let ao = AttentionOutput::new(output.clone());
assert_eq!(ao.output, output);
assert!(ao.attention_weights.is_none());
}
#[test]
fn test_attention_output_with_weights() {
let output = vec![vec![0.5; 4]];
let w = vec![vec![vec![0.25; 4]; 4]; 2];
let ao = AttentionOutput::with_weights(output, w.clone());
assert!(ao.attention_weights.is_some());
assert_eq!(ao.attention_weights.as_ref().map(|x| x.len()), Some(2));
}
#[test]
fn test_position_encoding_variants() {
let variants = [
PositionEncoding::Sinusoidal,
PositionEncoding::Learned,
PositionEncoding::RoPE,
PositionEncoding::ALiBi,
PositionEncoding::NoPE,
];
for v in &variants {
assert_eq!(v, v);
}
}
#[test]
fn test_attention_mask_causal_variant() {
let mask = AttentionMask::Causal;
match mask {
AttentionMask::Causal => {}
_ => panic!("expected Causal"),
}
}
}