const DEFAULT_LOOKAHEAD_BLOCKS: usize = 4;
#[derive(Debug, Clone)]
pub struct PrefetchConfig {
pub lookahead_blocks: usize,
pub strategy: PrefetchStrategy,
}
impl Default for PrefetchConfig {
fn default() -> Self {
Self {
lookahead_blocks: DEFAULT_LOOKAHEAD_BLOCKS,
strategy: PrefetchStrategy::Temporal,
}
}
}
impl PrefetchConfig {
pub fn for_gemv() -> Self {
Self {
lookahead_blocks: 4,
strategy: PrefetchStrategy::Temporal,
}
}
pub fn for_gemm(batch_size: usize) -> Self {
if batch_size > 32 {
Self {
lookahead_blocks: 8,
strategy: PrefetchStrategy::NonTemporal,
}
} else {
Self {
lookahead_blocks: 4,
strategy: PrefetchStrategy::Temporal,
}
}
}
pub fn none() -> Self {
Self {
lookahead_blocks: 0,
strategy: PrefetchStrategy::None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PrefetchStrategy {
None,
Temporal,
NonTemporal,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PrefetchLocality {
High,
Medium,
Low,
}
#[inline(always)]
pub fn prefetch_read<T>(ptr: *const T, locality: PrefetchLocality) {
#[cfg(target_arch = "x86_64")]
{
prefetch_read_x86(ptr.cast::<i8>(), locality);
}
#[cfg(target_arch = "aarch64")]
{
prefetch_read_aarch64(ptr.cast::<i8>(), locality);
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
let _ = ptr;
let _ = locality;
}
}
#[inline(always)]
pub fn prefetch_write<T>(ptr: *mut T, locality: PrefetchLocality) {
#[cfg(target_arch = "x86_64")]
{
prefetch_write_x86(ptr.cast::<i8>(), locality);
}
#[cfg(target_arch = "aarch64")]
{
prefetch_write_aarch64(ptr.cast::<i8>(), locality);
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
let _ = ptr;
let _ = locality;
}
}
#[inline]
pub fn prefetch_range_read<T>(ptr: *const T, byte_count: usize, locality: PrefetchLocality) {
let cache_line = 64usize;
let mut offset = 0;
while offset < byte_count {
let addr = unsafe { (ptr as *const u8).add(offset) };
prefetch_read(addr, locality);
offset += cache_line;
}
}
#[cfg(target_arch = "x86_64")]
#[inline(always)]
fn prefetch_read_x86(ptr: *const i8, locality: PrefetchLocality) {
unsafe {
match locality {
PrefetchLocality::High => {
core::arch::x86_64::_mm_prefetch(ptr, core::arch::x86_64::_MM_HINT_T0);
}
PrefetchLocality::Medium => {
core::arch::x86_64::_mm_prefetch(ptr, core::arch::x86_64::_MM_HINT_T1);
}
PrefetchLocality::Low => {
core::arch::x86_64::_mm_prefetch(ptr, core::arch::x86_64::_MM_HINT_NTA);
}
}
}
}
#[cfg(target_arch = "x86_64")]
#[inline(always)]
fn prefetch_write_x86(ptr: *const i8, locality: PrefetchLocality) {
prefetch_read_x86(ptr, locality);
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
fn prefetch_read_aarch64(ptr: *const i8, locality: PrefetchLocality) {
unsafe {
match locality {
PrefetchLocality::High => {
core::arch::aarch64::_prefetch(ptr, 0, 3); }
PrefetchLocality::Medium => {
core::arch::aarch64::_prefetch(ptr, 0, 2); }
PrefetchLocality::Low => {
core::arch::aarch64::_prefetch(ptr, 0, 0); }
}
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
fn prefetch_write_aarch64(ptr: *const i8, locality: PrefetchLocality) {
unsafe {
match locality {
PrefetchLocality::High => {
core::arch::aarch64::_prefetch(ptr, 1, 3);
}
PrefetchLocality::Medium => {
core::arch::aarch64::_prefetch(ptr, 1, 2);
}
PrefetchLocality::Low => {
core::arch::aarch64::_prefetch(ptr, 1, 0);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn prefetch_config_defaults() {
let config = PrefetchConfig::default();
assert_eq!(config.lookahead_blocks, 4);
assert_eq!(config.strategy, PrefetchStrategy::Temporal);
}
#[test]
fn prefetch_config_for_gemv() {
let config = PrefetchConfig::for_gemv();
assert_eq!(config.strategy, PrefetchStrategy::Temporal);
assert!(config.lookahead_blocks > 0);
}
#[test]
fn prefetch_config_for_gemm_small_batch() {
let config = PrefetchConfig::for_gemm(4);
assert_eq!(config.strategy, PrefetchStrategy::Temporal);
}
#[test]
fn prefetch_config_for_gemm_large_batch() {
let config = PrefetchConfig::for_gemm(64);
assert_eq!(config.strategy, PrefetchStrategy::NonTemporal);
assert!(config.lookahead_blocks > 4);
}
#[test]
fn prefetch_config_none() {
let config = PrefetchConfig::none();
assert_eq!(config.lookahead_blocks, 0);
assert_eq!(config.strategy, PrefetchStrategy::None);
}
#[test]
fn prefetch_read_smoke_test() {
let data = [1.0f32, 2.0, 3.0, 4.0];
prefetch_read(data.as_ptr(), PrefetchLocality::High);
prefetch_read(data.as_ptr(), PrefetchLocality::Medium);
prefetch_read(data.as_ptr(), PrefetchLocality::Low);
}
#[test]
fn prefetch_write_smoke_test() {
let mut data = [0.0f32; 16];
prefetch_write(data.as_mut_ptr(), PrefetchLocality::High);
prefetch_write(data.as_mut_ptr(), PrefetchLocality::Medium);
prefetch_write(data.as_mut_ptr(), PrefetchLocality::Low);
data[0] = 42.0;
assert!((data[0] - 42.0).abs() < f32::EPSILON);
}
#[test]
fn prefetch_range_read_smoke_test() {
let data = vec![0.0f32; 1024];
let byte_count = data.len() * std::mem::size_of::<f32>();
prefetch_range_read(data.as_ptr(), byte_count, PrefetchLocality::High);
prefetch_range_read(data.as_ptr(), byte_count, PrefetchLocality::Low);
}
#[test]
fn prefetch_strategy_equality() {
assert_eq!(PrefetchStrategy::None, PrefetchStrategy::None);
assert_ne!(PrefetchStrategy::Temporal, PrefetchStrategy::NonTemporal);
}
#[test]
fn prefetch_locality_equality() {
assert_eq!(PrefetchLocality::High, PrefetchLocality::High);
assert_ne!(PrefetchLocality::High, PrefetchLocality::Low);
}
}