use crate::shape::Shape;
#[cfg(feature = "std")]
use std::cell::RefCell;
#[cfg(feature = "std")]
use std::sync::Mutex;
#[cfg(not(feature = "std"))]
use core::cell::RefCell;
pub const MAX_STACK_DIMS: usize = 8;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct StackShape<const N: usize> {
pub dims: [usize; N],
pub ndim: usize,
}
impl<const N: usize> StackShape<N> {
#[inline]
pub const fn new(dims: [usize; N]) -> Self {
Self { dims, ndim: N }
}
#[inline]
pub fn from_slice(dims: &[usize]) -> Option<Self> {
if dims.len() > N {
return None;
}
let mut stack_dims = [0; N];
let mut i = 0;
while i < dims.len() {
stack_dims[i] = dims[i];
i += 1;
}
Some(Self {
dims: stack_dims,
ndim: dims.len(),
})
}
#[inline]
pub fn as_slice(&self) -> &[usize] {
&self.dims[..self.ndim]
}
#[inline]
pub const fn numel(&self) -> usize {
let mut product = 1;
let mut i = 0;
while i < self.ndim {
product *= self.dims[i];
i += 1;
}
product
}
#[inline]
pub fn to_shape(&self) -> Shape {
Shape::new(self.as_slice().to_vec())
}
#[inline]
pub fn broadcast_compatible<const M: usize>(&self, other: &StackShape<M>) -> bool {
let max_ndim = self.ndim.max(other.ndim);
for i in 0..max_ndim {
let dim1 = if i < self.ndim {
self.dims[self.ndim - 1 - i]
} else {
1
};
let dim2 = if i < other.ndim {
other.dims[other.ndim - 1 - i]
} else {
1
};
if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
return false;
}
}
true
}
}
#[derive(Debug, Clone)]
pub enum CowShape {
Borrowed(&'static [usize]),
Owned(Shape),
}
impl CowShape {
#[inline]
pub const fn from_static(dims: &'static [usize]) -> Self {
CowShape::Borrowed(dims)
}
#[inline]
pub fn from_owned(shape: Shape) -> Self {
CowShape::Owned(shape)
}
#[inline]
pub fn as_slice(&self) -> &[usize] {
match self {
CowShape::Borrowed(dims) => dims,
CowShape::Owned(shape) => shape.dims(),
}
}
#[inline]
pub fn into_owned(self) -> Shape {
match self {
CowShape::Borrowed(dims) => Shape::new(dims.to_vec()),
CowShape::Owned(shape) => shape,
}
}
#[inline]
pub fn numel(&self) -> usize {
self.as_slice().iter().product()
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct AllocationStats {
pub total_allocations: u64,
pub total_bytes: u64,
pub avoidable_allocations: u64,
pub avoidable_bytes: u64,
pub small_allocations: u64,
pub medium_allocations: u64,
pub large_allocations: u64,
}
impl AllocationStats {
#[inline]
pub fn record_allocation(&mut self, bytes: usize, avoidable: bool) {
self.total_allocations += 1;
self.total_bytes += bytes as u64;
if avoidable {
self.avoidable_allocations += 1;
self.avoidable_bytes += bytes as u64;
}
if bytes < 64 {
self.small_allocations += 1;
} else if bytes < 1024 {
self.medium_allocations += 1;
} else {
self.large_allocations += 1;
}
}
pub fn waste_percentage(&self) -> f64 {
if self.total_bytes == 0 {
0.0
} else {
(self.avoidable_bytes as f64 / self.total_bytes as f64) * 100.0
}
}
pub fn report(&self) -> String {
format!(
"Allocation Statistics:\n\
Total: {} allocations, {} bytes\n\
Avoidable: {} allocations, {} bytes ({:.1}% waste)\n\
Size distribution: {} small, {} medium, {} large",
self.total_allocations,
self.total_bytes,
self.avoidable_allocations,
self.avoidable_bytes,
self.waste_percentage(),
self.small_allocations,
self.medium_allocations,
self.large_allocations
)
}
}
#[cfg(feature = "std")]
thread_local! {
static ALLOC_STATS: RefCell<AllocationStats> = RefCell::new(AllocationStats::default());
}
#[cfg(feature = "std")]
#[inline]
pub fn track_allocation(bytes: usize, avoidable: bool) {
ALLOC_STATS.with(|stats| {
stats.borrow_mut().record_allocation(bytes, avoidable);
});
}
#[cfg(feature = "std")]
pub fn get_allocation_stats() -> AllocationStats {
ALLOC_STATS.with(|stats| *stats.borrow())
}
#[cfg(feature = "std")]
pub fn reset_allocation_stats() {
ALLOC_STATS.with(|stats| {
*stats.borrow_mut() = AllocationStats::default();
});
}
#[cfg(feature = "std")]
pub struct BufferPool<T> {
buffers: Mutex<Vec<Vec<T>>>,
max_pool_size: usize,
buffer_capacity: usize,
}
#[cfg(feature = "std")]
impl<T: Clone + Default> BufferPool<T> {
pub fn new(buffer_capacity: usize, max_pool_size: usize) -> Self {
Self {
buffers: Mutex::new(Vec::new()),
max_pool_size,
buffer_capacity,
}
}
pub fn acquire(&self) -> Vec<T> {
let mut buffers = self.buffers.lock().expect("lock should not be poisoned");
buffers
.pop()
.unwrap_or_else(|| Vec::with_capacity(self.buffer_capacity))
}
pub fn release(&self, mut buffer: Vec<T>) {
buffer.clear();
let mut buffers = self.buffers.lock().expect("lock should not be poisoned");
if buffers.len() < self.max_pool_size {
buffers.push(buffer);
}
}
pub fn stats(&self) -> (usize, usize) {
let buffers = self.buffers.lock().expect("lock should not be poisoned");
(buffers.len(), self.max_pool_size)
}
}
#[cfg(feature = "std")]
static SHAPE_BUFFER_POOL: once_cell::sync::Lazy<BufferPool<usize>> =
once_cell::sync::Lazy::new(|| BufferPool::new(8, 100));
#[cfg(feature = "std")]
#[inline]
pub fn acquire_shape_buffer() -> Vec<usize> {
SHAPE_BUFFER_POOL.acquire()
}
#[cfg(feature = "std")]
#[inline]
pub fn release_shape_buffer(buffer: Vec<usize>) {
SHAPE_BUFFER_POOL.release(buffer);
}
#[cfg(feature = "std")]
pub struct ScopedBuffer<T: Clone + Default + 'static> {
buffer: Option<Vec<T>>,
pool: &'static BufferPool<T>,
}
#[cfg(feature = "std")]
impl<T: Clone + Default + 'static> ScopedBuffer<T> {
pub fn new(pool: &'static BufferPool<T>) -> Self {
Self {
buffer: Some(pool.acquire()),
pool,
}
}
pub fn get_mut(&mut self) -> &mut Vec<T> {
self.buffer
.as_mut()
.expect("buffer should be present before drop")
}
pub fn get(&self) -> &Vec<T> {
self.buffer
.as_ref()
.expect("buffer should be present before drop")
}
}
#[cfg(feature = "std")]
impl<T: Clone + Default + 'static> Drop for ScopedBuffer<T> {
fn drop(&mut self) {
if let Some(buffer) = self.buffer.take() {
self.pool.release(buffer);
}
}
}
#[derive(Debug, Clone)]
pub struct OptimizationRecommendations {
pub use_stack_shapes: bool,
pub use_buffer_pools: bool,
pub use_cow_shapes: bool,
pub estimated_speedup: f64,
pub estimated_memory_savings: u64,
}
impl OptimizationRecommendations {
pub fn from_stats(stats: &AllocationStats) -> Self {
let use_stack_shapes = stats.small_allocations > stats.total_allocations / 2;
let use_buffer_pools = stats.avoidable_allocations > stats.total_allocations / 3;
let use_cow_shapes = stats.total_allocations > 100;
let mut estimated_speedup = 1.0;
if use_stack_shapes {
estimated_speedup *= 2.0;
}
if use_buffer_pools {
estimated_speedup *= 1.5;
}
if use_cow_shapes {
estimated_speedup *= 1.2;
}
Self {
use_stack_shapes,
use_buffer_pools,
use_cow_shapes,
estimated_speedup,
estimated_memory_savings: stats.avoidable_bytes,
}
}
pub fn report(&self) -> String {
let mut recommendations = Vec::new();
if self.use_stack_shapes {
recommendations.push("Use StackShape for operations with ≤8 dimensions");
}
if self.use_buffer_pools {
recommendations.push("Use buffer pools for temporary allocations");
}
if self.use_cow_shapes {
recommendations.push("Use CowShape for borrowed/static shapes");
}
format!(
"Optimization Recommendations:\n\
{}\n\
Estimated speedup: {:.1}x\n\
Estimated memory savings: {} bytes",
recommendations.join("\n"),
self.estimated_speedup,
self.estimated_memory_savings
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stack_shape_creation() {
let shape = StackShape::<4>::new([2, 3, 4, 5]);
assert_eq!(shape.ndim, 4);
assert_eq!(shape.as_slice(), &[2, 3, 4, 5]);
assert_eq!(shape.numel(), 120);
}
#[test]
fn test_stack_shape_from_slice() {
let dims = vec![2, 3, 4];
let shape = StackShape::<8>::from_slice(&dims).expect("from_slice should succeed");
assert_eq!(shape.ndim, 3);
assert_eq!(shape.as_slice(), &[2, 3, 4]);
}
#[test]
fn test_stack_shape_broadcast_compatible() {
let shape1 = StackShape::<4>::new([3, 1, 4, 1]);
let shape2 = StackShape::<3>::from_slice(&[2, 4, 5]).expect("from_slice should succeed");
assert!(shape1.broadcast_compatible(&shape2));
let shape3 = StackShape::<3>::from_slice(&[1, 4, 5]).expect("from_slice should succeed");
assert!(shape1.broadcast_compatible(&shape3));
let shape4 = StackShape::<3>::from_slice(&[2, 3, 4]).expect("from_slice should succeed");
let shape5 = StackShape::<3>::from_slice(&[2, 5, 4]).expect("from_slice should succeed");
assert!(!shape4.broadcast_compatible(&shape5)); }
#[test]
fn test_cow_shape_borrowed() {
static DIMS: [usize; 3] = [2, 3, 4];
let cow = CowShape::from_static(&DIMS);
assert_eq!(cow.as_slice(), &[2, 3, 4]);
assert_eq!(cow.numel(), 24);
}
#[test]
fn test_cow_shape_owned() {
let shape = Shape::new(vec![2, 3, 4]);
let cow = CowShape::from_owned(shape);
assert_eq!(cow.as_slice(), &[2, 3, 4]);
}
#[test]
fn test_allocation_stats() {
let mut stats = AllocationStats::default();
stats.record_allocation(32, true); stats.record_allocation(128, false); stats.record_allocation(2048, true);
assert_eq!(stats.total_allocations, 3);
assert_eq!(stats.avoidable_allocations, 2);
assert_eq!(stats.small_allocations, 1);
assert_eq!(stats.medium_allocations, 1);
assert_eq!(stats.large_allocations, 1);
let waste = stats.waste_percentage();
assert!(waste > 90.0); }
#[test]
#[cfg(feature = "std")]
fn test_buffer_pool() {
let pool = BufferPool::<usize>::new(10, 5);
let mut buffer1 = pool.acquire();
buffer1.extend_from_slice(&[1, 2, 3]);
pool.release(buffer1);
let (available, max) = pool.stats();
assert_eq!(available, 1);
assert_eq!(max, 5);
let buffer2 = pool.acquire();
assert!(buffer2.is_empty());
let (available, _) = pool.stats();
assert_eq!(available, 0);
pool.release(buffer2);
let (available, _) = pool.stats();
assert_eq!(available, 1);
}
#[test]
#[cfg(feature = "std")]
fn test_scoped_buffer() {
static POOL: once_cell::sync::Lazy<BufferPool<usize>> =
once_cell::sync::Lazy::new(|| BufferPool::new(10, 5));
{
let mut scoped = ScopedBuffer::new(&*POOL);
scoped.get_mut().push(42);
assert_eq!(scoped.get()[0], 42);
}
let (available, _) = POOL.stats();
assert_eq!(available, 1);
}
#[test]
fn test_optimization_recommendations() {
let mut stats = AllocationStats::default();
for _ in 0..100 {
stats.record_allocation(32, true);
}
let recommendations = OptimizationRecommendations::from_stats(&stats);
assert!(recommendations.use_stack_shapes);
assert!(recommendations.use_buffer_pools);
assert!(recommendations.estimated_speedup > 1.5);
}
#[test]
#[cfg(feature = "std")]
fn test_global_shape_buffer_pool() {
let mut buffer = acquire_shape_buffer();
buffer.extend_from_slice(&[1, 2, 3, 4]);
assert_eq!(buffer.len(), 4);
release_shape_buffer(buffer);
let buffer2 = acquire_shape_buffer();
assert_eq!(buffer2.len(), 0);
release_shape_buffer(buffer2);
}
#[test]
#[cfg(feature = "std")]
fn test_allocation_tracking() {
reset_allocation_stats();
track_allocation(64, false);
track_allocation(128, true);
let stats = get_allocation_stats();
assert_eq!(stats.total_allocations, 2);
assert_eq!(stats.avoidable_allocations, 1);
}
}