gllm_kernels/
types.rs

1//! Types and configuration for attention operations.
2
3/// Result type for configuration validation.
4pub type ConfigResult<T> = std::result::Result<T, String>;
5
6/// Precision mode for kernel computations.
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
8pub enum KernelPrecision {
9    /// FP32 throughout (highest precision, slowest).
10    FP32,
11    /// FP16 compute with FP32 accumulation (balanced).
12    #[default]
13    FP16,
14    /// BF16 compute with FP32 accumulation (best for training).
15    BF16,
16    /// FP8 compute with FP16 accumulation (fastest, Hopper+).
17    FP8,
18}
19
20/// Configuration for standard attention.
21#[derive(Debug, Clone)]
22pub struct AttentionConfig {
23    /// Batch size.
24    pub batch_size: usize,
25    /// Number of attention heads.
26    pub num_heads: usize,
27    /// Query sequence length.
28    pub query_len: usize,
29    /// Key/Value sequence length.
30    pub kv_len: usize,
31    /// Head dimension.
32    pub head_dim: usize,
33    /// Whether to apply causal masking.
34    pub causal: bool,
35    /// Softmax scale factor (usually 1/sqrt(head_dim)).
36    pub scale: f32,
37    /// Computation precision.
38    pub precision: KernelPrecision,
39    /// Block size for query tiling.
40    pub block_q: usize,
41    /// Block size for KV tiling.
42    pub block_kv: usize,
43}
44
45impl Default for AttentionConfig {
46    fn default() -> Self {
47        Self {
48            batch_size: 1,
49            num_heads: 8,
50            query_len: 1,
51            kv_len: 1024,
52            head_dim: 64,
53            causal: true,
54            scale: 0.125,
55            precision: KernelPrecision::FP16,
56            block_q: 64,
57            block_kv: 64,
58        }
59    }
60}
61
62impl AttentionConfig {
63    /// Create config for the given dimensions.
64    pub fn new(
65        batch_size: usize,
66        num_heads: usize,
67        query_len: usize,
68        kv_len: usize,
69        head_dim: usize,
70    ) -> Self {
71        Self {
72            batch_size,
73            num_heads,
74            query_len,
75            kv_len,
76            head_dim,
77            scale: 1.0 / (head_dim as f32).sqrt(),
78            ..Default::default()
79        }
80    }
81
82    /// Set causal masking.
83    pub fn with_causal(mut self, causal: bool) -> Self {
84        self.causal = causal;
85        self
86    }
87
88    /// Set precision mode.
89    pub fn with_precision(mut self, precision: KernelPrecision) -> Self {
90        self.precision = precision;
91        self
92    }
93
94    /// Set block sizes for tiling.
95    pub fn with_block_sizes(mut self, block_q: usize, block_kv: usize) -> Self {
96        self.block_q = block_q;
97        self.block_kv = block_kv;
98        self
99    }
100
101    /// Validate configuration values.
102    pub fn validate(&self) -> ConfigResult<()> {
103        if self.batch_size == 0 {
104            return Err("batch_size must be > 0".to_string());
105        }
106        if self.num_heads == 0 {
107            return Err("num_heads must be > 0".to_string());
108        }
109        if self.head_dim == 0 {
110            return Err("head_dim must be > 0".to_string());
111        }
112        if self.head_dim > 256 {
113            return Err("head_dim > 256 not supported".to_string());
114        }
115        if self.block_q == 0 || self.block_kv == 0 {
116            return Err("block sizes must be > 0".to_string());
117        }
118        Ok(())
119    }
120}
121
122/// Configuration for paged attention.
123#[derive(Debug, Clone)]
124pub struct PagedAttentionConfig {
125    /// Base attention config.
126    pub attention: AttentionConfig,
127    /// Block size in the page table (tokens per block).
128    pub page_block_size: usize,
129    /// Maximum number of blocks per sequence.
130    pub max_blocks_per_seq: usize,
131    /// Maximum number of sequences.
132    pub max_num_seqs: usize,
133}
134
135impl Default for PagedAttentionConfig {
136    fn default() -> Self {
137        Self {
138            attention: AttentionConfig::default(),
139            page_block_size: 16,
140            max_blocks_per_seq: 128,
141            max_num_seqs: 256,
142        }
143    }
144}
145
146impl PagedAttentionConfig {
147    /// Create config with the given block size.
148    pub fn new(page_block_size: usize) -> Self {
149        Self {
150            page_block_size,
151            ..Default::default()
152        }
153    }
154
155    /// Set attention config.
156    pub fn with_attention(mut self, attention: AttentionConfig) -> Self {
157        self.attention = attention;
158        self
159    }
160
161    /// Set max blocks per sequence.
162    pub fn with_max_blocks(mut self, max_blocks: usize) -> Self {
163        self.max_blocks_per_seq = max_blocks;
164        self
165    }
166
167    /// Maximum context length supported.
168    pub fn max_context_len(&self) -> usize {
169        self.page_block_size * self.max_blocks_per_seq
170    }
171
172    /// Validate configuration values.
173    pub fn validate(&self) -> ConfigResult<()> {
174        self.attention.validate()?;
175        if self.page_block_size == 0 {
176            return Err("page_block_size must be > 0".to_string());
177        }
178        if self.max_blocks_per_seq == 0 {
179            return Err("max_blocks_per_seq must be > 0".to_string());
180        }
181        Ok(())
182    }
183}