Skip to main content

rivven_core/
buffer_pool.rs

1//! Zero-Copy Buffer Pool
2//!
3//! High-performance buffer management for Rivven with:
4//! - **Slab Allocation**: Pre-allocated buffers to avoid runtime allocation
5//! - **Size Classes**: Different buffer sizes for optimal memory usage
6//! - **Thread-Local Caching**: Reduce contention in hot paths
7//! - **Reference Counting**: Safe buffer sharing without copies
8//!
9//! # Architecture
10//!
11//! ```text
12//! ┌─────────────────────────────────────────────────────────────────┐
13//! │                      Buffer Pool                                │
14//! ├─────────────────────────────────────────────────────────────────┤
15//! │  ┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────────────┐ │
16//! │  │  Small   │  │  Medium  │  │  Large   │  │   Huge (alloc)   │ │
17//! │  │  <= 4KB  │  │  <= 64KB │  │  <= 1MB  │  │   > 1MB          │ │
18//! │  └────┬─────┘  └────┬─────┘  └────┬─────┘  └────────┬─────────┘ │
19//! │       │              │              │                │          │
20//! │       └──────────────┴──────────────┴────────────────┘          │
21//! │                           │                                      │
22//! │              ┌────────────▼────────────┐                        │
23//! │              │    Thread-Local Cache    │                        │
24//! │              │   (lock-free fast path)  │                        │
25//! │              └────────────┬────────────┘                        │
26//! │                           │                                      │
27//! │              ┌────────────▼────────────┐                        │
28//! │              │    Global Pool (CAS)     │                        │
29//! │              └─────────────────────────┘                        │
30//! └─────────────────────────────────────────────────────────────────┘
31//! ```
32//!
33//! # Performance Characteristics
34//!
35//! - Allocation: O(1) for cached sizes
36//! - Deallocation: O(1) (return to pool)
37//! - Memory overhead: ~4 bytes per buffer (ref count)
38//! - Contention: Near-zero with thread-local caching
39
40use bytes::{Bytes, BytesMut};
41use crossbeam_channel::{bounded, Receiver, Sender};
42use std::cell::RefCell;
43use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
44use std::sync::Arc;
45
46// ============================================================================
47// Configuration
48// ============================================================================
49
50/// Buffer size classes
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum SizeClass {
53    /// Small buffers (<= 4KB) - for headers, small messages
54    Small = 0,
55    /// Medium buffers (<= 64KB) - for typical messages
56    Medium = 1,
57    /// Large buffers (<= 1MB) - for batch operations
58    Large = 2,
59    /// Huge buffers (> 1MB) - direct allocation
60    Huge = 3,
61}
62
63impl SizeClass {
64    /// Get the buffer size for this class
65    pub const fn size(&self) -> usize {
66        match self {
67            Self::Small => 4 * 1024,    // 4 KB
68            Self::Medium => 64 * 1024,  // 64 KB
69            Self::Large => 1024 * 1024, // 1 MB
70            Self::Huge => 0,            // Dynamic
71        }
72    }
73
74    /// Determine size class for a given size
75    pub fn for_size(size: usize) -> Self {
76        if size <= Self::Small.size() {
77            Self::Small
78        } else if size <= Self::Medium.size() {
79            Self::Medium
80        } else if size <= Self::Large.size() {
81            Self::Large
82        } else {
83            Self::Huge
84        }
85    }
86}
87
88/// Buffer pool configuration
89#[derive(Debug, Clone)]
90pub struct BufferPoolConfig {
91    /// Number of small buffers to pre-allocate
92    pub small_pool_size: usize,
93    /// Number of medium buffers to pre-allocate  
94    pub medium_pool_size: usize,
95    /// Number of large buffers to pre-allocate
96    pub large_pool_size: usize,
97    /// Thread-local cache size per size class
98    pub thread_cache_size: usize,
99    /// Enable memory usage tracking
100    pub enable_tracking: bool,
101}
102
103impl Default for BufferPoolConfig {
104    fn default() -> Self {
105        Self {
106            small_pool_size: 1024,
107            medium_pool_size: 256,
108            large_pool_size: 32,
109            thread_cache_size: 16,
110            enable_tracking: true,
111        }
112    }
113}
114
115impl BufferPoolConfig {
116    /// Configuration for high-throughput workloads
117    pub fn high_throughput() -> Self {
118        Self {
119            small_pool_size: 4096,
120            medium_pool_size: 1024,
121            large_pool_size: 128,
122            thread_cache_size: 64,
123            enable_tracking: false,
124        }
125    }
126
127    /// Configuration for low-memory environments
128    pub fn low_memory() -> Self {
129        Self {
130            small_pool_size: 256,
131            medium_pool_size: 64,
132            large_pool_size: 8,
133            thread_cache_size: 4,
134            enable_tracking: true,
135        }
136    }
137}
138
139// ============================================================================
140// Pool Statistics
141// ============================================================================
142
143/// Buffer pool statistics
144#[derive(Debug, Default)]
145pub struct PoolStats {
146    /// Total allocations
147    pub allocations: AtomicU64,
148    /// Total deallocations (returns to pool)
149    pub deallocations: AtomicU64,
150    /// Cache hits (from thread-local)
151    pub cache_hits: AtomicU64,
152    /// Cache misses (went to global pool)
153    pub cache_misses: AtomicU64,
154    /// Pool misses (new allocation)
155    pub pool_misses: AtomicU64,
156    /// Current bytes allocated
157    pub bytes_allocated: AtomicUsize,
158    /// Peak bytes allocated
159    pub peak_bytes: AtomicUsize,
160}
161
162impl PoolStats {
163    /// Get cache hit rate (0.0 - 1.0)
164    pub fn cache_hit_rate(&self) -> f64 {
165        let hits = self.cache_hits.load(Ordering::Relaxed);
166        let misses = self.cache_misses.load(Ordering::Relaxed);
167        let total = hits + misses;
168        if total == 0 {
169            0.0
170        } else {
171            hits as f64 / total as f64
172        }
173    }
174
175    /// Get pool hit rate (0.0 - 1.0)
176    pub fn pool_hit_rate(&self) -> f64 {
177        let allocs = self.allocations.load(Ordering::Relaxed);
178        let misses = self.pool_misses.load(Ordering::Relaxed);
179        if allocs == 0 {
180            1.0
181        } else {
182            1.0 - (misses as f64 / allocs as f64)
183        }
184    }
185}
186
187// ============================================================================
188// Global Buffer Pool
189// ============================================================================
190
191/// Global buffer pool with size-class segregation
192pub struct BufferPool {
193    /// Small buffer free list
194    small_pool: (Sender<BytesMut>, Receiver<BytesMut>),
195    /// Medium buffer free list
196    medium_pool: (Sender<BytesMut>, Receiver<BytesMut>),
197    /// Large buffer free list
198    large_pool: (Sender<BytesMut>, Receiver<BytesMut>),
199    /// Configuration
200    config: BufferPoolConfig,
201    /// Statistics
202    stats: Arc<PoolStats>,
203}
204
205impl BufferPool {
206    /// Create a new buffer pool
207    pub fn new(config: BufferPoolConfig) -> Arc<Self> {
208        let small_pool = bounded(config.small_pool_size);
209        let medium_pool = bounded(config.medium_pool_size);
210        let large_pool = bounded(config.large_pool_size);
211
212        let pool = Arc::new(Self {
213            small_pool,
214            medium_pool,
215            large_pool,
216            config: config.clone(),
217            stats: Arc::new(PoolStats::default()),
218        });
219
220        // Pre-allocate buffers
221        pool.preallocate();
222
223        pool
224    }
225
226    /// Pre-allocate buffers to fill the pools
227    fn preallocate(&self) {
228        // Small buffers
229        for _ in 0..self.config.small_pool_size {
230            let buf = BytesMut::with_capacity(SizeClass::Small.size());
231            let _ = self.small_pool.0.try_send(buf);
232        }
233
234        // Medium buffers
235        for _ in 0..self.config.medium_pool_size {
236            let buf = BytesMut::with_capacity(SizeClass::Medium.size());
237            let _ = self.medium_pool.0.try_send(buf);
238        }
239
240        // Large buffers
241        for _ in 0..self.config.large_pool_size {
242            let buf = BytesMut::with_capacity(SizeClass::Large.size());
243            let _ = self.large_pool.0.try_send(buf);
244        }
245    }
246
247    /// Allocate a buffer of at least the given size
248    pub fn allocate(&self, size: usize) -> BytesMut {
249        if self.config.enable_tracking {
250            self.stats.allocations.fetch_add(1, Ordering::Relaxed);
251        }
252
253        let class = SizeClass::for_size(size);
254        let (receiver, class_size) = match class {
255            SizeClass::Small => (&self.small_pool.1, SizeClass::Small.size()),
256            SizeClass::Medium => (&self.medium_pool.1, SizeClass::Medium.size()),
257            SizeClass::Large => (&self.large_pool.1, SizeClass::Large.size()),
258            SizeClass::Huge => {
259                // Huge buffers are always freshly allocated
260                if self.config.enable_tracking {
261                    self.stats.pool_misses.fetch_add(1, Ordering::Relaxed);
262                    self.update_bytes_allocated(size as isize);
263                }
264                return BytesMut::with_capacity(size);
265            }
266        };
267
268        // Try to get from pool
269        match receiver.try_recv() {
270            Ok(mut buf) => {
271                buf.clear();
272                if self.config.enable_tracking {
273                    self.update_bytes_allocated(class_size as isize);
274                }
275                buf
276            }
277            Err(_) => {
278                // Pool empty, allocate new
279                if self.config.enable_tracking {
280                    self.stats.pool_misses.fetch_add(1, Ordering::Relaxed);
281                    self.update_bytes_allocated(class_size as isize);
282                }
283                BytesMut::with_capacity(class_size)
284            }
285        }
286    }
287
288    /// Return a buffer to the pool
289    pub fn deallocate(&self, mut buf: BytesMut) {
290        if self.config.enable_tracking {
291            self.stats.deallocations.fetch_add(1, Ordering::Relaxed);
292            self.update_bytes_allocated(-(buf.capacity() as isize));
293        }
294
295        buf.clear();
296        let class = SizeClass::for_size(buf.capacity());
297
298        let sender = match class {
299            SizeClass::Small => &self.small_pool.0,
300            SizeClass::Medium => &self.medium_pool.0,
301            SizeClass::Large => &self.large_pool.0,
302            SizeClass::Huge => return, // Don't pool huge buffers
303        };
304
305        // Try to return to pool, drop if full
306        let _ = sender.try_send(buf);
307    }
308
309    /// Get pool statistics
310    pub fn stats(&self) -> &PoolStats {
311        &self.stats
312    }
313
314    fn update_bytes_allocated(&self, delta: isize) {
315        if delta > 0 {
316            let new = self
317                .stats
318                .bytes_allocated
319                .fetch_add(delta as usize, Ordering::Relaxed)
320                + delta as usize;
321            // Update peak if necessary
322            let mut peak = self.stats.peak_bytes.load(Ordering::Relaxed);
323            while new > peak {
324                match self.stats.peak_bytes.compare_exchange_weak(
325                    peak,
326                    new,
327                    Ordering::AcqRel,
328                    Ordering::Relaxed,
329                ) {
330                    Ok(_) => break,
331                    Err(p) => peak = p,
332                }
333            }
334        } else {
335            self.stats
336                .bytes_allocated
337                .fetch_sub((-delta) as usize, Ordering::Relaxed);
338        }
339    }
340}
341
342// ============================================================================
343// Thread-Local Buffer Cache
344// ============================================================================
345
346thread_local! {
347    static THREAD_CACHE: RefCell<ThreadCache> = RefCell::new(ThreadCache::new());
348}
349
350/// Thread-local buffer cache for lock-free fast path
351struct ThreadCache {
352    small: Vec<BytesMut>,
353    medium: Vec<BytesMut>,
354    large: Vec<BytesMut>,
355    max_size: usize,
356}
357
358impl ThreadCache {
359    fn new() -> Self {
360        Self {
361            small: Vec::with_capacity(16),
362            medium: Vec::with_capacity(8),
363            large: Vec::with_capacity(4),
364            max_size: 16,
365        }
366    }
367
368    fn get(&mut self, class: SizeClass) -> Option<BytesMut> {
369        match class {
370            SizeClass::Small => self.small.pop(),
371            SizeClass::Medium => self.medium.pop(),
372            SizeClass::Large => self.large.pop(),
373            SizeClass::Huge => None,
374        }
375    }
376
377    fn put(&mut self, buf: BytesMut) -> bool {
378        let class = SizeClass::for_size(buf.capacity());
379        let (cache, max) = match class {
380            SizeClass::Small => (&mut self.small, self.max_size),
381            SizeClass::Medium => (&mut self.medium, self.max_size / 2),
382            SizeClass::Large => (&mut self.large, self.max_size / 4),
383            SizeClass::Huge => return false,
384        };
385
386        if cache.len() < max {
387            cache.push(buf);
388            true
389        } else {
390            false
391        }
392    }
393}
394
395// ============================================================================
396// Pooled Buffer Handle
397// ============================================================================
398
399/// A buffer handle that returns to the pool on drop
400pub struct PooledBuffer {
401    inner: Option<BytesMut>,
402    pool: Arc<BufferPool>,
403}
404
405impl PooledBuffer {
406    /// Create from pool
407    pub fn new(pool: Arc<BufferPool>, size: usize) -> Self {
408        // Try thread-local cache first
409        let buf = THREAD_CACHE
410            .with(|cache| {
411                let mut cache = cache.borrow_mut();
412                let class = SizeClass::for_size(size);
413                cache.get(class)
414            })
415            .unwrap_or_else(|| {
416                if pool.config.enable_tracking {
417                    pool.stats.cache_misses.fetch_add(1, Ordering::Relaxed);
418                }
419                pool.allocate(size)
420            });
421
422        if pool.config.enable_tracking && buf.capacity() > 0 {
423            pool.stats.cache_hits.fetch_add(1, Ordering::Relaxed);
424        }
425
426        Self {
427            inner: Some(buf),
428            pool,
429        }
430    }
431
432    /// Get mutable reference to buffer
433    pub fn inner_mut(&mut self) -> &mut BytesMut {
434        self.inner.as_mut().unwrap()
435    }
436
437    /// Get immutable reference to buffer
438    pub fn inner_ref(&self) -> &BytesMut {
439        self.inner.as_ref().unwrap()
440    }
441
442    /// Freeze into immutable Bytes (consumes the buffer)
443    pub fn freeze(mut self) -> Bytes {
444        self.inner.take().unwrap().freeze()
445    }
446
447    /// Get length of data in buffer
448    pub fn len(&self) -> usize {
449        self.inner.as_ref().map(|b| b.len()).unwrap_or(0)
450    }
451
452    /// Check if buffer is empty
453    pub fn is_empty(&self) -> bool {
454        self.len() == 0
455    }
456
457    /// Get capacity of buffer
458    pub fn capacity(&self) -> usize {
459        self.inner.as_ref().map(|b| b.capacity()).unwrap_or(0)
460    }
461}
462
463impl Drop for PooledBuffer {
464    fn drop(&mut self) {
465        if let Some(mut buf) = self.inner.take() {
466            buf.clear();
467
468            // Try thread-local cache first
469            let returned = THREAD_CACHE.with(|cache| cache.borrow_mut().put(buf.clone()));
470
471            if !returned {
472                // Thread-local cache full, return to global pool
473                self.pool.deallocate(buf);
474            }
475        }
476    }
477}
478
479impl std::ops::Deref for PooledBuffer {
480    type Target = BytesMut;
481
482    fn deref(&self) -> &Self::Target {
483        self.inner.as_ref().unwrap()
484    }
485}
486
487impl std::ops::DerefMut for PooledBuffer {
488    fn deref_mut(&mut self) -> &mut Self::Target {
489        self.inner.as_mut().unwrap()
490    }
491}
492
493// ============================================================================
494// Zero-Copy Buffer Chain
495// ============================================================================
496
497/// A chain of buffers for scatter-gather I/O
498#[derive(Default)]
499pub struct BufferChain {
500    buffers: Vec<Bytes>,
501    total_len: usize,
502}
503
504impl BufferChain {
505    /// Create empty chain
506    pub fn new() -> Self {
507        Self::default()
508    }
509
510    /// Create chain with single buffer
511    pub fn single(buf: Bytes) -> Self {
512        let len = buf.len();
513        Self {
514            buffers: vec![buf],
515            total_len: len,
516        }
517    }
518
519    /// Append a buffer to the chain
520    pub fn push(&mut self, buf: Bytes) {
521        self.total_len += buf.len();
522        self.buffers.push(buf);
523    }
524
525    /// Prepend a buffer to the chain
526    pub fn prepend(&mut self, buf: Bytes) {
527        self.total_len += buf.len();
528        self.buffers.insert(0, buf);
529    }
530
531    /// Get total length of all buffers
532    pub fn len(&self) -> usize {
533        self.total_len
534    }
535
536    /// Check if chain is empty
537    pub fn is_empty(&self) -> bool {
538        self.total_len == 0
539    }
540
541    /// Get number of buffers in chain
542    pub fn buffer_count(&self) -> usize {
543        self.buffers.len()
544    }
545
546    /// Iterate over buffers
547    pub fn iter(&self) -> impl Iterator<Item = &Bytes> {
548        self.buffers.iter()
549    }
550
551    /// Flatten into single buffer (copies data)
552    pub fn flatten(self) -> Bytes {
553        if self.buffers.len() == 1 {
554            return self.buffers.into_iter().next().unwrap();
555        }
556
557        let mut result = BytesMut::with_capacity(self.total_len);
558        for buf in self.buffers {
559            result.extend_from_slice(&buf);
560        }
561        result.freeze()
562    }
563
564    /// Convert to iovec-style slices for vectored I/O
565    pub fn as_slices(&self) -> Vec<&[u8]> {
566        self.buffers.iter().map(|b| b.as_ref()).collect()
567    }
568}
569
570// ============================================================================
571// Aligned Buffer for Direct I/O
572// ============================================================================
573
574/// Buffer aligned for direct I/O (O_DIRECT)
575/// Required alignment is typically 512 bytes or 4KB
576#[repr(C, align(4096))]
577pub struct AlignedBuffer {
578    data: [u8; 4096],
579}
580
581impl Default for AlignedBuffer {
582    fn default() -> Self {
583        Self::new()
584    }
585}
586
587impl AlignedBuffer {
588    /// Create new aligned buffer
589    pub const fn new() -> Self {
590        Self { data: [0u8; 4096] }
591    }
592
593    /// Get slice reference
594    pub fn as_slice(&self) -> &[u8] {
595        &self.data
596    }
597
598    /// Get mutable slice reference
599    pub fn as_mut_slice(&mut self) -> &mut [u8] {
600        &mut self.data
601    }
602
603    /// Check alignment
604    pub fn is_aligned(&self) -> bool {
605        (self.data.as_ptr() as usize).is_multiple_of(4096)
606    }
607}
608
609// ============================================================================
610// Tests
611// ============================================================================
612
613#[cfg(test)]
614mod tests {
615    use super::*;
616
617    #[test]
618    fn test_size_class() {
619        assert_eq!(SizeClass::for_size(100), SizeClass::Small);
620        assert_eq!(SizeClass::for_size(4096), SizeClass::Small);
621        assert_eq!(SizeClass::for_size(4097), SizeClass::Medium);
622        assert_eq!(SizeClass::for_size(65536), SizeClass::Medium);
623        assert_eq!(SizeClass::for_size(65537), SizeClass::Large);
624        assert_eq!(SizeClass::for_size(1024 * 1024), SizeClass::Large);
625        assert_eq!(SizeClass::for_size(1024 * 1024 + 1), SizeClass::Huge);
626    }
627
628    #[test]
629    fn test_buffer_pool_allocate() {
630        let pool = BufferPool::new(BufferPoolConfig::default());
631
632        let buf1 = pool.allocate(100);
633        assert!(buf1.capacity() >= 100);
634        assert!(buf1.capacity() <= SizeClass::Small.size());
635
636        let buf2 = pool.allocate(10000);
637        assert!(buf2.capacity() >= 10000);
638        assert!(buf2.capacity() <= SizeClass::Medium.size());
639    }
640
641    #[test]
642    fn test_buffer_pool_roundtrip() {
643        let pool = BufferPool::new(BufferPoolConfig::default());
644
645        let buf = pool.allocate(1000);
646        let cap = buf.capacity();
647
648        pool.deallocate(buf);
649
650        let buf2 = pool.allocate(1000);
651        assert_eq!(buf2.capacity(), cap);
652    }
653
654    #[test]
655    fn test_pooled_buffer() {
656        let pool = BufferPool::new(BufferPoolConfig::default());
657
658        {
659            let mut buf = PooledBuffer::new(pool.clone(), 1000);
660            buf.extend_from_slice(b"hello world");
661            assert_eq!(buf.len(), 11);
662        }
663        // Buffer returned to pool on drop
664    }
665
666    #[test]
667    fn test_buffer_chain() {
668        let mut chain = BufferChain::new();
669        chain.push(Bytes::from_static(b"hello "));
670        chain.push(Bytes::from_static(b"world"));
671
672        assert_eq!(chain.len(), 11);
673        assert_eq!(chain.buffer_count(), 2);
674
675        let flat = chain.flatten();
676        assert_eq!(&flat[..], b"hello world");
677    }
678
679    #[test]
680    fn test_aligned_buffer() {
681        let buf = AlignedBuffer::new();
682        assert!(buf.is_aligned());
683    }
684
685    #[test]
686    fn test_pool_stats() {
687        let config = BufferPoolConfig {
688            enable_tracking: true,
689            ..Default::default()
690        };
691        let pool = BufferPool::new(config);
692
693        let _buf1 = pool.allocate(100);
694        let _buf2 = pool.allocate(200);
695
696        assert_eq!(pool.stats().allocations.load(Ordering::Relaxed), 2);
697    }
698}