1pub type ConfigResult<T> = std::result::Result<T, String>;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
8pub enum KernelPrecision {
9 FP32,
11 #[default]
13 FP16,
14 BF16,
16 FP8,
18}
19
20#[derive(Debug, Clone)]
22pub struct AttentionConfig {
23 pub batch_size: usize,
25 pub num_heads: usize,
27 pub query_len: usize,
29 pub kv_len: usize,
31 pub head_dim: usize,
33 pub causal: bool,
35 pub scale: f32,
37 pub precision: KernelPrecision,
39 pub block_q: usize,
41 pub block_kv: usize,
43}
44
45impl Default for AttentionConfig {
46 fn default() -> Self {
47 Self {
48 batch_size: 1,
49 num_heads: 8,
50 query_len: 1,
51 kv_len: 1024,
52 head_dim: 64,
53 causal: true,
54 scale: 0.125,
55 precision: KernelPrecision::FP16,
56 block_q: 64,
57 block_kv: 64,
58 }
59 }
60}
61
62impl AttentionConfig {
63 pub fn new(
65 batch_size: usize,
66 num_heads: usize,
67 query_len: usize,
68 kv_len: usize,
69 head_dim: usize,
70 ) -> Self {
71 Self {
72 batch_size,
73 num_heads,
74 query_len,
75 kv_len,
76 head_dim,
77 scale: 1.0 / (head_dim as f32).sqrt(),
78 ..Default::default()
79 }
80 }
81
82 pub fn with_causal(mut self, causal: bool) -> Self {
84 self.causal = causal;
85 self
86 }
87
88 pub fn with_precision(mut self, precision: KernelPrecision) -> Self {
90 self.precision = precision;
91 self
92 }
93
94 pub fn with_block_sizes(mut self, block_q: usize, block_kv: usize) -> Self {
96 self.block_q = block_q;
97 self.block_kv = block_kv;
98 self
99 }
100
101 pub fn validate(&self) -> ConfigResult<()> {
103 if self.batch_size == 0 {
104 return Err("batch_size must be > 0".to_string());
105 }
106 if self.num_heads == 0 {
107 return Err("num_heads must be > 0".to_string());
108 }
109 if self.head_dim == 0 {
110 return Err("head_dim must be > 0".to_string());
111 }
112 if self.head_dim > 256 {
113 return Err("head_dim > 256 not supported".to_string());
114 }
115 if self.block_q == 0 || self.block_kv == 0 {
116 return Err("block sizes must be > 0".to_string());
117 }
118 Ok(())
119 }
120}
121
122#[derive(Debug, Clone)]
124pub struct PagedAttentionConfig {
125 pub attention: AttentionConfig,
127 pub page_block_size: usize,
129 pub max_blocks_per_seq: usize,
131 pub max_num_seqs: usize,
133}
134
135impl Default for PagedAttentionConfig {
136 fn default() -> Self {
137 Self {
138 attention: AttentionConfig::default(),
139 page_block_size: 16,
140 max_blocks_per_seq: 128,
141 max_num_seqs: 256,
142 }
143 }
144}
145
146impl PagedAttentionConfig {
147 pub fn new(page_block_size: usize) -> Self {
149 Self {
150 page_block_size,
151 ..Default::default()
152 }
153 }
154
155 pub fn with_attention(mut self, attention: AttentionConfig) -> Self {
157 self.attention = attention;
158 self
159 }
160
161 pub fn with_max_blocks(mut self, max_blocks: usize) -> Self {
163 self.max_blocks_per_seq = max_blocks;
164 self
165 }
166
167 pub fn max_context_len(&self) -> usize {
169 self.page_block_size * self.max_blocks_per_seq
170 }
171
172 pub fn validate(&self) -> ConfigResult<()> {
174 self.attention.validate()?;
175 if self.page_block_size == 0 {
176 return Err("page_block_size must be > 0".to_string());
177 }
178 if self.max_blocks_per_seq == 0 {
179 return Err("max_blocks_per_seq must be > 0".to_string());
180 }
181 Ok(())
182 }
183}