#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::__cpuid;
use std::cmp;
use std::mem;
use std::ptr;
#[derive(Debug, Clone)]
struct CacheInfo {
line_size: usize,
l1_size: usize,
l2_size: usize,
l3_size: usize,
#[allow(dead_code)]
associativity: usize,
}
lazy_static::lazy_static! {
static ref CACHE_DATA: CacheInfo = detect_cache_info();
}
#[derive(Debug, Copy, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum LayoutStrategy {
RowMajor,
ColumnMajor,
Morton,
Hilbert,
CacheOblivious,
Blocked(usize), }
pub fn optimize_layout<T: Copy>(data: &mut [T], strategy: LayoutStrategy) {
match strategy {
LayoutStrategy::RowMajor => {
align_for_cache_line(data);
}
LayoutStrategy::ColumnMajor => {
optimize_for_column_access(data);
}
LayoutStrategy::Morton => {
apply_morton_order(data);
}
LayoutStrategy::Hilbert => {
apply_hilbert_order(data);
}
LayoutStrategy::CacheOblivious => {
apply_cache_oblivious_layout(data);
}
LayoutStrategy::Blocked(block_size) => {
apply_blocked_layout(data, block_size);
}
}
}
fn align_for_cache_line<T: Copy>(data: &mut [T]) {
let cache_line_size = get_cache_line_size();
let data_ptr = data.as_ptr() as usize;
let misalignment = data_ptr % cache_line_size;
if misalignment == 0 {
return;
}
let shift = cache_line_size - misalignment;
if shift < std::mem::size_of_val(data) {
unsafe {
let src = data.as_ptr();
let dst = (data.as_mut_ptr() as *mut u8).add(shift) as *mut T;
ptr::copy(src, dst, data.len());
}
}
}
fn get_cache_line_size() -> usize {
get_cache_info().line_size
}
fn get_cache_info() -> &'static CacheInfo {
&CACHE_DATA
}
fn detect_cache_info() -> CacheInfo {
#[cfg(target_arch = "x86_64")]
{
detect_x86_cache_info()
}
#[cfg(not(target_arch = "x86_64"))]
{
CacheInfo {
line_size: 64,
l1_size: 32 * 1024,
l2_size: 256 * 1024,
l3_size: 8 * 1024 * 1024,
associativity: 8,
}
}
}
#[cfg(target_arch = "x86_64")]
fn detect_x86_cache_info() -> CacheInfo {
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::__cpuid;
let mut info = CacheInfo {
line_size: 64,
l1_size: 32 * 1024,
l2_size: 256 * 1024,
l3_size: 8 * 1024 * 1024,
associativity: 8,
};
let cpuid_result = __cpuid(0x80000000);
if cpuid_result.eax >= 0x80000006 {
let cache_result = __cpuid(0x80000006);
info.l1_size = ((cache_result.ecx >> 24) & 0xFF) as usize * 1024;
info.line_size = (cache_result.ecx & 0xFF) as usize;
info.associativity = ((cache_result.ecx >> 16) & 0xFF) as usize;
info.l2_size = ((cache_result.ecx >> 16) & 0xFFFF) as usize * 1024;
info.l3_size = ((cache_result.edx >> 18) & 0x3FFF) as usize * 512 * 1024;
}
let vendor_result = __cpuid(0);
if vendor_result.ebx == 0x756e6547 && vendor_result.edx == 0x49656e69 && vendor_result.ecx == 0x6c65746e
{
detect_intel_cache_info(&mut info);
}
if vendor_result.ebx == 0x68747541 && vendor_result.edx == 0x69746e65 && vendor_result.ecx == 0x444d4163
{
detect_amd_cache_info(&mut info);
}
info
}
#[cfg(target_arch = "x86_64")]
fn detect_intel_cache_info(info: &mut CacheInfo) {
unsafe {
let mut cache_level = 0;
loop {
let cache_info = __cpuid_count(4, cache_level);
if cache_info.eax & 0x1F == 0 {
break;
}
let cache_type = cache_info.eax & 0x1F;
let level = (cache_info.eax >> 5) & 0x7;
let line_size = ((cache_info.ebx & 0xFFF) + 1) as usize;
let partitions = (((cache_info.ebx >> 12) & 0x3FF) + 1) as usize;
let ways = (((cache_info.ebx >> 22) & 0x3FF) + 1) as usize;
let sets = (cache_info.ecx + 1) as usize;
let size = line_size * partitions * ways * sets;
if cache_type == 1 || cache_type == 3 {
match level {
1 => {
info.l1_size = size;
info.line_size = line_size;
info.associativity = ways;
}
2 => info.l2_size = size,
3 => info.l3_size = size,
_ => {}
}
}
cache_level += 1;
if cache_level > 10 {
break;
}
}
}
}
#[cfg(target_arch = "x86_64")]
fn detect_amd_cache_info(info: &mut CacheInfo) {
let l1_info = __cpuid(0x80000005);
info.l1_size = ((l1_info.ecx >> 24) & 0xFF) as usize * 1024;
info.line_size = (l1_info.ecx & 0xFF) as usize;
info.associativity = ((l1_info.ecx >> 16) & 0xFF) as usize;
let l23_info = __cpuid(0x80000006);
info.l2_size = ((l23_info.ecx >> 16) & 0xFFFF) as usize * 1024;
info.l3_size = ((l23_info.edx >> 18) & 0x3FFF) as usize * 512 * 1024;
}
#[cfg(target_arch = "x86_64")]
unsafe fn __cpuid_count(leaf: u32, sub_leaf: u32) -> std::arch::x86_64::CpuidResult {
let mut eax = leaf;
let mut ecx = sub_leaf;
let mut edx = 0;
let ebx: u32;
std::arch::asm!(
"push rbx", "cpuid", "mov {0:e}, ebx", "pop rbx", out(reg) ebx,
inout("eax") eax,
inout("ecx") ecx,
inout("edx") edx,
);
std::arch::x86_64::CpuidResult { eax, ebx, ecx, edx }
}
pub fn calculate_optimal_block_size<T>() -> usize {
let l1_cache_size = get_l1_cache_size();
let type_size = mem::size_of::<T>();
let elements_per_cache = l1_cache_size / type_size;
let block_size = (elements_per_cache as f64).sqrt() as usize;
block_size.clamp(1, 1024)
}
fn optimize_for_column_access<T: Copy>(data: &mut [T]) {
prefetch_data_pattern(data, get_cache_line_size());
}
fn apply_morton_order<T: Copy>(data: &mut [T]) {
let len = data.len();
if len < 4 {
return; }
let side = (len as f64).sqrt() as usize;
if side * side != len {
apply_blocked_layout(data, calculate_optimal_block_size::<T>());
return;
}
let mut temp = vec![data[0]; len];
for (i, temp_item) in temp.iter_mut().enumerate().take(len) {
let (x, y) = morton_decode(i, side);
if x < side && y < side {
let linear_index = y * side + x;
if linear_index < len {
*temp_item = data[linear_index];
}
}
}
data.copy_from_slice(&temp);
}
fn apply_hilbert_order<T: Copy>(data: &mut [T]) {
let len = data.len();
if len < 4 {
return; }
let side = (len as f64).sqrt() as usize;
if side * side != len || !side.is_power_of_two() {
apply_morton_order(data);
return;
}
let mut temp = vec![data[0]; len];
for (i, temp_item) in temp.iter_mut().enumerate().take(len) {
let (x, y) = hilbert_decode(i, side);
if x < side && y < side {
let linear_index = y * side + x;
if linear_index < len {
*temp_item = data[linear_index];
}
}
}
data.copy_from_slice(&temp);
}
fn apply_cache_oblivious_layout<T: Copy>(data: &mut [T]) {
if data.len() <= get_cache_line_size() / mem::size_of::<T>() {
return; }
cache_oblivious_recursive(data, 0, data.len());
}
fn cache_oblivious_recursive<T: Copy>(data: &mut [T], start: usize, end: usize) {
let len = end - start;
if len <= 1 {
return;
}
let cache_size = get_cache_info().l1_size / mem::size_of::<T>();
if len <= cache_size {
return; }
let mid = start + len / 2;
cache_oblivious_recursive(data, start, mid);
cache_oblivious_recursive(data, mid, end);
interleave_data(&mut data[start..end]);
}
fn apply_blocked_layout<T: Copy>(data: &mut [T], block_size: usize) {
let len = data.len();
if len < block_size * block_size {
return; }
let side = (len as f64).sqrt() as usize;
if side * side != len {
return; }
let mut temp = vec![data[0]; len];
let mut temp_idx = 0;
for block_row in (0..side).step_by(block_size) {
for block_col in (0..side).step_by(block_size) {
let max_row = cmp::min(block_row + block_size, side);
let max_col = cmp::min(block_col + block_size, side);
for row in block_row..max_row {
for col in block_col..max_col {
let linear_idx = row * side + col;
if linear_idx < len && temp_idx < len {
temp[temp_idx] = data[linear_idx];
temp_idx += 1;
}
}
}
}
}
data.copy_from_slice(&temp);
}
fn prefetch_data_pattern<T: Copy>(data: &mut [T], cache_line_size: usize) {
let elements_per_line = cache_line_size / mem::size_of::<T>();
for i in (0..data.len()).step_by(elements_per_line) {
if i + elements_per_line < data.len() {
#[cfg(target_arch = "x86_64")]
unsafe {
{
let ptr = data.as_ptr().add(i + elements_per_line);
std::arch::x86_64::_mm_prefetch(
ptr as *const i8,
std::arch::x86_64::_MM_HINT_T0,
);
}
}
}
}
}
fn morton_decode(morton: usize, side: usize) -> (usize, usize) {
let mut x = 0;
let mut y = 0;
let mut bit = 0;
let mut m = morton;
while m > 0 && bit < 32 {
if (m & 1) != 0 {
x |= 1 << (bit / 2);
}
m >>= 1;
if (m & 1) != 0 {
y |= 1 << (bit / 2);
}
m >>= 1;
bit += 2;
}
(x % side, y % side)
}
fn hilbert_decode(h: usize, n: usize) -> (usize, usize) {
let mut t = h;
let mut x = 0;
let mut y = 0;
let mut s = 1;
while s < n {
let rx = 1 & (t / 2);
let ry = 1 & (t ^ rx);
if ry == 0 {
if rx == 1 {
x = s - 1 - x;
y = s - 1 - y;
}
std::mem::swap(&mut x, &mut y);
}
x += s * rx;
y += s * ry;
t /= 4;
s *= 2;
}
(x % n, y % n)
}
fn interleave_data<T: Copy>(data: &mut [T]) {
let len = data.len();
if len < 2 {
return;
}
let mid = len / 2;
let mut temp = vec![data[0]; len];
for i in 0..mid {
temp[2 * i] = data[i];
if 2 * i + 1 < len && i + mid < len {
temp[2 * i + 1] = data[i + mid];
}
}
if len % 2 == 1 {
temp[len - 1] = data[len - 1];
}
data.copy_from_slice(&temp);
}
fn get_l1_cache_size() -> usize {
get_cache_info().l1_size
}
#[allow(dead_code)]
fn get_l2_cache_size() -> usize {
get_cache_info().l2_size
}
#[allow(dead_code)]
fn get_l3_cache_size() -> usize {
get_cache_info().l3_size
}