use crate::kernel::{Complex, Float};
use crate::prelude::*;
pub struct ScratchGuard<T: Float> {
buf: Vec<Complex<T>>,
}
impl<T: Float> ScratchGuard<T> {
#[inline]
pub fn as_slice(&self) -> &[Complex<T>] {
&self.buf
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [Complex<T>] {
&mut self.buf
}
#[inline]
#[must_use]
pub fn len(&self) -> usize {
self.buf.len()
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.buf.is_empty()
}
}
impl<T: Float> core::ops::Deref for ScratchGuard<T> {
type Target = [Complex<T>];
#[inline]
fn deref(&self) -> &[Complex<T>] {
&self.buf
}
}
impl<T: Float> core::ops::DerefMut for ScratchGuard<T> {
#[inline]
fn deref_mut(&mut self) -> &mut [Complex<T>] {
&mut self.buf
}
}
#[cfg(feature = "std")]
mod tls {
use super::*;
use std::cell::RefCell;
struct RawScratch {
bytes: Vec<u8>,
capacity: usize,
}
impl RawScratch {
fn new() -> Self {
Self {
bytes: Vec::new(),
capacity: 0,
}
}
fn ensure_capacity(&mut self, byte_count: usize) {
if byte_count > self.capacity {
self.bytes.resize(byte_count, 0);
self.capacity = byte_count;
} else {
self.bytes[..byte_count].fill(0);
}
}
fn as_mut_ptr(&mut self) -> *mut u8 {
self.bytes.as_mut_ptr()
}
}
thread_local! {
static SCRATCH: RefCell<RawScratch> = RefCell::new(RawScratch::new());
}
pub fn with_scratch<T: Float, F, R>(n: usize, f: F) -> R
where
F: FnOnce(&mut [Complex<T>]) -> R,
{
let byte_count = n * core::mem::size_of::<Complex<T>>();
SCRATCH.with(|cell| {
let mut raw = cell.borrow_mut();
raw.ensure_capacity(byte_count);
let ptr = raw.as_mut_ptr();
if byte_count == 0 {
return f(&mut []);
}
let align = core::mem::align_of::<Complex<T>>();
if !(ptr as usize).is_multiple_of(align) {
drop(raw);
let mut fallback = vec![Complex::<T>::zero(); n];
return f(&mut fallback);
}
let slice = unsafe { core::slice::from_raw_parts_mut(ptr.cast::<Complex<T>>(), n) };
f(slice)
})
}
pub fn with_scratch_nested<T: Float, F, R>(n: usize, f: F) -> R
where
F: FnOnce(&mut [Complex<T>]) -> R,
{
let mut buf = vec![Complex::<T>::zero(); n];
f(&mut buf)
}
pub fn get_scratch<T: Float>(n: usize) -> ScratchGuard<T> {
ScratchGuard {
buf: vec![Complex::<T>::zero(); n],
}
}
pub fn scratch_capacity<T: Float>() -> usize {
let elem_size = core::mem::size_of::<Complex<T>>();
if elem_size == 0 {
return 0;
}
SCRATCH.with(|cell| {
let raw = cell.borrow();
raw.capacity / elem_size
})
}
}
#[cfg(not(feature = "std"))]
mod fallback {
use super::*;
pub fn with_scratch<T: Float, F, R>(n: usize, f: F) -> R
where
F: FnOnce(&mut [Complex<T>]) -> R,
{
let mut buf = vec![Complex::<T>::zero(); n];
f(&mut buf)
}
pub fn with_scratch_nested<T: Float, F, R>(n: usize, f: F) -> R
where
F: FnOnce(&mut [Complex<T>]) -> R,
{
with_scratch(n, f)
}
pub fn get_scratch<T: Float>(n: usize) -> ScratchGuard<T> {
ScratchGuard {
buf: vec![Complex::<T>::zero(); n],
}
}
pub fn scratch_capacity<T: Float>() -> usize {
0
}
}
#[cfg(feature = "std")]
pub use tls::{get_scratch, scratch_capacity, with_scratch, with_scratch_nested};
#[cfg(not(feature = "std"))]
pub use fallback::{get_scratch, scratch_capacity, with_scratch, with_scratch_nested};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_with_scratch_basic() {
with_scratch::<f64, _, _>(128, |buf| {
assert_eq!(buf.len(), 128);
for c in buf.iter() {
assert_eq!(c.re, 0.0);
assert_eq!(c.im, 0.0);
}
});
}
#[test]
fn test_scratch_grows_but_does_not_shrink() {
with_scratch::<f64, _, _>(64, |buf| {
assert_eq!(buf.len(), 64);
});
#[cfg(feature = "std")]
{
let cap1 = scratch_capacity::<f64>();
assert!(cap1 >= 64);
with_scratch::<f64, _, _>(256, |buf| {
assert_eq!(buf.len(), 256);
});
let cap2 = scratch_capacity::<f64>();
assert!(cap2 >= 256);
with_scratch::<f64, _, _>(32, |buf| {
assert_eq!(buf.len(), 32);
});
let cap3 = scratch_capacity::<f64>();
assert!(cap3 >= 256, "capacity should not shrink: got {cap3}");
}
}
#[test]
fn test_scratch_zero_size() {
with_scratch::<f64, _, _>(0, |buf| {
assert!(buf.is_empty());
});
}
#[test]
fn test_scratch_nested_does_not_panic() {
with_scratch::<f64, _, _>(64, |outer| {
outer[0] = Complex::new(1.0, 2.0);
with_scratch_nested::<f64, _, _>(32, |inner| {
assert_eq!(inner.len(), 32);
inner[0] = Complex::new(3.0, 4.0);
});
assert_eq!(outer[0].re, 1.0);
assert_eq!(outer[0].im, 2.0);
});
}
#[test]
fn test_get_scratch_guard() {
let mut guard = get_scratch::<f64>(512);
assert_eq!(guard.len(), 512);
assert!(!guard.is_empty());
guard[0] = Complex::new(42.0, 0.0);
assert_eq!(guard.as_slice()[0].re, 42.0);
assert_eq!(guard.as_mut_slice()[0].re, 42.0);
}
#[test]
fn test_scratch_f32() {
with_scratch::<f32, _, _>(256, |buf| {
assert_eq!(buf.len(), 256);
for c in buf.iter() {
assert_eq!(c.re, 0.0f32);
assert_eq!(c.im, 0.0f32);
}
});
}
#[cfg(feature = "std")]
#[test]
fn test_scratch_across_threads() {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
let success_count = Arc::new(AtomicUsize::new(0));
let num_threads = 4;
let mut handles = Vec::new();
for _ in 0..num_threads {
let counter = Arc::clone(&success_count);
handles.push(std::thread::spawn(move || {
with_scratch::<f64, _, _>(1024, |buf| {
assert_eq!(buf.len(), 1024);
for (i, c) in buf.iter_mut().enumerate() {
c.re = i as f64;
c.im = -(i as f64);
}
for (i, c) in buf.iter().enumerate() {
assert_eq!(c.re, i as f64);
assert_eq!(c.im, -(i as f64));
}
counter.fetch_add(1, Ordering::SeqCst);
});
}));
}
for h in handles {
h.join().expect("thread panicked");
}
assert_eq!(success_count.load(Ordering::SeqCst), num_threads);
}
#[test]
fn test_scratch_guard_deref() {
let guard = get_scratch::<f64>(16);
let _slice: &[Complex<f64>] = &guard;
assert_eq!(_slice.len(), 16);
}
#[test]
fn test_scratch_repeated_same_size() {
for _ in 0..100 {
with_scratch::<f64, _, _>(128, |buf| {
assert_eq!(buf.len(), 128);
});
}
}
}