use core::fmt::Debug;
use core::ops::{Deref, DerefMut};
use core::ptr::NonNull;
use core::{cell::RefCell, marker::PhantomData};
use crate::hal::IxgbeHal;
use crate::{IxgbeError, IxgbeResult};
use alloc::sync::Arc;
use alloc::vec::Vec;
use alloc::{fmt, slice};
pub type PhysAddr = usize;
pub type VirtAddr = usize;
const HUGE_PAGE_BITS: u32 = 21;
const HUGE_PAGE_SIZE: usize = 1 << HUGE_PAGE_BITS;
pub const PACKET_HEADROOM: usize = 32;
pub struct MemPool {
base_addr: *mut u8,
num_entries: usize,
entry_size: usize,
phys_addr: Vec<usize>,
pub(crate) free_stack: RefCell<Vec<usize>>,
}
impl MemPool {
pub fn allocate<H: IxgbeHal>(entries: usize, size: usize) -> IxgbeResult<Arc<MemPool>> {
let entry_size = match size {
0 => 2048,
x => x,
};
if HUGE_PAGE_SIZE % entry_size != 0 {
error!("entry size must be a divisor of the page size");
return Err(IxgbeError::PageNotAligned);
}
let dma = Dma::<u8, H>::allocate(entries * entry_size, false)?;
let mut phys_addr = Vec::with_capacity(entries);
for i in 0..entries {
phys_addr.push(unsafe {
H::mmio_virt_to_phys(
NonNull::new(dma.virt.add(i * entry_size)).unwrap(),
entry_size,
)
})
}
let pool = MemPool {
base_addr: dma.virt,
num_entries: entries,
entry_size,
phys_addr,
free_stack: RefCell::new(Vec::with_capacity(entries)),
};
let pool = Arc::new(pool);
pool.free_stack.borrow_mut().extend(0..entries);
Ok(pool)
}
pub(crate) fn alloc_buf(&self) -> Option<usize> {
self.free_stack.borrow_mut().pop()
}
pub(crate) fn free_buf(&self, id: usize) {
assert!(
id < self.num_entries,
"buffer outside of memory pool, id: {id}"
);
let mut free_stack = self.free_stack.borrow_mut();
if free_stack.contains(&id) {
panic!("free buf: buffer already free");
}
free_stack.push(id);
}
pub fn entry_size(&self) -> usize {
self.entry_size
}
pub(crate) fn get_virt_addr(&self, id: usize) -> *mut u8 {
assert!(
id < self.num_entries,
"buffer outside of memory pool, id: {id}"
);
unsafe { self.base_addr.add(id * self.entry_size) }
}
pub fn get_phys_addr(&self, id: usize) -> usize {
self.phys_addr[id]
}
}
pub struct Dma<T, H: IxgbeHal> {
pub virt: *mut T,
pub phys: usize,
_marker: PhantomData<H>,
}
impl<T, H: IxgbeHal> Dma<T, H> {
pub fn allocate(size: usize, _require_contiguous: bool) -> IxgbeResult<Dma<T, H>> {
let (pa, va) = H::dma_alloc(size);
info!(
"allocated DMA memory @pa: {:#x}, va: {:#x}, size: {:#x}",
pa,
va.as_ptr() as usize,
size
);
Ok(Dma::<T, H> {
virt: va.as_ptr() as *mut T,
phys: pa,
_marker: PhantomData,
})
}
}
pub struct Packet {
pub(crate) addr_virt: NonNull<u8>,
pub(crate) addr_phys: usize,
pub(crate) len: usize,
pub(crate) pool: Arc<MemPool>,
pub(crate) pool_entry: usize,
}
impl Clone for Packet {
fn clone(&self) -> Self {
let mut p = alloc_pkt(&self.pool, self.len).expect("no buffer available");
p.clone_from_slice(self);
p
}
}
impl Deref for Packet {
type Target = [u8];
fn deref(&self) -> &[u8] {
unsafe { slice::from_raw_parts(self.addr_virt.as_ptr(), self.len) }
}
}
impl DerefMut for Packet {
fn deref_mut(&mut self) -> &mut [u8] {
unsafe { slice::from_raw_parts_mut(self.addr_virt.as_ptr(), self.len) }
}
}
impl Debug for Packet {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(**self).fmt(f)
}
}
impl Drop for Packet {
fn drop(&mut self) {
self.pool.free_buf(self.pool_entry);
}
}
impl Packet {
pub(crate) unsafe fn new(
addr_virt: *mut u8,
addr_phys: usize,
len: usize,
pool: Arc<MemPool>,
pool_entry: usize,
) -> Packet {
Packet {
addr_virt: NonNull::new_unchecked(addr_virt),
addr_phys,
len,
pool,
pool_entry,
}
}
pub fn get_virt_addr(&self) -> *mut u8 {
self.addr_virt.as_ptr()
}
pub fn get_phys_addr(&self) -> usize {
self.addr_phys
}
pub fn as_bytes(&self) -> &[u8] {
unsafe { slice::from_raw_parts(self.addr_virt.as_ptr(), self.len) }
}
pub fn as_mut_bytes(&mut self) -> &mut [u8] {
unsafe { slice::from_raw_parts_mut(self.addr_virt.as_ptr(), self.len) }
}
pub fn headroom_mut(&mut self, len: usize) -> &mut [u8] {
assert!(len <= PACKET_HEADROOM);
unsafe { slice::from_raw_parts_mut(self.addr_virt.as_ptr().sub(len), len) }
}
#[cfg(target_arch = "x86_64")]
#[inline(always)]
pub(crate) fn prefrtch(&self, hint: Prefetch) {
if core_detect::is_x86_feature_detected!("sse") {
let addr = self.get_virt_addr() as *const _;
unsafe {
use core::arch::x86_64;
match hint {
Prefetch::Time0 => x86_64::_mm_prefetch(addr, x86_64::_MM_HINT_T0),
Prefetch::Time1 => x86_64::_mm_prefetch(addr, x86_64::_MM_HINT_T1),
Prefetch::Time2 => x86_64::_mm_prefetch(addr, x86_64::_MM_HINT_T2),
Prefetch::NonTemporal => x86_64::_mm_prefetch(addr, x86_64::_MM_HINT_NTA),
}
}
}
}
}
pub fn alloc_pkt(pool: &Arc<MemPool>, size: usize) -> Option<Packet> {
if size > pool.entry_size - PACKET_HEADROOM {
return None;
}
pool.alloc_buf().map(|id| unsafe {
Packet::new(
pool.get_virt_addr(id).add(PACKET_HEADROOM),
pool.get_phys_addr(id) + PACKET_HEADROOM,
size,
Arc::clone(pool),
id,
)
})
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)]
pub enum Prefetch {
Time0,
Time1,
Time2,
NonTemporal,
}
unsafe impl Sync for MemPool {}
unsafe impl Send for MemPool {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_constants() {
assert_eq!(PACKET_HEADROOM, 32);
assert_eq!(HUGE_PAGE_BITS, 21);
assert_eq!(HUGE_PAGE_SIZE, 1 << 21);
assert_eq!(HUGE_PAGE_SIZE, 0x200000);
}
#[test]
fn test_prefetch_ord() {
assert!(Prefetch::Time0 < Prefetch::Time1);
assert!(Prefetch::Time1 < Prefetch::Time2);
assert!(Prefetch::Time2 < Prefetch::NonTemporal);
}
#[test]
fn test_prefetch_eq() {
assert_eq!(Prefetch::Time0, Prefetch::Time0);
assert_ne!(Prefetch::Time0, Prefetch::Time1);
}
#[test]
fn test_prefetch_copy() {
let hint = Prefetch::Time0;
let copied = hint;
assert_eq!(hint, copied);
}
#[test]
fn test_mempool_allocation_alignment() {
let valid_sizes = [
2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288,
];
for &size in &valid_sizes {
assert_eq!(
HUGE_PAGE_SIZE % size,
0,
"Size {} should divide page size",
size
);
}
}
#[test]
fn test_mempool_invalid_alignment() {
let invalid_sizes = [100, 1536, 3000, 5000];
for &size in &invalid_sizes {
assert_ne!(
HUGE_PAGE_SIZE % size,
0,
"Size {} should not divide page size evenly",
size
);
}
}
#[test]
fn test_mempool_entry_size_default() {
let size = 0;
let entry_size = match size {
0 => 2048,
x => x,
};
assert_eq!(entry_size, 2048);
}
}