use alloc::collections::BTreeMap;
use alloc::vec::Vec;
use lazy_static::lazy_static;
use spin::Mutex;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum IoPriority {
Critical = 0,
High = 1,
Normal = 2,
Low = 3,
Idle = 4,
}
impl IoPriority {
pub fn weight(&self) -> u32 {
match self {
IoPriority::Critical => 1000,
IoPriority::High => 500,
IoPriority::Normal => 100,
IoPriority::Low => 50,
IoPriority::Idle => 10,
}
}
pub fn name(&self) -> &'static str {
match self {
IoPriority::Critical => "critical",
IoPriority::High => "high",
IoPriority::Normal => "normal",
IoPriority::Low => "low",
IoPriority::Idle => "idle",
}
}
}
#[derive(Debug, Clone)]
pub struct TokenBucket {
capacity: u64,
tokens: u64,
refill_rate: u64,
last_refill: u64,
}
impl TokenBucket {
pub fn new(capacity: u64, refill_rate: u64) -> Self {
Self {
capacity,
tokens: capacity, refill_rate,
last_refill: 0,
}
}
pub fn refill(&mut self, current_time: u64) {
let elapsed = current_time.saturating_sub(self.last_refill);
if elapsed > 0 {
let new_tokens = elapsed * self.refill_rate;
self.tokens = (self.tokens + new_tokens).min(self.capacity);
self.last_refill = current_time;
}
}
pub fn consume(&mut self, bytes: u64, current_time: u64) -> Result<(), u64> {
self.refill(current_time);
if self.tokens >= bytes {
self.tokens -= bytes;
Ok(())
} else {
let needed = bytes - self.tokens;
let wait_time = needed.div_ceil(self.refill_rate); Err(wait_time)
}
}
pub fn fill_level(&self) -> f32 {
if self.capacity == 0 {
return 0.0;
}
self.tokens as f32 / self.capacity as f32
}
}
#[derive(Debug, Clone)]
pub struct QosPolicy {
pub dataset_id: u64,
pub priority: IoPriority,
pub read_limit_bps: u64,
pub write_limit_bps: u64,
read_bucket: Option<TokenBucket>,
write_bucket: Option<TokenBucket>,
}
impl QosPolicy {
pub fn new(dataset_id: u64, priority: IoPriority) -> Self {
Self {
dataset_id,
priority,
read_limit_bps: 0, write_limit_bps: 0, read_bucket: None,
write_bucket: None,
}
}
pub fn set_read_limit(&mut self, bytes_per_sec: u64, burst_size: Option<u64>) {
self.read_limit_bps = bytes_per_sec;
if bytes_per_sec > 0 {
let burst = burst_size.unwrap_or(bytes_per_sec * 10);
self.read_bucket = Some(TokenBucket::new(burst, bytes_per_sec));
} else {
self.read_bucket = None;
}
}
pub fn set_write_limit(&mut self, bytes_per_sec: u64, burst_size: Option<u64>) {
self.write_limit_bps = bytes_per_sec;
if bytes_per_sec > 0 {
let burst = burst_size.unwrap_or(bytes_per_sec * 10);
self.write_bucket = Some(TokenBucket::new(burst, bytes_per_sec));
} else {
self.write_bucket = None;
}
}
pub fn check_read(&mut self, bytes: u64, current_time: u64) -> Result<(), u64> {
if let Some(bucket) = &mut self.read_bucket {
bucket.consume(bytes, current_time)
} else {
Ok(()) }
}
pub fn check_write(&mut self, bytes: u64, current_time: u64) -> Result<(), u64> {
if let Some(bucket) = &mut self.write_bucket {
bucket.consume(bytes, current_time)
} else {
Ok(()) }
}
}
#[derive(Debug, Clone, Default)]
pub struct QosStats {
pub bytes_read: u64,
pub bytes_written: u64,
pub throttle_count: u64,
pub throttle_time: u64,
pub ops_completed: u64,
}
lazy_static! {
static ref QOS_MANAGER: Mutex<QosManager> = Mutex::new(QosManager::new());
}
pub struct QosManager {
policies: BTreeMap<u64, QosPolicy>,
stats: BTreeMap<u64, QosStats>,
default_priority: IoPriority,
}
impl Default for QosManager {
fn default() -> Self {
Self::new()
}
}
impl QosManager {
pub fn new() -> Self {
Self {
policies: BTreeMap::new(),
stats: BTreeMap::new(),
default_priority: IoPriority::Normal,
}
}
pub fn set_policy(&mut self, policy: QosPolicy) {
self.policies.insert(policy.dataset_id, policy);
}
pub fn get_policy(&self, dataset_id: u64) -> Option<&QosPolicy> {
self.policies.get(&dataset_id)
}
pub fn get_policy_mut(&mut self, dataset_id: u64) -> Option<&mut QosPolicy> {
self.policies.get_mut(&dataset_id)
}
pub fn set_priority(&mut self, dataset_id: u64, priority: IoPriority) {
self.policies
.entry(dataset_id)
.or_insert_with(|| QosPolicy::new(dataset_id, priority))
.priority = priority;
}
pub fn set_read_limit(&mut self, dataset_id: u64, bytes_per_sec: u64) {
let policy = self
.policies
.entry(dataset_id)
.or_insert_with(|| QosPolicy::new(dataset_id, self.default_priority));
policy.set_read_limit(bytes_per_sec, None);
}
pub fn set_write_limit(&mut self, dataset_id: u64, bytes_per_sec: u64) {
let policy = self
.policies
.entry(dataset_id)
.or_insert_with(|| QosPolicy::new(dataset_id, self.default_priority));
policy.set_write_limit(bytes_per_sec, None);
}
pub fn throttle_read(
&mut self,
dataset_id: u64,
bytes: u64,
current_time: u64,
) -> Result<(), u64> {
if let Some(policy) = self.policies.get_mut(&dataset_id) {
let result = policy.check_read(bytes, current_time);
let stats = self.stats.entry(dataset_id).or_default();
stats.bytes_read += bytes;
if let Err(wait) = result {
stats.throttle_count += 1;
stats.throttle_time += wait;
}
result
} else {
let stats = self.stats.entry(dataset_id).or_default();
stats.bytes_read += bytes;
Ok(())
}
}
pub fn throttle_write(
&mut self,
dataset_id: u64,
bytes: u64,
current_time: u64,
) -> Result<(), u64> {
if let Some(policy) = self.policies.get_mut(&dataset_id) {
let result = policy.check_write(bytes, current_time);
let stats = self.stats.entry(dataset_id).or_default();
stats.bytes_written += bytes;
if let Err(wait) = result {
stats.throttle_count += 1;
stats.throttle_time += wait;
}
result
} else {
let stats = self.stats.entry(dataset_id).or_default();
stats.bytes_written += bytes;
Ok(())
}
}
pub fn record_completion(&mut self, dataset_id: u64) {
self.stats.entry(dataset_id).or_default().ops_completed += 1;
}
pub fn get_stats(&self, dataset_id: u64) -> Option<QosStats> {
self.stats.get(&dataset_id).cloned()
}
pub fn get_priority(&self, dataset_id: u64) -> IoPriority {
self.policies
.get(&dataset_id)
.map(|p| p.priority)
.unwrap_or(self.default_priority)
}
}
pub struct QosEngine;
impl QosEngine {
pub fn set_priority(dataset_id: u64, priority: IoPriority) {
let mut mgr = QOS_MANAGER.lock();
mgr.set_priority(dataset_id, priority);
}
pub fn set_read_limit(dataset_id: u64, bytes_per_sec: u64) {
let mut mgr = QOS_MANAGER.lock();
mgr.set_read_limit(dataset_id, bytes_per_sec);
}
pub fn set_write_limit(dataset_id: u64, bytes_per_sec: u64) {
let mut mgr = QOS_MANAGER.lock();
mgr.set_write_limit(dataset_id, bytes_per_sec);
}
pub fn throttle_read(dataset_id: u64, bytes: u64, current_time: u64) -> Result<(), u64> {
let mut mgr = QOS_MANAGER.lock();
mgr.throttle_read(dataset_id, bytes, current_time)
}
pub fn throttle_write(dataset_id: u64, bytes: u64, current_time: u64) -> Result<(), u64> {
let mut mgr = QOS_MANAGER.lock();
mgr.throttle_write(dataset_id, bytes, current_time)
}
pub fn record_completion(dataset_id: u64) {
let mut mgr = QOS_MANAGER.lock();
mgr.record_completion(dataset_id);
}
pub fn stats(dataset_id: u64) -> Option<QosStats> {
let mgr = QOS_MANAGER.lock();
mgr.get_stats(dataset_id)
}
pub fn get_priority(dataset_id: u64) -> IoPriority {
let mgr = QOS_MANAGER.lock();
mgr.get_priority(dataset_id)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_priority_weights() {
assert!(IoPriority::Critical.weight() > IoPriority::High.weight());
assert!(IoPriority::High.weight() > IoPriority::Normal.weight());
assert!(IoPriority::Normal.weight() > IoPriority::Low.weight());
assert!(IoPriority::Low.weight() > IoPriority::Idle.weight());
}
#[test]
fn test_token_bucket_basic() {
let mut bucket = TokenBucket::new(1000, 100);
assert!(bucket.consume(1000, 0).is_ok());
assert!(bucket.consume(1, 0).is_err());
}
#[test]
fn test_token_bucket_refill() {
let mut bucket = TokenBucket::new(1000, 100);
bucket
.consume(1000, 0)
.expect("test: operation should succeed");
assert!(bucket.consume(500, 5).is_ok());
assert!(bucket.consume(1, 5).is_err());
}
#[test]
fn test_token_bucket_burst() {
let mut bucket = TokenBucket::new(10000, 1000);
assert!(bucket.consume(10000, 0).is_ok());
assert!(bucket.consume(5000, 3).is_err()); }
#[test]
fn test_qos_policy_unlimited() {
let mut policy = QosPolicy::new(1, IoPriority::Normal);
assert!(policy.check_read(1000000, 0).is_ok());
assert!(policy.check_write(1000000, 0).is_ok());
}
#[test]
fn test_qos_policy_limits() {
let mut policy = QosPolicy::new(1, IoPriority::Normal);
policy.set_read_limit(1000, None);
assert!(policy.check_read(10000, 0).is_ok());
assert!(policy.check_read(1, 0).is_err());
}
#[test]
fn test_qos_manager_priority() {
let mut mgr = QosManager::new();
mgr.set_priority(100, IoPriority::High);
assert_eq!(mgr.get_priority(100), IoPriority::High);
assert_eq!(mgr.get_priority(200), IoPriority::Normal);
}
#[test]
fn test_qos_manager_throttling() {
let mut mgr = QosManager::new();
mgr.set_read_limit(100, 1000);
assert!(mgr.throttle_read(100, 10000, 0).is_ok());
assert!(mgr.throttle_read(100, 1000, 0).is_err());
assert!(mgr.throttle_read(100, 1000, 2).is_ok());
}
#[test]
fn test_qos_statistics() {
let mut mgr = QosManager::new();
mgr.set_read_limit(100, 1000);
let _ = mgr.throttle_read(100, 10000, 0); let _ = mgr.throttle_read(100, 1000, 0); mgr.record_completion(100);
let stats = mgr.get_stats(100).expect("test: operation should succeed");
assert_eq!(stats.bytes_read, 11000);
assert_eq!(stats.throttle_count, 1);
assert_eq!(stats.ops_completed, 1);
}
#[test]
fn test_fill_level() {
let mut bucket = TokenBucket::new(1000, 100);
assert!((bucket.fill_level() - 1.0).abs() < 0.01);
bucket
.consume(500, 0)
.expect("test: operation should succeed");
assert!((bucket.fill_level() - 0.5).abs() < 0.01); }
}