1#[cfg(not(feature = "std"))]
10use alloc::vec::Vec;
11
12use scirs2_core::ndarray::{Array1, Array2};
13use std::cell::RefCell;
14
15#[derive(Debug)]
17pub struct DiscretizationCache {
18 a_bar_cache: Vec<Array2<f32>>,
20 b_bar_cache: Vec<Array2<f32>>,
22 cached_delta: f32,
24 valid: bool,
26}
27
28impl DiscretizationCache {
29 pub fn new(num_layers: usize, hidden_dim: usize, state_dim: usize) -> Self {
31 let a_bar_cache = (0..num_layers)
32 .map(|_| Array2::zeros((hidden_dim, state_dim)))
33 .collect();
34 let b_bar_cache = (0..num_layers)
35 .map(|_| Array2::zeros((hidden_dim, state_dim)))
36 .collect();
37
38 Self {
39 a_bar_cache,
40 b_bar_cache,
41 cached_delta: 0.0,
42 valid: false,
43 }
44 }
45
46 pub fn update(&mut self, layer_idx: usize, delta: f32, a_bar: Array2<f32>, b_bar: Array2<f32>) {
48 if layer_idx < self.a_bar_cache.len() {
49 self.a_bar_cache[layer_idx] = a_bar;
50 self.b_bar_cache[layer_idx] = b_bar;
51 self.cached_delta = delta;
52 self.valid = true;
53 }
54 }
55
56 pub fn get(&self, layer_idx: usize, delta: f32) -> Option<(&Array2<f32>, &Array2<f32>)> {
58 if self.valid
59 && (delta - self.cached_delta).abs() < 1e-6
60 && layer_idx < self.a_bar_cache.len()
61 {
62 Some((&self.a_bar_cache[layer_idx], &self.b_bar_cache[layer_idx]))
63 } else {
64 None
65 }
66 }
67
68 pub fn invalidate(&mut self) {
70 self.valid = false;
71 }
72
73 pub fn is_valid(&self, delta: f32) -> bool {
75 self.valid && (delta - self.cached_delta).abs() < 1e-6
76 }
77}
78
79#[derive(Debug)]
81pub struct SSMWorkspace {
82 temp_hidden: Array1<f32>,
84 temp_state: Array2<f32>,
86 temp_output: Array1<f32>,
88}
89
90impl SSMWorkspace {
91 pub fn new(hidden_dim: usize, state_dim: usize) -> Self {
93 Self {
94 temp_hidden: Array1::zeros(hidden_dim),
95 temp_state: Array2::zeros((hidden_dim, state_dim)),
96 temp_output: Array1::zeros(hidden_dim),
97 }
98 }
99
100 pub fn temp_hidden_mut(&mut self) -> &mut Array1<f32> {
102 &mut self.temp_hidden
103 }
104
105 pub fn temp_state_mut(&mut self) -> &mut Array2<f32> {
107 &mut self.temp_state
108 }
109
110 pub fn temp_output_mut(&mut self) -> &mut Array1<f32> {
112 &mut self.temp_output
113 }
114
115 pub fn clear(&mut self) {
117 self.temp_hidden.fill(0.0);
118 self.temp_state.fill(0.0);
119 self.temp_output.fill(0.0);
120 }
121}
122
123thread_local! {
125 static WORKSPACE_POOL: RefCell<Vec<SSMWorkspace>> = const { RefCell::new(Vec::new()) };
126}
127
128pub fn acquire_workspace(hidden_dim: usize, state_dim: usize) -> SSMWorkspace {
130 WORKSPACE_POOL.with(|pool| {
131 let mut pool = pool.borrow_mut();
132 pool.pop()
133 .unwrap_or_else(|| SSMWorkspace::new(hidden_dim, state_dim))
134 })
135}
136
137pub fn release_workspace(mut workspace: SSMWorkspace) {
139 workspace.clear();
140 WORKSPACE_POOL.with(|pool| {
141 let mut pool = pool.borrow_mut();
142 if pool.len() < 16 {
143 pool.push(workspace);
145 }
146 });
147}
148
149pub struct WorkspaceGuard {
151 workspace: Option<SSMWorkspace>,
152}
153
154impl WorkspaceGuard {
155 pub fn new(hidden_dim: usize, state_dim: usize) -> Self {
157 Self {
158 workspace: Some(acquire_workspace(hidden_dim, state_dim)),
159 }
160 }
161
162 pub fn get(&self) -> &SSMWorkspace {
164 self.workspace.as_ref().expect("workspace should exist")
165 }
166
167 pub fn get_mut(&mut self) -> &mut SSMWorkspace {
169 self.workspace.as_mut().expect("workspace should exist")
170 }
171}
172
173impl Drop for WorkspaceGuard {
174 fn drop(&mut self) {
175 if let Some(workspace) = self.workspace.take() {
176 release_workspace(workspace);
177 }
178 }
179}
180
181#[inline(always)]
183pub fn prefetch<T>(_ptr: *const T) {
184 #[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
188 unsafe {
189 core::arch::x86_64::_mm_prefetch::<3>(_ptr as *const i8);
190 }
191
192 }
195
196#[repr(align(64))]
198pub struct CacheAligned<T> {
199 data: T,
200}
201
202impl<T> CacheAligned<T> {
203 pub fn new(data: T) -> Self {
205 Self { data }
206 }
207
208 pub fn get(&self) -> &T {
210 &self.data
211 }
212
213 pub fn get_mut(&mut self) -> &mut T {
215 &mut self.data
216 }
217
218 pub fn into_inner(self) -> T {
220 self.data
221 }
222}
223
224pub mod ilp {
226 use scirs2_core::ndarray::{Array1, ArrayView1};
227
228 #[inline]
230 pub fn dot_unrolled(a: ArrayView1<f32>, b: ArrayView1<f32>) -> f32 {
231 let len = a.len().min(b.len());
232 let mut sum0 = 0.0f32;
233 let mut sum1 = 0.0f32;
234 let mut sum2 = 0.0f32;
235 let mut sum3 = 0.0f32;
236
237 let chunks = len / 4;
238 let remainder = len % 4;
239
240 for i in 0..chunks {
242 let idx = i * 4;
243 sum0 += a[idx] * b[idx];
244 sum1 += a[idx + 1] * b[idx + 1];
245 sum2 += a[idx + 2] * b[idx + 2];
246 sum3 += a[idx + 3] * b[idx + 3];
247 }
248
249 let mut sum_remainder = 0.0f32;
251 for i in (chunks * 4)..(chunks * 4 + remainder) {
252 sum_remainder += a[i] * b[i];
253 }
254
255 sum0 + sum1 + sum2 + sum3 + sum_remainder
256 }
257
258 #[inline]
260 pub fn add_unrolled(a: &Array1<f32>, b: &Array1<f32>, out: &mut Array1<f32>) {
261 let len = a.len().min(b.len()).min(out.len());
262 let chunks = len / 4;
263 let remainder = len % 4;
264
265 for i in 0..chunks {
266 let idx = i * 4;
267 out[idx] = a[idx] + b[idx];
268 out[idx + 1] = a[idx + 1] + b[idx + 1];
269 out[idx + 2] = a[idx + 2] + b[idx + 2];
270 out[idx + 3] = a[idx + 3] + b[idx + 3];
271 }
272
273 for i in (chunks * 4)..(chunks * 4 + remainder) {
274 out[i] = a[i] + b[i];
275 }
276 }
277
278 #[inline]
280 pub fn fma_unrolled(a: &Array1<f32>, b: &Array1<f32>, c: &Array1<f32>, out: &mut Array1<f32>) {
281 let len = a.len().min(b.len()).min(c.len()).min(out.len());
282 let chunks = len / 4;
283 let remainder = len % 4;
284
285 for i in 0..chunks {
286 let idx = i * 4;
287 out[idx] = a[idx].mul_add(b[idx], c[idx]);
288 out[idx + 1] = a[idx + 1].mul_add(b[idx + 1], c[idx + 1]);
289 out[idx + 2] = a[idx + 2].mul_add(b[idx + 2], c[idx + 2]);
290 out[idx + 3] = a[idx + 3].mul_add(b[idx + 3], c[idx + 3]);
291 }
292
293 for i in (chunks * 4)..(chunks * 4 + remainder) {
294 out[i] = a[i].mul_add(b[i], c[i]);
295 }
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 #[test]
304 fn test_discretization_cache() {
305 let mut cache = DiscretizationCache::new(2, 64, 8);
306 assert!(!cache.is_valid(0.1));
307
308 let a_bar = Array2::ones((64, 8));
309 let b_bar = Array2::ones((64, 8));
310
311 cache.update(0, 0.1, a_bar.clone(), b_bar.clone());
312 assert!(cache.is_valid(0.1));
313
314 let (cached_a, cached_b) = cache.get(0, 0.1).expect("cache should hit");
315 assert_eq!(cached_a.shape(), &[64, 8]);
316 assert_eq!(cached_b.shape(), &[64, 8]);
317
318 cache.invalidate();
319 assert!(!cache.is_valid(0.1));
320 }
321
322 #[test]
323 fn test_workspace() {
324 let mut workspace = SSMWorkspace::new(64, 8);
325 workspace.temp_hidden_mut().fill(1.0);
326 assert_eq!(workspace.temp_hidden_mut().len(), 64);
327
328 workspace.clear();
329 assert_eq!(workspace.temp_hidden_mut().sum(), 0.0);
330 }
331
332 #[test]
333 fn test_workspace_pool() {
334 let workspace1 = acquire_workspace(64, 8);
335 assert_eq!(workspace1.temp_hidden.len(), 64);
336
337 release_workspace(workspace1);
338
339 let workspace2 = acquire_workspace(64, 8);
340 assert_eq!(workspace2.temp_hidden.len(), 64);
341 }
342
343 #[test]
344 fn test_workspace_guard() {
345 let mut guard = WorkspaceGuard::new(64, 8);
346 guard.get_mut().temp_hidden_mut().fill(1.0);
347 assert_eq!(guard.get().temp_hidden.len(), 64);
348 }
349
350 #[test]
351 fn test_cache_aligned() {
352 let aligned = CacheAligned::new(vec![1.0f32, 2.0, 3.0]);
353 assert_eq!(aligned.get().len(), 3);
354
355 let mut aligned = CacheAligned::new(42);
356 *aligned.get_mut() = 100;
357 assert_eq!(*aligned.get(), 100);
358 }
359
360 #[test]
361 fn test_ilp_dot_unrolled() {
362 use scirs2_core::ndarray::arr1;
363
364 let a = arr1(&[1.0, 2.0, 3.0, 4.0, 5.0]);
365 let b = arr1(&[2.0, 3.0, 4.0, 5.0, 6.0]);
366 let result = ilp::dot_unrolled(a.view(), b.view());
367 let expected: f32 = 1.0 * 2.0 + 2.0 * 3.0 + 3.0 * 4.0 + 4.0 * 5.0 + 5.0 * 6.0;
368 assert!((result - expected).abs() < 1e-5);
369 }
370
371 #[test]
372 fn test_ilp_add_unrolled() {
373 use scirs2_core::ndarray::arr1;
374
375 let a = arr1(&[1.0, 2.0, 3.0, 4.0, 5.0]);
376 let b = arr1(&[2.0, 3.0, 4.0, 5.0, 6.0]);
377 let mut out = Array1::zeros(5);
378
379 ilp::add_unrolled(&a, &b, &mut out);
380 assert_eq!(out[0], 3.0);
381 assert_eq!(out[4], 11.0);
382 }
383
384 #[test]
385 fn test_ilp_fma_unrolled() {
386 use scirs2_core::ndarray::arr1;
387
388 let a = arr1(&[1.0, 2.0, 3.0, 4.0]);
389 let b = arr1(&[2.0, 3.0, 4.0, 5.0]);
390 let c = arr1(&[1.0, 1.0, 1.0, 1.0]);
391 let mut out = Array1::zeros(4);
392
393 ilp::fma_unrolled(&a, &b, &c, &mut out);
394 assert_eq!(out[0], 1.0 * 2.0 + 1.0);
395 assert_eq!(out[3], 4.0 * 5.0 + 1.0);
396 }
397}