use crate::types::FaultRecord;
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum FaultQueueError {
QueueFull,
QueueEmpty,
LockPoisoned,
}
impl std::fmt::Display for FaultQueueError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::QueueFull => write!(f, "Fault queue is full"),
Self::QueueEmpty => write!(f, "Fault queue is empty"),
Self::LockPoisoned => write!(f, "Fault queue lock is poisoned"),
}
}
}
impl std::error::Error for FaultQueueError {}
#[derive(Debug)]
pub struct FaultQueue {
queue: Arc<Mutex<VecDeque<FaultRecord>>>,
capacity: usize,
}
impl FaultQueue {
#[must_use]
pub fn new(capacity: usize) -> Self {
Self {
queue: Arc::new(Mutex::new(VecDeque::with_capacity(capacity))),
capacity,
}
}
pub fn push(&self, fault: FaultRecord) -> Result<(), FaultQueueError> {
let mut queue = self.queue.lock().map_err(|_| FaultQueueError::LockPoisoned)?;
if queue.len() >= self.capacity {
return Err(FaultQueueError::QueueFull);
}
queue.push_back(fault);
Ok(())
}
#[must_use]
pub fn pop(&self) -> Option<FaultRecord> {
self.queue.lock().ok()?.pop_front()
}
#[must_use]
pub fn len(&self) -> usize {
self.queue.lock().map(|q| q.len()).unwrap_or(0)
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub fn is_full(&self) -> bool {
self.len() >= self.capacity
}
#[must_use]
pub const fn capacity(&self) -> usize {
self.capacity
}
pub fn clear(&self) {
if let Ok(mut queue) = self.queue.lock() {
queue.clear();
}
}
#[must_use]
pub fn peek(&self) -> Option<FaultRecord> {
self.queue.lock().ok()?.front().cloned()
}
#[must_use]
pub fn get_all(&self) -> Vec<FaultRecord> {
self.queue.lock().map(|q| q.iter().cloned().collect()).unwrap_or_default()
}
}
impl Clone for FaultQueue {
fn clone(&self) -> Self {
let queue = self.queue.lock().unwrap();
Self {
queue: Arc::new(Mutex::new(queue.clone())),
capacity: self.capacity,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{AccessType, FaultType, StreamID, IOVA, PASID};
fn create_test_fault(stream_id: u32) -> FaultRecord {
FaultRecord::builder()
.stream_id(StreamID::new(stream_id).unwrap())
.pasid(PASID::new(1).unwrap())
.address(IOVA::new(0x1000_0000).unwrap())
.fault_type(FaultType::TranslationFault)
.access_type(AccessType::Read)
.build()
}
#[test]
fn test_new_queue() {
let queue = FaultQueue::new(100);
assert_eq!(queue.capacity(), 100);
assert_eq!(queue.len(), 0);
assert!(queue.is_empty());
}
#[test]
fn test_push_pop() {
let queue = FaultQueue::new(10);
let fault = create_test_fault(0x100);
assert!(queue.push(fault.clone()).is_ok());
assert_eq!(queue.len(), 1);
let popped = queue.pop();
assert!(popped.is_some());
assert_eq!(popped.unwrap().stream_id(), fault.stream_id());
assert!(queue.is_empty());
}
#[test]
fn test_capacity_limit() {
let queue = FaultQueue::new(2);
assert!(queue.push(create_test_fault(0x100)).is_ok());
assert!(queue.push(create_test_fault(0x200)).is_ok());
assert!(queue.is_full());
assert_eq!(queue.push(create_test_fault(0x300)), Err(FaultQueueError::QueueFull));
}
#[test]
fn test_fifo_order() {
let queue = FaultQueue::new(10);
queue.push(create_test_fault(0x100)).unwrap();
queue.push(create_test_fault(0x200)).unwrap();
queue.push(create_test_fault(0x300)).unwrap();
assert_eq!(queue.pop().unwrap().stream_id().as_u32(), 0x100);
assert_eq!(queue.pop().unwrap().stream_id().as_u32(), 0x200);
assert_eq!(queue.pop().unwrap().stream_id().as_u32(), 0x300);
}
}