kizzasi_core/
optimizations.rs

1//! Performance optimizations for kizzasi-core
2//!
3//! This module provides optimizations based on profiling results:
4//! 1. Allocation reduction through object pooling
5//! 2. Cache-friendly data layouts
6//! 3. Instruction-level parallelism
7//! 4. Prefetching strategies
8
9#[cfg(not(feature = "std"))]
10use alloc::vec::Vec;
11
12use scirs2_core::ndarray::{Array1, Array2};
13use std::cell::RefCell;
14
15/// Cache for discretized SSM matrices to avoid recomputation
16#[derive(Debug)]
17pub struct DiscretizationCache {
18    /// Cached A_bar matrices (one per layer)
19    a_bar_cache: Vec<Array2<f32>>,
20    /// Cached B_bar matrices (one per layer)
21    b_bar_cache: Vec<Array2<f32>>,
22    /// Delta value used for discretization
23    cached_delta: f32,
24    /// Whether cache is valid
25    valid: bool,
26}
27
28impl DiscretizationCache {
29    /// Create a new discretization cache
30    pub fn new(num_layers: usize, hidden_dim: usize, state_dim: usize) -> Self {
31        let a_bar_cache = (0..num_layers)
32            .map(|_| Array2::zeros((hidden_dim, state_dim)))
33            .collect();
34        let b_bar_cache = (0..num_layers)
35            .map(|_| Array2::zeros((hidden_dim, state_dim)))
36            .collect();
37
38        Self {
39            a_bar_cache,
40            b_bar_cache,
41            cached_delta: 0.0,
42            valid: false,
43        }
44    }
45
46    /// Update the cache with new discretized matrices
47    pub fn update(&mut self, layer_idx: usize, delta: f32, a_bar: Array2<f32>, b_bar: Array2<f32>) {
48        if layer_idx < self.a_bar_cache.len() {
49            self.a_bar_cache[layer_idx] = a_bar;
50            self.b_bar_cache[layer_idx] = b_bar;
51            self.cached_delta = delta;
52            self.valid = true;
53        }
54    }
55
56    /// Get cached discretized matrices if valid
57    pub fn get(&self, layer_idx: usize, delta: f32) -> Option<(&Array2<f32>, &Array2<f32>)> {
58        if self.valid
59            && (delta - self.cached_delta).abs() < 1e-6
60            && layer_idx < self.a_bar_cache.len()
61        {
62            Some((&self.a_bar_cache[layer_idx], &self.b_bar_cache[layer_idx]))
63        } else {
64            None
65        }
66    }
67
68    /// Invalidate the cache
69    pub fn invalidate(&mut self) {
70        self.valid = false;
71    }
72
73    /// Check if cache is valid for given delta
74    pub fn is_valid(&self, delta: f32) -> bool {
75        self.valid && (delta - self.cached_delta).abs() < 1e-6
76    }
77}
78
79/// Preallocated workspace for SSM computations to reduce allocations
80#[derive(Debug)]
81pub struct SSMWorkspace {
82    /// Temporary storage for intermediate results
83    temp_hidden: Array1<f32>,
84    /// Temporary storage for state updates
85    temp_state: Array2<f32>,
86    /// Temporary storage for layer outputs
87    temp_output: Array1<f32>,
88}
89
90impl SSMWorkspace {
91    /// Create a new workspace
92    pub fn new(hidden_dim: usize, state_dim: usize) -> Self {
93        Self {
94            temp_hidden: Array1::zeros(hidden_dim),
95            temp_state: Array2::zeros((hidden_dim, state_dim)),
96            temp_output: Array1::zeros(hidden_dim),
97        }
98    }
99
100    /// Get temporary hidden vector (mutable)
101    pub fn temp_hidden_mut(&mut self) -> &mut Array1<f32> {
102        &mut self.temp_hidden
103    }
104
105    /// Get temporary state matrix (mutable)
106    pub fn temp_state_mut(&mut self) -> &mut Array2<f32> {
107        &mut self.temp_state
108    }
109
110    /// Get temporary output vector (mutable)
111    pub fn temp_output_mut(&mut self) -> &mut Array1<f32> {
112        &mut self.temp_output
113    }
114
115    /// Reset all temporary storage to zeros
116    pub fn clear(&mut self) {
117        self.temp_hidden.fill(0.0);
118        self.temp_state.fill(0.0);
119        self.temp_output.fill(0.0);
120    }
121}
122
123// Thread-local workspace pool to avoid allocations
124thread_local! {
125    static WORKSPACE_POOL: RefCell<Vec<SSMWorkspace>> = const { RefCell::new(Vec::new()) };
126}
127
128/// Acquire a workspace from the pool
129pub fn acquire_workspace(hidden_dim: usize, state_dim: usize) -> SSMWorkspace {
130    WORKSPACE_POOL.with(|pool| {
131        let mut pool = pool.borrow_mut();
132        pool.pop()
133            .unwrap_or_else(|| SSMWorkspace::new(hidden_dim, state_dim))
134    })
135}
136
137/// Return a workspace to the pool
138pub fn release_workspace(mut workspace: SSMWorkspace) {
139    workspace.clear();
140    WORKSPACE_POOL.with(|pool| {
141        let mut pool = pool.borrow_mut();
142        if pool.len() < 16 {
143            // Limit pool size
144            pool.push(workspace);
145        }
146    });
147}
148
149/// RAII guard for automatic workspace return
150pub struct WorkspaceGuard {
151    workspace: Option<SSMWorkspace>,
152}
153
154impl WorkspaceGuard {
155    /// Create a new workspace guard
156    pub fn new(hidden_dim: usize, state_dim: usize) -> Self {
157        Self {
158            workspace: Some(acquire_workspace(hidden_dim, state_dim)),
159        }
160    }
161
162    /// Get reference to the workspace
163    pub fn get(&self) -> &SSMWorkspace {
164        self.workspace.as_ref().expect("workspace should exist")
165    }
166
167    /// Get mutable reference to the workspace
168    pub fn get_mut(&mut self) -> &mut SSMWorkspace {
169        self.workspace.as_mut().expect("workspace should exist")
170    }
171}
172
173impl Drop for WorkspaceGuard {
174    fn drop(&mut self) {
175        if let Some(workspace) = self.workspace.take() {
176            release_workspace(workspace);
177        }
178    }
179}
180
181/// Prefetch hint for cache optimization
182#[inline(always)]
183pub fn prefetch<T>(_ptr: *const T) {
184    // Prefetch is a hint and can be platform-specific
185    // For now, this is a no-op that will be optimized by the compiler
186    // In release builds with target-specific features, this can be expanded
187    #[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
188    unsafe {
189        core::arch::x86_64::_mm_prefetch::<3>(_ptr as *const i8);
190    }
191
192    // ARM prefetch intrinsics are unstable, so we skip them for now
193    // The compiler's auto-vectorization will handle prefetching on ARM
194}
195
196/// Cache-aligned buffer for better memory performance
197#[repr(align(64))]
198pub struct CacheAligned<T> {
199    data: T,
200}
201
202impl<T> CacheAligned<T> {
203    /// Create a new cache-aligned value
204    pub fn new(data: T) -> Self {
205        Self { data }
206    }
207
208    /// Get a reference to the inner data
209    pub fn get(&self) -> &T {
210        &self.data
211    }
212
213    /// Get a mutable reference to the inner data
214    pub fn get_mut(&mut self) -> &mut T {
215        &mut self.data
216    }
217
218    /// Consume and return the inner data
219    pub fn into_inner(self) -> T {
220        self.data
221    }
222}
223
224/// Instruction-level parallelism optimizations
225pub mod ilp {
226    use scirs2_core::ndarray::{Array1, ArrayView1};
227
228    /// Dot product with manual loop unrolling for ILP
229    #[inline]
230    pub fn dot_unrolled(a: ArrayView1<f32>, b: ArrayView1<f32>) -> f32 {
231        let len = a.len().min(b.len());
232        let mut sum0 = 0.0f32;
233        let mut sum1 = 0.0f32;
234        let mut sum2 = 0.0f32;
235        let mut sum3 = 0.0f32;
236
237        let chunks = len / 4;
238        let remainder = len % 4;
239
240        // Process 4 elements at a time for ILP
241        for i in 0..chunks {
242            let idx = i * 4;
243            sum0 += a[idx] * b[idx];
244            sum1 += a[idx + 1] * b[idx + 1];
245            sum2 += a[idx + 2] * b[idx + 2];
246            sum3 += a[idx + 3] * b[idx + 3];
247        }
248
249        // Process remainder
250        let mut sum_remainder = 0.0f32;
251        for i in (chunks * 4)..(chunks * 4 + remainder) {
252            sum_remainder += a[i] * b[i];
253        }
254
255        sum0 + sum1 + sum2 + sum3 + sum_remainder
256    }
257
258    /// Vector addition with loop unrolling
259    #[inline]
260    pub fn add_unrolled(a: &Array1<f32>, b: &Array1<f32>, out: &mut Array1<f32>) {
261        let len = a.len().min(b.len()).min(out.len());
262        let chunks = len / 4;
263        let remainder = len % 4;
264
265        for i in 0..chunks {
266            let idx = i * 4;
267            out[idx] = a[idx] + b[idx];
268            out[idx + 1] = a[idx + 1] + b[idx + 1];
269            out[idx + 2] = a[idx + 2] + b[idx + 2];
270            out[idx + 3] = a[idx + 3] + b[idx + 3];
271        }
272
273        for i in (chunks * 4)..(chunks * 4 + remainder) {
274            out[i] = a[i] + b[i];
275        }
276    }
277
278    /// Fused multiply-add with loop unrolling
279    #[inline]
280    pub fn fma_unrolled(a: &Array1<f32>, b: &Array1<f32>, c: &Array1<f32>, out: &mut Array1<f32>) {
281        let len = a.len().min(b.len()).min(c.len()).min(out.len());
282        let chunks = len / 4;
283        let remainder = len % 4;
284
285        for i in 0..chunks {
286            let idx = i * 4;
287            out[idx] = a[idx].mul_add(b[idx], c[idx]);
288            out[idx + 1] = a[idx + 1].mul_add(b[idx + 1], c[idx + 1]);
289            out[idx + 2] = a[idx + 2].mul_add(b[idx + 2], c[idx + 2]);
290            out[idx + 3] = a[idx + 3].mul_add(b[idx + 3], c[idx + 3]);
291        }
292
293        for i in (chunks * 4)..(chunks * 4 + remainder) {
294            out[i] = a[i].mul_add(b[i], c[i]);
295        }
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302
303    #[test]
304    fn test_discretization_cache() {
305        let mut cache = DiscretizationCache::new(2, 64, 8);
306        assert!(!cache.is_valid(0.1));
307
308        let a_bar = Array2::ones((64, 8));
309        let b_bar = Array2::ones((64, 8));
310
311        cache.update(0, 0.1, a_bar.clone(), b_bar.clone());
312        assert!(cache.is_valid(0.1));
313
314        let (cached_a, cached_b) = cache.get(0, 0.1).expect("cache should hit");
315        assert_eq!(cached_a.shape(), &[64, 8]);
316        assert_eq!(cached_b.shape(), &[64, 8]);
317
318        cache.invalidate();
319        assert!(!cache.is_valid(0.1));
320    }
321
322    #[test]
323    fn test_workspace() {
324        let mut workspace = SSMWorkspace::new(64, 8);
325        workspace.temp_hidden_mut().fill(1.0);
326        assert_eq!(workspace.temp_hidden_mut().len(), 64);
327
328        workspace.clear();
329        assert_eq!(workspace.temp_hidden_mut().sum(), 0.0);
330    }
331
332    #[test]
333    fn test_workspace_pool() {
334        let workspace1 = acquire_workspace(64, 8);
335        assert_eq!(workspace1.temp_hidden.len(), 64);
336
337        release_workspace(workspace1);
338
339        let workspace2 = acquire_workspace(64, 8);
340        assert_eq!(workspace2.temp_hidden.len(), 64);
341    }
342
343    #[test]
344    fn test_workspace_guard() {
345        let mut guard = WorkspaceGuard::new(64, 8);
346        guard.get_mut().temp_hidden_mut().fill(1.0);
347        assert_eq!(guard.get().temp_hidden.len(), 64);
348    }
349
350    #[test]
351    fn test_cache_aligned() {
352        let aligned = CacheAligned::new(vec![1.0f32, 2.0, 3.0]);
353        assert_eq!(aligned.get().len(), 3);
354
355        let mut aligned = CacheAligned::new(42);
356        *aligned.get_mut() = 100;
357        assert_eq!(*aligned.get(), 100);
358    }
359
360    #[test]
361    fn test_ilp_dot_unrolled() {
362        use scirs2_core::ndarray::arr1;
363
364        let a = arr1(&[1.0, 2.0, 3.0, 4.0, 5.0]);
365        let b = arr1(&[2.0, 3.0, 4.0, 5.0, 6.0]);
366        let result = ilp::dot_unrolled(a.view(), b.view());
367        let expected: f32 = 1.0 * 2.0 + 2.0 * 3.0 + 3.0 * 4.0 + 4.0 * 5.0 + 5.0 * 6.0;
368        assert!((result - expected).abs() < 1e-5);
369    }
370
371    #[test]
372    fn test_ilp_add_unrolled() {
373        use scirs2_core::ndarray::arr1;
374
375        let a = arr1(&[1.0, 2.0, 3.0, 4.0, 5.0]);
376        let b = arr1(&[2.0, 3.0, 4.0, 5.0, 6.0]);
377        let mut out = Array1::zeros(5);
378
379        ilp::add_unrolled(&a, &b, &mut out);
380        assert_eq!(out[0], 3.0);
381        assert_eq!(out[4], 11.0);
382    }
383
384    #[test]
385    fn test_ilp_fma_unrolled() {
386        use scirs2_core::ndarray::arr1;
387
388        let a = arr1(&[1.0, 2.0, 3.0, 4.0]);
389        let b = arr1(&[2.0, 3.0, 4.0, 5.0]);
390        let c = arr1(&[1.0, 1.0, 1.0, 1.0]);
391        let mut out = Array1::zeros(4);
392
393        ilp::fma_unrolled(&a, &b, &c, &mut out);
394        assert_eq!(out[0], 1.0 * 2.0 + 1.0);
395        assert_eq!(out[3], 4.0 * 5.0 + 1.0);
396    }
397}