use core::alloc::Layout;
use core::mem::size_of;
#[cfg(not(feature = "std"))]
use alloc::alloc::{alloc, alloc_zeroed, dealloc};
#[cfg(feature = "std")]
use std::alloc::{alloc, alloc_zeroed, dealloc};
pub unsafe trait Alloc: Clone {
fn allocate(&self, layout: Layout) -> *mut u8;
fn allocate_zeroed(&self, layout: Layout) -> *mut u8;
unsafe fn deallocate(&self, ptr: *mut u8, layout: Layout);
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Global;
unsafe impl Alloc for Global {
#[inline]
fn allocate(&self, layout: Layout) -> *mut u8 {
if layout.size() == 0 {
return layout.align() as *mut u8;
}
unsafe { alloc(layout) }
}
#[inline]
fn allocate_zeroed(&self, layout: Layout) -> *mut u8 {
if layout.size() == 0 {
return layout.align() as *mut u8;
}
unsafe { alloc_zeroed(layout) }
}
#[inline]
unsafe fn deallocate(&self, ptr: *mut u8, layout: Layout) {
if layout.size() != 0 {
dealloc(ptr, layout);
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PrefetchLocality {
NonTemporal,
Low,
Medium,
High,
}
#[inline]
pub fn prefetch_read<T>(ptr: *const T, locality: PrefetchLocality) {
#[cfg(target_arch = "x86_64")]
{
use core::arch::x86_64::*;
unsafe {
match locality {
PrefetchLocality::NonTemporal => _mm_prefetch(ptr.cast(), _MM_HINT_NTA),
PrefetchLocality::Low => _mm_prefetch(ptr.cast(), _MM_HINT_T2),
PrefetchLocality::Medium => _mm_prefetch(ptr.cast(), _MM_HINT_T1),
PrefetchLocality::High => _mm_prefetch(ptr.cast(), _MM_HINT_T0),
}
}
}
#[cfg(target_arch = "aarch64")]
{
unsafe {
match locality {
PrefetchLocality::NonTemporal | PrefetchLocality::Low => {
core::arch::asm!(
"prfm pldl3keep, [{0}]",
in(reg) ptr,
options(nostack, preserves_flags)
);
}
PrefetchLocality::Medium => {
core::arch::asm!(
"prfm pldl2keep, [{0}]",
in(reg) ptr,
options(nostack, preserves_flags)
);
}
PrefetchLocality::High => {
core::arch::asm!(
"prfm pldl1keep, [{0}]",
in(reg) ptr,
options(nostack, preserves_flags)
);
}
}
}
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
let _ = (ptr, locality);
}
}
#[inline]
pub fn prefetch_write<T>(ptr: *mut T, locality: PrefetchLocality) {
#[cfg(target_arch = "x86_64")]
{
use core::arch::x86_64::*;
unsafe {
match locality {
PrefetchLocality::NonTemporal => _mm_prefetch(ptr.cast(), _MM_HINT_NTA),
PrefetchLocality::Low => _mm_prefetch(ptr.cast(), _MM_HINT_T2),
PrefetchLocality::Medium => _mm_prefetch(ptr.cast(), _MM_HINT_T1),
PrefetchLocality::High => _mm_prefetch(ptr.cast(), _MM_HINT_T0),
}
}
}
#[cfg(target_arch = "aarch64")]
{
unsafe {
match locality {
PrefetchLocality::NonTemporal | PrefetchLocality::Low => {
core::arch::asm!(
"prfm pstl3keep, [{0}]",
in(reg) ptr,
options(nostack, preserves_flags)
);
}
PrefetchLocality::Medium => {
core::arch::asm!(
"prfm pstl2keep, [{0}]",
in(reg) ptr,
options(nostack, preserves_flags)
);
}
PrefetchLocality::High => {
core::arch::asm!(
"prfm pstl1keep, [{0}]",
in(reg) ptr,
options(nostack, preserves_flags)
);
}
}
}
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
let _ = (ptr, locality);
}
}
#[inline]
pub fn prefetch_read_range<T>(ptr: *const T, count: usize, locality: PrefetchLocality) {
let bytes = count * size_of::<T>();
let num_lines = bytes.div_ceil(CACHE_LINE_SIZE);
for i in 0..num_lines {
let offset = i * CACHE_LINE_SIZE;
prefetch_read(unsafe { (ptr as *const u8).add(offset) }, locality);
}
}
#[inline]
pub fn prefetch_write_range<T>(ptr: *mut T, count: usize, locality: PrefetchLocality) {
let bytes = count * size_of::<T>();
let num_lines = bytes.div_ceil(CACHE_LINE_SIZE);
for i in 0..num_lines {
let offset = i * CACHE_LINE_SIZE;
prefetch_write(unsafe { (ptr as *mut u8).add(offset) }, locality);
}
}
#[derive(Debug, Clone, Copy)]
pub struct PrefetchDistance {
pub lines_ahead: usize,
}
impl Default for PrefetchDistance {
fn default() -> Self {
Self { lines_ahead: 8 }
}
}
impl PrefetchDistance {
pub const fn new(lines_ahead: usize) -> Self {
Self { lines_ahead }
}
#[inline]
pub const fn offset_bytes(&self) -> usize {
self.lines_ahead * CACHE_LINE_SIZE
}
#[inline]
pub const fn offset_elements<T>(&self) -> usize {
self.offset_bytes() / size_of::<T>()
}
}
#[cfg(target_arch = "aarch64")]
pub const CACHE_LINE_SIZE: usize = 128;
#[cfg(not(target_arch = "aarch64"))]
pub const CACHE_LINE_SIZE: usize = 64;
pub const DEFAULT_ALIGN: usize = CACHE_LINE_SIZE;
#[inline]
pub const fn aligned_size<T>(count: usize, align: usize) -> usize {
let size = count * size_of::<T>();
(size + align - 1) & !(align - 1)
}
#[inline]
pub const fn elements_per_aligned_bytes<T>(bytes: usize) -> usize {
bytes / size_of::<T>()
}
#[inline]
pub const fn round_up_pow2(value: usize, align: usize) -> usize {
debug_assert!(align.is_power_of_two());
(value + align - 1) & !(align - 1)
}
#[inline]
pub fn is_aligned<T>(ptr: *const T, align: usize) -> bool {
(ptr as usize) % align == 0
}