#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use scirs2_core::ndarray::{Array1, Array2};
use std::cell::RefCell;
#[derive(Debug)]
pub struct DiscretizationCache {
a_bar_cache: Vec<Array2<f32>>,
b_bar_cache: Vec<Array2<f32>>,
cached_delta: f32,
valid: bool,
}
impl DiscretizationCache {
pub fn new(num_layers: usize, hidden_dim: usize, state_dim: usize) -> Self {
let a_bar_cache = (0..num_layers)
.map(|_| Array2::zeros((hidden_dim, state_dim)))
.collect();
let b_bar_cache = (0..num_layers)
.map(|_| Array2::zeros((hidden_dim, state_dim)))
.collect();
Self {
a_bar_cache,
b_bar_cache,
cached_delta: 0.0,
valid: false,
}
}
pub fn update(&mut self, layer_idx: usize, delta: f32, a_bar: Array2<f32>, b_bar: Array2<f32>) {
if layer_idx < self.a_bar_cache.len() {
self.a_bar_cache[layer_idx] = a_bar;
self.b_bar_cache[layer_idx] = b_bar;
self.cached_delta = delta;
self.valid = true;
}
}
pub fn get(&self, layer_idx: usize, delta: f32) -> Option<(&Array2<f32>, &Array2<f32>)> {
if self.valid
&& (delta - self.cached_delta).abs() < 1e-6
&& layer_idx < self.a_bar_cache.len()
{
Some((&self.a_bar_cache[layer_idx], &self.b_bar_cache[layer_idx]))
} else {
None
}
}
pub fn invalidate(&mut self) {
self.valid = false;
}
pub fn is_valid(&self, delta: f32) -> bool {
self.valid && (delta - self.cached_delta).abs() < 1e-6
}
}
#[derive(Debug)]
pub struct SSMWorkspace {
temp_hidden: Array1<f32>,
temp_state: Array2<f32>,
temp_output: Array1<f32>,
}
impl SSMWorkspace {
pub fn new(hidden_dim: usize, state_dim: usize) -> Self {
Self {
temp_hidden: Array1::zeros(hidden_dim),
temp_state: Array2::zeros((hidden_dim, state_dim)),
temp_output: Array1::zeros(hidden_dim),
}
}
pub fn temp_hidden_mut(&mut self) -> &mut Array1<f32> {
&mut self.temp_hidden
}
pub fn temp_state_mut(&mut self) -> &mut Array2<f32> {
&mut self.temp_state
}
pub fn temp_output_mut(&mut self) -> &mut Array1<f32> {
&mut self.temp_output
}
pub fn clear(&mut self) {
self.temp_hidden.fill(0.0);
self.temp_state.fill(0.0);
self.temp_output.fill(0.0);
}
}
thread_local! {
static WORKSPACE_POOL: RefCell<Vec<SSMWorkspace>> = const { RefCell::new(Vec::new()) };
}
pub fn acquire_workspace(hidden_dim: usize, state_dim: usize) -> SSMWorkspace {
WORKSPACE_POOL.with(|pool| {
let mut pool = pool.borrow_mut();
pool.pop()
.unwrap_or_else(|| SSMWorkspace::new(hidden_dim, state_dim))
})
}
pub fn release_workspace(mut workspace: SSMWorkspace) {
workspace.clear();
WORKSPACE_POOL.with(|pool| {
let mut pool = pool.borrow_mut();
if pool.len() < 16 {
pool.push(workspace);
}
});
}
pub struct WorkspaceGuard {
workspace: Option<SSMWorkspace>,
}
impl WorkspaceGuard {
pub fn new(hidden_dim: usize, state_dim: usize) -> Self {
Self {
workspace: Some(acquire_workspace(hidden_dim, state_dim)),
}
}
pub fn get(&self) -> &SSMWorkspace {
self.workspace.as_ref().expect("workspace should exist")
}
pub fn get_mut(&mut self) -> &mut SSMWorkspace {
self.workspace.as_mut().expect("workspace should exist")
}
}
impl Drop for WorkspaceGuard {
fn drop(&mut self) {
if let Some(workspace) = self.workspace.take() {
release_workspace(workspace);
}
}
}
#[inline(always)]
pub fn prefetch<T>(_ptr: *const T) {
#[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
unsafe {
core::arch::x86_64::_mm_prefetch::<3>(_ptr as *const i8);
}
}
#[repr(align(64))]
pub struct CacheAligned<T> {
data: T,
}
impl<T> CacheAligned<T> {
pub fn new(data: T) -> Self {
Self { data }
}
pub fn get(&self) -> &T {
&self.data
}
pub fn get_mut(&mut self) -> &mut T {
&mut self.data
}
pub fn into_inner(self) -> T {
self.data
}
}
pub mod ilp {
use scirs2_core::ndarray::{Array1, ArrayView1};
#[inline]
pub fn dot_unrolled(a: ArrayView1<f32>, b: ArrayView1<f32>) -> f32 {
let len = a.len().min(b.len());
let mut sum0 = 0.0f32;
let mut sum1 = 0.0f32;
let mut sum2 = 0.0f32;
let mut sum3 = 0.0f32;
let chunks = len / 4;
let remainder = len % 4;
for i in 0..chunks {
let idx = i * 4;
sum0 += a[idx] * b[idx];
sum1 += a[idx + 1] * b[idx + 1];
sum2 += a[idx + 2] * b[idx + 2];
sum3 += a[idx + 3] * b[idx + 3];
}
let mut sum_remainder = 0.0f32;
for i in (chunks * 4)..(chunks * 4 + remainder) {
sum_remainder += a[i] * b[i];
}
sum0 + sum1 + sum2 + sum3 + sum_remainder
}
#[inline]
pub fn add_unrolled(a: &Array1<f32>, b: &Array1<f32>, out: &mut Array1<f32>) {
let len = a.len().min(b.len()).min(out.len());
let chunks = len / 4;
let remainder = len % 4;
for i in 0..chunks {
let idx = i * 4;
out[idx] = a[idx] + b[idx];
out[idx + 1] = a[idx + 1] + b[idx + 1];
out[idx + 2] = a[idx + 2] + b[idx + 2];
out[idx + 3] = a[idx + 3] + b[idx + 3];
}
for i in (chunks * 4)..(chunks * 4 + remainder) {
out[i] = a[i] + b[i];
}
}
#[inline]
pub fn fma_unrolled(a: &Array1<f32>, b: &Array1<f32>, c: &Array1<f32>, out: &mut Array1<f32>) {
let len = a.len().min(b.len()).min(c.len()).min(out.len());
let chunks = len / 4;
let remainder = len % 4;
for i in 0..chunks {
let idx = i * 4;
out[idx] = a[idx].mul_add(b[idx], c[idx]);
out[idx + 1] = a[idx + 1].mul_add(b[idx + 1], c[idx + 1]);
out[idx + 2] = a[idx + 2].mul_add(b[idx + 2], c[idx + 2]);
out[idx + 3] = a[idx + 3].mul_add(b[idx + 3], c[idx + 3]);
}
for i in (chunks * 4)..(chunks * 4 + remainder) {
out[i] = a[i].mul_add(b[i], c[i]);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_discretization_cache() {
let mut cache = DiscretizationCache::new(2, 64, 8);
assert!(!cache.is_valid(0.1));
let a_bar = Array2::ones((64, 8));
let b_bar = Array2::ones((64, 8));
cache.update(0, 0.1, a_bar.clone(), b_bar.clone());
assert!(cache.is_valid(0.1));
let (cached_a, cached_b) = cache.get(0, 0.1).expect("cache should hit");
assert_eq!(cached_a.shape(), &[64, 8]);
assert_eq!(cached_b.shape(), &[64, 8]);
cache.invalidate();
assert!(!cache.is_valid(0.1));
}
#[test]
fn test_workspace() {
let mut workspace = SSMWorkspace::new(64, 8);
workspace.temp_hidden_mut().fill(1.0);
assert_eq!(workspace.temp_hidden_mut().len(), 64);
workspace.clear();
assert_eq!(workspace.temp_hidden_mut().sum(), 0.0);
}
#[test]
fn test_workspace_pool() {
let workspace1 = acquire_workspace(64, 8);
assert_eq!(workspace1.temp_hidden.len(), 64);
release_workspace(workspace1);
let workspace2 = acquire_workspace(64, 8);
assert_eq!(workspace2.temp_hidden.len(), 64);
}
#[test]
fn test_workspace_guard() {
let mut guard = WorkspaceGuard::new(64, 8);
guard.get_mut().temp_hidden_mut().fill(1.0);
assert_eq!(guard.get().temp_hidden.len(), 64);
}
#[test]
fn test_cache_aligned() {
let aligned = CacheAligned::new(vec![1.0f32, 2.0, 3.0]);
assert_eq!(aligned.get().len(), 3);
let mut aligned = CacheAligned::new(42);
*aligned.get_mut() = 100;
assert_eq!(*aligned.get(), 100);
}
#[test]
fn test_ilp_dot_unrolled() {
use scirs2_core::ndarray::arr1;
let a = arr1(&[1.0, 2.0, 3.0, 4.0, 5.0]);
let b = arr1(&[2.0, 3.0, 4.0, 5.0, 6.0]);
let result = ilp::dot_unrolled(a.view(), b.view());
let expected: f32 = 1.0 * 2.0 + 2.0 * 3.0 + 3.0 * 4.0 + 4.0 * 5.0 + 5.0 * 6.0;
assert!((result - expected).abs() < 1e-5);
}
#[test]
fn test_ilp_add_unrolled() {
use scirs2_core::ndarray::arr1;
let a = arr1(&[1.0, 2.0, 3.0, 4.0, 5.0]);
let b = arr1(&[2.0, 3.0, 4.0, 5.0, 6.0]);
let mut out = Array1::zeros(5);
ilp::add_unrolled(&a, &b, &mut out);
assert_eq!(out[0], 3.0);
assert_eq!(out[4], 11.0);
}
#[test]
fn test_ilp_fma_unrolled() {
use scirs2_core::ndarray::arr1;
let a = arr1(&[1.0, 2.0, 3.0, 4.0]);
let b = arr1(&[2.0, 3.0, 4.0, 5.0]);
let c = arr1(&[1.0, 1.0, 1.0, 1.0]);
let mut out = Array1::zeros(4);
ilp::fma_unrolled(&a, &b, &c, &mut out);
assert_eq!(out[0], 1.0 * 2.0 + 1.0);
assert_eq!(out[3], 4.0 * 5.0 + 1.0);
}
}