oxiblas_core/
blocking.rs

1//! Cache-oblivious blocking utilities.
2//!
3//! This module provides utilities for implementing cache-oblivious algorithms,
4//! which achieve near-optimal cache performance without requiring knowledge of
5//! cache sizes.
6//!
7//! # Cache-Oblivious Algorithms
8//!
9//! Cache-oblivious algorithms use recursive divide-and-conquer strategies that
10//! naturally adapt to all levels of the memory hierarchy. The key insight is that
11//! by recursively splitting the problem until the subproblems fit in cache, we get
12//! optimal cache behavior without needing to know the cache size.
13//!
14//! # Block Size Calculation
15//!
16//! For cache-aware algorithms, this module also provides utilities to calculate
17//! optimal block sizes based on:
18//! - Available cache sizes (L1, L2, L3)
19//! - SIMD register widths
20//! - Memory layout (row-major, column-major)
21
22use crate::tuning::{L1_CACHE_SIZE, L2_CACHE_SIZE};
23use core::mem::size_of;
24
25/// Base case threshold for recursive algorithms.
26///
27/// When the problem size drops below this threshold, we switch to
28/// a direct (non-recursive) implementation.
29pub const BASE_CASE_THRESHOLD: usize = 64;
30
31/// Minimum block size for tiled algorithms.
32pub const MIN_BLOCK_SIZE: usize = 16;
33
34/// Maximum block size for tiled algorithms.
35pub const MAX_BLOCK_SIZE: usize = 512;
36
37/// Calculates the optimal block size for GEMM-like operations.
38///
39/// The block size is chosen to maximize data reuse in the L2 cache.
40/// For GEMM with blocks of size M×K and K×N, we want:
41/// `2 * M * K + K * N ≈ L2_CACHE_SIZE`
42///
43/// # Arguments
44/// * `m` - Number of rows in the result
45/// * `n` - Number of columns in the result
46/// * `k` - Inner dimension
47///
48/// # Returns
49/// A tuple `(block_m, block_n, block_k)` of optimal block sizes.
50pub fn gemm_block_sizes<T>(m: usize, n: usize, k: usize) -> (usize, usize, usize) {
51    let elem_size = size_of::<T>();
52
53    // Target: fit 2 input panels + 1 output panel in L2
54    // A panel: block_m × block_k
55    // B panel: block_k × block_n
56    // C panel: block_m × block_n
57    let target_bytes = L2_CACHE_SIZE / 2;
58
59    // Start with a balanced block size
60    let max_block = ((target_bytes / elem_size / 3) as f64).sqrt() as usize;
61    let mut block = max_block.clamp(MIN_BLOCK_SIZE, MAX_BLOCK_SIZE);
62
63    // Align to SIMD-friendly boundaries
64    block = (block / 8) * 8;
65    if block < MIN_BLOCK_SIZE {
66        block = MIN_BLOCK_SIZE;
67    }
68
69    // Adjust for actual dimensions
70    let block_m = block.min(m);
71    let block_n = block.min(n);
72    let block_k = block.min(k);
73
74    (block_m, block_n, block_k)
75}
76
77/// Calculates the optimal block size for triangular solves (TRSM).
78///
79/// For TRSM, we need to balance between:
80/// - Keeping the triangular block in L1 cache
81/// - Processing multiple right-hand side columns
82pub fn trsm_block_size<T>(n: usize, nrhs: usize) -> usize {
83    let elem_size = size_of::<T>();
84
85    // Target: fit triangular block in L1
86    // Triangular block: n² / 2 elements
87    let max_block = ((2 * L1_CACHE_SIZE / elem_size) as f64).sqrt() as usize;
88    let block = max_block.clamp(MIN_BLOCK_SIZE, MAX_BLOCK_SIZE / 2);
89
90    // Align and adjust
91    let block = (block / 8) * 8;
92    block.min(n).min(nrhs).max(MIN_BLOCK_SIZE)
93}
94
95/// Calculates the optimal panel width for factorizations (LU, Cholesky, QR).
96///
97/// The panel width determines how many columns are processed together
98/// before updating the trailing submatrix.
99pub fn factorization_panel_width<T>(n: usize) -> usize {
100    let elem_size = size_of::<T>();
101
102    // For factorization, we want the panel to fit in L2 cache
103    // Panel size: n × panel_width
104    let max_panel = L2_CACHE_SIZE / (elem_size * n.max(1));
105    let panel = max_panel.clamp(16, 128);
106
107    // Align to SIMD boundaries
108    ((panel / 4) * 4).min(n).max(16)
109}
110
111/// Recursive block range for cache-oblivious algorithms.
112///
113/// This structure represents a range that can be recursively split
114/// for divide-and-conquer algorithms.
115#[derive(Debug, Clone, Copy)]
116pub struct BlockRange {
117    /// Start index (inclusive)
118    pub start: usize,
119    /// End index (exclusive)
120    pub end: usize,
121}
122
123impl BlockRange {
124    /// Creates a new block range.
125    #[inline]
126    pub const fn new(start: usize, end: usize) -> Self {
127        BlockRange { start, end }
128    }
129
130    /// Creates a range from 0 to n.
131    #[inline]
132    pub const fn from_len(n: usize) -> Self {
133        BlockRange { start: 0, end: n }
134    }
135
136    /// Returns the length of this range.
137    #[inline]
138    pub const fn len(&self) -> usize {
139        self.end.saturating_sub(self.start)
140    }
141
142    /// Returns true if this range is empty.
143    #[inline]
144    pub const fn is_empty(&self) -> bool {
145        self.start >= self.end
146    }
147
148    /// Returns true if this range is a base case (should not be split further).
149    #[inline]
150    pub fn is_base_case(&self, threshold: usize) -> bool {
151        self.len() <= threshold
152    }
153
154    /// Splits this range in half.
155    ///
156    /// Returns `(left_half, right_half)`.
157    #[inline]
158    pub fn split(&self) -> (Self, Self) {
159        let mid = self.start + self.len() / 2;
160        (
161            BlockRange::new(self.start, mid),
162            BlockRange::new(mid, self.end),
163        )
164    }
165
166    /// Splits at a specific point.
167    #[inline]
168    pub fn split_at(&self, point: usize) -> (Self, Self) {
169        let split = (self.start + point).min(self.end);
170        (
171            BlockRange::new(self.start, split),
172            BlockRange::new(split, self.end),
173        )
174    }
175}
176
177/// Task for cache-oblivious recursive algorithm.
178///
179/// This represents a subproblem in a recursive decomposition.
180#[derive(Debug, Clone, Copy)]
181pub struct RecursiveTask {
182    /// Row range
183    pub rows: BlockRange,
184    /// Column range
185    pub cols: BlockRange,
186}
187
188impl RecursiveTask {
189    /// Creates a new recursive task.
190    #[inline]
191    pub const fn new(rows: BlockRange, cols: BlockRange) -> Self {
192        RecursiveTask { rows, cols }
193    }
194
195    /// Creates a task for an m×n matrix.
196    #[inline]
197    pub const fn from_dims(m: usize, n: usize) -> Self {
198        RecursiveTask {
199            rows: BlockRange::from_len(m),
200            cols: BlockRange::from_len(n),
201        }
202    }
203
204    /// Returns the number of elements in this task.
205    #[inline]
206    pub fn size(&self) -> usize {
207        self.rows.len() * self.cols.len()
208    }
209
210    /// Returns true if this is a base case.
211    #[inline]
212    pub fn is_base_case(&self, threshold: usize) -> bool {
213        self.rows.len() <= threshold && self.cols.len() <= threshold
214    }
215
216    /// Splits along the larger dimension.
217    ///
218    /// Returns two subtasks by splitting the larger dimension in half.
219    pub fn split(&self) -> (Self, Self) {
220        if self.rows.len() >= self.cols.len() {
221            // Split rows
222            let (r1, r2) = self.rows.split();
223            (
224                RecursiveTask::new(r1, self.cols),
225                RecursiveTask::new(r2, self.cols),
226            )
227        } else {
228            // Split columns
229            let (c1, c2) = self.cols.split();
230            (
231                RecursiveTask::new(self.rows, c1),
232                RecursiveTask::new(self.rows, c2),
233            )
234        }
235    }
236
237    /// Quadrant decomposition for 2D recursive algorithms.
238    ///
239    /// Returns `(top_left, top_right, bottom_left, bottom_right)`.
240    pub fn quadrants(&self) -> (Self, Self, Self, Self) {
241        let (r1, r2) = self.rows.split();
242        let (c1, c2) = self.cols.split();
243
244        (
245            RecursiveTask::new(r1, c1), // top-left
246            RecursiveTask::new(r1, c2), // top-right
247            RecursiveTask::new(r2, c1), // bottom-left
248            RecursiveTask::new(r2, c2), // bottom-right
249        )
250    }
251}
252
253/// Visitor pattern for cache-oblivious matrix traversal.
254///
255/// Implement this trait to process matrix blocks in a cache-efficient order.
256pub trait BlockVisitor {
257    /// The error type for visit operations.
258    type Error;
259
260    /// Visits a matrix block.
261    ///
262    /// # Arguments
263    /// * `row_start`, `row_end` - Row range (exclusive end)
264    /// * `col_start`, `col_end` - Column range (exclusive end)
265    fn visit_block(
266        &mut self,
267        row_start: usize,
268        row_end: usize,
269        col_start: usize,
270        col_end: usize,
271    ) -> Result<(), Self::Error>;
272}
273
274/// Performs a cache-oblivious traversal of a matrix.
275///
276/// This recursively divides the matrix into quadrants until reaching
277/// the base case threshold, then visits each block.
278pub fn cache_oblivious_traverse<V: BlockVisitor>(
279    visitor: &mut V,
280    task: RecursiveTask,
281    threshold: usize,
282) -> Result<(), V::Error> {
283    if task.is_base_case(threshold) {
284        // Base case: visit this block directly
285        visitor.visit_block(
286            task.rows.start,
287            task.rows.end,
288            task.cols.start,
289            task.cols.end,
290        )
291    } else {
292        // Recursive case: split and process
293        let (t1, t2) = task.split();
294        cache_oblivious_traverse(visitor, t1, threshold)?;
295        cache_oblivious_traverse(visitor, t2, threshold)
296    }
297}
298
299/// Morton (Z-order) curve index calculation.
300///
301/// Morton ordering provides good cache locality for 2D data by
302/// interleaving the bits of x and y coordinates.
303#[inline]
304pub fn morton_index(x: u32, y: u32) -> u64 {
305    fn expand_bits(v: u32) -> u64 {
306        let mut v = v as u64;
307        v = (v | (v << 16)) & 0x0000_FFFF_0000_FFFF;
308        v = (v | (v << 8)) & 0x00FF_00FF_00FF_00FF;
309        v = (v | (v << 4)) & 0x0F0F_0F0F_0F0F_0F0F;
310        v = (v | (v << 2)) & 0x3333_3333_3333_3333;
311        v = (v | (v << 1)) & 0x5555_5555_5555_5555;
312        v
313    }
314    expand_bits(x) | (expand_bits(y) << 1)
315}
316
317/// Inverse Morton index: extracts (x, y) from a Morton index.
318#[inline]
319pub fn morton_decode(z: u64) -> (u32, u32) {
320    fn compact_bits(mut v: u64) -> u32 {
321        v &= 0x5555_5555_5555_5555;
322        v = (v | (v >> 1)) & 0x3333_3333_3333_3333;
323        v = (v | (v >> 2)) & 0x0F0F_0F0F_0F0F_0F0F;
324        v = (v | (v >> 4)) & 0x00FF_00FF_00FF_00FF;
325        v = (v | (v >> 8)) & 0x0000_FFFF_0000_FFFF;
326        v = (v | (v >> 16)) & 0x0000_0000_FFFF_FFFF;
327        v as u32
328    }
329    (compact_bits(z), compact_bits(z >> 1))
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    #[test]
337    fn test_gemm_block_sizes() {
338        let (bm, bn, bk) = gemm_block_sizes::<f64>(1024, 1024, 1024);
339
340        // Block sizes should be reasonable
341        assert!(bm >= MIN_BLOCK_SIZE);
342        assert!(bn >= MIN_BLOCK_SIZE);
343        assert!(bk >= MIN_BLOCK_SIZE);
344        assert!(bm <= MAX_BLOCK_SIZE);
345        assert!(bn <= MAX_BLOCK_SIZE);
346        assert!(bk <= MAX_BLOCK_SIZE);
347
348        // Should be divisible by 8
349        assert_eq!(bm % 8, 0);
350    }
351
352    #[test]
353    fn test_block_range() {
354        let range = BlockRange::new(0, 100);
355        assert_eq!(range.len(), 100);
356
357        let (left, right) = range.split();
358        assert_eq!(left.start, 0);
359        assert_eq!(left.end, 50);
360        assert_eq!(right.start, 50);
361        assert_eq!(right.end, 100);
362
363        assert!(BlockRange::new(0, 32).is_base_case(64));
364        assert!(!BlockRange::new(0, 100).is_base_case(64));
365    }
366
367    #[test]
368    fn test_recursive_task() {
369        let task = RecursiveTask::from_dims(100, 200);
370        assert_eq!(task.size(), 20000);
371
372        // Should split along columns (larger dimension)
373        let (t1, t2) = task.split();
374        assert_eq!(t1.cols.len(), 100);
375        assert_eq!(t2.cols.len(), 100);
376        assert_eq!(t1.rows.len(), 100);
377        assert_eq!(t2.rows.len(), 100);
378    }
379
380    #[test]
381    fn test_quadrants() {
382        let task = RecursiveTask::from_dims(100, 100);
383        let (tl, _tr, _bl, br) = task.quadrants();
384
385        assert_eq!(tl.rows.start, 0);
386        assert_eq!(tl.rows.end, 50);
387        assert_eq!(tl.cols.start, 0);
388        assert_eq!(tl.cols.end, 50);
389
390        assert_eq!(br.rows.start, 50);
391        assert_eq!(br.rows.end, 100);
392        assert_eq!(br.cols.start, 50);
393        assert_eq!(br.cols.end, 100);
394    }
395
396    #[test]
397    fn test_morton_index() {
398        // Morton index interleaves bits
399        assert_eq!(morton_index(0, 0), 0);
400        assert_eq!(morton_index(1, 0), 1);
401        assert_eq!(morton_index(0, 1), 2);
402        assert_eq!(morton_index(1, 1), 3);
403        assert_eq!(morton_index(2, 0), 4);
404
405        // Roundtrip test
406        for x in 0..100 {
407            for y in 0..100 {
408                let z = morton_index(x, y);
409                let (dx, dy) = morton_decode(z);
410                assert_eq!((dx, dy), (x, y));
411            }
412        }
413    }
414
415    struct CountingVisitor {
416        count: usize,
417        total_elements: usize,
418    }
419
420    impl BlockVisitor for CountingVisitor {
421        type Error = ();
422
423        fn visit_block(
424            &mut self,
425            row_start: usize,
426            row_end: usize,
427            col_start: usize,
428            col_end: usize,
429        ) -> Result<(), ()> {
430            self.count += 1;
431            self.total_elements += (row_end - row_start) * (col_end - col_start);
432            Ok(())
433        }
434    }
435
436    #[test]
437    fn test_cache_oblivious_traverse() {
438        let task = RecursiveTask::from_dims(128, 128);
439        let mut visitor = CountingVisitor {
440            count: 0,
441            total_elements: 0,
442        };
443
444        cache_oblivious_traverse(&mut visitor, task, 32).unwrap();
445
446        // Should visit multiple blocks
447        assert!(visitor.count > 1);
448        // Should cover all elements
449        assert_eq!(visitor.total_elements, 128 * 128);
450    }
451}