#[cfg(feature = "no-std")]
use alloc::alloc::{alloc, dealloc, Layout};
#[cfg(not(feature = "no-std"))]
use std::alloc::{alloc, dealloc, Layout};
#[cfg(feature = "no-std")]
use core::ptr::NonNull;
#[cfg(not(feature = "no-std"))]
use std::ptr::NonNull;
#[cfg(feature = "no-std")]
use core::{mem, slice};
#[cfg(not(feature = "no-std"))]
use std::{mem, slice};
#[derive(Debug)]
pub struct AllocError;
#[cfg(feature = "no-std")]
impl core::fmt::Display for AllocError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Memory allocation failed")
}
}
#[cfg(not(feature = "no-std"))]
impl std::fmt::Display for AllocError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Memory allocation failed")
}
}
#[cfg(not(feature = "no-std"))]
#[cfg(not(feature = "no-std"))]
impl std::error::Error for AllocError {}
#[cfg(feature = "no-std")]
impl core::error::Error for AllocError {}
pub const CACHE_LINE_SIZE: usize = 64;
pub const L1_CACHE_SIZE: usize = 32 * 1024;
pub const L2_CACHE_SIZE: usize = 256 * 1024;
pub const L3_CACHE_SIZE: usize = 8 * 1024 * 1024;
pub const SIMD_ALIGNMENT: usize = 32;
#[derive(Debug, Clone, Copy)]
pub enum PrefetchHint {
T0,
T1,
T2,
Nta,
}
pub struct AlignedAlloc<T> {
ptr: NonNull<T>,
layout: Layout,
len: usize,
}
impl<T> AlignedAlloc<T> {
pub fn new(len: usize) -> Result<Self, AllocError> {
let layout = Layout::from_size_align(len * mem::size_of::<T>(), SIMD_ALIGNMENT)
.map_err(|_| AllocError)?;
let ptr = unsafe { alloc(layout) };
if ptr.is_null() {
return Err(AllocError);
}
Ok(Self {
ptr: unsafe { NonNull::new_unchecked(ptr as *mut T) },
layout,
len,
})
}
pub fn as_mut_slice(&mut self) -> &mut [T] {
unsafe { slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
}
pub fn as_slice(&self) -> &[T] {
unsafe { slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
}
pub fn as_ptr(&self) -> *const T {
self.ptr.as_ptr()
}
pub fn as_mut_ptr(&mut self) -> *mut T {
self.ptr.as_ptr()
}
}
impl<T> Drop for AlignedAlloc<T> {
fn drop(&mut self) {
unsafe {
dealloc(self.ptr.as_ptr() as *mut u8, self.layout);
}
}
}
pub mod prefetch {
use super::PrefetchHint;
#[inline(always)]
pub fn prefetch_read_data(_address: *const u8, _hint: PrefetchHint) {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
unsafe {
#[cfg(feature = "no-std")]
use core::arch::x86_64::*;
#[cfg(not(feature = "no-std"))]
use core::arch::x86_64::*;
match _hint {
PrefetchHint::T0 => _mm_prefetch(_address as *const i8, _MM_HINT_T0),
PrefetchHint::T1 => _mm_prefetch(_address as *const i8, _MM_HINT_T1),
PrefetchHint::T2 => _mm_prefetch(_address as *const i8, _MM_HINT_T2),
PrefetchHint::Nta => _mm_prefetch(_address as *const i8, _MM_HINT_NTA),
}
}
}
#[inline]
pub fn prefetch_range<T>(slice: &[T], hint: PrefetchHint) {
let start = slice.as_ptr() as *const u8;
let size = core::mem::size_of_val(slice);
let end = unsafe { start.add(size) };
let mut current = start;
while current < end {
prefetch_read_data(current, hint);
current = unsafe { current.add(super::CACHE_LINE_SIZE) };
}
}
}
pub mod cache_aware {
pub fn optimal_block_size(cache_size: usize, element_size: usize) -> usize {
let elements_in_cache = cache_size / element_size;
(elements_in_cache as f64).sqrt() as usize
}
pub fn transpose_blocked(
input: &[f32],
output: &mut [f32],
rows: usize,
cols: usize,
block_size: usize,
) {
assert_eq!(input.len(), rows * cols);
assert_eq!(output.len(), rows * cols);
for block_row in (0..rows).step_by(block_size) {
for block_col in (0..cols).step_by(block_size) {
let end_row = (block_row + block_size).min(rows);
let end_col = (block_col + block_size).min(cols);
for i in block_row..end_row {
for j in block_col..end_col {
output[j * rows + i] = input[i * cols + j];
}
}
}
}
}
pub fn matrix_multiply_blocked(
a: &[f32],
b: &[f32],
c: &mut [f32],
m: usize,
n: usize,
k: usize,
block_size: usize,
) {
assert_eq!(a.len(), m * k);
assert_eq!(b.len(), k * n);
assert_eq!(c.len(), m * n);
c.fill(0.0);
for kk in (0..k).step_by(block_size) {
for ii in (0..m).step_by(block_size) {
for jj in (0..n).step_by(block_size) {
let end_k = (kk + block_size).min(k);
let end_i = (ii + block_size).min(m);
let end_j = (jj + block_size).min(n);
for i in ii..end_i {
for j in jj..end_j {
let mut sum = 0.0;
for l in kk..end_k {
sum += a[i * k + l] * b[l * n + j];
}
c[i * n + j] += sum;
}
}
}
}
}
}
}
pub mod streaming {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub fn stream_store_f32(dest: &mut [f32], src: &[f32]) {
assert_eq!(dest.len(), src.len());
if !crate::simd_feature_detected!("sse2") {
dest.copy_from_slice(src);
return;
}
unsafe {
stream_store_sse2(dest, src);
}
}
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "sse2")]
unsafe fn stream_store_sse2(dest: &mut [f32], src: &[f32]) {
#[cfg(feature = "no-std")]
use core::arch::x86_64::*;
#[cfg(not(feature = "no-std"))]
use core::arch::x86_64::*;
let mut i = 0;
let len = dest.len();
while i + 4 <= len {
let data = _mm_loadu_ps(src.as_ptr().add(i));
_mm_stream_ps(dest.as_mut_ptr().add(i), data);
i += 4;
}
while i < len {
dest[i] = src[i];
i += 1;
}
_mm_sfence();
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
pub fn stream_store_f32(dest: &mut [f32], src: &[f32]) {
dest.copy_from_slice(src);
}
}
pub mod bandwidth {
use super::{prefetch::prefetch_range, PrefetchHint};
#[cfg(not(feature = "no-std"))]
use std::{mem, time::Instant};
pub fn copy_with_prefetch<T: Copy>(dest: &mut [T], src: &[T]) {
assert_eq!(dest.len(), src.len());
prefetch_range(src, PrefetchHint::Nta);
if core::mem::size_of_val(dest) > super::L1_CACHE_SIZE {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
if core::mem::size_of::<T>() == core::mem::size_of::<f32>() {
unsafe {
super::streaming::stream_store_f32(
core::slice::from_raw_parts_mut(dest.as_mut_ptr() as *mut f32, dest.len()),
core::slice::from_raw_parts(src.as_ptr() as *const f32, src.len()),
);
}
return;
}
}
dest.copy_from_slice(src);
}
#[cfg(not(feature = "no-std"))]
pub fn measure_bandwidth() -> f64 {
const SIZE: usize = 1024 * 1024; let src = vec![1.0f32; SIZE];
let mut dest = vec![0.0f32; SIZE];
let start = Instant::now();
for _ in 0..100 {
copy_with_prefetch(&mut dest, &src);
}
let elapsed = start.elapsed();
let bytes_transferred = SIZE * mem::size_of::<f32>() * 100 * 2; bytes_transferred as f64 / elapsed.as_secs_f64() / (1024.0 * 1024.0 * 1024.0)
}
#[cfg(feature = "no-std")]
pub fn measure_bandwidth() -> f64 {
1.0 }
}
#[allow(non_snake_case)]
#[cfg(all(test, not(feature = "no-std")))]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[cfg(feature = "no-std")]
use alloc::{vec, vec::Vec};
#[test]
fn test_aligned_alloc() {
let mut alloc = AlignedAlloc::<f32>::new(1024).expect("operation should succeed");
let slice = alloc.as_mut_slice();
assert_eq!(slice.as_ptr() as usize % SIMD_ALIGNMENT, 0);
slice[0] = 1.0;
slice[1023] = 2.0;
assert_eq!(slice[0], 1.0);
assert_eq!(slice[1023], 2.0);
}
#[test]
fn test_cache_aware_transpose() {
let rows = 64;
let cols = 64;
let mut input = vec![0.0f32; rows * cols];
let mut output = vec![0.0f32; rows * cols];
for i in 0..rows {
for j in 0..cols {
input[i * cols + j] = (i * cols + j) as f32;
}
}
cache_aware::transpose_blocked(&input, &mut output, rows, cols, 16);
for i in 0..rows {
for j in 0..cols {
assert_relative_eq!(output[j * rows + i], input[i * cols + j], epsilon = 1e-6);
}
}
}
#[test]
fn test_cache_aware_matrix_multiply() {
let m = 32;
let n = 32;
let k = 32;
let a = vec![1.0f32; m * k];
let b = vec![1.0f32; k * n];
let mut c = vec![0.0f32; m * n];
cache_aware::matrix_multiply_blocked(&a, &b, &mut c, m, n, k, 16);
for &val in &c {
assert_relative_eq!(val, k as f32, epsilon = 1e-6);
}
}
#[test]
fn test_stream_store() {
let src = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let mut dest = vec![0.0f32; 8];
streaming::stream_store_f32(&mut dest, &src);
for (i, &val) in dest.iter().enumerate() {
assert_relative_eq!(val, src[i], epsilon = 1e-6);
}
}
#[test]
fn test_bandwidth_measurement() {
let bandwidth = bandwidth::measure_bandwidth();
assert!(bandwidth > 0.0);
println!("Measured bandwidth: {:.2} GB/s", bandwidth);
}
#[test]
fn test_optimal_block_size() {
let block_size = cache_aware::optimal_block_size(L1_CACHE_SIZE, 4);
assert!(block_size > 0);
assert!(block_size < 1000); }
#[test]
fn test_copy_with_prefetch() {
let src = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let mut dest = vec![0.0f32; 5];
bandwidth::copy_with_prefetch(&mut dest, &src);
for (i, &val) in dest.iter().enumerate() {
assert_relative_eq!(val, src[i], epsilon = 1e-6);
}
}
}