Skip to main content

cjc_runtime/
tensor_pool.rs

1//! Thread-local tensor buffer pool for reusing allocations.
2//!
3//! In tight loops (e.g., 50-layer NN forward pass), the same tensor sizes are
4//! allocated and freed every iteration. This pool caches freed buffers and
5//! returns them on the next allocation of the same size, avoiding repeated
6//! malloc/free cycles.
7//!
8//! # Determinism
9//!
10//! Pool reuse does NOT affect computed values — only memory addresses change.
11//! The buffer contents are always overwritten before use. Snap hashes
12//! are computed from data, not addresses, so they remain identical.
13//!
14//! # Usage
15//!
16//! ```ignore
17//! // Get a buffer (may be recycled or freshly allocated)
18//! let mut buf = tensor_pool::acquire(1000);
19//! // ... fill buf with data ...
20//! // Buffer is returned to pool when dropped (via TensorPool::recycle)
21//! ```
22
23use std::cell::RefCell;
24
25/// Maximum number of buffers cached per size class.
26const MAX_CACHED_PER_SIZE: usize = 4;
27
28/// Maximum total buffers in the pool (to prevent unbounded growth).
29const MAX_TOTAL_CACHED: usize = 64;
30
31/// Thread-local tensor buffer pool.
32struct TensorPool {
33    /// Cached buffers, sorted by capacity for binary search.
34    /// Each entry: (capacity, Vec<f64>).
35    buffers: Vec<Vec<f64>>,
36}
37
38impl TensorPool {
39    fn new() -> Self {
40        TensorPool {
41            buffers: Vec::new(),
42        }
43    }
44
45    /// Acquire a buffer of at least `size` elements.
46    /// Returns a recycled buffer (cleared to zero) if one of matching size exists,
47    /// otherwise allocates a new one.
48    fn acquire(&mut self, size: usize) -> Vec<f64> {
49        // Look for an exact-size match first (most common case in loops).
50        if let Some(pos) = self.buffers.iter().position(|b| b.capacity() == size) {
51            let mut buf = self.buffers.swap_remove(pos);
52            buf.clear();
53            buf.resize(size, 0.0);
54            return buf;
55        }
56        // No match — allocate fresh.
57        vec![0.0f64; size]
58    }
59
60    /// Return a buffer to the pool for future reuse.
61    fn recycle(&mut self, buf: Vec<f64>) {
62        if self.buffers.len() >= MAX_TOTAL_CACHED {
63            return; // Pool is full, just drop the buffer.
64        }
65        let cap = buf.capacity();
66        // Don't cache if too many of the same size already.
67        let same_size_count = self.buffers.iter().filter(|b| b.capacity() == cap).count();
68        if same_size_count >= MAX_CACHED_PER_SIZE {
69            return;
70        }
71        self.buffers.push(buf);
72    }
73}
74
75thread_local! {
76    static POOL: RefCell<TensorPool> = RefCell::new(TensorPool::new());
77}
78
79/// Acquire a zeroed buffer of `size` f64 elements from the thread-local pool.
80///
81/// If a buffer of the exact capacity is available in the pool, it's reused
82/// (avoiding malloc). Otherwise a new buffer is allocated.
83pub fn acquire(size: usize) -> Vec<f64> {
84    POOL.with(|pool| pool.borrow_mut().acquire(size))
85}
86
87/// Return a buffer to the thread-local pool for future reuse.
88///
89/// The buffer's contents are irrelevant — it will be cleared on next acquire.
90/// If the pool is full, the buffer is simply dropped.
91pub fn recycle(buf: Vec<f64>) {
92    POOL.with(|pool| pool.borrow_mut().recycle(buf));
93}
94
95/// Returns the current number of cached buffers in the pool (for diagnostics).
96#[allow(dead_code)]
97pub fn pool_size() -> usize {
98    POOL.with(|pool| pool.borrow().buffers.len())
99}
100
101// ── Tests ───────────────────────────────────────────────────────────────────
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    #[test]
108    fn test_acquire_returns_correct_size() {
109        let buf = acquire(100);
110        assert_eq!(buf.len(), 100);
111        assert!(buf.iter().all(|&x| x == 0.0));
112    }
113
114    #[test]
115    fn test_recycle_and_reuse() {
116        let buf = acquire(256);
117        assert_eq!(pool_size(), 0);
118
119        recycle(buf);
120        assert_eq!(pool_size(), 1);
121
122        let buf2 = acquire(256);
123        assert_eq!(buf2.len(), 256);
124        assert_eq!(pool_size(), 0); // Was taken from pool
125    }
126
127    #[test]
128    fn test_pool_max_per_size() {
129        for _ in 0..10 {
130            let buf = acquire(64);
131            recycle(buf);
132        }
133        // Should cap at MAX_CACHED_PER_SIZE
134        assert!(pool_size() <= MAX_CACHED_PER_SIZE);
135    }
136
137    #[test]
138    fn test_pool_total_limit() {
139        for size in 0..100 {
140            let buf = acquire(size + 1);
141            recycle(buf);
142        }
143        assert!(pool_size() <= MAX_TOTAL_CACHED);
144    }
145
146    #[test]
147    fn test_acquired_buffer_is_zeroed() {
148        let mut buf = acquire(10);
149        for x in buf.iter_mut() {
150            *x = 42.0; // Dirty the buffer
151        }
152        recycle(buf);
153
154        let buf2 = acquire(10);
155        assert!(buf2.iter().all(|&x| x == 0.0), "Recycled buffer must be zeroed");
156    }
157}