cjc_runtime/tensor_pool.rs
1//! Thread-local tensor buffer pool for reusing allocations.
2//!
3//! In tight loops (e.g., 50-layer NN forward pass), the same tensor sizes are
4//! allocated and freed every iteration. This pool caches freed buffers and
5//! returns them on the next allocation of the same size, avoiding repeated
6//! malloc/free cycles.
7//!
8//! # Determinism
9//!
10//! Pool reuse does NOT affect computed values — only memory addresses change.
11//! The buffer contents are always overwritten before use. Snap hashes
12//! are computed from data, not addresses, so they remain identical.
13//!
14//! # Usage
15//!
16//! ```ignore
17//! // Get a buffer (may be recycled or freshly allocated)
18//! let mut buf = tensor_pool::acquire(1000);
19//! // ... fill buf with data ...
20//! // Buffer is returned to pool when dropped (via TensorPool::recycle)
21//! ```
22
23use std::cell::RefCell;
24
25/// Maximum number of buffers cached per size class.
26const MAX_CACHED_PER_SIZE: usize = 4;
27
28/// Maximum total buffers in the pool (to prevent unbounded growth).
29const MAX_TOTAL_CACHED: usize = 64;
30
31/// Thread-local tensor buffer pool.
32struct TensorPool {
33 /// Cached buffers, sorted by capacity for binary search.
34 /// Each entry: (capacity, Vec<f64>).
35 buffers: Vec<Vec<f64>>,
36}
37
38impl TensorPool {
39 fn new() -> Self {
40 TensorPool {
41 buffers: Vec::new(),
42 }
43 }
44
45 /// Acquire a buffer of at least `size` elements.
46 /// Returns a recycled buffer (cleared to zero) if one of matching size exists,
47 /// otherwise allocates a new one.
48 fn acquire(&mut self, size: usize) -> Vec<f64> {
49 // Look for an exact-size match first (most common case in loops).
50 if let Some(pos) = self.buffers.iter().position(|b| b.capacity() == size) {
51 let mut buf = self.buffers.swap_remove(pos);
52 buf.clear();
53 buf.resize(size, 0.0);
54 return buf;
55 }
56 // No match — allocate fresh.
57 vec![0.0f64; size]
58 }
59
60 /// Return a buffer to the pool for future reuse.
61 fn recycle(&mut self, buf: Vec<f64>) {
62 if self.buffers.len() >= MAX_TOTAL_CACHED {
63 return; // Pool is full, just drop the buffer.
64 }
65 let cap = buf.capacity();
66 // Don't cache if too many of the same size already.
67 let same_size_count = self.buffers.iter().filter(|b| b.capacity() == cap).count();
68 if same_size_count >= MAX_CACHED_PER_SIZE {
69 return;
70 }
71 self.buffers.push(buf);
72 }
73}
74
75thread_local! {
76 static POOL: RefCell<TensorPool> = RefCell::new(TensorPool::new());
77}
78
79/// Acquire a zeroed buffer of `size` f64 elements from the thread-local pool.
80///
81/// If a buffer of the exact capacity is available in the pool, it's reused
82/// (avoiding malloc). Otherwise a new buffer is allocated.
83pub fn acquire(size: usize) -> Vec<f64> {
84 POOL.with(|pool| pool.borrow_mut().acquire(size))
85}
86
87/// Return a buffer to the thread-local pool for future reuse.
88///
89/// The buffer's contents are irrelevant — it will be cleared on next acquire.
90/// If the pool is full, the buffer is simply dropped.
91pub fn recycle(buf: Vec<f64>) {
92 POOL.with(|pool| pool.borrow_mut().recycle(buf));
93}
94
95/// Returns the current number of cached buffers in the pool (for diagnostics).
96#[allow(dead_code)]
97pub fn pool_size() -> usize {
98 POOL.with(|pool| pool.borrow().buffers.len())
99}
100
101// ── Tests ───────────────────────────────────────────────────────────────────
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106
107 #[test]
108 fn test_acquire_returns_correct_size() {
109 let buf = acquire(100);
110 assert_eq!(buf.len(), 100);
111 assert!(buf.iter().all(|&x| x == 0.0));
112 }
113
114 #[test]
115 fn test_recycle_and_reuse() {
116 let buf = acquire(256);
117 assert_eq!(pool_size(), 0);
118
119 recycle(buf);
120 assert_eq!(pool_size(), 1);
121
122 let buf2 = acquire(256);
123 assert_eq!(buf2.len(), 256);
124 assert_eq!(pool_size(), 0); // Was taken from pool
125 }
126
127 #[test]
128 fn test_pool_max_per_size() {
129 for _ in 0..10 {
130 let buf = acquire(64);
131 recycle(buf);
132 }
133 // Should cap at MAX_CACHED_PER_SIZE
134 assert!(pool_size() <= MAX_CACHED_PER_SIZE);
135 }
136
137 #[test]
138 fn test_pool_total_limit() {
139 for size in 0..100 {
140 let buf = acquire(size + 1);
141 recycle(buf);
142 }
143 assert!(pool_size() <= MAX_TOTAL_CACHED);
144 }
145
146 #[test]
147 fn test_acquired_buffer_is_zeroed() {
148 let mut buf = acquire(10);
149 for x in buf.iter_mut() {
150 *x = 42.0; // Dirty the buffer
151 }
152 recycle(buf);
153
154 let buf2 = acquire(10);
155 assert!(buf2.iter().all(|&x| x == 0.0), "Recycled buffer must be zeroed");
156 }
157}