1use crate::tuning::{L1_CACHE_SIZE, L2_CACHE_SIZE};
23use core::mem::size_of;
24
25pub const BASE_CASE_THRESHOLD: usize = 64;
30
31pub const MIN_BLOCK_SIZE: usize = 16;
33
34pub const MAX_BLOCK_SIZE: usize = 512;
36
37pub fn gemm_block_sizes<T>(m: usize, n: usize, k: usize) -> (usize, usize, usize) {
51 let elem_size = size_of::<T>();
52
53 let target_bytes = L2_CACHE_SIZE / 2;
58
59 let max_block = ((target_bytes / elem_size / 3) as f64).sqrt() as usize;
61 let mut block = max_block.clamp(MIN_BLOCK_SIZE, MAX_BLOCK_SIZE);
62
63 block = (block / 8) * 8;
65 if block < MIN_BLOCK_SIZE {
66 block = MIN_BLOCK_SIZE;
67 }
68
69 let block_m = block.min(m);
71 let block_n = block.min(n);
72 let block_k = block.min(k);
73
74 (block_m, block_n, block_k)
75}
76
77pub fn trsm_block_size<T>(n: usize, nrhs: usize) -> usize {
83 let elem_size = size_of::<T>();
84
85 let max_block = ((2 * L1_CACHE_SIZE / elem_size) as f64).sqrt() as usize;
88 let block = max_block.clamp(MIN_BLOCK_SIZE, MAX_BLOCK_SIZE / 2);
89
90 let block = (block / 8) * 8;
92 block.min(n).min(nrhs).max(MIN_BLOCK_SIZE)
93}
94
95pub fn factorization_panel_width<T>(n: usize) -> usize {
100 let elem_size = size_of::<T>();
101
102 let max_panel = L2_CACHE_SIZE / (elem_size * n.max(1));
105 let panel = max_panel.clamp(16, 128);
106
107 ((panel / 4) * 4).min(n).max(16)
109}
110
111#[derive(Debug, Clone, Copy)]
116pub struct BlockRange {
117 pub start: usize,
119 pub end: usize,
121}
122
123impl BlockRange {
124 #[inline]
126 pub const fn new(start: usize, end: usize) -> Self {
127 BlockRange { start, end }
128 }
129
130 #[inline]
132 pub const fn from_len(n: usize) -> Self {
133 BlockRange { start: 0, end: n }
134 }
135
136 #[inline]
138 pub const fn len(&self) -> usize {
139 self.end.saturating_sub(self.start)
140 }
141
142 #[inline]
144 pub const fn is_empty(&self) -> bool {
145 self.start >= self.end
146 }
147
148 #[inline]
150 pub fn is_base_case(&self, threshold: usize) -> bool {
151 self.len() <= threshold
152 }
153
154 #[inline]
158 pub fn split(&self) -> (Self, Self) {
159 let mid = self.start + self.len() / 2;
160 (
161 BlockRange::new(self.start, mid),
162 BlockRange::new(mid, self.end),
163 )
164 }
165
166 #[inline]
168 pub fn split_at(&self, point: usize) -> (Self, Self) {
169 let split = (self.start + point).min(self.end);
170 (
171 BlockRange::new(self.start, split),
172 BlockRange::new(split, self.end),
173 )
174 }
175}
176
177#[derive(Debug, Clone, Copy)]
181pub struct RecursiveTask {
182 pub rows: BlockRange,
184 pub cols: BlockRange,
186}
187
188impl RecursiveTask {
189 #[inline]
191 pub const fn new(rows: BlockRange, cols: BlockRange) -> Self {
192 RecursiveTask { rows, cols }
193 }
194
195 #[inline]
197 pub const fn from_dims(m: usize, n: usize) -> Self {
198 RecursiveTask {
199 rows: BlockRange::from_len(m),
200 cols: BlockRange::from_len(n),
201 }
202 }
203
204 #[inline]
206 pub fn size(&self) -> usize {
207 self.rows.len() * self.cols.len()
208 }
209
210 #[inline]
212 pub fn is_base_case(&self, threshold: usize) -> bool {
213 self.rows.len() <= threshold && self.cols.len() <= threshold
214 }
215
216 pub fn split(&self) -> (Self, Self) {
220 if self.rows.len() >= self.cols.len() {
221 let (r1, r2) = self.rows.split();
223 (
224 RecursiveTask::new(r1, self.cols),
225 RecursiveTask::new(r2, self.cols),
226 )
227 } else {
228 let (c1, c2) = self.cols.split();
230 (
231 RecursiveTask::new(self.rows, c1),
232 RecursiveTask::new(self.rows, c2),
233 )
234 }
235 }
236
237 pub fn quadrants(&self) -> (Self, Self, Self, Self) {
241 let (r1, r2) = self.rows.split();
242 let (c1, c2) = self.cols.split();
243
244 (
245 RecursiveTask::new(r1, c1), RecursiveTask::new(r1, c2), RecursiveTask::new(r2, c1), RecursiveTask::new(r2, c2), )
250 }
251}
252
253pub trait BlockVisitor {
257 type Error;
259
260 fn visit_block(
266 &mut self,
267 row_start: usize,
268 row_end: usize,
269 col_start: usize,
270 col_end: usize,
271 ) -> Result<(), Self::Error>;
272}
273
274pub fn cache_oblivious_traverse<V: BlockVisitor>(
279 visitor: &mut V,
280 task: RecursiveTask,
281 threshold: usize,
282) -> Result<(), V::Error> {
283 if task.is_base_case(threshold) {
284 visitor.visit_block(
286 task.rows.start,
287 task.rows.end,
288 task.cols.start,
289 task.cols.end,
290 )
291 } else {
292 let (t1, t2) = task.split();
294 cache_oblivious_traverse(visitor, t1, threshold)?;
295 cache_oblivious_traverse(visitor, t2, threshold)
296 }
297}
298
299#[inline]
304pub fn morton_index(x: u32, y: u32) -> u64 {
305 fn expand_bits(v: u32) -> u64 {
306 let mut v = v as u64;
307 v = (v | (v << 16)) & 0x0000_FFFF_0000_FFFF;
308 v = (v | (v << 8)) & 0x00FF_00FF_00FF_00FF;
309 v = (v | (v << 4)) & 0x0F0F_0F0F_0F0F_0F0F;
310 v = (v | (v << 2)) & 0x3333_3333_3333_3333;
311 v = (v | (v << 1)) & 0x5555_5555_5555_5555;
312 v
313 }
314 expand_bits(x) | (expand_bits(y) << 1)
315}
316
317#[inline]
319pub fn morton_decode(z: u64) -> (u32, u32) {
320 fn compact_bits(mut v: u64) -> u32 {
321 v &= 0x5555_5555_5555_5555;
322 v = (v | (v >> 1)) & 0x3333_3333_3333_3333;
323 v = (v | (v >> 2)) & 0x0F0F_0F0F_0F0F_0F0F;
324 v = (v | (v >> 4)) & 0x00FF_00FF_00FF_00FF;
325 v = (v | (v >> 8)) & 0x0000_FFFF_0000_FFFF;
326 v = (v | (v >> 16)) & 0x0000_0000_FFFF_FFFF;
327 v as u32
328 }
329 (compact_bits(z), compact_bits(z >> 1))
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 #[test]
337 fn test_gemm_block_sizes() {
338 let (bm, bn, bk) = gemm_block_sizes::<f64>(1024, 1024, 1024);
339
340 assert!(bm >= MIN_BLOCK_SIZE);
342 assert!(bn >= MIN_BLOCK_SIZE);
343 assert!(bk >= MIN_BLOCK_SIZE);
344 assert!(bm <= MAX_BLOCK_SIZE);
345 assert!(bn <= MAX_BLOCK_SIZE);
346 assert!(bk <= MAX_BLOCK_SIZE);
347
348 assert_eq!(bm % 8, 0);
350 }
351
352 #[test]
353 fn test_block_range() {
354 let range = BlockRange::new(0, 100);
355 assert_eq!(range.len(), 100);
356
357 let (left, right) = range.split();
358 assert_eq!(left.start, 0);
359 assert_eq!(left.end, 50);
360 assert_eq!(right.start, 50);
361 assert_eq!(right.end, 100);
362
363 assert!(BlockRange::new(0, 32).is_base_case(64));
364 assert!(!BlockRange::new(0, 100).is_base_case(64));
365 }
366
367 #[test]
368 fn test_recursive_task() {
369 let task = RecursiveTask::from_dims(100, 200);
370 assert_eq!(task.size(), 20000);
371
372 let (t1, t2) = task.split();
374 assert_eq!(t1.cols.len(), 100);
375 assert_eq!(t2.cols.len(), 100);
376 assert_eq!(t1.rows.len(), 100);
377 assert_eq!(t2.rows.len(), 100);
378 }
379
380 #[test]
381 fn test_quadrants() {
382 let task = RecursiveTask::from_dims(100, 100);
383 let (tl, _tr, _bl, br) = task.quadrants();
384
385 assert_eq!(tl.rows.start, 0);
386 assert_eq!(tl.rows.end, 50);
387 assert_eq!(tl.cols.start, 0);
388 assert_eq!(tl.cols.end, 50);
389
390 assert_eq!(br.rows.start, 50);
391 assert_eq!(br.rows.end, 100);
392 assert_eq!(br.cols.start, 50);
393 assert_eq!(br.cols.end, 100);
394 }
395
396 #[test]
397 fn test_morton_index() {
398 assert_eq!(morton_index(0, 0), 0);
400 assert_eq!(morton_index(1, 0), 1);
401 assert_eq!(morton_index(0, 1), 2);
402 assert_eq!(morton_index(1, 1), 3);
403 assert_eq!(morton_index(2, 0), 4);
404
405 for x in 0..100 {
407 for y in 0..100 {
408 let z = morton_index(x, y);
409 let (dx, dy) = morton_decode(z);
410 assert_eq!((dx, dy), (x, y));
411 }
412 }
413 }
414
415 struct CountingVisitor {
416 count: usize,
417 total_elements: usize,
418 }
419
420 impl BlockVisitor for CountingVisitor {
421 type Error = ();
422
423 fn visit_block(
424 &mut self,
425 row_start: usize,
426 row_end: usize,
427 col_start: usize,
428 col_end: usize,
429 ) -> Result<(), ()> {
430 self.count += 1;
431 self.total_elements += (row_end - row_start) * (col_end - col_start);
432 Ok(())
433 }
434 }
435
436 #[test]
437 fn test_cache_oblivious_traverse() {
438 let task = RecursiveTask::from_dims(128, 128);
439 let mut visitor = CountingVisitor {
440 count: 0,
441 total_elements: 0,
442 };
443
444 cache_oblivious_traverse(&mut visitor, task, 32).unwrap();
445
446 assert!(visitor.count > 1);
448 assert_eq!(visitor.total_elements, 128 * 128);
450 }
451}