nntp_proxy/pool/
buffer.rs

1use crate::types::BufferSize;
2use crossbeam::queue::SegQueue;
3use std::ops::Deref;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicUsize, Ordering};
6use tracing::info;
7
8/// A pooled buffer that automatically returns to the pool when dropped
9///
10/// # Safety and Uninitialized Memory
11///
12/// Buffers contain **uninitialized memory** for performance (no zeroing overhead).
13/// The type tracks initialized bytes and only exposes that portion, preventing UB.
14///
15/// ## Usage
16/// ```ignore
17/// let mut buffer = pool.get_buffer().await;
18/// let n = buffer.read_from(&mut stream).await?;  // Automatic tracking
19/// process(&*buffer);  // Deref returns only &buffer[..n]
20/// ```
21pub struct PooledBuffer {
22    buffer: Vec<u8>,
23    initialized: usize,
24    pool: Arc<SegQueue<Vec<u8>>>,
25    pool_size: Arc<AtomicUsize>,
26    max_pool_size: usize,
27}
28
29impl PooledBuffer {
30    /// Get the full capacity of the buffer
31    #[inline]
32    pub fn capacity(&self) -> usize {
33        self.buffer.len()
34    }
35
36    /// Get the number of initialized bytes
37    #[inline]
38    pub fn initialized(&self) -> usize {
39        self.initialized
40    }
41
42    /// Read from an AsyncRead source, automatically tracking initialized bytes
43    pub async fn read_from<R>(&mut self, reader: &mut R) -> std::io::Result<usize>
44    where
45        R: tokio::io::AsyncReadExt + Unpin,
46    {
47        let n = reader.read(&mut self.buffer[..]).await?;
48        self.initialized = n;
49        Ok(n)
50    }
51
52    /// Copy data into buffer and mark as initialized
53    ///
54    /// # Panics
55    /// Panics if data.len() > capacity
56    #[inline]
57    pub fn copy_from_slice(&mut self, data: &[u8]) {
58        assert!(
59            data.len() <= self.buffer.len(),
60            "data exceeds buffer capacity"
61        );
62        self.buffer[..data.len()].copy_from_slice(data);
63        self.initialized = data.len();
64    }
65}
66
67impl Deref for PooledBuffer {
68    type Target = [u8];
69
70    #[inline]
71    fn deref(&self) -> &Self::Target {
72        // Immutable access only returns the initialized portion
73        &self.buffer[..self.initialized]
74    }
75}
76
77// Intentionally NO DerefMut or AsMut - forces explicit use of read_from() or as_mut_slice()
78
79impl AsRef<[u8]> for PooledBuffer {
80    #[inline]
81    fn as_ref(&self) -> &[u8] {
82        // Only return initialized portion
83        &self.buffer[..self.initialized]
84    }
85}
86
87impl Drop for PooledBuffer {
88    fn drop(&mut self) {
89        // Atomically return buffer to pool if pool is not full
90        let mut current_size = self.pool_size.load(Ordering::Relaxed);
91        while current_size < self.max_pool_size {
92            match self.pool_size.compare_exchange_weak(
93                current_size,
94                current_size + 1,
95                Ordering::Relaxed,
96                Ordering::Relaxed,
97            ) {
98                Ok(_) => {
99                    let buffer = std::mem::take(&mut self.buffer);
100                    self.pool.push(buffer);
101                    return;
102                }
103                Err(new_size) => {
104                    current_size = new_size;
105                }
106            }
107        }
108        // If pool is full, buffer is dropped
109    }
110}
111
112/// Lock-free buffer pool for reusing large I/O buffers
113/// Uses crossbeam's SegQueue for lock-free operations
114#[derive(Debug, Clone)]
115pub struct BufferPool {
116    pool: Arc<SegQueue<Vec<u8>>>,
117    buffer_size: BufferSize,
118    max_pool_size: usize,
119    pool_size: Arc<AtomicUsize>,
120}
121
122impl BufferPool {
123    /// Create a page-aligned buffer for optimal DMA performance
124    ///
125    /// Returns a raw Vec<u8> that will be wrapped in PooledBuffer by get_buffer().
126    /// The buffer is NOT zero-initialized for performance.
127    ///
128    /// # Safety
129    ///
130    /// **INTERNAL USE ONLY.** This function is not exposed publicly and is only used
131    /// within the buffer pool implementation where the safety contract is guaranteed.
132    ///
133    /// The returned buffer contains uninitialized memory. Callers must ensure that the buffer
134    /// is only used with `AsyncRead`/`AsyncWrite` operations that fully initialize the bytes
135    /// before they are read. Only the initialized portion of the buffer (`&buf[..n]`, where `n`
136    /// is the number of bytes read or written) may be accessed. Accessing uninitialized bytes
137    /// is undefined behavior.
138    ///
139    /// The public API (`get_buffer()`) returns a `PooledBuffer` which is a safe wrapper that
140    /// enforces this contract through the type system and usage patterns.
141    #[allow(clippy::uninit_vec)]
142    fn create_aligned_buffer(size: usize) -> Vec<u8> {
143        // Align to page boundaries (4KB) for better memory performance
144        let page_size = 4096;
145        let aligned_size = size.div_ceil(page_size) * page_size;
146
147        // Use aligned allocation for better cache performance
148        let mut buffer = Vec::with_capacity(aligned_size);
149        // SAFETY: We're setting the length without initializing the data.
150        // This is safe because:
151        // 1. The Vec is wrapped in PooledBuffer which derefs to &[u8]
152        // 2. PooledBuffer is immediately used with AsyncRead which writes into it
153        // 3. Callers only access &buffer[..n] where n is the bytes actually read
154        // 4. Unwritten/uninitialized bytes are never accessed
155        unsafe {
156            buffer.set_len(size);
157        }
158        buffer
159    }
160
161    /// Create a new buffer pool with pre-allocated buffers
162    ///
163    /// # Arguments
164    /// * `buffer_size` - Size of each buffer in bytes (must be non-zero)
165    /// * `max_pool_size` - Maximum number of buffers to pool
166    ///
167    /// All buffers are pre-allocated at creation time for optimal performance.
168    #[must_use]
169    pub fn new(buffer_size: BufferSize, max_pool_size: usize) -> Self {
170        let pool = Arc::new(SegQueue::new());
171        let pool_size = Arc::new(AtomicUsize::new(0));
172
173        // Pre-allocate all buffers at startup to eliminate allocation overhead
174        info!(
175            "Pre-allocating {} buffers of {}KB each ({}MB total)",
176            max_pool_size,
177            buffer_size.get() / 1024,
178            (max_pool_size * buffer_size.get()) / (1024 * 1024)
179        );
180
181        for _ in 0..max_pool_size {
182            let buffer = Self::create_aligned_buffer(buffer_size.get());
183            pool.push(buffer);
184            pool_size.fetch_add(1, Ordering::Relaxed);
185        }
186
187        info!("Buffer pool pre-allocation complete");
188
189        Self {
190            pool,
191            buffer_size,
192            max_pool_size,
193            pool_size,
194        }
195    }
196
197    /// Get a buffer from the pool or create a new one (lock-free)
198    ///
199    /// Returns a PooledBuffer that automatically returns to the pool when dropped.
200    /// The buffer may contain old data, but this is safe because:
201    /// - Callers use AsyncRead which writes into the buffer
202    /// - They get back `n` bytes written and access only `&buf[..n]`
203    /// - Stale data beyond `n` is never accessed
204    pub async fn get_buffer(&self) -> PooledBuffer {
205        let buffer = if let Some(buffer) = self.pool.pop() {
206            self.pool_size.fetch_sub(1, Ordering::Relaxed);
207            // Buffer from pool is already the correct size (enforced on return)
208            debug_assert_eq!(buffer.len(), self.buffer_size.get());
209            buffer
210        } else {
211            // Create new page-aligned buffer for better DMA performance
212            Self::create_aligned_buffer(self.buffer_size.get())
213        };
214
215        PooledBuffer {
216            buffer,
217            initialized: 0, // Start with 0 bytes safe to read
218            pool: Arc::clone(&self.pool),
219            pool_size: Arc::clone(&self.pool_size),
220            max_pool_size: self.max_pool_size,
221        }
222    }
223
224    /// Return a buffer to the pool (lock-free)
225    ///
226    /// Note: Usually not needed as PooledBuffer returns itself automatically on drop
227    #[allow(dead_code)]
228    pub async fn return_buffer(&self, buffer: Vec<u8>) {
229        if buffer.len() == self.buffer_size.get() {
230            let current_size = self.pool_size.load(Ordering::Relaxed);
231            if current_size < self.max_pool_size {
232                self.pool.push(buffer);
233                self.pool_size.fetch_add(1, Ordering::Relaxed);
234            }
235            // If pool is full, just drop the buffer
236        }
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[tokio::test]
245    async fn test_buffer_pool_creation() {
246        let pool = BufferPool::new(BufferSize::new(8192).unwrap(), 10);
247
248        // Pool should pre-allocate buffers
249        let buffer1 = pool.get_buffer().await;
250        assert_eq!(buffer1.capacity(), 8192);
251        assert_eq!(buffer1.initialized(), 0); // No bytes initialized yet
252        // Buffer automatically returned on drop
253    }
254
255    #[tokio::test]
256    async fn test_buffer_pool_get_and_return() {
257        let pool = BufferPool::new(BufferSize::new(4096).unwrap(), 5);
258
259        // Get a buffer
260        let buffer = pool.get_buffer().await;
261        assert_eq!(buffer.capacity(), 4096);
262        assert_eq!(buffer.initialized(), 0);
263
264        // Buffer contains uninitialized data - this is intentional for performance
265        // Callers will write to it via AsyncRead before accessing
266
267        // Drop it (automatically returns to pool)
268        drop(buffer);
269
270        // Get it again - should be from pool
271        let buffer2 = pool.get_buffer().await;
272        assert_eq!(buffer2.capacity(), 4096);
273    }
274
275    #[tokio::test]
276    async fn test_buffer_pool_exhaustion() {
277        let pool = BufferPool::new(BufferSize::new(1024).unwrap(), 2);
278
279        // Get all pre-allocated buffers
280        let buf1 = pool.get_buffer().await;
281        let buf2 = pool.get_buffer().await;
282
283        // Pool is exhausted, should create new buffer
284        let buf3 = pool.get_buffer().await;
285        assert_eq!(buf3.capacity(), 1024);
286
287        // Drop buffers (automatically returned)
288        drop(buf1);
289        drop(buf2);
290        drop(buf3);
291    }
292
293    #[tokio::test]
294    async fn test_buffer_pool_concurrent_access() {
295        let pool = BufferPool::new(BufferSize::new(2048).unwrap(), 10);
296
297        // Spawn multiple tasks accessing the pool concurrently
298        let mut handles = vec![];
299
300        for _ in 0..20 {
301            let pool_clone = pool.clone();
302            let handle = tokio::spawn(async move {
303                let buffer = pool_clone.get_buffer().await;
304                assert_eq!(buffer.capacity(), 2048);
305                // Simulate some work
306                tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
307            });
308            handles.push(handle);
309        }
310
311        // Wait for all tasks to complete
312        for handle in handles {
313            handle.await.unwrap();
314        }
315    }
316
317    #[tokio::test]
318    async fn test_buffer_alignment() {
319        let pool = BufferPool::new(BufferSize::new(8192).unwrap(), 1);
320        let buffer = pool.get_buffer().await;
321
322        // Buffer capacity should be aligned to page boundaries (4KB)
323        assert!(buffer.capacity() >= 8192);
324        // Should be page-aligned (multiple of 4096)
325        assert_eq!(buffer.capacity() % 4096, 0);
326    }
327
328    #[tokio::test]
329    async fn test_buffer_clear_and_resize() {
330        let pool = BufferPool::new(BufferSize::new(1024).unwrap(), 2);
331
332        let mut buffer = pool.get_buffer().await;
333
334        // Write data using copy_from_slice
335        let data = vec![42u8; 101];
336        buffer.copy_from_slice(&data);
337        assert_eq!(buffer.initialized(), 101);
338
339        // Drop returns it to pool
340        drop(buffer);
341
342        // Get it again - may contain old data (performance optimization)
343        let buffer2 = pool.get_buffer().await;
344        assert_eq!(buffer2.capacity(), 1024);
345        // Note: buffer may contain previous data - callers must use &buf[..n] pattern
346    }
347
348    #[tokio::test]
349    async fn test_buffer_pool_max_size_enforcement() {
350        let pool = BufferPool::new(BufferSize::new(512).unwrap(), 3);
351
352        // Get all buffers
353        let buf1 = pool.get_buffer().await;
354        let buf2 = pool.get_buffer().await;
355        let buf3 = pool.get_buffer().await;
356
357        // Get one more (should create new)
358        let buf4 = pool.get_buffer().await;
359
360        // Drop all buffers (automatically returned)
361        drop(buf1);
362        drop(buf2);
363        drop(buf3);
364        drop(buf4);
365
366        // Pool should not exceed max size
367        // (We can't directly test pool size, but the implementation handles it)
368    }
369
370    #[tokio::test]
371    async fn test_buffer_wrong_size_not_returned() {
372        let pool = BufferPool::new(BufferSize::new(1024).unwrap(), 2);
373
374        let buffer = pool.get_buffer().await;
375        assert_eq!(buffer.capacity(), 1024);
376
377        // PooledBuffer auto-returns on drop with correct size enforcement in Drop impl
378        drop(buffer);
379    }
380
381    #[tokio::test]
382    async fn test_buffer_pool_multiple_get_return_cycles() {
383        let pool = BufferPool::new(BufferSize::new(4096).unwrap(), 5);
384
385        // Do multiple get/return cycles
386        for i in 0..20 {
387            let mut buffer = pool.get_buffer().await;
388            assert_eq!(buffer.capacity(), 4096);
389
390            // Write some data using copy_from_slice
391            let data = vec![i as u8; 1];
392            buffer.copy_from_slice(&data);
393            assert_eq!(buffer.initialized(), 1);
394        }
395    }
396
397    #[test]
398    fn test_buffer_pool_clone() {
399        let pool1 = BufferPool::new(BufferSize::new(1024).unwrap(), 5);
400        let _pool2 = pool1.clone();
401
402        // Both should share the same underlying pool
403        // (Arc ensures shared ownership)
404    }
405
406    #[tokio::test]
407    async fn test_different_buffer_sizes() {
408        let small_pool = BufferPool::new(BufferSize::new(1024).unwrap(), 5);
409        let medium_pool = BufferPool::new(BufferSize::new(8192).unwrap(), 5);
410        let large_pool = BufferPool::new(BufferSize::new(65536).unwrap(), 5);
411
412        let small_buf = small_pool.get_buffer().await;
413        let medium_buf = medium_pool.get_buffer().await;
414        let large_buf = large_pool.get_buffer().await;
415
416        assert_eq!(small_buf.capacity(), 1024);
417        assert_eq!(medium_buf.capacity(), 8192);
418        assert_eq!(large_buf.capacity(), 65536);
419
420        // Buffers auto-return on drop
421    }
422
423    #[tokio::test]
424    async fn test_buffer_pool_stress() {
425        let pool = BufferPool::new(BufferSize::new(4096).unwrap(), 10);
426
427        // Stress test with many concurrent operations
428        let mut handles = vec![];
429
430        for _ in 0..100 {
431            let pool_clone = pool.clone();
432            let handle = tokio::spawn(async move {
433                for _ in 0..10 {
434                    let buffer = pool_clone.get_buffer().await;
435                    assert_eq!(buffer.capacity(), 4096);
436                }
437            });
438            handles.push(handle);
439        }
440
441        for handle in handles {
442            handle.await.unwrap();
443        }
444    }
445}