1use bytes::{Bytes, BytesMut};
41use crossbeam_channel::{bounded, Receiver, Sender};
42use std::cell::RefCell;
43use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
44use std::sync::Arc;
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum SizeClass {
53 Small = 0,
55 Medium = 1,
57 Large = 2,
59 Huge = 3,
61}
62
63impl SizeClass {
64 pub const fn size(&self) -> usize {
66 match self {
67 Self::Small => 4 * 1024, Self::Medium => 64 * 1024, Self::Large => 1024 * 1024, Self::Huge => 0, }
72 }
73
74 pub fn for_size(size: usize) -> Self {
76 if size <= Self::Small.size() {
77 Self::Small
78 } else if size <= Self::Medium.size() {
79 Self::Medium
80 } else if size <= Self::Large.size() {
81 Self::Large
82 } else {
83 Self::Huge
84 }
85 }
86}
87
88#[derive(Debug, Clone)]
90pub struct BufferPoolConfig {
91 pub small_pool_size: usize,
93 pub medium_pool_size: usize,
95 pub large_pool_size: usize,
97 pub thread_cache_size: usize,
99 pub enable_tracking: bool,
101}
102
103impl Default for BufferPoolConfig {
104 fn default() -> Self {
105 Self {
106 small_pool_size: 1024,
107 medium_pool_size: 256,
108 large_pool_size: 32,
109 thread_cache_size: 16,
110 enable_tracking: true,
111 }
112 }
113}
114
115impl BufferPoolConfig {
116 pub fn high_throughput() -> Self {
118 Self {
119 small_pool_size: 4096,
120 medium_pool_size: 1024,
121 large_pool_size: 128,
122 thread_cache_size: 64,
123 enable_tracking: false,
124 }
125 }
126
127 pub fn low_memory() -> Self {
129 Self {
130 small_pool_size: 256,
131 medium_pool_size: 64,
132 large_pool_size: 8,
133 thread_cache_size: 4,
134 enable_tracking: true,
135 }
136 }
137}
138
139#[derive(Debug, Default)]
145pub struct PoolStats {
146 pub allocations: AtomicU64,
148 pub deallocations: AtomicU64,
150 pub cache_hits: AtomicU64,
152 pub cache_misses: AtomicU64,
154 pub pool_misses: AtomicU64,
156 pub bytes_allocated: AtomicUsize,
158 pub peak_bytes: AtomicUsize,
160}
161
162impl PoolStats {
163 pub fn cache_hit_rate(&self) -> f64 {
165 let hits = self.cache_hits.load(Ordering::Relaxed);
166 let misses = self.cache_misses.load(Ordering::Relaxed);
167 let total = hits + misses;
168 if total == 0 {
169 0.0
170 } else {
171 hits as f64 / total as f64
172 }
173 }
174
175 pub fn pool_hit_rate(&self) -> f64 {
177 let allocs = self.allocations.load(Ordering::Relaxed);
178 let misses = self.pool_misses.load(Ordering::Relaxed);
179 if allocs == 0 {
180 1.0
181 } else {
182 1.0 - (misses as f64 / allocs as f64)
183 }
184 }
185}
186
187pub struct BufferPool {
193 small_pool: (Sender<BytesMut>, Receiver<BytesMut>),
195 medium_pool: (Sender<BytesMut>, Receiver<BytesMut>),
197 large_pool: (Sender<BytesMut>, Receiver<BytesMut>),
199 config: BufferPoolConfig,
201 stats: Arc<PoolStats>,
203}
204
205impl BufferPool {
206 pub fn new(config: BufferPoolConfig) -> Arc<Self> {
208 let small_pool = bounded(config.small_pool_size);
209 let medium_pool = bounded(config.medium_pool_size);
210 let large_pool = bounded(config.large_pool_size);
211
212 let pool = Arc::new(Self {
213 small_pool,
214 medium_pool,
215 large_pool,
216 config: config.clone(),
217 stats: Arc::new(PoolStats::default()),
218 });
219
220 pool.preallocate();
222
223 pool
224 }
225
226 fn preallocate(&self) {
228 for _ in 0..self.config.small_pool_size {
230 let buf = BytesMut::with_capacity(SizeClass::Small.size());
231 let _ = self.small_pool.0.try_send(buf);
232 }
233
234 for _ in 0..self.config.medium_pool_size {
236 let buf = BytesMut::with_capacity(SizeClass::Medium.size());
237 let _ = self.medium_pool.0.try_send(buf);
238 }
239
240 for _ in 0..self.config.large_pool_size {
242 let buf = BytesMut::with_capacity(SizeClass::Large.size());
243 let _ = self.large_pool.0.try_send(buf);
244 }
245 }
246
247 pub fn allocate(&self, size: usize) -> BytesMut {
249 if self.config.enable_tracking {
250 self.stats.allocations.fetch_add(1, Ordering::Relaxed);
251 }
252
253 let class = SizeClass::for_size(size);
254 let (receiver, class_size) = match class {
255 SizeClass::Small => (&self.small_pool.1, SizeClass::Small.size()),
256 SizeClass::Medium => (&self.medium_pool.1, SizeClass::Medium.size()),
257 SizeClass::Large => (&self.large_pool.1, SizeClass::Large.size()),
258 SizeClass::Huge => {
259 if self.config.enable_tracking {
261 self.stats.pool_misses.fetch_add(1, Ordering::Relaxed);
262 self.update_bytes_allocated(size as isize);
263 }
264 return BytesMut::with_capacity(size);
265 }
266 };
267
268 match receiver.try_recv() {
270 Ok(mut buf) => {
271 buf.clear();
272 if self.config.enable_tracking {
273 self.update_bytes_allocated(class_size as isize);
274 }
275 buf
276 }
277 Err(_) => {
278 if self.config.enable_tracking {
280 self.stats.pool_misses.fetch_add(1, Ordering::Relaxed);
281 self.update_bytes_allocated(class_size as isize);
282 }
283 BytesMut::with_capacity(class_size)
284 }
285 }
286 }
287
288 pub fn deallocate(&self, mut buf: BytesMut) {
290 if self.config.enable_tracking {
291 self.stats.deallocations.fetch_add(1, Ordering::Relaxed);
292 self.update_bytes_allocated(-(buf.capacity() as isize));
293 }
294
295 buf.clear();
296 let class = SizeClass::for_size(buf.capacity());
297
298 let sender = match class {
299 SizeClass::Small => &self.small_pool.0,
300 SizeClass::Medium => &self.medium_pool.0,
301 SizeClass::Large => &self.large_pool.0,
302 SizeClass::Huge => return, };
304
305 let _ = sender.try_send(buf);
307 }
308
309 pub fn stats(&self) -> &PoolStats {
311 &self.stats
312 }
313
314 fn update_bytes_allocated(&self, delta: isize) {
315 if delta > 0 {
316 let new = self
317 .stats
318 .bytes_allocated
319 .fetch_add(delta as usize, Ordering::Relaxed)
320 + delta as usize;
321 let mut peak = self.stats.peak_bytes.load(Ordering::Relaxed);
323 while new > peak {
324 match self.stats.peak_bytes.compare_exchange_weak(
325 peak,
326 new,
327 Ordering::AcqRel,
328 Ordering::Relaxed,
329 ) {
330 Ok(_) => break,
331 Err(p) => peak = p,
332 }
333 }
334 } else {
335 self.stats
336 .bytes_allocated
337 .fetch_sub((-delta) as usize, Ordering::Relaxed);
338 }
339 }
340}
341
342thread_local! {
347 static THREAD_CACHE: RefCell<ThreadCache> = RefCell::new(ThreadCache::new());
348}
349
350struct ThreadCache {
352 small: Vec<BytesMut>,
353 medium: Vec<BytesMut>,
354 large: Vec<BytesMut>,
355 max_size: usize,
356}
357
358impl ThreadCache {
359 fn new() -> Self {
360 Self {
361 small: Vec::with_capacity(16),
362 medium: Vec::with_capacity(8),
363 large: Vec::with_capacity(4),
364 max_size: 16,
365 }
366 }
367
368 fn get(&mut self, class: SizeClass) -> Option<BytesMut> {
369 match class {
370 SizeClass::Small => self.small.pop(),
371 SizeClass::Medium => self.medium.pop(),
372 SizeClass::Large => self.large.pop(),
373 SizeClass::Huge => None,
374 }
375 }
376
377 fn put(&mut self, buf: BytesMut) -> bool {
378 let class = SizeClass::for_size(buf.capacity());
379 let (cache, max) = match class {
380 SizeClass::Small => (&mut self.small, self.max_size),
381 SizeClass::Medium => (&mut self.medium, self.max_size / 2),
382 SizeClass::Large => (&mut self.large, self.max_size / 4),
383 SizeClass::Huge => return false,
384 };
385
386 if cache.len() < max {
387 cache.push(buf);
388 true
389 } else {
390 false
391 }
392 }
393}
394
395pub struct PooledBuffer {
401 inner: Option<BytesMut>,
402 pool: Arc<BufferPool>,
403}
404
405impl PooledBuffer {
406 pub fn new(pool: Arc<BufferPool>, size: usize) -> Self {
408 let buf = THREAD_CACHE
410 .with(|cache| {
411 let mut cache = cache.borrow_mut();
412 let class = SizeClass::for_size(size);
413 cache.get(class)
414 })
415 .unwrap_or_else(|| {
416 if pool.config.enable_tracking {
417 pool.stats.cache_misses.fetch_add(1, Ordering::Relaxed);
418 }
419 pool.allocate(size)
420 });
421
422 if pool.config.enable_tracking && buf.capacity() > 0 {
423 pool.stats.cache_hits.fetch_add(1, Ordering::Relaxed);
424 }
425
426 Self {
427 inner: Some(buf),
428 pool,
429 }
430 }
431
432 pub fn inner_mut(&mut self) -> &mut BytesMut {
434 self.inner.as_mut().unwrap()
435 }
436
437 pub fn inner_ref(&self) -> &BytesMut {
439 self.inner.as_ref().unwrap()
440 }
441
442 pub fn freeze(mut self) -> Bytes {
444 self.inner.take().unwrap().freeze()
445 }
446
447 pub fn len(&self) -> usize {
449 self.inner.as_ref().map(|b| b.len()).unwrap_or(0)
450 }
451
452 pub fn is_empty(&self) -> bool {
454 self.len() == 0
455 }
456
457 pub fn capacity(&self) -> usize {
459 self.inner.as_ref().map(|b| b.capacity()).unwrap_or(0)
460 }
461}
462
463impl Drop for PooledBuffer {
464 fn drop(&mut self) {
465 if let Some(mut buf) = self.inner.take() {
466 buf.clear();
467
468 let returned = THREAD_CACHE.with(|cache| cache.borrow_mut().put(buf.clone()));
470
471 if !returned {
472 self.pool.deallocate(buf);
474 }
475 }
476 }
477}
478
479impl std::ops::Deref for PooledBuffer {
480 type Target = BytesMut;
481
482 fn deref(&self) -> &Self::Target {
483 self.inner.as_ref().unwrap()
484 }
485}
486
487impl std::ops::DerefMut for PooledBuffer {
488 fn deref_mut(&mut self) -> &mut Self::Target {
489 self.inner.as_mut().unwrap()
490 }
491}
492
493#[derive(Default)]
499pub struct BufferChain {
500 buffers: Vec<Bytes>,
501 total_len: usize,
502}
503
504impl BufferChain {
505 pub fn new() -> Self {
507 Self::default()
508 }
509
510 pub fn single(buf: Bytes) -> Self {
512 let len = buf.len();
513 Self {
514 buffers: vec![buf],
515 total_len: len,
516 }
517 }
518
519 pub fn push(&mut self, buf: Bytes) {
521 self.total_len += buf.len();
522 self.buffers.push(buf);
523 }
524
525 pub fn prepend(&mut self, buf: Bytes) {
527 self.total_len += buf.len();
528 self.buffers.insert(0, buf);
529 }
530
531 pub fn len(&self) -> usize {
533 self.total_len
534 }
535
536 pub fn is_empty(&self) -> bool {
538 self.total_len == 0
539 }
540
541 pub fn buffer_count(&self) -> usize {
543 self.buffers.len()
544 }
545
546 pub fn iter(&self) -> impl Iterator<Item = &Bytes> {
548 self.buffers.iter()
549 }
550
551 pub fn flatten(self) -> Bytes {
553 if self.buffers.len() == 1 {
554 return self.buffers.into_iter().next().unwrap();
555 }
556
557 let mut result = BytesMut::with_capacity(self.total_len);
558 for buf in self.buffers {
559 result.extend_from_slice(&buf);
560 }
561 result.freeze()
562 }
563
564 pub fn as_slices(&self) -> Vec<&[u8]> {
566 self.buffers.iter().map(|b| b.as_ref()).collect()
567 }
568}
569
570#[repr(C, align(4096))]
577pub struct AlignedBuffer {
578 data: [u8; 4096],
579}
580
581impl Default for AlignedBuffer {
582 fn default() -> Self {
583 Self::new()
584 }
585}
586
587impl AlignedBuffer {
588 pub const fn new() -> Self {
590 Self { data: [0u8; 4096] }
591 }
592
593 pub fn as_slice(&self) -> &[u8] {
595 &self.data
596 }
597
598 pub fn as_mut_slice(&mut self) -> &mut [u8] {
600 &mut self.data
601 }
602
603 pub fn is_aligned(&self) -> bool {
605 (self.data.as_ptr() as usize).is_multiple_of(4096)
606 }
607}
608
609#[cfg(test)]
614mod tests {
615 use super::*;
616
617 #[test]
618 fn test_size_class() {
619 assert_eq!(SizeClass::for_size(100), SizeClass::Small);
620 assert_eq!(SizeClass::for_size(4096), SizeClass::Small);
621 assert_eq!(SizeClass::for_size(4097), SizeClass::Medium);
622 assert_eq!(SizeClass::for_size(65536), SizeClass::Medium);
623 assert_eq!(SizeClass::for_size(65537), SizeClass::Large);
624 assert_eq!(SizeClass::for_size(1024 * 1024), SizeClass::Large);
625 assert_eq!(SizeClass::for_size(1024 * 1024 + 1), SizeClass::Huge);
626 }
627
628 #[test]
629 fn test_buffer_pool_allocate() {
630 let pool = BufferPool::new(BufferPoolConfig::default());
631
632 let buf1 = pool.allocate(100);
633 assert!(buf1.capacity() >= 100);
634 assert!(buf1.capacity() <= SizeClass::Small.size());
635
636 let buf2 = pool.allocate(10000);
637 assert!(buf2.capacity() >= 10000);
638 assert!(buf2.capacity() <= SizeClass::Medium.size());
639 }
640
641 #[test]
642 fn test_buffer_pool_roundtrip() {
643 let pool = BufferPool::new(BufferPoolConfig::default());
644
645 let buf = pool.allocate(1000);
646 let cap = buf.capacity();
647
648 pool.deallocate(buf);
649
650 let buf2 = pool.allocate(1000);
651 assert_eq!(buf2.capacity(), cap);
652 }
653
654 #[test]
655 fn test_pooled_buffer() {
656 let pool = BufferPool::new(BufferPoolConfig::default());
657
658 {
659 let mut buf = PooledBuffer::new(pool.clone(), 1000);
660 buf.extend_from_slice(b"hello world");
661 assert_eq!(buf.len(), 11);
662 }
663 }
665
666 #[test]
667 fn test_buffer_chain() {
668 let mut chain = BufferChain::new();
669 chain.push(Bytes::from_static(b"hello "));
670 chain.push(Bytes::from_static(b"world"));
671
672 assert_eq!(chain.len(), 11);
673 assert_eq!(chain.buffer_count(), 2);
674
675 let flat = chain.flatten();
676 assert_eq!(&flat[..], b"hello world");
677 }
678
679 #[test]
680 fn test_aligned_buffer() {
681 let buf = AlignedBuffer::new();
682 assert!(buf.is_aligned());
683 }
684
685 #[test]
686 fn test_pool_stats() {
687 let config = BufferPoolConfig {
688 enable_tracking: true,
689 ..Default::default()
690 };
691 let pool = BufferPool::new(config);
692
693 let _buf1 = pool.allocate(100);
694 let _buf2 = pool.allocate(200);
695
696 assert_eq!(pool.stats().allocations.load(Ordering::Relaxed), 2);
697 }
698}