numrs2/memory_optimize/
cache_layout.rs1#[cfg(target_arch = "x86_64")]
7use std::arch::x86_64::__cpuid;
8use std::cmp;
9use std::mem;
10use std::ptr;
11
12#[derive(Debug, Clone)]
14struct CacheInfo {
15 line_size: usize,
16 l1_size: usize,
17 l2_size: usize,
18 l3_size: usize,
19 #[allow(dead_code)]
20 associativity: usize,
21}
22
23lazy_static::lazy_static! {
24 static ref CACHE_DATA: CacheInfo = detect_cache_info();
25}
26
27#[derive(Debug, Copy, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
29pub enum LayoutStrategy {
30 RowMajor,
32 ColumnMajor,
34 Morton,
36 Hilbert,
38 CacheOblivious,
40 Blocked(usize), }
43
44pub fn optimize_layout<T: Copy>(data: &mut [T], strategy: LayoutStrategy) {
55 match strategy {
56 LayoutStrategy::RowMajor => {
57 align_for_cache_line(data);
60 }
61 LayoutStrategy::ColumnMajor => {
62 optimize_for_column_access(data);
66 }
67 LayoutStrategy::Morton => {
68 apply_morton_order(data);
70 }
71 LayoutStrategy::Hilbert => {
72 apply_hilbert_order(data);
74 }
75 LayoutStrategy::CacheOblivious => {
76 apply_cache_oblivious_layout(data);
78 }
79 LayoutStrategy::Blocked(block_size) => {
80 apply_blocked_layout(data, block_size);
82 }
83 }
84}
85
86fn align_for_cache_line<T: Copy>(data: &mut [T]) {
91 let cache_line_size = get_cache_line_size();
93
94 let data_ptr = data.as_ptr() as usize;
96 let misalignment = data_ptr % cache_line_size;
97
98 if misalignment == 0 {
99 return;
101 }
102
103 let shift = cache_line_size - misalignment;
107 if shift < std::mem::size_of_val(data) {
108 unsafe {
109 let src = data.as_ptr();
110 let dst = (data.as_mut_ptr() as *mut u8).add(shift) as *mut T;
111 ptr::copy(src, dst, data.len());
112 }
113 }
114}
115
116fn get_cache_line_size() -> usize {
121 get_cache_info().line_size
122}
123
124fn get_cache_info() -> &'static CacheInfo {
126 &CACHE_DATA
127}
128
129fn detect_cache_info() -> CacheInfo {
131 #[cfg(target_arch = "x86_64")]
132 {
133 detect_x86_cache_info()
134 }
135
136 #[cfg(not(target_arch = "x86_64"))]
137 {
138 CacheInfo {
140 line_size: 64,
141 l1_size: 32 * 1024,
142 l2_size: 256 * 1024,
143 l3_size: 8 * 1024 * 1024,
144 associativity: 8,
145 }
146 }
147}
148
149#[cfg(target_arch = "x86_64")]
150fn detect_x86_cache_info() -> CacheInfo {
151 #[cfg(target_arch = "x86_64")]
152 use std::arch::x86_64::__cpuid;
153
154 let mut info = CacheInfo {
155 line_size: 64,
156 l1_size: 32 * 1024,
157 l2_size: 256 * 1024,
158 l3_size: 8 * 1024 * 1024,
159 associativity: 8,
160 };
161
162 let cpuid_result = __cpuid(0x80000000);
165 if cpuid_result.eax >= 0x80000006 {
166 let cache_result = __cpuid(0x80000006);
167
168 info.l1_size = ((cache_result.ecx >> 24) & 0xFF) as usize * 1024;
170 info.line_size = (cache_result.ecx & 0xFF) as usize;
171 info.associativity = ((cache_result.ecx >> 16) & 0xFF) as usize;
172
173 info.l2_size = ((cache_result.ecx >> 16) & 0xFFFF) as usize * 1024;
175
176 info.l3_size = ((cache_result.edx >> 18) & 0x3FFF) as usize * 512 * 1024;
178 }
179
180 let vendor_result = __cpuid(0);
182 if vendor_result.ebx == 0x756e6547 && vendor_result.edx == 0x49656e69 && vendor_result.ecx == 0x6c65746e
185 {
186 detect_intel_cache_info(&mut info);
188 }
189
190 if vendor_result.ebx == 0x68747541 && vendor_result.edx == 0x69746e65 && vendor_result.ecx == 0x444d4163
194 {
195 detect_amd_cache_info(&mut info);
197 }
198
199 info
200}
201
202#[cfg(target_arch = "x86_64")]
203fn detect_intel_cache_info(info: &mut CacheInfo) {
204 unsafe {
206 let mut cache_level = 0;
207 loop {
208 let cache_info = __cpuid_count(4, cache_level);
209
210 if cache_info.eax & 0x1F == 0 {
212 break;
213 }
214
215 let cache_type = cache_info.eax & 0x1F;
216 let level = (cache_info.eax >> 5) & 0x7;
217 let line_size = ((cache_info.ebx & 0xFFF) + 1) as usize;
218 let partitions = (((cache_info.ebx >> 12) & 0x3FF) + 1) as usize;
219 let ways = (((cache_info.ebx >> 22) & 0x3FF) + 1) as usize;
220 let sets = (cache_info.ecx + 1) as usize;
221
222 let size = line_size * partitions * ways * sets;
223
224 if cache_type == 1 || cache_type == 3 {
226 match level {
227 1 => {
228 info.l1_size = size;
229 info.line_size = line_size;
230 info.associativity = ways;
231 }
232 2 => info.l2_size = size,
233 3 => info.l3_size = size,
234 _ => {}
235 }
236 }
237
238 cache_level += 1;
239 if cache_level > 10 {
240 break;
242 }
243 }
244 }
245}
246
247#[cfg(target_arch = "x86_64")]
248fn detect_amd_cache_info(info: &mut CacheInfo) {
249 let l1_info = __cpuid(0x80000005);
253 info.l1_size = ((l1_info.ecx >> 24) & 0xFF) as usize * 1024;
254 info.line_size = (l1_info.ecx & 0xFF) as usize;
255 info.associativity = ((l1_info.ecx >> 16) & 0xFF) as usize;
256
257 let l23_info = __cpuid(0x80000006);
259 info.l2_size = ((l23_info.ecx >> 16) & 0xFFFF) as usize * 1024;
260 info.l3_size = ((l23_info.edx >> 18) & 0x3FFF) as usize * 512 * 1024;
261}
262
263#[cfg(target_arch = "x86_64")]
264unsafe fn __cpuid_count(leaf: u32, sub_leaf: u32) -> std::arch::x86_64::CpuidResult {
265 let mut eax = leaf;
266 let mut ecx = sub_leaf;
267 let mut edx = 0;
268
269 let ebx: u32;
271 std::arch::asm!(
272 "push rbx", "cpuid", "mov {0:e}, ebx", "pop rbx", out(reg) ebx,
277 inout("eax") eax,
278 inout("ecx") ecx,
279 inout("edx") edx,
280 );
281
282 std::arch::x86_64::CpuidResult { eax, ebx, ecx, edx }
283}
284
285pub fn calculate_optimal_block_size<T>() -> usize {
290 let l1_cache_size = get_l1_cache_size();
292 let type_size = mem::size_of::<T>();
293
294 let elements_per_cache = l1_cache_size / type_size;
297 let block_size = (elements_per_cache as f64).sqrt() as usize;
298
299 block_size.clamp(1, 1024)
301}
302
303fn optimize_for_column_access<T: Copy>(data: &mut [T]) {
305 prefetch_data_pattern(data, get_cache_line_size());
309}
310
311fn apply_morton_order<T: Copy>(data: &mut [T]) {
313 let len = data.len();
314 if len < 4 {
315 return; }
317
318 let side = (len as f64).sqrt() as usize;
321 if side * side != len {
322 apply_blocked_layout(data, calculate_optimal_block_size::<T>());
324 return;
325 }
326
327 let mut temp = vec![data[0]; len];
329
330 for (i, temp_item) in temp.iter_mut().enumerate().take(len) {
332 let (x, y) = morton_decode(i, side);
333 if x < side && y < side {
334 let linear_index = y * side + x;
335 if linear_index < len {
336 *temp_item = data[linear_index];
337 }
338 }
339 }
340
341 data.copy_from_slice(&temp);
343}
344
345fn apply_hilbert_order<T: Copy>(data: &mut [T]) {
347 let len = data.len();
348 if len < 4 {
349 return; }
351
352 let side = (len as f64).sqrt() as usize;
354 if side * side != len || !side.is_power_of_two() {
355 apply_morton_order(data);
357 return;
358 }
359
360 let mut temp = vec![data[0]; len];
362
363 for (i, temp_item) in temp.iter_mut().enumerate().take(len) {
365 let (x, y) = hilbert_decode(i, side);
366 if x < side && y < side {
367 let linear_index = y * side + x;
368 if linear_index < len {
369 *temp_item = data[linear_index];
370 }
371 }
372 }
373
374 data.copy_from_slice(&temp);
376}
377
378fn apply_cache_oblivious_layout<T: Copy>(data: &mut [T]) {
380 if data.len() <= get_cache_line_size() / mem::size_of::<T>() {
381 return; }
383
384 cache_oblivious_recursive(data, 0, data.len());
386}
387
388fn cache_oblivious_recursive<T: Copy>(data: &mut [T], start: usize, end: usize) {
390 let len = end - start;
391 if len <= 1 {
392 return;
393 }
394
395 let cache_size = get_cache_info().l1_size / mem::size_of::<T>();
396 if len <= cache_size {
397 return; }
399
400 let mid = start + len / 2;
402 cache_oblivious_recursive(data, start, mid);
403 cache_oblivious_recursive(data, mid, end);
404
405 interleave_data(&mut data[start..end]);
407}
408
409fn apply_blocked_layout<T: Copy>(data: &mut [T], block_size: usize) {
411 let len = data.len();
412 if len < block_size * block_size {
413 return; }
415
416 let side = (len as f64).sqrt() as usize;
418 if side * side != len {
419 return; }
421
422 let mut temp = vec![data[0]; len];
424 let mut temp_idx = 0;
425
426 for block_row in (0..side).step_by(block_size) {
428 for block_col in (0..side).step_by(block_size) {
429 let max_row = cmp::min(block_row + block_size, side);
430 let max_col = cmp::min(block_col + block_size, side);
431
432 for row in block_row..max_row {
433 for col in block_col..max_col {
434 let linear_idx = row * side + col;
435 if linear_idx < len && temp_idx < len {
436 temp[temp_idx] = data[linear_idx];
437 temp_idx += 1;
438 }
439 }
440 }
441 }
442 }
443
444 data.copy_from_slice(&temp);
446}
447
448fn prefetch_data_pattern<T: Copy>(data: &mut [T], cache_line_size: usize) {
450 let elements_per_line = cache_line_size / mem::size_of::<T>();
451
452 for i in (0..data.len()).step_by(elements_per_line) {
454 if i + elements_per_line < data.len() {
456 #[cfg(target_arch = "x86_64")]
457 unsafe {
458 {
459 let ptr = data.as_ptr().add(i + elements_per_line);
460 std::arch::x86_64::_mm_prefetch(
461 ptr as *const i8,
462 std::arch::x86_64::_MM_HINT_T0,
463 );
464 }
465 }
466 }
467 }
468}
469
470fn morton_decode(morton: usize, side: usize) -> (usize, usize) {
472 let mut x = 0;
473 let mut y = 0;
474 let mut bit = 0;
475 let mut m = morton;
476
477 while m > 0 && bit < 32 {
478 if (m & 1) != 0 {
479 x |= 1 << (bit / 2);
480 }
481 m >>= 1;
482
483 if (m & 1) != 0 {
484 y |= 1 << (bit / 2);
485 }
486 m >>= 1;
487
488 bit += 2;
489 }
490
491 (x % side, y % side)
492}
493
494fn hilbert_decode(h: usize, n: usize) -> (usize, usize) {
496 let mut t = h;
497 let mut x = 0;
498 let mut y = 0;
499 let mut s = 1;
500
501 while s < n {
502 let rx = 1 & (t / 2);
503 let ry = 1 & (t ^ rx);
504
505 if ry == 0 {
506 if rx == 1 {
507 x = s - 1 - x;
508 y = s - 1 - y;
509 }
510
511 std::mem::swap(&mut x, &mut y);
513 }
514
515 x += s * rx;
516 y += s * ry;
517 t /= 4;
518 s *= 2;
519 }
520
521 (x % n, y % n)
522}
523
524fn interleave_data<T: Copy>(data: &mut [T]) {
526 let len = data.len();
527 if len < 2 {
528 return;
529 }
530
531 let mid = len / 2;
532 let mut temp = vec![data[0]; len];
533
534 for i in 0..mid {
536 temp[2 * i] = data[i];
537 if 2 * i + 1 < len && i + mid < len {
538 temp[2 * i + 1] = data[i + mid];
539 }
540 }
541
542 if len % 2 == 1 {
544 temp[len - 1] = data[len - 1];
545 }
546
547 data.copy_from_slice(&temp);
548}
549
550fn get_l1_cache_size() -> usize {
552 get_cache_info().l1_size
553}
554
555#[allow(dead_code)]
557fn get_l2_cache_size() -> usize {
558 get_cache_info().l2_size
559}
560
561#[allow(dead_code)]
563fn get_l3_cache_size() -> usize {
564 get_cache_info().l3_size
565}
566
567#[cfg(test)]
568mod tests {
569 use super::*;
570
571 #[test]
578 fn test_issue_11_cpuid_safe() {
579 let info = detect_cache_info();
584 assert!(info.line_size > 0, "cache line size should be non-zero");
586 assert!(info.l1_size > 0, "L1 cache size should be non-zero");
587 assert!(info.l2_size > 0, "L2 cache size should be non-zero");
588 assert!(info.l3_size > 0, "L3 cache size should be non-zero");
589 }
590}