Skip to main content

sklears_simd/
memory.rs

1//! Memory optimization utilities for SIMD operations
2//!
3//! This module provides cache-aware algorithms, memory prefetching,
4//! and aligned memory operations to improve SIMD performance.
5
6#[cfg(feature = "no-std")]
7use alloc::alloc::{alloc, dealloc, Layout};
8#[cfg(not(feature = "no-std"))]
9use std::alloc::{alloc, dealloc, Layout};
10
11#[cfg(feature = "no-std")]
12use core::ptr::NonNull;
13#[cfg(not(feature = "no-std"))]
14use std::ptr::NonNull;
15
16#[cfg(feature = "no-std")]
17use core::{mem, slice};
18#[cfg(not(feature = "no-std"))]
19use std::{mem, slice};
20
21/// Simple allocation error type
22#[derive(Debug)]
23pub struct AllocError;
24
25#[cfg(feature = "no-std")]
26impl core::fmt::Display for AllocError {
27    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
28        write!(f, "Memory allocation failed")
29    }
30}
31
32#[cfg(not(feature = "no-std"))]
33impl std::fmt::Display for AllocError {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        write!(f, "Memory allocation failed")
36    }
37}
38
39#[cfg(not(feature = "no-std"))]
40#[cfg(not(feature = "no-std"))]
41impl std::error::Error for AllocError {}
42
43#[cfg(feature = "no-std")]
44impl core::error::Error for AllocError {}
45
46/// Cache line size constants for different architectures
47pub const CACHE_LINE_SIZE: usize = 64;
48pub const L1_CACHE_SIZE: usize = 32 * 1024;
49pub const L2_CACHE_SIZE: usize = 256 * 1024;
50pub const L3_CACHE_SIZE: usize = 8 * 1024 * 1024;
51
52/// Alignment requirements for SIMD operations
53pub const SIMD_ALIGNMENT: usize = 32; // AVX2 alignment
54
55/// Memory prefetch hint types
56#[derive(Debug, Clone, Copy)]
57pub enum PrefetchHint {
58    /// Prefetch for read with temporal locality
59    T0,
60    /// Prefetch for read with low temporal locality
61    T1,
62    /// Prefetch for read with minimal temporal locality
63    T2,
64    /// Prefetch for read with no temporal locality
65    Nta,
66}
67
68/// SIMD-aligned memory allocator
69pub struct AlignedAlloc<T> {
70    ptr: NonNull<T>,
71    layout: Layout,
72    len: usize,
73}
74
75impl<T> AlignedAlloc<T> {
76    /// Allocate aligned memory for SIMD operations
77    pub fn new(len: usize) -> Result<Self, AllocError> {
78        let layout = Layout::from_size_align(len * mem::size_of::<T>(), SIMD_ALIGNMENT)
79            .map_err(|_| AllocError)?;
80
81        let ptr = unsafe { alloc(layout) };
82        if ptr.is_null() {
83            return Err(AllocError);
84        }
85
86        Ok(Self {
87            ptr: unsafe { NonNull::new_unchecked(ptr as *mut T) },
88            layout,
89            len,
90        })
91    }
92
93    /// Get a mutable slice to the aligned memory
94    pub fn as_mut_slice(&mut self) -> &mut [T] {
95        unsafe { slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
96    }
97
98    /// Get a slice to the aligned memory
99    pub fn as_slice(&self) -> &[T] {
100        unsafe { slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
101    }
102
103    /// Get the raw pointer
104    pub fn as_ptr(&self) -> *const T {
105        self.ptr.as_ptr()
106    }
107
108    /// Get the raw mutable pointer
109    pub fn as_mut_ptr(&mut self) -> *mut T {
110        self.ptr.as_ptr()
111    }
112}
113
114impl<T> Drop for AlignedAlloc<T> {
115    fn drop(&mut self) {
116        unsafe {
117            dealloc(self.ptr.as_ptr() as *mut u8, self.layout);
118        }
119    }
120}
121
122/// Memory prefetch operations
123pub mod prefetch {
124    use super::PrefetchHint;
125
126    /// Prefetch memory to cache with specified hint
127    #[inline(always)]
128    pub fn prefetch_read_data(_address: *const u8, _hint: PrefetchHint) {
129        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
130        unsafe {
131            #[cfg(feature = "no-std")]
132            use core::arch::x86_64::*;
133            #[cfg(not(feature = "no-std"))]
134            use core::arch::x86_64::*;
135            match _hint {
136                PrefetchHint::T0 => _mm_prefetch(_address as *const i8, _MM_HINT_T0),
137                PrefetchHint::T1 => _mm_prefetch(_address as *const i8, _MM_HINT_T1),
138                PrefetchHint::T2 => _mm_prefetch(_address as *const i8, _MM_HINT_T2),
139                PrefetchHint::Nta => _mm_prefetch(_address as *const i8, _MM_HINT_NTA),
140            }
141        }
142    }
143
144    /// Prefetch multiple cache lines for a memory range
145    #[inline]
146    pub fn prefetch_range<T>(slice: &[T], hint: PrefetchHint) {
147        let start = slice.as_ptr() as *const u8;
148        let size = core::mem::size_of_val(slice);
149        let end = unsafe { start.add(size) };
150
151        let mut current = start;
152        while current < end {
153            prefetch_read_data(current, hint);
154            current = unsafe { current.add(super::CACHE_LINE_SIZE) };
155        }
156    }
157}
158
159/// Cache-aware matrix operations
160pub mod cache_aware {
161
162    /// Calculate optimal block size for cache-friendly matrix operations
163    pub fn optimal_block_size(cache_size: usize, element_size: usize) -> usize {
164        // Use square root of available cache space
165        let elements_in_cache = cache_size / element_size;
166        (elements_in_cache as f64).sqrt() as usize
167    }
168
169    /// Cache-friendly matrix transpose
170    pub fn transpose_blocked(
171        input: &[f32],
172        output: &mut [f32],
173        rows: usize,
174        cols: usize,
175        block_size: usize,
176    ) {
177        assert_eq!(input.len(), rows * cols);
178        assert_eq!(output.len(), rows * cols);
179
180        for block_row in (0..rows).step_by(block_size) {
181            for block_col in (0..cols).step_by(block_size) {
182                let end_row = (block_row + block_size).min(rows);
183                let end_col = (block_col + block_size).min(cols);
184
185                for i in block_row..end_row {
186                    for j in block_col..end_col {
187                        output[j * rows + i] = input[i * cols + j];
188                    }
189                }
190            }
191        }
192    }
193
194    /// Cache-friendly matrix multiplication with blocking
195    pub fn matrix_multiply_blocked(
196        a: &[f32],
197        b: &[f32],
198        c: &mut [f32],
199        m: usize,
200        n: usize,
201        k: usize,
202        block_size: usize,
203    ) {
204        assert_eq!(a.len(), m * k);
205        assert_eq!(b.len(), k * n);
206        assert_eq!(c.len(), m * n);
207
208        // Initialize output
209        c.fill(0.0);
210
211        for kk in (0..k).step_by(block_size) {
212            for ii in (0..m).step_by(block_size) {
213                for jj in (0..n).step_by(block_size) {
214                    let end_k = (kk + block_size).min(k);
215                    let end_i = (ii + block_size).min(m);
216                    let end_j = (jj + block_size).min(n);
217
218                    for i in ii..end_i {
219                        for j in jj..end_j {
220                            let mut sum = 0.0;
221                            for l in kk..end_k {
222                                sum += a[i * k + l] * b[l * n + j];
223                            }
224                            c[i * n + j] += sum;
225                        }
226                    }
227                }
228            }
229        }
230    }
231}
232
233/// Non-temporal store operations for streaming data
234pub mod streaming {
235    /// Non-temporal store for f32 arrays (bypasses cache)
236    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
237    pub fn stream_store_f32(dest: &mut [f32], src: &[f32]) {
238        assert_eq!(dest.len(), src.len());
239
240        if !crate::simd_feature_detected!("sse2") {
241            dest.copy_from_slice(src);
242            return;
243        }
244
245        unsafe {
246            stream_store_sse2(dest, src);
247        }
248    }
249
250    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
251    #[target_feature(enable = "sse2")]
252    unsafe fn stream_store_sse2(dest: &mut [f32], src: &[f32]) {
253        #[cfg(feature = "no-std")]
254        use core::arch::x86_64::*;
255        #[cfg(not(feature = "no-std"))]
256        use core::arch::x86_64::*;
257
258        let mut i = 0;
259        let len = dest.len();
260
261        // Process 4 elements at a time with non-temporal stores
262        while i + 4 <= len {
263            let data = _mm_loadu_ps(src.as_ptr().add(i));
264            _mm_stream_ps(dest.as_mut_ptr().add(i), data);
265            i += 4;
266        }
267
268        // Handle remaining elements
269        while i < len {
270            dest[i] = src[i];
271            i += 1;
272        }
273
274        // Ensure all stores are complete
275        _mm_sfence();
276    }
277
278    #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
279    pub fn stream_store_f32(dest: &mut [f32], src: &[f32]) {
280        dest.copy_from_slice(src);
281    }
282}
283
284/// Memory bandwidth optimization utilities
285pub mod bandwidth {
286    use super::{prefetch::prefetch_range, PrefetchHint};
287
288    #[cfg(not(feature = "no-std"))]
289    use std::{mem, time::Instant};
290
291    /// Bandwidth-optimized vector copy with prefetching
292    pub fn copy_with_prefetch<T: Copy>(dest: &mut [T], src: &[T]) {
293        assert_eq!(dest.len(), src.len());
294
295        // Prefetch the source data
296        prefetch_range(src, PrefetchHint::Nta);
297
298        // Use streaming store for large arrays
299        if core::mem::size_of_val(dest) > super::L1_CACHE_SIZE {
300            #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
301            if core::mem::size_of::<T>() == core::mem::size_of::<f32>() {
302                unsafe {
303                    super::streaming::stream_store_f32(
304                        core::slice::from_raw_parts_mut(dest.as_mut_ptr() as *mut f32, dest.len()),
305                        core::slice::from_raw_parts(src.as_ptr() as *const f32, src.len()),
306                    );
307                }
308                return;
309            }
310        }
311
312        dest.copy_from_slice(src);
313    }
314
315    /// Memory bandwidth test for performance tuning
316    #[cfg(not(feature = "no-std"))]
317    pub fn measure_bandwidth() -> f64 {
318        const SIZE: usize = 1024 * 1024; // 1MB
319        let src = vec![1.0f32; SIZE];
320        let mut dest = vec![0.0f32; SIZE];
321
322        let start = Instant::now();
323        for _ in 0..100 {
324            copy_with_prefetch(&mut dest, &src);
325        }
326        let elapsed = start.elapsed();
327
328        let bytes_transferred = SIZE * mem::size_of::<f32>() * 100 * 2; // read + write
329        bytes_transferred as f64 / elapsed.as_secs_f64() / (1024.0 * 1024.0 * 1024.0)
330        // GB/s
331    }
332
333    /// Memory bandwidth test for performance tuning (no-std version)
334    #[cfg(feature = "no-std")]
335    pub fn measure_bandwidth() -> f64 {
336        // Return a mock value for no-std environments where timing is not available
337        1.0 // GB/s
338    }
339}
340
341#[allow(non_snake_case)]
342#[cfg(all(test, not(feature = "no-std")))]
343mod tests {
344    use super::*;
345    use approx::assert_relative_eq;
346
347    #[cfg(feature = "no-std")]
348    use alloc::{vec, vec::Vec};
349
350    #[test]
351    fn test_aligned_alloc() {
352        let mut alloc = AlignedAlloc::<f32>::new(1024).expect("operation should succeed");
353        let slice = alloc.as_mut_slice();
354
355        // Check alignment
356        assert_eq!(slice.as_ptr() as usize % SIMD_ALIGNMENT, 0);
357
358        // Test basic operations
359        slice[0] = 1.0;
360        slice[1023] = 2.0;
361        assert_eq!(slice[0], 1.0);
362        assert_eq!(slice[1023], 2.0);
363    }
364
365    #[test]
366    fn test_cache_aware_transpose() {
367        let rows = 64;
368        let cols = 64;
369        let mut input = vec![0.0f32; rows * cols];
370        let mut output = vec![0.0f32; rows * cols];
371
372        // Initialize input matrix
373        for i in 0..rows {
374            for j in 0..cols {
375                input[i * cols + j] = (i * cols + j) as f32;
376            }
377        }
378
379        cache_aware::transpose_blocked(&input, &mut output, rows, cols, 16);
380
381        // Verify transpose
382        for i in 0..rows {
383            for j in 0..cols {
384                assert_relative_eq!(output[j * rows + i], input[i * cols + j], epsilon = 1e-6);
385            }
386        }
387    }
388
389    #[test]
390    fn test_cache_aware_matrix_multiply() {
391        let m = 32;
392        let n = 32;
393        let k = 32;
394
395        let a = vec![1.0f32; m * k];
396        let b = vec![1.0f32; k * n];
397        let mut c = vec![0.0f32; m * n];
398
399        cache_aware::matrix_multiply_blocked(&a, &b, &mut c, m, n, k, 16);
400
401        // Verify result (each element should be k)
402        for &val in &c {
403            assert_relative_eq!(val, k as f32, epsilon = 1e-6);
404        }
405    }
406
407    #[test]
408    fn test_stream_store() {
409        let src = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
410        let mut dest = vec![0.0f32; 8];
411
412        streaming::stream_store_f32(&mut dest, &src);
413
414        for (i, &val) in dest.iter().enumerate() {
415            assert_relative_eq!(val, src[i], epsilon = 1e-6);
416        }
417    }
418
419    #[test]
420    fn test_bandwidth_measurement() {
421        let bandwidth = bandwidth::measure_bandwidth();
422        // Just check that bandwidth measurement runs and returns positive value
423        assert!(bandwidth > 0.0);
424        println!("Measured bandwidth: {:.2} GB/s", bandwidth);
425    }
426
427    #[test]
428    fn test_optimal_block_size() {
429        let block_size = cache_aware::optimal_block_size(L1_CACHE_SIZE, 4);
430        assert!(block_size > 0);
431        assert!(block_size < 1000); // Reasonable upper bound
432    }
433
434    #[test]
435    fn test_copy_with_prefetch() {
436        let src = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
437        let mut dest = vec![0.0f32; 5];
438
439        bandwidth::copy_with_prefetch(&mut dest, &src);
440
441        for (i, &val) in dest.iter().enumerate() {
442            assert_relative_eq!(val, src[i], epsilon = 1e-6);
443        }
444    }
445}