use std::{
num::NonZeroUsize,
ptr::NonNull,
sync::{Arc, Mutex},
};
use dma_api::*;
#[derive(Debug, Clone, PartialEq)]
pub enum DmaOperation {
Flush { addr: usize, size: usize },
Invalidate { addr: usize, size: usize },
MapSingle {
virt_addr: usize,
size: usize,
direction: DmaDirection,
},
UnmapSingle { size: usize },
AllocCoherent { size: usize, align: usize },
DeallocCoherent { size: usize },
}
pub struct TrackingDmaOp {
operations: Arc<Mutex<Vec<DmaOperation>>>,
base_addr: usize,
}
impl TrackingDmaOp {
pub fn new(base_addr: usize) -> Self {
Self {
operations: Arc::new(Mutex::new(Vec::new())),
base_addr,
}
}
pub fn get_operations(&self) -> Vec<DmaOperation> {
self.operations.lock().unwrap().clone()
}
pub fn clear(&self) {
self.operations.lock().unwrap().clear();
}
pub fn count_flush(&self) -> usize {
self.operations
.lock()
.unwrap()
.iter()
.filter(|op| matches!(op, DmaOperation::Flush { .. }))
.count()
}
pub fn count_invalidate(&self) -> usize {
self.operations
.lock()
.unwrap()
.iter()
.filter(|op| matches!(op, DmaOperation::Invalidate { .. }))
.count()
}
pub fn find_flush_at(&self, offset: usize, size: usize) -> bool {
let expected_addr = self.base_addr + offset;
self.operations.lock().unwrap().iter().any(|op| {
matches!(op, DmaOperation::Flush { addr, size: s }
if *addr == expected_addr && *s == size)
})
}
pub fn find_inv_at(&self, offset: usize, size: usize) -> bool {
let expected_addr = self.base_addr + offset;
self.operations.lock().unwrap().iter().any(|op| {
matches!(op, DmaOperation::Invalidate { addr, size: s }
if *addr == expected_addr && *s == size)
})
}
pub fn last_flush(&self) -> Option<(usize, usize)> {
self.operations.lock().unwrap().iter().rev().find_map(|op| {
if let DmaOperation::Flush { addr, size } = op {
Some((*addr, *size))
} else {
None
}
})
}
pub fn last_invalidate(&self) -> Option<(usize, usize)> {
self.operations.lock().unwrap().iter().rev().find_map(|op| {
if let DmaOperation::Invalidate { addr, size } = op {
Some((*addr, *size))
} else {
None
}
})
}
}
impl DmaOp for TrackingDmaOp {
fn page_size(&self) -> usize {
0x1000
}
unsafe fn map_single(
&self,
_dma_mask: u64,
addr: NonNull<u8>,
size: NonZeroUsize,
_align: usize,
direction: DmaDirection,
) -> Result<DmaMapHandle, DmaError> {
self.operations
.lock()
.unwrap()
.push(DmaOperation::MapSingle {
virt_addr: addr.as_ptr() as usize,
size: size.get(),
direction,
});
let layout = core::alloc::Layout::from_size_align(size.get(), 8)?;
Ok(unsafe { DmaMapHandle::new(addr, (addr.as_ptr() as u64).into(), layout, None) })
}
unsafe fn unmap_single(&self, handle: DmaMapHandle) {
self.operations
.lock()
.unwrap()
.push(DmaOperation::UnmapSingle {
size: handle.size(),
});
}
fn flush(&self, addr: NonNull<u8>, size: usize) {
self.operations.lock().unwrap().push(DmaOperation::Flush {
addr: addr.as_ptr() as usize,
size,
});
}
fn invalidate(&self, addr: NonNull<u8>, size: usize) {
self.operations
.lock()
.unwrap()
.push(DmaOperation::Invalidate {
addr: addr.as_ptr() as usize,
size,
});
}
unsafe fn alloc_coherent(
&self,
_dma_mask: u64,
layout: core::alloc::Layout,
) -> Option<DmaHandle> {
self.operations
.lock()
.unwrap()
.push(DmaOperation::AllocCoherent {
size: layout.size(),
align: layout.align(),
});
let ptr = unsafe { std::alloc::alloc_zeroed(layout) };
if ptr.is_null() {
return None;
}
Some(unsafe { DmaHandle::new(NonNull::new(ptr).unwrap(), (ptr as u64).into(), layout) })
}
unsafe fn dealloc_coherent(&self, handle: DmaHandle) {
self.operations
.lock()
.unwrap()
.push(DmaOperation::DeallocCoherent {
size: handle.size(),
});
unsafe { std::alloc::dealloc(handle.as_ptr().as_ptr(), handle.layout()) };
}
}