use crate::frame::Frame;
use tokio::sync::Mutex;
use std::collections::{BinaryHeap, HashMap, VecDeque};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum ClassId {
Ctrl = 0,
Data = 1,
}
#[derive(Debug, Clone)]
pub struct WriteRequest {
pub class: ClassId,
pub frame: Frame,
pub seq: u32,
}
#[derive(Debug)]
pub struct WriteResult {
pub n: usize,
pub err: Option<String>,
}
pub struct ShaperQueue {
streams: HashMap<u32, BinaryHeap<WriteRequest>>,
rr_list: VecDeque<u32>,
next: usize,
count: usize,
mu: Mutex<()>,
}
impl ShaperQueue {
pub fn new() -> Self {
Self {
streams: HashMap::new(),
rr_list: VecDeque::new(),
next: 0,
count: 0,
mu: Mutex::new(()),
}
}
pub async fn push(&mut self, req: WriteRequest) {
let _guard = self.mu.lock().await;
let sid = req.frame.sid;
if !self.streams.contains_key(&sid) {
self.streams.insert(sid, BinaryHeap::new());
self.rr_list.push_back(sid);
if self.rr_list.len() == 1 {
self.next = 0;
}
}
if let Some(heap) = self.streams.get_mut(&sid) {
heap.push(req);
self.count += 1;
}
}
pub async fn pop(&mut self) -> Option<WriteRequest> {
let _guard = self.mu.lock().await;
if self.rr_list.is_empty() || self.count == 0 {
return None;
}
let start = self.next;
let mut current = start;
loop {
let sid = self.rr_list[current];
if let Some(heap) = self.streams.get_mut(&sid) {
if !heap.is_empty() {
let top_req = heap.peek().unwrap();
if top_req.class == ClassId::Ctrl {
let req = heap.pop().unwrap();
self.count -= 1;
self.next = (current + 1) % self.rr_list.len();
if heap.is_empty() {
self.streams.remove(&sid);
self.rr_list.remove(current);
if self.rr_list.is_empty() {
self.next = 0;
} else {
if self.next >= self.rr_list.len() {
self.next = 0;
}
}
}
return Some(req);
}
}
}
current = (current + 1) % self.rr_list.len();
if current == start {
break;
}
}
let start = self.next;
let mut current = start;
loop {
let sid = self.rr_list[current];
if let Some(heap) = self.streams.get_mut(&sid) {
if let Some(req) = heap.pop() {
self.count -= 1;
self.next = (current + 1) % self.rr_list.len();
if heap.is_empty() {
self.streams.remove(&sid);
self.rr_list.remove(current);
if self.rr_list.is_empty() {
self.next = 0;
} else {
if self.next >= self.rr_list.len() {
self.next = 0;
}
}
}
return Some(req);
}
}
current = (current + 1) % self.rr_list.len();
if current == start {
break;
}
}
None
}
pub async fn is_empty(&self) -> bool {
let _guard = self.mu.lock().await;
self.count == 0
}
pub async fn len(&self) -> usize {
let _guard = self.mu.lock().await;
self.count
}
}
impl Default for ShaperQueue {
fn default() -> Self {
Self::new()
}
}
impl Ord for WriteRequest {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
match other.class.cmp(&self.class) { std::cmp::Ordering::Equal => {
other.seq.cmp(&self.seq)
}
ord => ord,
}
}
}
impl PartialOrd for WriteRequest {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq for WriteRequest {
fn eq(&self, other: &Self) -> bool {
self.class == other.class && self.seq == other.seq
}
}
impl Eq for WriteRequest {}