use crate::tuning::{L1_CACHE_SIZE, L2_CACHE_SIZE};
use core::mem::size_of;
pub const BASE_CASE_THRESHOLD: usize = 64;
pub const MIN_BLOCK_SIZE: usize = 16;
pub const MAX_BLOCK_SIZE: usize = 512;
pub fn gemm_block_sizes<T>(m: usize, n: usize, k: usize) -> (usize, usize, usize) {
let elem_size = size_of::<T>();
let target_bytes = L2_CACHE_SIZE / 2;
let max_block = ((target_bytes / elem_size / 3) as f64).sqrt() as usize;
let mut block = max_block.clamp(MIN_BLOCK_SIZE, MAX_BLOCK_SIZE);
block = (block / 8) * 8;
if block < MIN_BLOCK_SIZE {
block = MIN_BLOCK_SIZE;
}
let block_m = block.min(m);
let block_n = block.min(n);
let block_k = block.min(k);
(block_m, block_n, block_k)
}
pub fn trsm_block_size<T>(n: usize, nrhs: usize) -> usize {
let elem_size = size_of::<T>();
let max_block = ((2 * L1_CACHE_SIZE / elem_size) as f64).sqrt() as usize;
let block = max_block.clamp(MIN_BLOCK_SIZE, MAX_BLOCK_SIZE / 2);
let block = (block / 8) * 8;
block.min(n).min(nrhs).max(MIN_BLOCK_SIZE)
}
pub fn factorization_panel_width<T>(n: usize) -> usize {
let elem_size = size_of::<T>();
let max_panel = L2_CACHE_SIZE / (elem_size * n.max(1));
let panel = max_panel.clamp(16, 128);
((panel / 4) * 4).min(n).max(16)
}
#[derive(Debug, Clone, Copy)]
pub struct BlockRange {
pub start: usize,
pub end: usize,
}
impl BlockRange {
#[inline]
pub const fn new(start: usize, end: usize) -> Self {
BlockRange { start, end }
}
#[inline]
pub const fn from_len(n: usize) -> Self {
BlockRange { start: 0, end: n }
}
#[inline]
pub const fn len(&self) -> usize {
self.end.saturating_sub(self.start)
}
#[inline]
pub const fn is_empty(&self) -> bool {
self.start >= self.end
}
#[inline]
pub fn is_base_case(&self, threshold: usize) -> bool {
self.len() <= threshold
}
#[inline]
pub fn split(&self) -> (Self, Self) {
let mid = self.start + self.len() / 2;
(
BlockRange::new(self.start, mid),
BlockRange::new(mid, self.end),
)
}
#[inline]
pub fn split_at(&self, point: usize) -> (Self, Self) {
let split = (self.start + point).min(self.end);
(
BlockRange::new(self.start, split),
BlockRange::new(split, self.end),
)
}
}
#[derive(Debug, Clone, Copy)]
pub struct RecursiveTask {
pub rows: BlockRange,
pub cols: BlockRange,
}
impl RecursiveTask {
#[inline]
pub const fn new(rows: BlockRange, cols: BlockRange) -> Self {
RecursiveTask { rows, cols }
}
#[inline]
pub const fn from_dims(m: usize, n: usize) -> Self {
RecursiveTask {
rows: BlockRange::from_len(m),
cols: BlockRange::from_len(n),
}
}
#[inline]
pub fn size(&self) -> usize {
self.rows.len() * self.cols.len()
}
#[inline]
pub fn is_base_case(&self, threshold: usize) -> bool {
self.rows.len() <= threshold && self.cols.len() <= threshold
}
pub fn split(&self) -> (Self, Self) {
if self.rows.len() >= self.cols.len() {
let (r1, r2) = self.rows.split();
(
RecursiveTask::new(r1, self.cols),
RecursiveTask::new(r2, self.cols),
)
} else {
let (c1, c2) = self.cols.split();
(
RecursiveTask::new(self.rows, c1),
RecursiveTask::new(self.rows, c2),
)
}
}
pub fn quadrants(&self) -> (Self, Self, Self, Self) {
let (r1, r2) = self.rows.split();
let (c1, c2) = self.cols.split();
(
RecursiveTask::new(r1, c1), RecursiveTask::new(r1, c2), RecursiveTask::new(r2, c1), RecursiveTask::new(r2, c2), )
}
}
pub trait BlockVisitor {
type Error;
fn visit_block(
&mut self,
row_start: usize,
row_end: usize,
col_start: usize,
col_end: usize,
) -> Result<(), Self::Error>;
}
pub fn cache_oblivious_traverse<V: BlockVisitor>(
visitor: &mut V,
task: RecursiveTask,
threshold: usize,
) -> Result<(), V::Error> {
if task.is_base_case(threshold) {
visitor.visit_block(
task.rows.start,
task.rows.end,
task.cols.start,
task.cols.end,
)
} else {
let (t1, t2) = task.split();
cache_oblivious_traverse(visitor, t1, threshold)?;
cache_oblivious_traverse(visitor, t2, threshold)
}
}
#[inline]
pub fn morton_index(x: u32, y: u32) -> u64 {
fn expand_bits(v: u32) -> u64 {
let mut v = v as u64;
v = (v | (v << 16)) & 0x0000_FFFF_0000_FFFF;
v = (v | (v << 8)) & 0x00FF_00FF_00FF_00FF;
v = (v | (v << 4)) & 0x0F0F_0F0F_0F0F_0F0F;
v = (v | (v << 2)) & 0x3333_3333_3333_3333;
v = (v | (v << 1)) & 0x5555_5555_5555_5555;
v
}
expand_bits(x) | (expand_bits(y) << 1)
}
#[inline]
pub fn morton_decode(z: u64) -> (u32, u32) {
fn compact_bits(mut v: u64) -> u32 {
v &= 0x5555_5555_5555_5555;
v = (v | (v >> 1)) & 0x3333_3333_3333_3333;
v = (v | (v >> 2)) & 0x0F0F_0F0F_0F0F_0F0F;
v = (v | (v >> 4)) & 0x00FF_00FF_00FF_00FF;
v = (v | (v >> 8)) & 0x0000_FFFF_0000_FFFF;
v = (v | (v >> 16)) & 0x0000_0000_FFFF_FFFF;
v as u32
}
(compact_bits(z), compact_bits(z >> 1))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gemm_block_sizes() {
let (bm, bn, bk) = gemm_block_sizes::<f64>(1024, 1024, 1024);
assert!(bm >= MIN_BLOCK_SIZE);
assert!(bn >= MIN_BLOCK_SIZE);
assert!(bk >= MIN_BLOCK_SIZE);
assert!(bm <= MAX_BLOCK_SIZE);
assert!(bn <= MAX_BLOCK_SIZE);
assert!(bk <= MAX_BLOCK_SIZE);
assert_eq!(bm % 8, 0);
}
#[test]
fn test_block_range() {
let range = BlockRange::new(0, 100);
assert_eq!(range.len(), 100);
let (left, right) = range.split();
assert_eq!(left.start, 0);
assert_eq!(left.end, 50);
assert_eq!(right.start, 50);
assert_eq!(right.end, 100);
assert!(BlockRange::new(0, 32).is_base_case(64));
assert!(!BlockRange::new(0, 100).is_base_case(64));
}
#[test]
fn test_recursive_task() {
let task = RecursiveTask::from_dims(100, 200);
assert_eq!(task.size(), 20000);
let (t1, t2) = task.split();
assert_eq!(t1.cols.len(), 100);
assert_eq!(t2.cols.len(), 100);
assert_eq!(t1.rows.len(), 100);
assert_eq!(t2.rows.len(), 100);
}
#[test]
fn test_quadrants() {
let task = RecursiveTask::from_dims(100, 100);
let (tl, _tr, _bl, br) = task.quadrants();
assert_eq!(tl.rows.start, 0);
assert_eq!(tl.rows.end, 50);
assert_eq!(tl.cols.start, 0);
assert_eq!(tl.cols.end, 50);
assert_eq!(br.rows.start, 50);
assert_eq!(br.rows.end, 100);
assert_eq!(br.cols.start, 50);
assert_eq!(br.cols.end, 100);
}
#[test]
fn test_morton_index() {
assert_eq!(morton_index(0, 0), 0);
assert_eq!(morton_index(1, 0), 1);
assert_eq!(morton_index(0, 1), 2);
assert_eq!(morton_index(1, 1), 3);
assert_eq!(morton_index(2, 0), 4);
for x in 0..100 {
for y in 0..100 {
let z = morton_index(x, y);
let (dx, dy) = morton_decode(z);
assert_eq!((dx, dy), (x, y));
}
}
}
struct CountingVisitor {
count: usize,
total_elements: usize,
}
impl BlockVisitor for CountingVisitor {
type Error = ();
fn visit_block(
&mut self,
row_start: usize,
row_end: usize,
col_start: usize,
col_end: usize,
) -> Result<(), ()> {
self.count += 1;
self.total_elements += (row_end - row_start) * (col_end - col_start);
Ok(())
}
}
#[test]
fn test_cache_oblivious_traverse() {
let task = RecursiveTask::from_dims(128, 128);
let mut visitor = CountingVisitor {
count: 0,
total_elements: 0,
};
cache_oblivious_traverse(&mut visitor, task, 32).unwrap();
assert!(visitor.count > 1);
assert_eq!(visitor.total_elements, 128 * 128);
}
}