use std::cell::RefCell;
use crate::vector3::Vector3;
pub struct VectorPool<T> {
pool: Vec<Vec<T>>,
max_cached: usize,
}
impl<T: Default + Clone> VectorPool<T> {
pub fn new() -> Self {
Self {
pool: Vec::new(),
max_cached: 16,
}
}
pub fn with_capacity(max_cached: usize) -> Self {
Self {
pool: Vec::with_capacity(max_cached),
max_cached,
}
}
pub fn get(&mut self, size: usize) -> Vec<T> {
if let Some(idx) = self.pool.iter().position(|v| v.capacity() >= size) {
let mut vec = self.pool.swap_remove(idx);
vec.clear();
vec.resize(size, T::default());
vec
} else {
vec![T::default(); size]
}
}
pub fn put(&mut self, vec: Vec<T>) {
if self.pool.len() < self.max_cached && vec.capacity() > 0 {
self.pool.push(vec);
}
}
pub fn clear(&mut self) {
self.pool.clear();
}
pub fn cached_count(&self) -> usize {
self.pool.len()
}
pub fn total_capacity(&self) -> usize {
self.pool.iter().map(|v| v.capacity()).sum()
}
}
impl<T: Default + Clone> Default for VectorPool<T> {
fn default() -> Self {
Self::new()
}
}
pub struct SpinArrayPool {
pool: Vec<Vec<Vector3<f64>>>,
max_cached: usize,
}
impl SpinArrayPool {
pub fn new() -> Self {
Self {
pool: Vec::new(),
max_cached: 8,
}
}
pub fn get(&mut self, size: usize) -> Vec<Vector3<f64>> {
if let Some(idx) = self.pool.iter().position(|v| v.capacity() >= size) {
let mut vec = self.pool.swap_remove(idx);
vec.clear();
vec.resize(size, Vector3::new(0.0, 0.0, 0.0));
vec
} else {
vec![Vector3::new(0.0, 0.0, 0.0); size]
}
}
pub fn put(&mut self, vec: Vec<Vector3<f64>>) {
if self.pool.len() < self.max_cached && vec.capacity() > 0 {
self.pool.push(vec);
}
}
pub fn clear(&mut self) {
self.pool.clear();
}
}
impl Default for SpinArrayPool {
fn default() -> Self {
Self::new()
}
}
thread_local! {
static F64_POOL: RefCell<VectorPool<f64>> = RefCell::new(VectorPool::new());
static SPIN_POOL: RefCell<SpinArrayPool> = RefCell::new(SpinArrayPool::new());
}
pub fn get_f64_vec(size: usize) -> Vec<f64> {
F64_POOL.with(|pool| pool.borrow_mut().get(size))
}
pub fn put_f64_vec(vec: Vec<f64>) {
F64_POOL.with(|pool| pool.borrow_mut().put(vec));
}
pub fn get_spin_array(size: usize) -> Vec<Vector3<f64>> {
SPIN_POOL.with(|pool| pool.borrow_mut().get(size))
}
pub fn put_spin_array(vec: Vec<Vector3<f64>>) {
SPIN_POOL.with(|pool| pool.borrow_mut().put(vec));
}
pub struct Rk4Workspace {
pub k1: Vec<Vector3<f64>>,
pub k2: Vec<Vector3<f64>>,
pub k3: Vec<Vector3<f64>>,
pub k4: Vec<Vector3<f64>>,
pub temp: Vec<Vector3<f64>>,
}
impl Rk4Workspace {
pub fn new(size: usize) -> Self {
let zero = Vector3::new(0.0, 0.0, 0.0);
Self {
k1: vec![zero; size],
k2: vec![zero; size],
k3: vec![zero; size],
k4: vec![zero; size],
temp: vec![zero; size],
}
}
pub fn resize(&mut self, size: usize) {
let zero = Vector3::new(0.0, 0.0, 0.0);
self.k1.resize(size, zero);
self.k2.resize(size, zero);
self.k3.resize(size, zero);
self.k4.resize(size, zero);
self.temp.resize(size, zero);
}
pub fn clear(&mut self) {
let zero = Vector3::new(0.0, 0.0, 0.0);
for k in self.k1.iter_mut() {
*k = zero;
}
for k in self.k2.iter_mut() {
*k = zero;
}
for k in self.k3.iter_mut() {
*k = zero;
}
for k in self.k4.iter_mut() {
*k = zero;
}
for t in self.temp.iter_mut() {
*t = zero;
}
}
}
pub struct HeunWorkspace {
pub predictor: Vec<Vector3<f64>>,
pub corrector: Vec<Vector3<f64>>,
pub noise: Vec<Vector3<f64>>,
}
impl HeunWorkspace {
pub fn new(size: usize) -> Self {
let zero = Vector3::new(0.0, 0.0, 0.0);
Self {
predictor: vec![zero; size],
corrector: vec![zero; size],
noise: vec![zero; size],
}
}
pub fn resize(&mut self, size: usize) {
let zero = Vector3::new(0.0, 0.0, 0.0);
self.predictor.resize(size, zero);
self.corrector.resize(size, zero);
self.noise.resize(size, zero);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vector_pool_basic() {
let mut pool: VectorPool<f64> = VectorPool::new();
let v1 = pool.get(100);
assert_eq!(v1.len(), 100);
assert!(v1.iter().all(|&x| x == 0.0));
pool.put(v1);
assert_eq!(pool.cached_count(), 1);
let v2 = pool.get(50);
assert_eq!(v2.len(), 50);
assert_eq!(pool.cached_count(), 0);
}
#[test]
fn test_spin_array_pool() {
let mut pool = SpinArrayPool::new();
let spins = pool.get(100);
assert_eq!(spins.len(), 100);
pool.put(spins);
let spins2 = pool.get(50);
assert_eq!(spins2.len(), 50);
}
#[test]
fn test_thread_local_pool() {
let v = get_f64_vec(100);
assert_eq!(v.len(), 100);
put_f64_vec(v);
let spins = get_spin_array(100);
assert_eq!(spins.len(), 100);
put_spin_array(spins);
}
#[test]
fn test_rk4_workspace() {
let ws = Rk4Workspace::new(100);
assert_eq!(ws.k1.len(), 100);
assert_eq!(ws.k2.len(), 100);
assert_eq!(ws.k3.len(), 100);
assert_eq!(ws.k4.len(), 100);
assert_eq!(ws.temp.len(), 100);
}
#[test]
fn test_heun_workspace() {
let ws = HeunWorkspace::new(100);
assert_eq!(ws.predictor.len(), 100);
assert_eq!(ws.corrector.len(), 100);
assert_eq!(ws.noise.len(), 100);
}
}