1#[cfg(feature = "no-std")]
7use alloc::alloc::{alloc, dealloc, Layout};
8#[cfg(not(feature = "no-std"))]
9use std::alloc::{alloc, dealloc, Layout};
10
11#[cfg(feature = "no-std")]
12use core::ptr::NonNull;
13#[cfg(not(feature = "no-std"))]
14use std::ptr::NonNull;
15
16#[cfg(feature = "no-std")]
17use core::{mem, slice};
18#[cfg(not(feature = "no-std"))]
19use std::{mem, slice};
20
21#[derive(Debug)]
23pub struct AllocError;
24
25#[cfg(feature = "no-std")]
26impl core::fmt::Display for AllocError {
27 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
28 write!(f, "Memory allocation failed")
29 }
30}
31
32#[cfg(not(feature = "no-std"))]
33impl std::fmt::Display for AllocError {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 write!(f, "Memory allocation failed")
36 }
37}
38
39#[cfg(not(feature = "no-std"))]
40#[cfg(not(feature = "no-std"))]
41impl std::error::Error for AllocError {}
42
43#[cfg(feature = "no-std")]
44impl core::error::Error for AllocError {}
45
46pub const CACHE_LINE_SIZE: usize = 64;
48pub const L1_CACHE_SIZE: usize = 32 * 1024;
49pub const L2_CACHE_SIZE: usize = 256 * 1024;
50pub const L3_CACHE_SIZE: usize = 8 * 1024 * 1024;
51
52pub const SIMD_ALIGNMENT: usize = 32; #[derive(Debug, Clone, Copy)]
57pub enum PrefetchHint {
58 T0,
60 T1,
62 T2,
64 Nta,
66}
67
68pub struct AlignedAlloc<T> {
70 ptr: NonNull<T>,
71 layout: Layout,
72 len: usize,
73}
74
75impl<T> AlignedAlloc<T> {
76 pub fn new(len: usize) -> Result<Self, AllocError> {
78 let layout = Layout::from_size_align(len * mem::size_of::<T>(), SIMD_ALIGNMENT)
79 .map_err(|_| AllocError)?;
80
81 let ptr = unsafe { alloc(layout) };
82 if ptr.is_null() {
83 return Err(AllocError);
84 }
85
86 Ok(Self {
87 ptr: unsafe { NonNull::new_unchecked(ptr as *mut T) },
88 layout,
89 len,
90 })
91 }
92
93 pub fn as_mut_slice(&mut self) -> &mut [T] {
95 unsafe { slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
96 }
97
98 pub fn as_slice(&self) -> &[T] {
100 unsafe { slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
101 }
102
103 pub fn as_ptr(&self) -> *const T {
105 self.ptr.as_ptr()
106 }
107
108 pub fn as_mut_ptr(&mut self) -> *mut T {
110 self.ptr.as_ptr()
111 }
112}
113
114impl<T> Drop for AlignedAlloc<T> {
115 fn drop(&mut self) {
116 unsafe {
117 dealloc(self.ptr.as_ptr() as *mut u8, self.layout);
118 }
119 }
120}
121
122pub mod prefetch {
124 use super::PrefetchHint;
125
126 #[inline(always)]
128 pub fn prefetch_read_data(_address: *const u8, _hint: PrefetchHint) {
129 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
130 unsafe {
131 #[cfg(feature = "no-std")]
132 use core::arch::x86_64::*;
133 #[cfg(not(feature = "no-std"))]
134 use core::arch::x86_64::*;
135 match _hint {
136 PrefetchHint::T0 => _mm_prefetch(_address as *const i8, _MM_HINT_T0),
137 PrefetchHint::T1 => _mm_prefetch(_address as *const i8, _MM_HINT_T1),
138 PrefetchHint::T2 => _mm_prefetch(_address as *const i8, _MM_HINT_T2),
139 PrefetchHint::Nta => _mm_prefetch(_address as *const i8, _MM_HINT_NTA),
140 }
141 }
142 }
143
144 #[inline]
146 pub fn prefetch_range<T>(slice: &[T], hint: PrefetchHint) {
147 let start = slice.as_ptr() as *const u8;
148 let size = core::mem::size_of_val(slice);
149 let end = unsafe { start.add(size) };
150
151 let mut current = start;
152 while current < end {
153 prefetch_read_data(current, hint);
154 current = unsafe { current.add(super::CACHE_LINE_SIZE) };
155 }
156 }
157}
158
159pub mod cache_aware {
161
162 pub fn optimal_block_size(cache_size: usize, element_size: usize) -> usize {
164 let elements_in_cache = cache_size / element_size;
166 (elements_in_cache as f64).sqrt() as usize
167 }
168
169 pub fn transpose_blocked(
171 input: &[f32],
172 output: &mut [f32],
173 rows: usize,
174 cols: usize,
175 block_size: usize,
176 ) {
177 assert_eq!(input.len(), rows * cols);
178 assert_eq!(output.len(), rows * cols);
179
180 for block_row in (0..rows).step_by(block_size) {
181 for block_col in (0..cols).step_by(block_size) {
182 let end_row = (block_row + block_size).min(rows);
183 let end_col = (block_col + block_size).min(cols);
184
185 for i in block_row..end_row {
186 for j in block_col..end_col {
187 output[j * rows + i] = input[i * cols + j];
188 }
189 }
190 }
191 }
192 }
193
194 pub fn matrix_multiply_blocked(
196 a: &[f32],
197 b: &[f32],
198 c: &mut [f32],
199 m: usize,
200 n: usize,
201 k: usize,
202 block_size: usize,
203 ) {
204 assert_eq!(a.len(), m * k);
205 assert_eq!(b.len(), k * n);
206 assert_eq!(c.len(), m * n);
207
208 c.fill(0.0);
210
211 for kk in (0..k).step_by(block_size) {
212 for ii in (0..m).step_by(block_size) {
213 for jj in (0..n).step_by(block_size) {
214 let end_k = (kk + block_size).min(k);
215 let end_i = (ii + block_size).min(m);
216 let end_j = (jj + block_size).min(n);
217
218 for i in ii..end_i {
219 for j in jj..end_j {
220 let mut sum = 0.0;
221 for l in kk..end_k {
222 sum += a[i * k + l] * b[l * n + j];
223 }
224 c[i * n + j] += sum;
225 }
226 }
227 }
228 }
229 }
230 }
231}
232
233pub mod streaming {
235 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
237 pub fn stream_store_f32(dest: &mut [f32], src: &[f32]) {
238 assert_eq!(dest.len(), src.len());
239
240 if !crate::simd_feature_detected!("sse2") {
241 dest.copy_from_slice(src);
242 return;
243 }
244
245 unsafe {
246 stream_store_sse2(dest, src);
247 }
248 }
249
250 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
251 #[target_feature(enable = "sse2")]
252 unsafe fn stream_store_sse2(dest: &mut [f32], src: &[f32]) {
253 #[cfg(feature = "no-std")]
254 use core::arch::x86_64::*;
255 #[cfg(not(feature = "no-std"))]
256 use core::arch::x86_64::*;
257
258 let mut i = 0;
259 let len = dest.len();
260
261 while i + 4 <= len {
263 let data = _mm_loadu_ps(src.as_ptr().add(i));
264 _mm_stream_ps(dest.as_mut_ptr().add(i), data);
265 i += 4;
266 }
267
268 while i < len {
270 dest[i] = src[i];
271 i += 1;
272 }
273
274 _mm_sfence();
276 }
277
278 #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
279 pub fn stream_store_f32(dest: &mut [f32], src: &[f32]) {
280 dest.copy_from_slice(src);
281 }
282}
283
284pub mod bandwidth {
286 use super::{prefetch::prefetch_range, PrefetchHint};
287
288 #[cfg(not(feature = "no-std"))]
289 use std::{mem, time::Instant};
290
291 pub fn copy_with_prefetch<T: Copy>(dest: &mut [T], src: &[T]) {
293 assert_eq!(dest.len(), src.len());
294
295 prefetch_range(src, PrefetchHint::Nta);
297
298 if core::mem::size_of_val(dest) > super::L1_CACHE_SIZE {
300 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
301 if core::mem::size_of::<T>() == core::mem::size_of::<f32>() {
302 unsafe {
303 super::streaming::stream_store_f32(
304 core::slice::from_raw_parts_mut(dest.as_mut_ptr() as *mut f32, dest.len()),
305 core::slice::from_raw_parts(src.as_ptr() as *const f32, src.len()),
306 );
307 }
308 return;
309 }
310 }
311
312 dest.copy_from_slice(src);
313 }
314
315 #[cfg(not(feature = "no-std"))]
317 pub fn measure_bandwidth() -> f64 {
318 const SIZE: usize = 1024 * 1024; let src = vec![1.0f32; SIZE];
320 let mut dest = vec![0.0f32; SIZE];
321
322 let start = Instant::now();
323 for _ in 0..100 {
324 copy_with_prefetch(&mut dest, &src);
325 }
326 let elapsed = start.elapsed();
327
328 let bytes_transferred = SIZE * mem::size_of::<f32>() * 100 * 2; bytes_transferred as f64 / elapsed.as_secs_f64() / (1024.0 * 1024.0 * 1024.0)
330 }
332
333 #[cfg(feature = "no-std")]
335 pub fn measure_bandwidth() -> f64 {
336 1.0 }
339}
340
341#[allow(non_snake_case)]
342#[cfg(all(test, not(feature = "no-std")))]
343mod tests {
344 use super::*;
345 use approx::assert_relative_eq;
346
347 #[cfg(feature = "no-std")]
348 use alloc::{vec, vec::Vec};
349
350 #[test]
351 fn test_aligned_alloc() {
352 let mut alloc = AlignedAlloc::<f32>::new(1024).expect("operation should succeed");
353 let slice = alloc.as_mut_slice();
354
355 assert_eq!(slice.as_ptr() as usize % SIMD_ALIGNMENT, 0);
357
358 slice[0] = 1.0;
360 slice[1023] = 2.0;
361 assert_eq!(slice[0], 1.0);
362 assert_eq!(slice[1023], 2.0);
363 }
364
365 #[test]
366 fn test_cache_aware_transpose() {
367 let rows = 64;
368 let cols = 64;
369 let mut input = vec![0.0f32; rows * cols];
370 let mut output = vec![0.0f32; rows * cols];
371
372 for i in 0..rows {
374 for j in 0..cols {
375 input[i * cols + j] = (i * cols + j) as f32;
376 }
377 }
378
379 cache_aware::transpose_blocked(&input, &mut output, rows, cols, 16);
380
381 for i in 0..rows {
383 for j in 0..cols {
384 assert_relative_eq!(output[j * rows + i], input[i * cols + j], epsilon = 1e-6);
385 }
386 }
387 }
388
389 #[test]
390 fn test_cache_aware_matrix_multiply() {
391 let m = 32;
392 let n = 32;
393 let k = 32;
394
395 let a = vec![1.0f32; m * k];
396 let b = vec![1.0f32; k * n];
397 let mut c = vec![0.0f32; m * n];
398
399 cache_aware::matrix_multiply_blocked(&a, &b, &mut c, m, n, k, 16);
400
401 for &val in &c {
403 assert_relative_eq!(val, k as f32, epsilon = 1e-6);
404 }
405 }
406
407 #[test]
408 fn test_stream_store() {
409 let src = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
410 let mut dest = vec![0.0f32; 8];
411
412 streaming::stream_store_f32(&mut dest, &src);
413
414 for (i, &val) in dest.iter().enumerate() {
415 assert_relative_eq!(val, src[i], epsilon = 1e-6);
416 }
417 }
418
419 #[test]
420 fn test_bandwidth_measurement() {
421 let bandwidth = bandwidth::measure_bandwidth();
422 assert!(bandwidth > 0.0);
424 println!("Measured bandwidth: {:.2} GB/s", bandwidth);
425 }
426
427 #[test]
428 fn test_optimal_block_size() {
429 let block_size = cache_aware::optimal_block_size(L1_CACHE_SIZE, 4);
430 assert!(block_size > 0);
431 assert!(block_size < 1000); }
433
434 #[test]
435 fn test_copy_with_prefetch() {
436 let src = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
437 let mut dest = vec![0.0f32; 5];
438
439 bandwidth::copy_with_prefetch(&mut dest, &src);
440
441 for (i, &val) in dest.iter().enumerate() {
442 assert_relative_eq!(val, src[i], epsilon = 1e-6);
443 }
444 }
445}