use kapsl_hal::device_mesh::{DeviceMesh, MeshTopology};
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RoutingStrategy {
RoundRobin,
SessionAffinity,
TensorParallel,
PipelineParallel,
}
#[derive(Debug)]
pub struct MeshRouter {
mesh: Option<Arc<DeviceMesh>>,
num_workers: usize,
rr_counter: AtomicUsize,
}
impl MeshRouter {
pub fn new(mesh: Option<Arc<DeviceMesh>>, num_workers: usize) -> Self {
Self {
mesh,
num_workers,
rr_counter: AtomicUsize::new(0),
}
}
pub fn strategy(&self) -> RoutingStrategy {
match &self.mesh {
None => RoutingStrategy::RoundRobin,
Some(mesh) => match &mesh.topology {
MeshTopology::DataParallel => RoutingStrategy::RoundRobin,
MeshTopology::TensorParallel { .. } => RoutingStrategy::TensorParallel,
MeshTopology::PipelineParallel { .. } => RoutingStrategy::PipelineParallel,
MeshTopology::Mixed { .. } => RoutingStrategy::TensorParallel, },
}
}
pub fn route(&self, session_id: &Option<String>, tp_group_hint: Option<usize>) -> usize {
if self.num_workers == 0 {
return 0;
}
if let Some(ref id) = session_id {
return self.route_by_session(id);
}
match self.strategy() {
RoutingStrategy::RoundRobin | RoutingStrategy::SessionAffinity => {
self.route_round_robin()
}
RoutingStrategy::TensorParallel => {
self.route_tensor_parallel(session_id, tp_group_hint)
}
RoutingStrategy::PipelineParallel => self.route_pipeline_parallel(),
}
}
fn route_round_robin(&self) -> usize {
self.rr_counter.fetch_add(1, Ordering::Relaxed) % self.num_workers
}
fn route_by_session(&self, session_id: &str) -> usize {
let mut hasher = DefaultHasher::new();
session_id.hash(&mut hasher);
(hasher.finish() as usize) % self.num_workers
}
fn route_tensor_parallel(
&self,
session_id: &Option<String>,
tp_group_hint: Option<usize>,
) -> usize {
let mesh = match &self.mesh {
Some(m) => m,
None => return self.route_round_robin(),
};
let tp_degree = match &mesh.topology {
MeshTopology::TensorParallel { degree, .. } => *degree,
MeshTopology::Mixed { tp, .. } => *tp,
_ => 1,
};
if tp_degree <= 1 {
return self.route_round_robin();
}
let num_tp_groups = self.num_workers / tp_degree;
if num_tp_groups == 0 {
return self.route_round_robin();
}
let tp_group = if let Some(hint) = tp_group_hint {
hint % num_tp_groups
} else if let Some(ref id) = session_id {
let mut hasher = DefaultHasher::new();
id.hash(&mut hasher);
(hasher.finish() as usize) % num_tp_groups
} else {
(self.rr_counter.fetch_add(1, Ordering::Relaxed) / tp_degree) % num_tp_groups
};
tp_group * tp_degree
}
fn route_pipeline_parallel(&self) -> usize {
0
}
pub fn get_tp_group(&self, worker_idx: usize) -> Option<usize> {
let mesh = self.mesh.as_ref()?;
let tp_degree = match &mesh.topology {
MeshTopology::TensorParallel { degree, .. } => *degree,
MeshTopology::Mixed { tp, .. } => *tp,
_ => return None,
};
if tp_degree <= 1 {
return None;
}
Some(worker_idx / tp_degree)
}
pub fn get_tp_group_workers(&self, tp_group: usize) -> Vec<usize> {
let mesh = match &self.mesh {
Some(m) => m,
None => return vec![],
};
let tp_degree = match &mesh.topology {
MeshTopology::TensorParallel { degree, .. } => *degree,
MeshTopology::Mixed { tp, .. } => *tp,
_ => return vec![],
};
let start = tp_group * tp_degree;
(start..start + tp_degree)
.filter(|&i| i < self.num_workers)
.collect()
}
pub fn mesh_stats(&self) -> Option<MeshRouterStats> {
let mesh = self.mesh.as_ref()?;
Some(MeshRouterStats {
world_size: mesh.world_size,
topology: format!("{:?}", mesh.topology),
num_workers: self.num_workers,
strategy: self.strategy(),
})
}
}
#[derive(Debug, Clone)]
pub struct MeshRouterStats {
pub world_size: usize,
pub topology: String,
pub num_workers: usize,
pub strategy: RoutingStrategy,
}
#[cfg(test)]
mod tests {
use super::*;
use kapsl_hal::device::{Device, DeviceBackend};
fn create_test_devices(count: usize) -> Vec<Device> {
(0..count)
.map(|i| Device {
id: i,
name: format!("GPU_{}", i),
backend: DeviceBackend::Cuda,
memory_mb: 16000,
compute_units: 80,
pci_bus_id: None,
driver_version: None,
compute_capability: None,
utilization_gpu_pct: None,
temperature_c: None,
supports_fp16: true,
supports_int8: true,
cuda_version: Some("12.0".to_string()),
partition_id: None,
})
.collect()
}
#[test]
fn test_round_robin_routing() {
let router = MeshRouter::new(None, 4);
let indices: Vec<usize> = (0..8).map(|_| router.route(&None, None)).collect();
assert_eq!(indices, vec![0, 1, 2, 3, 0, 1, 2, 3]);
}
#[test]
fn test_session_affinity_routing() {
let devices = create_test_devices(4);
let mesh = DeviceMesh::new(devices);
let router = MeshRouter::new(Some(Arc::new(mesh)), 4);
let session = Some("user-123".to_string());
let first = router.route(&session, None);
let second = router.route(&session, None);
assert_eq!(first, second);
}
#[test]
fn test_tensor_parallel_routing() {
let devices = create_test_devices(8);
let topology = MeshTopology::TensorParallel {
degree: 4,
mesh_shape: (2, 4),
};
let mesh = DeviceMesh::with_topology(devices, topology).unwrap();
let router = MeshRouter::new(Some(Arc::new(mesh)), 8);
assert_eq!(router.strategy(), RoutingStrategy::TensorParallel);
let idx = router.route(&None, Some(0));
assert!(idx < 4);
let idx = router.route(&None, Some(1));
assert!((4..8).contains(&idx)); }
#[test]
fn test_pipeline_parallel_routing() {
let devices = create_test_devices(4);
let topology = MeshTopology::PipelineParallel { stages: 4 };
let mesh = DeviceMesh::with_topology(devices, topology).unwrap();
let router = MeshRouter::new(Some(Arc::new(mesh)), 4);
assert_eq!(router.strategy(), RoutingStrategy::PipelineParallel);
for _ in 0..10 {
assert_eq!(router.route(&None, None), 0);
}
}
#[test]
fn test_get_tp_group_workers() {
let devices = create_test_devices(8);
let topology = MeshTopology::TensorParallel {
degree: 4,
mesh_shape: (2, 4),
};
let mesh = DeviceMesh::with_topology(devices, topology).unwrap();
let router = MeshRouter::new(Some(Arc::new(mesh)), 8);
let group0 = router.get_tp_group_workers(0);
let group1 = router.get_tp_group_workers(1);
assert_eq!(group0, vec![0, 1, 2, 3]);
assert_eq!(group1, vec![4, 5, 6, 7]);
}
#[test]
fn test_mesh_stats() {
let devices = create_test_devices(4);
let mesh = DeviceMesh::new(devices);
let router = MeshRouter::new(Some(Arc::new(mesh)), 4);
let stats = router.mesh_stats().unwrap();
assert_eq!(stats.world_size, 4);
assert_eq!(stats.num_workers, 4);
assert_eq!(stats.strategy, RoutingStrategy::RoundRobin);
}
}