1use std::cell::UnsafeCell;
7use std::mem::ManuallyDrop;
8use std::ops::{Deref, DerefMut};
9use std::sync::atomic::{AtomicU32, Ordering};
10
11use crate::error::{Error, Result};
12
13use super::ring::MpmcRing;
14use super::{DEFAULT_POOL_CAPACITY, MAX_PACKET_SIZE};
15
16#[repr(C, align(64))]
18pub struct PacketBuffer {
19 data: [u8; MAX_PACKET_SIZE],
21 len: u32,
23 index: u32,
25 refcount: AtomicU32,
27}
28
29impl PacketBuffer {
30 #[inline]
32 #[allow(clippy::large_stack_arrays)] const fn new(index: u32) -> Self {
34 Self {
35 data: [0; MAX_PACKET_SIZE],
36 len: 0,
37 index,
38 refcount: AtomicU32::new(0),
39 }
40 }
41
42 #[inline]
44 #[must_use]
45 pub const fn index(&self) -> u32 {
46 self.index
47 }
48
49 #[inline]
51 #[must_use]
52 pub const fn len(&self) -> usize {
53 self.len as usize
54 }
55
56 #[inline]
58 #[must_use]
59 pub const fn is_empty(&self) -> bool {
60 self.len == 0
61 }
62
63 #[inline]
65 pub fn set_len(&mut self, len: usize) {
66 self.len = len.min(MAX_PACKET_SIZE) as u32;
67 }
68
69 #[inline]
71 #[must_use]
72 pub fn as_slice(&self) -> &[u8] {
73 &self.data[..self.len as usize]
74 }
75
76 #[inline]
78 #[must_use]
79 pub fn as_mut_slice(&mut self) -> &mut [u8] {
80 &mut self.data[..self.len as usize]
81 }
82
83 #[inline]
85 #[must_use]
86 pub fn as_full_slice(&self) -> &[u8] {
87 &self.data
88 }
89
90 #[inline]
92 #[must_use]
93 pub fn as_full_mut_slice(&mut self) -> &mut [u8] {
94 &mut self.data
95 }
96
97 #[inline]
99 #[must_use]
100 pub fn as_ptr(&self) -> *const u8 {
101 self.data.as_ptr()
102 }
103
104 #[inline]
106 #[must_use]
107 pub fn as_mut_ptr(&mut self) -> *mut u8 {
108 self.data.as_mut_ptr()
109 }
110
111 #[inline]
113 pub fn add_ref(&self) {
114 self.refcount.fetch_add(1, Ordering::AcqRel);
115 }
116
117 #[inline]
119 pub fn release(&self) -> bool {
120 self.refcount.fetch_sub(1, Ordering::AcqRel) == 1
121 }
122
123 #[inline]
125 #[must_use]
126 pub fn refcount(&self) -> u32 {
127 self.refcount.load(Ordering::Acquire)
128 }
129
130 #[inline]
132 pub fn reset(&mut self) {
133 self.len = 0;
134 self.refcount.store(0, Ordering::Release);
135 }
136
137 pub fn copy_from_slice(&mut self, data: &[u8]) -> Result<()> {
143 if data.len() > MAX_PACKET_SIZE {
144 return Err(Error::PacketPool(format!(
145 "data too large: {} > {}",
146 data.len(),
147 MAX_PACKET_SIZE
148 )));
149 }
150 self.data[..data.len()].copy_from_slice(data);
151 self.len = data.len() as u32;
152 Ok(())
153 }
154}
155
156#[allow(clippy::missing_fields_in_debug)] impl std::fmt::Debug for PacketBuffer {
158 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159 f.debug_struct("PacketBuffer")
160 .field("len", &self.len)
161 .field("index", &self.index)
162 .field("refcount", &self.refcount.load(Ordering::Relaxed))
163 .finish()
164 }
165}
166
167pub struct PacketRef<'pool> {
174 pool: &'pool PacketPool,
175 idx: u32,
176}
177
178impl PacketRef<'_> {
179 #[inline]
181 #[must_use]
182 pub fn index(&self) -> u32 {
183 self.idx
184 }
185
186 #[inline]
192 #[must_use]
193 pub fn into_index(self) -> u32 {
194 let md = ManuallyDrop::new(self);
195 md.idx
196 }
197}
198
199impl Deref for PacketRef<'_> {
200 type Target = PacketBuffer;
201
202 #[inline]
203 fn deref(&self) -> &PacketBuffer {
204 unsafe { &*self.pool.buffers[self.idx as usize].get() }
208 }
209}
210
211impl DerefMut for PacketRef<'_> {
212 #[inline]
213 fn deref_mut(&mut self) -> &mut PacketBuffer {
214 unsafe { &mut *self.pool.buffers[self.idx as usize].get() }
218 }
219}
220
221impl Drop for PacketRef<'_> {
222 fn drop(&mut self) {
223 unsafe { self.pool.free_by_index(self.idx) };
226 }
227}
228
229impl std::fmt::Debug for PacketRef<'_> {
230 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231 f.debug_struct("PacketRef")
232 .field("idx", &self.idx)
233 .field("buffer", &**self)
234 .finish()
235 }
236}
237
238pub struct PacketPool {
248 buffers: Box<[UnsafeCell<PacketBuffer>]>,
250 free_indices: MpmcRing<u32>,
252 capacity: usize,
254}
255
256impl PacketPool {
257 pub fn new(capacity: usize) -> Result<Self> {
263 let capacity = capacity.max(1);
264
265 let buffers: Vec<UnsafeCell<PacketBuffer>> = (0..capacity)
266 .map(|i| UnsafeCell::new(PacketBuffer::new(i as u32)))
267 .collect();
268
269 let free_indices = MpmcRing::new(capacity);
271 for i in 0..capacity {
272 let _ = free_indices.enqueue(i as u32);
273 }
274
275 Ok(Self {
276 buffers: buffers.into_boxed_slice(),
277 free_indices,
278 capacity,
279 })
280 }
281
282 pub fn with_default_capacity() -> Result<Self> {
288 Self::new(DEFAULT_POOL_CAPACITY)
289 }
290
291 #[inline]
293 #[must_use]
294 pub const fn capacity(&self) -> usize {
295 self.capacity
296 }
297
298 #[inline]
300 #[must_use]
301 pub fn free_count(&self) -> usize {
302 self.free_indices.len()
303 }
304
305 #[inline]
307 #[must_use]
308 pub fn allocated_count(&self) -> usize {
309 self.capacity - self.free_indices.len()
310 }
311
312 pub fn alloc(&self) -> Option<PacketRef<'_>> {
318 let idx = self.free_indices.dequeue()?;
319 unsafe {
321 (*self.buffers[idx as usize].get())
322 .refcount
323 .store(1, Ordering::Release);
324 }
325 Some(PacketRef { pool: self, idx })
326 }
327
328 pub fn alloc_index(&self) -> Option<u32> {
332 let idx = self.free_indices.dequeue()?;
333 unsafe {
335 (*self.buffers[idx as usize].get())
336 .refcount
337 .store(1, Ordering::Release);
338 }
339 Some(idx)
340 }
341
342 pub fn alloc_with_data(&self, data: &[u8]) -> Result<PacketRef<'_>> {
348 let mut pkt = self
349 .alloc()
350 .ok_or_else(|| Error::PacketPool("pool exhausted".to_string()))?;
351 pkt.copy_from_slice(data)?;
352 Ok(pkt)
353 }
354
355 pub(crate) unsafe fn free(&self, buffer: &mut PacketBuffer) {
365 let idx = buffer.index;
366 debug_assert!((idx as usize) < self.capacity);
367 buffer.reset();
368 while self.free_indices.enqueue(idx).is_err() {
373 std::hint::spin_loop();
374 }
375 }
376
377 pub unsafe fn free_by_index(&self, idx: u32) {
383 debug_assert!((idx as usize) < self.capacity);
384
385 let buffer = unsafe { &mut *self.buffers[idx as usize].get() };
387 unsafe { self.free(buffer) };
389 }
390
391 #[must_use]
397 pub unsafe fn get(&self, idx: u32) -> &PacketBuffer {
398 debug_assert!((idx as usize) < self.capacity);
399 unsafe { &*self.buffers[idx as usize].get() }
401 }
402
403 #[must_use]
411 #[allow(clippy::mut_from_ref)] pub unsafe fn get_mut(&self, idx: u32) -> &mut PacketBuffer {
413 debug_assert!((idx as usize) < self.capacity);
414 unsafe { &mut *self.buffers[idx as usize].get() }
416 }
417
418 pub fn alloc_batch_indices(&self, out: &mut [u32]) -> usize {
422 let mut count = 0;
423 for slot in out.iter_mut() {
424 if let Some(idx) = self.alloc_index() {
425 *slot = idx;
426 count += 1;
427 } else {
428 break;
429 }
430 }
431 count
432 }
433}
434
435#[allow(clippy::missing_fields_in_debug)] impl std::fmt::Debug for PacketPool {
437 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
438 f.debug_struct("PacketPool")
439 .field("capacity", &self.capacity)
440 .field("free_count", &self.free_count())
441 .field("allocated_count", &self.allocated_count())
442 .finish()
443 }
444}
445
446unsafe impl Send for PacketPool {}
448unsafe impl Sync for PacketPool {}
449
450#[allow(dead_code)]
453const _ASSERT_PACKET_REF_SEND: () = {
454 const fn assert_send<T: Send>() {}
455 assert_send::<PacketRef<'_>>();
456};
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461
462 #[test]
463 fn test_buffer_size() {
464 assert_eq!(std::mem::align_of::<PacketBuffer>(), 64);
466 }
467
468 #[test]
469 fn test_pool_creation() {
470 let pool = PacketPool::new(100).unwrap();
471 assert_eq!(pool.capacity(), 100);
472 assert_eq!(pool.free_count(), 100);
473 assert_eq!(pool.allocated_count(), 0);
474 }
475
476 #[test]
477 fn test_pool_alloc_drop() {
478 let pool = PacketPool::new(10).unwrap();
479
480 let buf = pool.alloc().unwrap();
482 assert_eq!(buf.refcount(), 1);
483 assert_eq!(pool.free_count(), 9);
484 assert_eq!(pool.allocated_count(), 1);
485
486 drop(buf);
488 assert_eq!(pool.free_count(), 10);
489 assert_eq!(pool.allocated_count(), 0);
490
491 let buf2 = pool.alloc().unwrap();
493 assert!(buf2.index() < 10);
494 }
495
496 #[test]
497 fn test_packet_ref_into_index() {
498 let pool = PacketPool::new(4).unwrap();
499
500 let buf = pool.alloc().unwrap();
501 let idx = buf.index();
502
503 let extracted = buf.into_index();
505 assert_eq!(extracted, idx);
506 assert_eq!(pool.free_count(), 3);
508
509 unsafe { pool.free_by_index(extracted) };
511 assert_eq!(pool.free_count(), 4);
512 }
513
514 #[test]
515 fn test_pool_exhaustion() {
516 let pool = PacketPool::new(2).unwrap();
517
518 let _buf1 = pool.alloc().unwrap();
519 let _buf2 = pool.alloc().unwrap();
520
521 assert!(pool.alloc().is_none());
523 assert_eq!(pool.free_count(), 0);
524 }
525
526 #[test]
527 fn test_buffer_copy() {
528 let pool = PacketPool::new(1).unwrap();
529 let mut buf = pool.alloc().unwrap();
530
531 let data = [1u8, 2, 3, 4, 5];
532 buf.copy_from_slice(&data).unwrap();
533
534 assert_eq!(buf.len(), 5);
535 assert_eq!(buf.as_slice(), &data);
536 }
537
538 #[test]
539 fn test_alloc_with_data() {
540 let pool = PacketPool::new(1).unwrap();
541 let data = [0xAB; 100];
542
543 let buf = pool.alloc_with_data(&data).unwrap();
544 assert_eq!(buf.len(), 100);
545 assert_eq!(buf.as_slice(), &data);
546 }
547
548 #[test]
549 fn test_batch_alloc_indices() {
550 let pool = PacketPool::new(5).unwrap();
551 let mut indices = [0u32; 10];
552
553 let count = pool.alloc_batch_indices(&mut indices);
554 assert_eq!(count, 5);
555 assert_eq!(pool.free_count(), 0);
556
557 for idx in &indices[..5] {
559 assert!(*idx < 5);
560 }
561 }
562
563 #[test]
564 fn test_concurrent_alloc_drop() {
565 use std::sync::Arc;
566
567 let pool = Arc::new(PacketPool::new(64).unwrap());
568 let iterations = 1000;
569 let threads = 4;
570
571 let handles: Vec<_> = (0..threads)
572 .map(|_| {
573 let pool = Arc::clone(&pool);
574 std::thread::spawn(move || {
575 for _ in 0..iterations {
576 if let Some(mut pkt) = pool.alloc() {
578 pkt.set_len(4);
579 pkt.as_full_mut_slice()[..4].copy_from_slice(&[0xDE, 0xAD, 0xBE, 0xEF]);
580 assert_eq!(pkt.as_slice(), &[0xDE, 0xAD, 0xBE, 0xEF]);
581 }
583 }
584 })
585 })
586 .collect();
587
588 for h in handles {
589 h.join().unwrap();
590 }
591
592 assert_eq!(pool.free_count(), 64);
594 assert_eq!(pool.allocated_count(), 0);
595 }
596
597 #[test]
598 fn test_concurrent_alloc_into_index_free() {
599 use std::sync::Arc;
600
601 let pool = Arc::new(PacketPool::new(32).unwrap());
602 let iterations = 500;
603 let threads = 4;
604
605 let handles: Vec<_> = (0..threads)
606 .map(|_| {
607 let pool = Arc::clone(&pool);
608 std::thread::spawn(move || {
609 for _ in 0..iterations {
610 if let Some(pkt) = pool.alloc() {
611 let idx = pkt.into_index();
613 assert!(idx < 32);
614 unsafe { pool.free_by_index(idx) };
615 }
616 }
617 })
618 })
619 .collect();
620
621 for h in handles {
622 h.join().unwrap();
623 }
624
625 assert_eq!(pool.free_count(), 32);
626 }
627}