pub use rlx_ir::{Coord2, Strides2, Tile2};
pub trait TileIO {
unsafe fn address(&self, base: *const f32, c: Coord2) -> *const f32;
#[inline(always)]
unsafe fn load(&self, base: *const f32, c: Coord2) -> f32 {
unsafe { *self.address(base, c) }
}
#[inline(always)]
unsafe fn store(&self, base: *mut f32, c: Coord2, v: f32) {
unsafe {
*(self.address(base, c) as *mut f32) = v;
}
}
#[inline(always)]
unsafe fn prefetch(&self, base: *const f32, c: Coord2) {
unsafe {
let addr = self.address(base, c);
#[cfg(target_arch = "aarch64")]
{
std::arch::asm!("prfm pldl1keep, [{0}]", in(reg) addr,
options(nostack, readonly));
}
#[cfg(not(target_arch = "aarch64"))]
{
let _ = addr;
}
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct RowMajorTile {
pub shape: Tile2,
}
impl TileIO for RowMajorTile {
#[inline(always)]
unsafe fn address(&self, base: *const f32, c: Coord2) -> *const f32 {
unsafe { base.add(c.row * self.shape.cols + c.col) }
}
}
#[derive(Debug, Clone, Copy)]
pub struct StridedTile {
pub shape: Tile2,
pub strides: Strides2,
}
impl TileIO for StridedTile {
#[inline(always)]
unsafe fn address(&self, base: *const f32, c: Coord2) -> *const f32 {
unsafe { base.add(c.row * self.strides.row + c.col * self.strides.col) }
}
}
#[inline(always)]
pub fn for_each_coord(shape: Tile2, mut f: impl FnMut(Coord2)) {
for r in 0..shape.rows {
for c in 0..shape.cols {
f(Coord2 { row: r, col: c });
}
}
}
#[inline]
pub unsafe fn copy_tile<S: TileIO, D: TileIO>(
src_io: &S,
src_base: *const f32,
dst_io: &D,
dst_base: *mut f32,
shape: Tile2,
) {
for_each_coord(shape, |c| unsafe {
dst_io.store(dst_base, c, src_io.load(src_base, c));
});
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn row_major_round_trip() {
let mut buf = [0f32; 12]; let io = RowMajorTile {
shape: Tile2::new(3, 4),
};
unsafe {
io.store(buf.as_mut_ptr(), Coord2 { row: 1, col: 2 }, 42.0);
assert_eq!(io.load(buf.as_ptr(), Coord2 { row: 1, col: 2 }), 42.0);
}
assert_eq!(buf[4 + 2], 42.0);
}
#[test]
fn strided_reads_non_contig_view() {
let parent: Vec<f32> = (0..32).map(|i| i as f32).collect();
let view = StridedTile {
shape: Tile2::new(4, 4),
strides: Strides2 { row: 8, col: 1 },
};
let base = unsafe { parent.as_ptr().add(2) };
let v = unsafe { view.load(base, Coord2 { row: 1, col: 1 }) };
assert_eq!(v, 11.0);
}
#[test]
fn prefetch_doesnt_panic() {
let buf = vec![0f32; 64];
let io = RowMajorTile {
shape: Tile2::new(8, 8),
};
unsafe {
io.prefetch(buf.as_ptr(), Coord2 { row: 0, col: 0 });
io.prefetch(buf.as_ptr(), Coord2 { row: 7, col: 7 });
}
}
#[test]
fn copy_tile_strided_to_contig() {
let parent: Vec<f32> = (0..32).map(|i| i as f32).collect();
let mut dst = vec![0f32; 16]; let src_io = StridedTile {
shape: Tile2::new(4, 4),
strides: Strides2 { row: 8, col: 1 },
};
let dst_io = RowMajorTile {
shape: Tile2::new(4, 4),
};
let base = unsafe { parent.as_ptr().add(2) };
unsafe {
copy_tile(&src_io, base, &dst_io, dst.as_mut_ptr(), Tile2::new(4, 4));
}
assert_eq!(&dst[0..4], &[2.0, 3.0, 4.0, 5.0]);
assert_eq!(&dst[4..8], &[10.0, 11.0, 12.0, 13.0]);
}
}