use alloc::collections::BTreeMap;
use alloc::vec::Vec;
use lazy_static::lazy_static;
use spin::Mutex;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PathState {
Active,
Degraded,
Failed,
Repairing,
}
impl PathState {
pub fn is_usable(&self) -> bool {
matches!(self, PathState::Active | PathState::Degraded)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LoadBalancePolicy {
RoundRobin,
LeastQueue,
LeastLatency,
ServiceTime,
}
#[derive(Debug, Clone)]
pub struct IoPath {
pub id: u64,
pub target_id: u64,
pub priority: u8,
pub state: PathState,
pub queue_depth: u32,
pub avg_latency_us: u64,
pub io_count: u64,
pub bytes_transferred: u64,
pub error_count: u64,
pub last_health_check: u64,
}
impl IoPath {
pub fn new(id: u64, target_id: u64, priority: u8) -> Self {
Self {
id,
target_id,
priority,
state: PathState::Active,
queue_depth: 0,
avg_latency_us: 0,
io_count: 0,
bytes_transferred: 0,
error_count: 0,
last_health_check: 0,
}
}
pub fn service_time(&self) -> u64 {
self.avg_latency_us + (self.queue_depth as u64 * 100)
}
pub fn record_io(&mut self, bytes: u64, latency_us: u64, success: bool) {
if !success {
self.error_count += 1;
return;
}
self.io_count += 1;
self.bytes_transferred += bytes;
if self.avg_latency_us == 0 {
self.avg_latency_us = latency_us;
} else {
self.avg_latency_us = (self.avg_latency_us * 4 + latency_us) / 5;
}
}
pub fn check_health(&mut self, current_time: u64) {
self.last_health_check = current_time;
let total_ops = self.io_count + self.error_count;
if total_ops > 100 {
let error_rate = (self.error_count * 100) / total_ops;
if error_rate > 10 {
self.state = PathState::Failed;
} else if error_rate > 5 {
self.state = PathState::Degraded;
} else {
self.state = PathState::Active;
}
}
}
}
#[derive(Debug, Clone)]
pub struct PathGroup {
pub target_id: u64,
pub paths: Vec<IoPath>,
pub policy: LoadBalancePolicy,
rr_index: usize,
}
impl PathGroup {
pub fn new(target_id: u64, policy: LoadBalancePolicy) -> Self {
Self {
target_id,
paths: Vec::new(),
policy,
rr_index: 0,
}
}
pub fn add_path(&mut self, path: IoPath) {
self.paths.push(path);
self.paths.sort_by_key(|p| p.priority);
}
pub fn select_path(&mut self) -> Option<&mut IoPath> {
let usable: Vec<usize> = self
.paths
.iter()
.enumerate()
.filter(|(_, p)| p.state.is_usable())
.map(|(i, _)| i)
.collect();
if usable.is_empty() {
return None;
}
let idx = match self.policy {
LoadBalancePolicy::RoundRobin => {
let idx = usable[self.rr_index % usable.len()];
self.rr_index += 1;
idx
}
LoadBalancePolicy::LeastQueue => {
debug_assert!(!usable.is_empty(), "usable verified non-empty above");
*usable
.iter()
.min_by_key(|&&i| self.paths[i].queue_depth)
.unwrap_or(&usable[0])
}
LoadBalancePolicy::LeastLatency => {
debug_assert!(!usable.is_empty(), "usable verified non-empty above");
*usable
.iter()
.min_by_key(|&&i| self.paths[i].avg_latency_us)
.unwrap_or(&usable[0])
}
LoadBalancePolicy::ServiceTime => {
debug_assert!(!usable.is_empty(), "usable verified non-empty above");
*usable
.iter()
.min_by_key(|&&i| self.paths[i].service_time())
.unwrap_or(&usable[0])
}
};
Some(&mut self.paths[idx])
}
pub fn active_count(&self) -> usize {
self.paths
.iter()
.filter(|p| p.state == PathState::Active)
.count()
}
pub fn health_check(&mut self, current_time: u64) {
for path in &mut self.paths {
path.check_health(current_time);
}
}
}
#[derive(Debug, Clone, Default)]
pub struct MpioStats {
pub total_ios: u64,
pub failovers: u64,
pub path_failures: u64,
pub path_repairs: u64,
pub total_bytes: u64,
}
lazy_static! {
static ref MPIO_MANAGER: Mutex<MpioManager> = Mutex::new(MpioManager::new());
}
pub struct MpioManager {
groups: BTreeMap<u64, PathGroup>,
stats: MpioStats,
health_check_interval: u64,
last_health_check: u64,
}
impl Default for MpioManager {
fn default() -> Self {
Self::new()
}
}
impl MpioManager {
pub fn new() -> Self {
Self {
groups: BTreeMap::new(),
stats: MpioStats::default(),
health_check_interval: 10_000, last_health_check: 0,
}
}
pub fn create_group(&mut self, target_id: u64, policy: LoadBalancePolicy) {
let group = PathGroup::new(target_id, policy);
self.groups.insert(target_id, group);
crate::lcpfs_println!(
"[ MPIO ] Created path group for target {} (policy: {:?})",
target_id,
policy
);
}
pub fn add_path(&mut self, target_id: u64, path: IoPath) -> Result<(), &'static str> {
let group = self.groups.get_mut(&target_id).ok_or("Target not found")?;
crate::lcpfs_println!(
"[ MPIO ] Added path {} to target {} (priority: {})",
path.id,
target_id,
path.priority
);
group.add_path(path);
Ok(())
}
pub fn select_path(&mut self, target_id: u64) -> Result<u64, &'static str> {
let group = self.groups.get_mut(&target_id).ok_or("Target not found")?;
let path = group.select_path().ok_or("No usable paths")?;
let path_id = path.id;
path.queue_depth += 1;
Ok(path_id)
}
pub fn complete_io(
&mut self,
target_id: u64,
path_id: u64,
bytes: u64,
latency_us: u64,
success: bool,
) -> Result<(), &'static str> {
let group = self.groups.get_mut(&target_id).ok_or("Target not found")?;
let path = group
.paths
.iter_mut()
.find(|p| p.id == path_id)
.ok_or("Path not found")?;
path.queue_depth = path.queue_depth.saturating_sub(1);
path.record_io(bytes, latency_us, success);
self.stats.total_ios += 1;
self.stats.total_bytes += bytes;
if !success {
self.handle_path_failure(target_id, path_id)?;
}
Ok(())
}
fn handle_path_failure(&mut self, target_id: u64, path_id: u64) -> Result<(), &'static str> {
let group = self.groups.get_mut(&target_id).ok_or("Target not found")?;
if let Some(path) = group.paths.iter_mut().find(|p| p.id == path_id) {
if path.state != PathState::Failed {
crate::lcpfs_println!(
"[ MPIO ] Path {} to target {} FAILED - triggering failover",
path_id,
target_id
);
path.state = PathState::Failed;
self.stats.path_failures += 1;
self.stats.failovers += 1;
}
}
Ok(())
}
pub fn health_check(&mut self, current_time: u64) {
if current_time < self.last_health_check + self.health_check_interval {
return;
}
self.last_health_check = current_time;
for group in self.groups.values_mut() {
group.health_check(current_time);
}
}
pub fn group_info(&self, target_id: u64) -> Option<(usize, usize)> {
self.groups.get(&target_id).map(|g| {
let total = g.paths.len();
let active = g.active_count();
(total, active)
})
}
pub fn stats(&self) -> MpioStats {
self.stats.clone()
}
}
pub struct Mpio;
impl Mpio {
pub fn create_group(target_id: u64, policy: LoadBalancePolicy) {
let mut mgr = MPIO_MANAGER.lock();
mgr.create_group(target_id, policy);
}
pub fn add_path(target_id: u64, path: IoPath) -> Result<(), &'static str> {
let mut mgr = MPIO_MANAGER.lock();
mgr.add_path(target_id, path)
}
pub fn select_path(target_id: u64) -> Result<u64, &'static str> {
let mut mgr = MPIO_MANAGER.lock();
mgr.select_path(target_id)
}
pub fn complete_io(
target_id: u64,
path_id: u64,
bytes: u64,
latency_us: u64,
success: bool,
) -> Result<(), &'static str> {
let mut mgr = MPIO_MANAGER.lock();
mgr.complete_io(target_id, path_id, bytes, latency_us, success)
}
pub fn health_check(current_time: u64) {
let mut mgr = MPIO_MANAGER.lock();
mgr.health_check(current_time);
}
pub fn stats() -> MpioStats {
let mgr = MPIO_MANAGER.lock();
mgr.stats()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_path_state() {
assert!(PathState::Active.is_usable());
assert!(PathState::Degraded.is_usable());
assert!(!PathState::Failed.is_usable());
}
#[test]
fn test_path_creation() {
let path = IoPath::new(1, 100, 0);
assert_eq!(path.id, 1);
assert_eq!(path.target_id, 100);
assert_eq!(path.state, PathState::Active);
assert_eq!(path.queue_depth, 0);
}
#[test]
fn test_path_io_recording() {
let mut path = IoPath::new(1, 100, 0);
path.record_io(4096, 100, true);
assert_eq!(path.io_count, 1);
assert_eq!(path.bytes_transferred, 4096);
assert_eq!(path.avg_latency_us, 100);
path.record_io(4096, 200, true);
assert_eq!(path.io_count, 2);
assert_eq!(path.avg_latency_us, 120);
}
#[test]
fn test_path_health_check() {
let mut path = IoPath::new(1, 100, 0);
for _ in 0..94 {
path.record_io(4096, 100, true);
}
for _ in 0..16 {
path.record_io(4096, 100, false);
}
path.check_health(1000);
assert_eq!(path.state, PathState::Failed); }
#[test]
fn test_service_time() {
let mut path = IoPath::new(1, 100, 0);
path.avg_latency_us = 50;
path.queue_depth = 10;
assert_eq!(path.service_time(), 1050);
}
#[test]
fn test_path_group_creation() {
let mut group = PathGroup::new(100, LoadBalancePolicy::RoundRobin);
let path1 = IoPath::new(1, 100, 0);
let path2 = IoPath::new(2, 100, 1);
group.add_path(path1);
group.add_path(path2);
assert_eq!(group.paths.len(), 2);
assert_eq!(group.active_count(), 2);
}
#[test]
fn test_round_robin_selection() {
let mut group = PathGroup::new(100, LoadBalancePolicy::RoundRobin);
group.add_path(IoPath::new(1, 100, 0));
group.add_path(IoPath::new(2, 100, 0));
let path1 = group.select_path().expect("test: operation should succeed");
let id1 = path1.id;
let path2 = group.select_path().expect("test: operation should succeed");
let id2 = path2.id;
assert_ne!(id1, id2); }
#[test]
fn test_least_queue_selection() {
let mut group = PathGroup::new(100, LoadBalancePolicy::LeastQueue);
let mut path1 = IoPath::new(1, 100, 0);
let mut path2 = IoPath::new(2, 100, 0);
path1.queue_depth = 10;
path2.queue_depth = 5;
group.add_path(path1);
group.add_path(path2);
let selected = group.select_path().expect("test: operation should succeed");
assert_eq!(selected.id, 2); }
#[test]
fn test_manager_basic() {
let mut mgr = MpioManager::new();
mgr.create_group(100, LoadBalancePolicy::RoundRobin);
let path = IoPath::new(1, 100, 0);
mgr.add_path(100, path)
.expect("test: operation should succeed");
let (total, active) = mgr.group_info(100).expect("test: operation should succeed");
assert_eq!(total, 1);
assert_eq!(active, 1);
}
#[test]
fn test_path_selection_and_completion() {
let mut mgr = MpioManager::new();
mgr.create_group(100, LoadBalancePolicy::RoundRobin);
mgr.add_path(100, IoPath::new(1, 100, 0))
.expect("test: operation should succeed");
let path_id = mgr
.select_path(100)
.expect("test: operation should succeed");
assert_eq!(path_id, 1);
mgr.complete_io(100, path_id, 4096, 100, true)
.expect("test: operation should succeed");
let stats = mgr.stats();
assert_eq!(stats.total_ios, 1);
assert_eq!(stats.total_bytes, 4096);
}
#[test]
fn test_failover() {
let mut mgr = MpioManager::new();
mgr.create_group(100, LoadBalancePolicy::RoundRobin);
mgr.add_path(100, IoPath::new(1, 100, 0))
.expect("test: operation should succeed");
mgr.add_path(100, IoPath::new(2, 100, 0))
.expect("test: operation should succeed");
let path_id = mgr
.select_path(100)
.expect("test: operation should succeed");
mgr.complete_io(100, path_id, 4096, 100, false)
.expect("test: operation should succeed");
let stats = mgr.stats();
assert_eq!(stats.failovers, 1);
assert_eq!(stats.path_failures, 1);
}
#[test]
fn test_no_usable_paths() {
let mut mgr = MpioManager::new();
mgr.create_group(100, LoadBalancePolicy::RoundRobin);
let mut path = IoPath::new(1, 100, 0);
path.state = PathState::Failed;
mgr.add_path(100, path)
.expect("test: operation should succeed");
let result = mgr.select_path(100);
assert!(result.is_err()); }
#[test]
fn test_priority_sorting() {
let mut group = PathGroup::new(100, LoadBalancePolicy::RoundRobin);
group.add_path(IoPath::new(1, 100, 2));
group.add_path(IoPath::new(2, 100, 0));
group.add_path(IoPath::new(3, 100, 1));
assert_eq!(group.paths[0].priority, 0);
assert_eq!(group.paths[1].priority, 1);
assert_eq!(group.paths[2].priority, 2);
}
#[test]
fn test_least_latency_policy() {
let mut group = PathGroup::new(100, LoadBalancePolicy::LeastLatency);
let mut path1 = IoPath::new(1, 100, 0);
let mut path2 = IoPath::new(2, 100, 0);
path1.avg_latency_us = 200;
path2.avg_latency_us = 100;
group.add_path(path1);
group.add_path(path2);
let selected = group.select_path().expect("test: operation should succeed");
assert_eq!(selected.id, 2); }
#[test]
fn test_service_time_policy() {
let mut group = PathGroup::new(100, LoadBalancePolicy::ServiceTime);
let mut path1 = IoPath::new(1, 100, 0);
let mut path2 = IoPath::new(2, 100, 0);
path1.avg_latency_us = 50;
path1.queue_depth = 20;
path2.avg_latency_us = 100;
path2.queue_depth = 5;
group.add_path(path1);
group.add_path(path2);
let selected = group.select_path().expect("test: operation should succeed");
assert_eq!(selected.id, 2); }
}