use nalgebra::DVector;
use num::complex::Complex64;
use parking_lot::Mutex;
use crate::thread_pool::ThreadExecutor;
#[cfg(not(feature = "rayon"))]
use crate::LadduError;
use crate::LadduResult;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ThreadPolicy {
Single,
GlobalPool,
Dedicated(usize),
}
#[derive(Debug, Default)]
pub struct ScratchAllocator {
byte_scratch: Vec<u8>,
scalar_scratch: Vec<f64>,
complex_scratch: Vec<Complex64>,
gradient_event_scratch: Vec<DVector<Complex64>>,
gradient_expr_scratch: Vec<DVector<Complex64>>,
}
impl ScratchAllocator {
pub fn reserve_bytes(&mut self, len: usize) -> &mut [u8] {
if self.byte_scratch.len() < len {
self.byte_scratch.resize(len, 0);
}
&mut self.byte_scratch[..len]
}
pub fn reserve_scalars(&mut self, len: usize) -> &mut [f64] {
if self.scalar_scratch.len() < len {
self.scalar_scratch.resize(len, 0.0);
}
&mut self.scalar_scratch[..len]
}
pub fn reserve_value_workspaces(
&mut self,
amplitude_len: usize,
slot_count: usize,
) -> (&mut [Complex64], &mut [Complex64]) {
let total = amplitude_len + slot_count;
if self.complex_scratch.len() < total {
self.complex_scratch.resize(total, Complex64::ZERO);
}
let (amplitudes, slots) = self.complex_scratch[..total].split_at_mut(amplitude_len);
(amplitudes, slots)
}
#[allow(clippy::type_complexity)]
pub fn reserve_gradient_workspaces(
&mut self,
amplitude_len: usize,
slot_count: usize,
grad_dim: usize,
) -> (
&mut [Complex64],
&mut [Complex64],
&mut [DVector<Complex64>],
&mut [DVector<Complex64>],
) {
Self::ensure_gradient_shape(&mut self.gradient_event_scratch, amplitude_len, grad_dim);
Self::ensure_gradient_shape(&mut self.gradient_expr_scratch, slot_count, grad_dim);
let total = amplitude_len + slot_count;
if self.complex_scratch.len() < total {
self.complex_scratch.resize(total, Complex64::ZERO);
}
let (amplitudes, slots) = self.complex_scratch[..total].split_at_mut(amplitude_len);
(
amplitudes,
slots,
&mut self.gradient_event_scratch[..amplitude_len],
&mut self.gradient_expr_scratch[..slot_count],
)
}
pub fn clear(&mut self) {
self.byte_scratch.clear();
self.scalar_scratch.clear();
self.complex_scratch.clear();
self.gradient_event_scratch.clear();
self.gradient_expr_scratch.clear();
}
pub fn capacities(&self) -> (usize, usize) {
(self.byte_scratch.capacity(), self.scalar_scratch.capacity())
}
fn ensure_gradient_shape(
buffer: &mut Vec<DVector<Complex64>>,
outer_len: usize,
grad_dim: usize,
) {
if buffer.len() < outer_len {
buffer.extend((buffer.len()..outer_len).map(|_| DVector::zeros(grad_dim)));
} else if buffer.len() > outer_len {
buffer.truncate(outer_len);
}
for gradient in buffer.iter_mut() {
if gradient.len() != grad_dim {
*gradient = DVector::zeros(grad_dim);
}
}
}
}
#[derive(Debug)]
pub struct ExecutionContext {
thread_policy: ThreadPolicy,
executor: ThreadExecutor,
scratch: Mutex<ScratchAllocator>,
}
impl ExecutionContext {
pub fn new(thread_policy: ThreadPolicy) -> LadduResult<Self> {
#[cfg(not(feature = "rayon"))]
{
if thread_policy != ThreadPolicy::Single {
return Err(LadduError::ExecutionContextError {
reason: "Rayon feature is required for non-single thread policies".into(),
});
}
}
let executor = match thread_policy {
ThreadPolicy::Single | ThreadPolicy::GlobalPool => ThreadExecutor::default(),
#[allow(unused_variables)]
ThreadPolicy::Dedicated(n_threads) => {
#[cfg(feature = "rayon")]
{
ThreadExecutor::dedicated(n_threads)?
}
#[cfg(not(feature = "rayon"))]
{
unreachable!("non-single thread policies are rejected above")
}
}
};
Ok(Self {
thread_policy,
executor,
scratch: Mutex::new(ScratchAllocator::default()),
})
}
pub fn thread_policy(&self) -> ThreadPolicy {
self.thread_policy
}
#[cfg(feature = "rayon")]
pub fn install<R: Send>(&self, op: impl FnOnce() -> R + Send) -> R {
self.executor.install(op)
}
#[cfg(not(feature = "rayon"))]
pub fn install<R>(&self, op: impl FnOnce() -> R) -> R {
self.executor.install(op)
}
pub fn with_scratch<R>(&self, op: impl FnOnce(&mut ScratchAllocator) -> R) -> R {
let mut scratch = self.scratch.lock();
op(&mut scratch)
}
}