use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use crate::error::{Result, RingKernelError};
use crate::k2k::K2KMessage;
use crate::runtime::{Backend, KernelId, LaunchOptions};
#[derive(Debug, Clone)]
pub struct MultiGpuConfig {
pub load_balancing: LoadBalancingStrategy,
pub auto_select_device: bool,
pub max_kernels_per_device: usize,
pub enable_p2p: bool,
pub preferred_devices: Vec<usize>,
}
impl Default for MultiGpuConfig {
fn default() -> Self {
Self {
load_balancing: LoadBalancingStrategy::LeastLoaded,
auto_select_device: true,
max_kernels_per_device: 64,
enable_p2p: true,
preferred_devices: vec![],
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LoadBalancingStrategy {
FirstAvailable,
LeastLoaded,
RoundRobin,
MemoryBased,
ComputeCapability,
Custom,
}
#[derive(Debug, Clone)]
pub struct DeviceInfo {
pub index: usize,
pub name: String,
pub backend: Backend,
pub total_memory: u64,
pub available_memory: u64,
pub compute_capability: Option<(u32, u32)>,
pub max_threads_per_block: u32,
pub multiprocessor_count: u32,
pub p2p_capable: bool,
}
impl DeviceInfo {
pub fn new(index: usize, name: String, backend: Backend) -> Self {
Self {
index,
name,
backend,
total_memory: 0,
available_memory: 0,
compute_capability: None,
max_threads_per_block: 1024,
multiprocessor_count: 1,
p2p_capable: false,
}
}
pub fn memory_utilization(&self) -> f64 {
if self.total_memory == 0 {
0.0
} else {
1.0 - (self.available_memory as f64 / self.total_memory as f64)
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum InterconnectType {
None,
Pcie,
NvLink,
NvSwitch,
InfinityFabric,
XeLink,
SameDevice,
}
impl InterconnectType {
pub fn estimated_bandwidth_gbps(&self) -> f64 {
match self {
InterconnectType::None => 16.0, InterconnectType::Pcie => 32.0, InterconnectType::NvLink => 300.0, InterconnectType::NvSwitch => 600.0, InterconnectType::InfinityFabric => 200.0, InterconnectType::XeLink => 100.0, InterconnectType::SameDevice => 2000.0, }
}
pub fn estimated_latency_us(&self) -> f64 {
match self {
InterconnectType::None => 10.0, InterconnectType::Pcie => 5.0, InterconnectType::NvLink => 1.0, InterconnectType::NvSwitch => 2.0, InterconnectType::InfinityFabric => 1.5,
InterconnectType::XeLink => 2.0,
InterconnectType::SameDevice => 0.0,
}
}
pub fn supports_p2p(&self) -> bool {
!matches!(self, InterconnectType::None)
}
}
#[derive(Debug, Clone)]
pub struct GpuConnection {
pub source: usize,
pub destination: usize,
pub interconnect: InterconnectType,
pub bandwidth_gbps: f64,
pub latency_us: f64,
pub bidirectional: bool,
pub hops: u32,
}
impl GpuConnection {
pub fn new(source: usize, destination: usize, interconnect: InterconnectType) -> Self {
Self {
source,
destination,
interconnect,
bandwidth_gbps: interconnect.estimated_bandwidth_gbps(),
latency_us: interconnect.estimated_latency_us(),
bidirectional: true,
hops: if source == destination { 0 } else { 1 },
}
}
pub fn with_bandwidth(mut self, gbps: f64) -> Self {
self.bandwidth_gbps = gbps;
self
}
pub fn with_latency(mut self, us: f64) -> Self {
self.latency_us = us;
self
}
pub fn with_hops(mut self, hops: u32) -> Self {
self.hops = hops;
self
}
}
#[derive(Debug, Clone)]
pub struct GpuTopology {
pub device_count: usize,
connections: Vec<Vec<Option<GpuConnection>>>,
pub numa_nodes: Vec<Option<u32>>,
pub probed: bool,
pub last_updated: Instant,
}
impl GpuTopology {
pub fn new(device_count: usize) -> Self {
let mut connections = vec![vec![None; device_count]; device_count];
for (i, row) in connections.iter_mut().enumerate().take(device_count) {
row[i] = Some(GpuConnection::new(i, i, InterconnectType::SameDevice));
}
Self {
device_count,
connections,
numa_nodes: vec![None; device_count],
probed: false,
last_updated: Instant::now(),
}
}
pub fn set_connection(&mut self, connection: GpuConnection) {
let src = connection.source;
let dst = connection.destination;
if src < self.device_count && dst < self.device_count {
self.connections[src][dst] = Some(connection.clone());
if connection.bidirectional && src != dst {
let reverse = GpuConnection {
source: dst,
destination: src,
..connection
};
self.connections[dst][src] = Some(reverse);
}
}
}
pub fn get_connection(&self, source: usize, destination: usize) -> Option<&GpuConnection> {
self.connections
.get(source)
.and_then(|row| row.get(destination))
.and_then(|c| c.as_ref())
}
pub fn best_path(&self, source: usize, destination: usize) -> Vec<usize> {
if source == destination {
return vec![source];
}
if let Some(conn) = self.get_connection(source, destination) {
if conn.interconnect != InterconnectType::None {
return vec![source, destination];
}
}
let mut best_path = vec![source, destination]; let mut best_bandwidth = 0.0;
for intermediate in 0..self.device_count {
if intermediate == source || intermediate == destination {
continue;
}
if let (Some(c1), Some(c2)) = (
self.get_connection(source, intermediate),
self.get_connection(intermediate, destination),
) {
let path_bandwidth = c1.bandwidth_gbps.min(c2.bandwidth_gbps);
if path_bandwidth > best_bandwidth {
best_bandwidth = path_bandwidth;
best_path = vec![source, intermediate, destination];
}
}
}
best_path
}
pub fn neighbors(&self, device: usize) -> Vec<usize> {
if device >= self.device_count {
return vec![];
}
self.connections[device]
.iter()
.enumerate()
.filter_map(|(i, conn)| {
if i != device
&& conn
.as_ref()
.map(|c| c.interconnect.supports_p2p())
.unwrap_or(false)
{
Some(i)
} else {
None
}
})
.collect()
}
pub fn bisection_bandwidth_gbps(&self) -> f64 {
let half = self.device_count / 2;
if half == 0 {
return 0.0;
}
let mut total = 0.0;
for src in 0..half {
for dst in half..self.device_count {
if let Some(conn) = self.get_connection(src, dst) {
total += conn.bandwidth_gbps;
}
}
}
total
}
pub fn is_fully_connected(&self) -> bool {
for src in 0..self.device_count {
for dst in 0..self.device_count {
if src != dst {
if let Some(conn) = self.get_connection(src, dst) {
if !conn.interconnect.supports_p2p() {
return false;
}
} else {
return false;
}
}
}
}
true
}
pub fn numa_neighbors(&self, device: usize) -> Vec<usize> {
let target_numa = self.numa_nodes.get(device).copied().flatten();
if target_numa.is_none() {
return vec![];
}
self.numa_nodes
.iter()
.enumerate()
.filter_map(|(i, numa)| {
if i != device && *numa == target_numa {
Some(i)
} else {
None
}
})
.collect()
}
pub fn set_numa_node(&mut self, device: usize, numa_node: u32) {
if device < self.numa_nodes.len() {
self.numa_nodes[device] = Some(numa_node);
}
}
pub fn mark_probed(&mut self) {
self.probed = true;
self.last_updated = Instant::now();
}
}
#[derive(Debug, Clone)]
pub struct DeviceStatus {
pub info: DeviceInfo,
pub kernel_count: usize,
pub kernels: Vec<KernelId>,
pub available: bool,
pub load: f64,
}
#[derive(Debug, Clone)]
pub struct DeviceUnregisterResult {
pub device_index: usize,
pub kernels_to_migrate: Vec<KernelMigrationPlan>,
pub orphaned_kernels: Vec<KernelId>,
pub success: bool,
}
#[derive(Debug, Clone)]
pub struct KernelMigrationPlan {
pub kernel_id: KernelId,
pub source_device: usize,
pub target_device: usize,
pub priority: MigrationPriority,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MigrationPriority {
Low,
Normal,
High,
Critical,
}
pub struct MultiGpuCoordinator {
config: MultiGpuConfig,
devices: RwLock<Vec<DeviceInfo>>,
kernel_device_map: RwLock<HashMap<KernelId, usize>>,
device_kernel_counts: RwLock<Vec<AtomicUsize>>,
round_robin_counter: AtomicUsize,
total_kernels: AtomicU64,
#[allow(clippy::type_complexity)]
custom_selector:
RwLock<Option<Arc<dyn Fn(&[DeviceStatus], &LaunchOptions) -> usize + Send + Sync>>>,
topology: RwLock<Option<GpuTopology>>,
}
impl MultiGpuCoordinator {
pub fn new(config: MultiGpuConfig) -> Arc<Self> {
Arc::new(Self {
config,
devices: RwLock::new(Vec::new()),
kernel_device_map: RwLock::new(HashMap::new()),
device_kernel_counts: RwLock::new(Vec::new()),
round_robin_counter: AtomicUsize::new(0),
total_kernels: AtomicU64::new(0),
custom_selector: RwLock::new(None),
topology: RwLock::new(None),
})
}
pub fn register_device(&self, device: DeviceInfo) {
let index = device.index;
let mut devices = self.devices.write();
let mut counts = self.device_kernel_counts.write();
let mut current_len = devices.len();
while current_len <= index {
devices.push(DeviceInfo::new(
current_len,
"Unknown".to_string(),
Backend::Cpu,
));
counts.push(AtomicUsize::new(0));
current_len += 1;
}
devices[index] = device;
}
pub fn unregister_device(&self, index: usize) -> DeviceUnregisterResult {
let devices = self.devices.read();
if index >= devices.len() {
return DeviceUnregisterResult {
device_index: index,
kernels_to_migrate: Vec::new(),
orphaned_kernels: Vec::new(),
success: false,
};
}
let kernels_on_device = self.kernels_on_device(index);
let available_targets: Vec<usize> = devices
.iter()
.enumerate()
.filter(|(i, _)| *i != index)
.map(|(i, _)| i)
.collect();
drop(devices);
let mut kernels_to_migrate = Vec::new();
let mut orphaned_kernels = Vec::new();
if available_targets.is_empty() {
orphaned_kernels = kernels_on_device;
} else {
for kernel_id in kernels_on_device {
if let Some(target) = self.select_migration_target(&available_targets) {
let priority = self.calculate_migration_priority(&kernel_id);
kernels_to_migrate.push(KernelMigrationPlan {
kernel_id,
source_device: index,
target_device: target,
priority,
});
} else {
orphaned_kernels.push(kernel_id);
}
}
}
{
let mut kernel_map = self.kernel_device_map.write();
let counts = self.device_kernel_counts.read();
for plan in &kernels_to_migrate {
kernel_map.insert(plan.kernel_id.clone(), plan.target_device);
if index < counts.len() {
counts[index].fetch_sub(1, Ordering::Relaxed);
}
if plan.target_device < counts.len() {
counts[plan.target_device].fetch_add(1, Ordering::Relaxed);
}
}
for kernel_id in &orphaned_kernels {
kernel_map.remove(kernel_id);
if index < counts.len() {
counts[index].fetch_sub(1, Ordering::Relaxed);
}
}
}
{
let mut devices = self.devices.write();
if index < devices.len() {
devices[index].available_memory = 0;
devices[index].name = format!("{} (unregistered)", devices[index].name);
}
}
DeviceUnregisterResult {
device_index: index,
kernels_to_migrate,
orphaned_kernels,
success: true,
}
}
fn select_migration_target(&self, candidates: &[usize]) -> Option<usize> {
if candidates.is_empty() {
return None;
}
let counts = self.device_kernel_counts.read();
candidates
.iter()
.filter_map(|&idx| {
if idx < counts.len() {
Some((idx, counts[idx].load(Ordering::Relaxed)))
} else {
None
}
})
.min_by_key(|(_, count)| *count)
.map(|(idx, _)| idx)
}
fn calculate_migration_priority(&self, _kernel_id: &KernelId) -> MigrationPriority {
MigrationPriority::Normal
}
pub fn devices(&self) -> Vec<DeviceInfo> {
self.devices.read().clone()
}
pub fn device(&self, index: usize) -> Option<DeviceInfo> {
self.devices.read().get(index).cloned()
}
pub fn device_count(&self) -> usize {
self.devices.read().len()
}
pub fn select_device(&self, options: &LaunchOptions) -> Result<usize> {
let devices = self.devices.read();
if devices.is_empty() {
return Err(RingKernelError::BackendUnavailable(
"No GPU devices available".to_string(),
));
}
let status = self.get_all_status();
if self.config.load_balancing == LoadBalancingStrategy::Custom {
if let Some(selector) = &*self.custom_selector.read() {
return Ok(selector(&status, options));
}
}
let candidates: Vec<_> = if !self.config.preferred_devices.is_empty() {
status
.into_iter()
.filter(|s| self.config.preferred_devices.contains(&s.info.index))
.collect()
} else {
status
};
if candidates.is_empty() {
return Err(RingKernelError::BackendUnavailable(
"No suitable GPU device available".to_string(),
));
}
let selected = match self.config.load_balancing {
LoadBalancingStrategy::FirstAvailable => {
candidates.first().map(|s| s.info.index).unwrap_or(0)
}
LoadBalancingStrategy::LeastLoaded => candidates
.iter()
.filter(|s| s.available && s.kernel_count < self.config.max_kernels_per_device)
.min_by(|a, b| a.kernel_count.cmp(&b.kernel_count))
.map(|s| s.info.index)
.unwrap_or(0),
LoadBalancingStrategy::RoundRobin => {
let available: Vec<_> = candidates.iter().filter(|s| s.available).collect();
if available.is_empty() {
candidates.first().map(|s| s.info.index).unwrap_or(0)
} else {
let idx =
self.round_robin_counter.fetch_add(1, Ordering::Relaxed) % available.len();
available[idx].info.index
}
}
LoadBalancingStrategy::MemoryBased => candidates
.iter()
.filter(|s| s.available)
.max_by(|a, b| a.info.available_memory.cmp(&b.info.available_memory))
.map(|s| s.info.index)
.unwrap_or(0),
LoadBalancingStrategy::ComputeCapability => candidates
.iter()
.filter(|s| s.available)
.max_by(|a, b| {
let a_cap = a.info.compute_capability.unwrap_or((0, 0));
let b_cap = b.info.compute_capability.unwrap_or((0, 0));
a_cap.cmp(&b_cap)
})
.map(|s| s.info.index)
.unwrap_or(0),
LoadBalancingStrategy::Custom => {
0
}
};
Ok(selected)
}
pub fn assign_kernel(&self, kernel_id: KernelId, device_index: usize) {
self.kernel_device_map
.write()
.insert(kernel_id, device_index);
let counts = self.device_kernel_counts.read();
if device_index < counts.len() {
counts[device_index].fetch_add(1, Ordering::Relaxed);
}
self.total_kernels.fetch_add(1, Ordering::Relaxed);
}
pub fn remove_kernel(&self, kernel_id: &KernelId) {
if let Some(device_index) = self.kernel_device_map.write().remove(kernel_id) {
let counts = self.device_kernel_counts.read();
if device_index < counts.len() {
counts[device_index].fetch_sub(1, Ordering::Relaxed);
}
}
}
pub fn get_kernel_device(&self, kernel_id: &KernelId) -> Option<usize> {
self.kernel_device_map.read().get(kernel_id).copied()
}
pub fn kernels_on_device(&self, device_index: usize) -> Vec<KernelId> {
self.kernel_device_map
.read()
.iter()
.filter(|(_, &idx)| idx == device_index)
.map(|(k, _)| k.clone())
.collect()
}
pub fn get_all_status(&self) -> Vec<DeviceStatus> {
let devices = self.devices.read();
let kernel_map = self.kernel_device_map.read();
let counts = self.device_kernel_counts.read();
devices
.iter()
.enumerate()
.map(|(idx, info)| {
let kernel_count = if idx < counts.len() {
counts[idx].load(Ordering::Relaxed)
} else {
0
};
let kernels: Vec<_> = kernel_map
.iter()
.filter(|(_, &dev_idx)| dev_idx == idx)
.map(|(k, _)| k.clone())
.collect();
let load = kernel_count as f64 / self.config.max_kernels_per_device as f64;
let available = kernel_count < self.config.max_kernels_per_device;
DeviceStatus {
info: info.clone(),
kernel_count,
kernels,
available,
load,
}
})
.collect()
}
pub fn get_device_status(&self, device_index: usize) -> Option<DeviceStatus> {
self.get_all_status().into_iter().nth(device_index)
}
pub fn set_custom_selector<F>(&self, selector: F)
where
F: Fn(&[DeviceStatus], &LaunchOptions) -> usize + Send + Sync + 'static,
{
*self.custom_selector.write() = Some(Arc::new(selector));
}
pub fn stats(&self) -> MultiGpuStats {
let status = self.get_all_status();
let total_kernels: usize = status.iter().map(|s| s.kernel_count).sum();
let total_memory: u64 = status.iter().map(|s| s.info.total_memory).sum();
let available_memory: u64 = status.iter().map(|s| s.info.available_memory).sum();
MultiGpuStats {
device_count: status.len(),
total_kernels,
total_memory,
available_memory,
kernels_launched: self.total_kernels.load(Ordering::Relaxed),
}
}
pub fn can_p2p(&self, device_a: usize, device_b: usize) -> bool {
if !self.config.enable_p2p {
return false;
}
let devices = self.devices.read();
if let (Some(a), Some(b)) = (devices.get(device_a), devices.get(device_b)) {
a.p2p_capable && b.p2p_capable
} else {
false
}
}
pub fn update_device_memory(&self, device_index: usize, available_memory: u64) {
let mut devices = self.devices.write();
if let Some(device) = devices.get_mut(device_index) {
device.available_memory = available_memory;
}
}
pub fn discover_topology(&self) -> GpuTopology {
let devices = self.devices.read();
let device_count = devices.len();
if device_count == 0 {
return GpuTopology::new(0);
}
let mut topo = GpuTopology::new(device_count);
for (i, dev_i) in devices.iter().enumerate() {
for (j, dev_j) in devices.iter().enumerate() {
if i == j {
continue;
}
let interconnect = if dev_i.p2p_capable && dev_j.p2p_capable {
if dev_i.backend == dev_j.backend {
match dev_i.backend {
Backend::Cuda => {
let cc_i = dev_i.compute_capability.unwrap_or((0, 0));
let cc_j = dev_j.compute_capability.unwrap_or((0, 0));
if cc_i.0 >= 8 && cc_j.0 >= 8 {
InterconnectType::NvLink
} else {
InterconnectType::Pcie
}
}
_ => InterconnectType::Pcie,
}
} else {
InterconnectType::None
}
} else {
InterconnectType::None
};
topo.set_connection(GpuConnection::new(i, j, interconnect));
}
}
*self.topology.write() = Some(topo.clone());
topo
}
pub fn topology(&self) -> GpuTopology {
{
let topo = self.topology.read();
if let Some(ref t) = *topo {
return t.clone();
}
}
self.discover_topology()
}
pub fn set_topology(&self, topology: GpuTopology) {
*self.topology.write() = Some(topology);
}
pub fn select_device_for_k2k(&self, source_kernel: &KernelId) -> Result<usize> {
let source_idx = match self.get_kernel_device(source_kernel) {
Some(idx) => idx,
None => return self.select_device(&LaunchOptions::default()),
};
let topo = self.topology();
let status = self.get_all_status();
let neighbors = topo.neighbors(source_idx);
if neighbors.is_empty() {
return self.select_device(&LaunchOptions::default());
}
let best = neighbors
.iter()
.filter_map(|&dev_idx| {
status.iter().find(|s| s.info.index == dev_idx).map(|s| {
let conn = topo.get_connection(source_idx, dev_idx);
let bandwidth = conn.map(|c| c.bandwidth_gbps).unwrap_or(1.0);
let score = bandwidth / (s.load + 0.1);
(dev_idx, score)
})
})
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx);
best.ok_or_else(|| {
RingKernelError::BackendUnavailable("No suitable K2K device found".to_string())
})
}
pub fn request_migration(
&self,
kernel_id: &KernelId,
target_device: usize,
) -> Result<MigrationRequest> {
let source_device = self
.get_kernel_device(kernel_id)
.ok_or_else(|| RingKernelError::KernelNotFound(kernel_id.as_str().to_string()))?;
if source_device == target_device {
return Err(RingKernelError::InvalidConfig(
"Cannot migrate to same device".to_string(),
));
}
let devices = self.devices.read();
if target_device >= devices.len() {
return Err(RingKernelError::DeviceNotAvailable(format!(
"Device {} not available",
target_device
)));
}
let topo = self.topology();
let path = topo.best_path(source_device, target_device);
let connection = topo.get_connection(source_device, target_device);
Ok(MigrationRequest {
kernel_id: kernel_id.clone(),
source_device,
target_device,
path,
estimated_bandwidth_gbps: connection.map(|c| c.bandwidth_gbps).unwrap_or(16.0),
estimated_latency_us: connection.map(|c| c.latency_us).unwrap_or(10.0),
state: MigrationState::Pending,
started_at: None,
})
}
pub fn complete_migration(&self, request: &MigrationRequest) -> Result<()> {
{
let mut map = self.kernel_device_map.write();
if let Some(dev) = map.get_mut(&request.kernel_id) {
*dev = request.target_device;
}
}
{
let counts = self.device_kernel_counts.read();
if request.source_device < counts.len() {
counts[request.source_device].fetch_sub(1, Ordering::Relaxed);
}
if request.target_device < counts.len() {
counts[request.target_device].fetch_add(1, Ordering::Relaxed);
}
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MigrationState {
Pending,
Quiescing,
Checkpointing,
Transferring,
Restoring,
Completed,
Failed,
Cancelled,
}
#[derive(Debug, Clone)]
pub struct MigrationRequest {
pub kernel_id: KernelId,
pub source_device: usize,
pub target_device: usize,
pub path: Vec<usize>,
pub estimated_bandwidth_gbps: f64,
pub estimated_latency_us: f64,
pub state: MigrationState,
pub started_at: Option<Instant>,
}
impl MigrationRequest {
pub fn estimate_transfer_time(&self, state_size_bytes: usize) -> Duration {
let size_gb = state_size_bytes as f64 / 1_000_000_000.0;
let transfer_time_s = size_gb / self.estimated_bandwidth_gbps;
let total_us = (transfer_time_s * 1_000_000.0) + self.estimated_latency_us;
Duration::from_micros(total_us as u64)
}
}
pub struct CrossGpuK2KRouter {
coordinator: Arc<MultiGpuCoordinator>,
pending_queues: RwLock<HashMap<(usize, usize), Vec<PendingK2KMessage>>>,
stats: CrossGpuRouterStats,
}
#[derive(Debug, Clone)]
pub struct PendingK2KMessage {
pub source_kernel: KernelId,
pub dest_kernel: KernelId,
pub message: K2KMessage,
pub queued_at: Instant,
pub hops: u32,
}
#[derive(Debug, Default)]
pub struct CrossGpuRouterStats {
messages_routed: AtomicU64,
bytes_transferred: AtomicU64,
messages_pending: AtomicUsize,
total_latency_us: AtomicU64,
routing_failures: AtomicU64,
}
impl CrossGpuK2KRouter {
pub fn new(coordinator: Arc<MultiGpuCoordinator>) -> Arc<Self> {
Arc::new(Self {
coordinator,
pending_queues: RwLock::new(HashMap::new()),
stats: CrossGpuRouterStats::default(),
})
}
pub fn route_message(
&self,
source_kernel: &KernelId,
dest_kernel: &KernelId,
message: K2KMessage,
) -> Result<RoutingDecision> {
let source_device = self
.coordinator
.get_kernel_device(source_kernel)
.ok_or_else(|| {
RingKernelError::K2KDestinationNotFound(source_kernel.as_str().to_string())
})?;
let dest_device = self
.coordinator
.get_kernel_device(dest_kernel)
.ok_or_else(|| {
RingKernelError::K2KDestinationNotFound(dest_kernel.as_str().to_string())
})?;
if source_device == dest_device {
return Ok(RoutingDecision::SameDevice);
}
let topo = self.coordinator.topology();
let path = topo.best_path(source_device, dest_device);
if let Some(conn) = topo.get_connection(source_device, dest_device) {
if conn.interconnect.supports_p2p() {
let pending = PendingK2KMessage {
source_kernel: source_kernel.clone(),
dest_kernel: dest_kernel.clone(),
message,
queued_at: Instant::now(),
hops: 1,
};
self.enqueue_pending(source_device, dest_device, pending);
self.stats.messages_pending.fetch_add(1, Ordering::Relaxed);
return Ok(RoutingDecision::DirectP2P {
source_device,
dest_device,
bandwidth_gbps: conn.bandwidth_gbps,
});
}
}
if path.len() > 2 {
let pending = PendingK2KMessage {
source_kernel: source_kernel.clone(),
dest_kernel: dest_kernel.clone(),
message,
queued_at: Instant::now(),
hops: (path.len() - 1) as u32,
};
self.enqueue_pending(source_device, path[1], pending);
self.stats.messages_pending.fetch_add(1, Ordering::Relaxed);
return Ok(RoutingDecision::MultiHop {
path: path.clone(),
total_hops: (path.len() - 1) as u32,
});
}
let pending = PendingK2KMessage {
source_kernel: source_kernel.clone(),
dest_kernel: dest_kernel.clone(),
message,
queued_at: Instant::now(),
hops: 2, };
self.enqueue_pending(source_device, dest_device, pending);
self.stats.messages_pending.fetch_add(1, Ordering::Relaxed);
Ok(RoutingDecision::HostMediated {
source_device,
dest_device,
})
}
pub fn drain_pending(&self, source: usize, dest: usize) -> Vec<PendingK2KMessage> {
let mut queues = self.pending_queues.write();
let messages = queues.remove(&(source, dest)).unwrap_or_default();
self.stats
.messages_pending
.fetch_sub(messages.len(), Ordering::Relaxed);
messages
}
pub fn record_delivery(&self, message: &PendingK2KMessage, payload_size: usize) {
self.stats.messages_routed.fetch_add(1, Ordering::Relaxed);
self.stats
.bytes_transferred
.fetch_add(payload_size as u64, Ordering::Relaxed);
let latency = message.queued_at.elapsed().as_micros() as u64;
self.stats
.total_latency_us
.fetch_add(latency, Ordering::Relaxed);
}
pub fn record_failure(&self) {
self.stats.routing_failures.fetch_add(1, Ordering::Relaxed);
}
pub fn stats(&self) -> CrossGpuRouterStatsSnapshot {
let messages_routed = self.stats.messages_routed.load(Ordering::Relaxed);
let total_latency = self.stats.total_latency_us.load(Ordering::Relaxed);
CrossGpuRouterStatsSnapshot {
messages_routed,
bytes_transferred: self.stats.bytes_transferred.load(Ordering::Relaxed),
messages_pending: self.stats.messages_pending.load(Ordering::Relaxed),
avg_latency_us: if messages_routed > 0 {
total_latency as f64 / messages_routed as f64
} else {
0.0
},
routing_failures: self.stats.routing_failures.load(Ordering::Relaxed),
}
}
fn enqueue_pending(&self, source: usize, dest: usize, message: PendingK2KMessage) {
let mut queues = self.pending_queues.write();
queues.entry((source, dest)).or_default().push(message);
}
}
#[derive(Debug, Clone)]
pub struct CrossGpuRouterStatsSnapshot {
pub messages_routed: u64,
pub bytes_transferred: u64,
pub messages_pending: usize,
pub avg_latency_us: f64,
pub routing_failures: u64,
}
#[derive(Debug, Clone)]
pub enum RoutingDecision {
SameDevice,
DirectP2P {
source_device: usize,
dest_device: usize,
bandwidth_gbps: f64,
},
MultiHop {
path: Vec<usize>,
total_hops: u32,
},
HostMediated {
source_device: usize,
dest_device: usize,
},
}
#[derive(Debug, Clone, Default)]
pub struct MultiGpuStats {
pub device_count: usize,
pub total_kernels: usize,
pub total_memory: u64,
pub available_memory: u64,
pub kernels_launched: u64,
}
pub struct MultiGpuBuilder {
config: MultiGpuConfig,
}
impl MultiGpuBuilder {
pub fn new() -> Self {
Self {
config: MultiGpuConfig::default(),
}
}
pub fn load_balancing(mut self, strategy: LoadBalancingStrategy) -> Self {
self.config.load_balancing = strategy;
self
}
pub fn auto_select_device(mut self, enable: bool) -> Self {
self.config.auto_select_device = enable;
self
}
pub fn max_kernels_per_device(mut self, max: usize) -> Self {
self.config.max_kernels_per_device = max;
self
}
pub fn enable_p2p(mut self, enable: bool) -> Self {
self.config.enable_p2p = enable;
self
}
pub fn preferred_devices(mut self, devices: Vec<usize>) -> Self {
self.config.preferred_devices = devices;
self
}
pub fn build(self) -> Arc<MultiGpuCoordinator> {
MultiGpuCoordinator::new(self.config)
}
}
impl Default for MultiGpuBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct CrossDeviceTransfer {
pub source_device: usize,
pub dest_device: usize,
pub size: usize,
pub use_p2p: bool,
}
impl CrossDeviceTransfer {
pub fn new(source: usize, dest: usize, size: usize) -> Self {
Self {
source_device: source,
dest_device: dest,
size,
use_p2p: true,
}
}
pub fn without_p2p(mut self) -> Self {
self.use_p2p = false;
self
}
}
use crate::checkpoint::{CheckpointStorage, CheckpointableKernel, MemoryStorage};
pub struct KernelMigrator {
coordinator: Arc<MultiGpuCoordinator>,
storage: Arc<dyn CheckpointStorage>,
stats: MigrationStats,
}
#[derive(Debug, Default)]
pub struct MigrationStats {
pub successful_migrations: AtomicU64,
pub failed_migrations: AtomicU64,
pub bytes_transferred: AtomicU64,
pub checkpoint_time_us: AtomicU64,
pub restore_time_us: AtomicU64,
}
#[derive(Debug, Clone)]
pub struct MigrationResult {
pub kernel_id: KernelId,
pub source_device: usize,
pub target_device: usize,
pub checkpoint_size: usize,
pub checkpoint_duration: Duration,
pub transfer_duration: Duration,
pub restore_duration: Duration,
pub total_duration: Duration,
}
impl KernelMigrator {
pub fn new(coordinator: Arc<MultiGpuCoordinator>) -> Self {
Self {
coordinator,
storage: Arc::new(MemoryStorage::new()),
stats: MigrationStats::default(),
}
}
pub fn with_storage(
coordinator: Arc<MultiGpuCoordinator>,
storage: Arc<dyn CheckpointStorage>,
) -> Self {
Self {
coordinator,
storage,
stats: MigrationStats::default(),
}
}
pub fn migrate_with_checkpoint<K: CheckpointableKernel>(
&self,
kernel: &K,
request: &mut MigrationRequest,
) -> Result<MigrationResult> {
let start_time = Instant::now();
request.started_at = Some(start_time);
request.state = MigrationState::Quiescing;
request.state = MigrationState::Checkpointing;
let checkpoint_start = Instant::now();
let checkpoint = kernel.create_checkpoint().map_err(|e| {
self.stats.failed_migrations.fetch_add(1, Ordering::Relaxed);
request.state = MigrationState::Failed;
RingKernelError::MigrationFailed(format!("Checkpoint creation failed: {}", e))
})?;
let checkpoint_duration = checkpoint_start.elapsed();
let checkpoint_size = checkpoint.total_size();
self.stats
.checkpoint_time_us
.fetch_add(checkpoint_duration.as_micros() as u64, Ordering::Relaxed);
request.state = MigrationState::Transferring;
let transfer_start = Instant::now();
let checkpoint_name = format!(
"migration_{}_{}_{}",
request.kernel_id.as_str(),
request.source_device,
request.target_device
);
self.storage
.save(&checkpoint, &checkpoint_name)
.map_err(|e| {
self.stats.failed_migrations.fetch_add(1, Ordering::Relaxed);
request.state = MigrationState::Failed;
RingKernelError::MigrationFailed(format!("Checkpoint transfer failed: {}", e))
})?;
let transfer_duration = transfer_start.elapsed();
self.stats
.bytes_transferred
.fetch_add(checkpoint_size as u64, Ordering::Relaxed);
request.state = MigrationState::Restoring;
let restore_start = Instant::now();
let _restored = self.storage.load(&checkpoint_name).map_err(|e| {
self.stats.failed_migrations.fetch_add(1, Ordering::Relaxed);
request.state = MigrationState::Failed;
RingKernelError::MigrationFailed(format!("Checkpoint restore failed: {}", e))
})?;
let restore_duration = restore_start.elapsed();
self.stats
.restore_time_us
.fetch_add(restore_duration.as_micros() as u64, Ordering::Relaxed);
request.state = MigrationState::Completed;
self.coordinator.complete_migration(request)?;
let _ = self.storage.delete(&checkpoint_name);
self.stats
.successful_migrations
.fetch_add(1, Ordering::Relaxed);
Ok(MigrationResult {
kernel_id: request.kernel_id.clone(),
source_device: request.source_device,
target_device: request.target_device,
checkpoint_size,
checkpoint_duration,
transfer_duration,
restore_duration,
total_duration: start_time.elapsed(),
})
}
pub fn coordinator(&self) -> &Arc<MultiGpuCoordinator> {
&self.coordinator
}
pub fn stats(&self) -> MigrationStatsSnapshot {
let successful = self.stats.successful_migrations.load(Ordering::Relaxed);
let failed = self.stats.failed_migrations.load(Ordering::Relaxed);
let total = successful + failed;
let checkpoint_us = self.stats.checkpoint_time_us.load(Ordering::Relaxed);
let restore_us = self.stats.restore_time_us.load(Ordering::Relaxed);
MigrationStatsSnapshot {
successful_migrations: successful,
failed_migrations: failed,
bytes_transferred: self.stats.bytes_transferred.load(Ordering::Relaxed),
avg_checkpoint_time: checkpoint_us
.checked_div(total)
.map(Duration::from_micros)
.unwrap_or(Duration::ZERO),
avg_restore_time: restore_us
.checked_div(total)
.map(Duration::from_micros)
.unwrap_or(Duration::ZERO),
}
}
}
#[derive(Debug, Clone)]
pub struct MigrationStatsSnapshot {
pub successful_migrations: u64,
pub failed_migrations: u64,
pub bytes_transferred: u64,
pub avg_checkpoint_time: Duration,
pub avg_restore_time: Duration,
}
pub trait MigratableKernel: CheckpointableKernel {
fn prepare_for_migration(&mut self) -> Result<()>;
fn cancel_migration(&mut self) -> Result<()>;
fn is_quiescent(&self) -> bool;
fn estimated_state_size(&self) -> usize;
}
#[derive(Debug, Clone)]
pub struct HotReloadConfig {
pub enabled: bool,
pub reload_timeout: Duration,
pub preserve_state: bool,
pub max_retries: u32,
pub retry_backoff: Duration,
pub validate_before_swap: bool,
pub keep_fallback: bool,
pub max_rule_history: usize,
}
impl Default for HotReloadConfig {
fn default() -> Self {
Self {
enabled: true,
reload_timeout: Duration::from_secs(30),
preserve_state: true,
max_retries: 3,
retry_backoff: Duration::from_millis(500),
validate_before_swap: true,
keep_fallback: true,
max_rule_history: 5,
}
}
}
impl HotReloadConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.reload_timeout = timeout;
self
}
pub fn with_preserve_state(mut self, preserve: bool) -> Self {
self.preserve_state = preserve;
self
}
pub fn with_max_retries(mut self, retries: u32) -> Self {
self.max_retries = retries;
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HotReloadState {
Idle,
Draining,
Checkpointing,
Compiling,
Validating,
Swapping,
Restoring,
Completed,
Failed,
RollingBack,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum KernelCodeFormat {
Ptx,
Cubin,
SpirV,
Wgsl,
Msl,
MetalLib,
Source,
}
#[derive(Debug, Clone)]
pub struct KernelCodeSource {
pub version_id: u64,
pub format: KernelCodeFormat,
pub code: Vec<u8>,
pub entry_point: String,
pub metadata: HashMap<String, String>,
pub created_at: Instant,
pub hash: [u8; 32],
}
impl KernelCodeSource {
pub fn new(format: KernelCodeFormat, code: Vec<u8>, entry_point: impl Into<String>) -> Self {
let hash = Self::compute_hash(&code);
Self {
version_id: 0,
format,
code,
entry_point: entry_point.into(),
metadata: HashMap::new(),
created_at: Instant::now(),
hash,
}
}
pub fn from_ptx(ptx: &str, entry_point: impl Into<String>) -> Self {
Self::new(KernelCodeFormat::Ptx, ptx.as_bytes().to_vec(), entry_point)
}
pub fn from_wgsl(wgsl: &str, entry_point: impl Into<String>) -> Self {
Self::new(
KernelCodeFormat::Wgsl,
wgsl.as_bytes().to_vec(),
entry_point,
)
}
pub fn from_msl(msl: &str, entry_point: impl Into<String>) -> Self {
Self::new(KernelCodeFormat::Msl, msl.as_bytes().to_vec(), entry_point)
}
pub fn with_version(mut self, version: u64) -> Self {
self.version_id = version;
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
fn compute_hash(data: &[u8]) -> [u8; 32] {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
data.hash(&mut hasher);
let h1 = hasher.finish();
h1.hash(&mut hasher);
let h2 = hasher.finish();
h1.hash(&mut hasher);
let h3 = hasher.finish();
h1.hash(&mut hasher);
let h4 = hasher.finish();
let mut hash = [0u8; 32];
hash[0..8].copy_from_slice(&h1.to_le_bytes());
hash[8..16].copy_from_slice(&h2.to_le_bytes());
hash[16..24].copy_from_slice(&h3.to_le_bytes());
hash[24..32].copy_from_slice(&h4.to_le_bytes());
hash
}
pub fn as_str(&self) -> Option<&str> {
match self.format {
KernelCodeFormat::Ptx
| KernelCodeFormat::Wgsl
| KernelCodeFormat::Msl
| KernelCodeFormat::Source => std::str::from_utf8(&self.code).ok(),
_ => None,
}
}
pub fn size(&self) -> usize {
self.code.len()
}
}
#[derive(Debug)]
pub struct HotReloadRequest {
pub kernel_id: KernelId,
pub new_code: KernelCodeSource,
pub state: HotReloadState,
pub created_at: Instant,
pub started_at: Option<Instant>,
pub retry_count: u32,
pub error: Option<String>,
checkpoint_data: Option<Vec<u8>>,
}
impl HotReloadRequest {
pub fn new(kernel_id: KernelId, new_code: KernelCodeSource) -> Self {
Self {
kernel_id,
new_code,
state: HotReloadState::Idle,
created_at: Instant::now(),
started_at: None,
retry_count: 0,
error: None,
checkpoint_data: None,
}
}
pub fn is_in_progress(&self) -> bool {
!matches!(
self.state,
HotReloadState::Idle | HotReloadState::Completed | HotReloadState::Failed
)
}
pub fn is_completed(&self) -> bool {
self.state == HotReloadState::Completed
}
pub fn is_failed(&self) -> bool {
self.state == HotReloadState::Failed
}
pub fn elapsed(&self) -> Duration {
self.created_at.elapsed()
}
pub fn reload_elapsed(&self) -> Option<Duration> {
self.started_at.map(|s| s.elapsed())
}
}
#[derive(Debug, Clone)]
pub struct HotReloadResult {
pub kernel_id: KernelId,
pub old_version: u64,
pub new_version: u64,
pub state_preserved: bool,
pub checkpoint_size: usize,
pub drain_duration: Duration,
pub checkpoint_duration: Duration,
pub compile_duration: Duration,
pub swap_duration: Duration,
pub restore_duration: Duration,
pub total_duration: Duration,
}
#[derive(Debug, Default)]
struct HotReloadStats {
successful_reloads: AtomicU64,
failed_reloads: AtomicU64,
rollbacks: AtomicU64,
total_drain_time_us: AtomicU64,
total_compile_time_us: AtomicU64,
total_swap_time_us: AtomicU64,
state_preserved_count: AtomicU64,
}
#[derive(Debug, Clone)]
pub struct HotReloadStatsSnapshot {
pub successful_reloads: u64,
pub failed_reloads: u64,
pub rollbacks: u64,
pub avg_drain_time: Duration,
pub avg_compile_time: Duration,
pub avg_swap_time: Duration,
pub state_preserved_count: u64,
}
pub struct HotReloadManager {
config: HotReloadConfig,
kernels: RwLock<HashMap<KernelId, KernelCodeSource>>,
fallbacks: RwLock<HashMap<KernelId, KernelCodeSource>>,
active_requests: RwLock<HashMap<KernelId, HotReloadRequest>>,
version_counter: AtomicU64,
stats: HotReloadStats,
rule_registry: Arc<crate::rules::RuleRegistry>,
}
impl HotReloadManager {
pub fn new(config: HotReloadConfig) -> Arc<Self> {
Self::with_rule_backend(config, Arc::new(crate::rules::NoopSwapBackend))
}
pub fn with_rule_backend(
config: HotReloadConfig,
rule_backend: Arc<dyn crate::rules::RuleSwapBackend>,
) -> Arc<Self> {
let rule_registry = Arc::new(crate::rules::RuleRegistry::new(
config.max_rule_history,
rule_backend,
));
Arc::new(Self {
config,
kernels: RwLock::new(HashMap::new()),
fallbacks: RwLock::new(HashMap::new()),
active_requests: RwLock::new(HashMap::new()),
version_counter: AtomicU64::new(1),
stats: HotReloadStats::default(),
rule_registry,
})
}
pub fn rule_registry(&self) -> &Arc<crate::rules::RuleRegistry> {
&self.rule_registry
}
pub fn with_defaults() -> Arc<Self> {
Self::new(HotReloadConfig::default())
}
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
pub fn register_kernel(&self, kernel_id: &KernelId, code: KernelCodeSource) {
let version = self.version_counter.fetch_add(1, Ordering::Relaxed);
let code = code.with_version(version);
self.kernels.write().insert(kernel_id.clone(), code);
}
pub fn unregister_kernel(&self, kernel_id: &KernelId) {
self.kernels.write().remove(kernel_id);
self.fallbacks.write().remove(kernel_id);
self.active_requests.write().remove(kernel_id);
}
pub fn get_current_version(&self, kernel_id: &KernelId) -> Option<u64> {
self.kernels.read().get(kernel_id).map(|c| c.version_id)
}
pub fn get_current_code(&self, kernel_id: &KernelId) -> Option<KernelCodeSource> {
self.kernels.read().get(kernel_id).cloned()
}
pub fn request_reload(
&self,
kernel_id: &KernelId,
new_code: KernelCodeSource,
) -> Result<HotReloadRequest> {
if !self.config.enabled {
return Err(RingKernelError::ValidationError(
"Hot reload is disabled".to_string(),
));
}
if !self.kernels.read().contains_key(kernel_id) {
return Err(RingKernelError::KernelNotFound(
kernel_id.as_str().to_string(),
));
}
{
let active = self.active_requests.read();
if let Some(existing) = active.get(kernel_id) {
if existing.is_in_progress() {
return Err(RingKernelError::ValidationError(
"Hot reload already in progress for this kernel".to_string(),
));
}
}
}
let version = self.version_counter.fetch_add(1, Ordering::Relaxed);
let new_code = new_code.with_version(version);
let request = HotReloadRequest::new(kernel_id.clone(), new_code);
self.active_requests.write().insert(
kernel_id.clone(),
HotReloadRequest::new(kernel_id.clone(), request.new_code.clone()),
);
Ok(request)
}
pub fn execute_reload<K: CheckpointableKernel>(
&self,
request: &mut HotReloadRequest,
kernel: &K,
) -> Result<HotReloadResult> {
let start_time = Instant::now();
request.started_at = Some(start_time);
let old_version = self
.kernels
.read()
.get(&request.kernel_id)
.map(|c| c.version_id)
.unwrap_or(0);
request.state = HotReloadState::Draining;
let drain_start = Instant::now();
std::thread::sleep(Duration::from_micros(10));
let drain_duration = drain_start.elapsed();
self.stats
.total_drain_time_us
.fetch_add(drain_duration.as_micros() as u64, Ordering::Relaxed);
request.state = HotReloadState::Checkpointing;
let checkpoint_start = Instant::now();
let checkpoint_size = if self.config.preserve_state {
let checkpoint = kernel.create_checkpoint()?;
let data = checkpoint.to_bytes();
request.checkpoint_data = Some(data.clone());
data.len()
} else {
0
};
let checkpoint_duration = checkpoint_start.elapsed();
request.state = HotReloadState::Validating;
if self.config.validate_before_swap {
self.validate_code(&request.new_code)?;
}
request.state = HotReloadState::Compiling;
let compile_start = Instant::now();
std::thread::sleep(Duration::from_micros(10));
let compile_duration = compile_start.elapsed();
self.stats
.total_compile_time_us
.fetch_add(compile_duration.as_micros() as u64, Ordering::Relaxed);
request.state = HotReloadState::Swapping;
let swap_start = Instant::now();
if self.config.keep_fallback {
if let Some(old_code) = self.kernels.read().get(&request.kernel_id).cloned() {
self.fallbacks
.write()
.insert(request.kernel_id.clone(), old_code);
}
}
self.kernels
.write()
.insert(request.kernel_id.clone(), request.new_code.clone());
let swap_duration = swap_start.elapsed();
self.stats
.total_swap_time_us
.fetch_add(swap_duration.as_micros() as u64, Ordering::Relaxed);
request.state = HotReloadState::Restoring;
let restore_start = Instant::now();
let restore_duration = restore_start.elapsed();
request.state = HotReloadState::Completed;
self.stats
.successful_reloads
.fetch_add(1, Ordering::Relaxed);
if self.config.preserve_state && checkpoint_size > 0 {
self.stats
.state_preserved_count
.fetch_add(1, Ordering::Relaxed);
}
self.active_requests.write().remove(&request.kernel_id);
Ok(HotReloadResult {
kernel_id: request.kernel_id.clone(),
old_version,
new_version: request.new_code.version_id,
state_preserved: self.config.preserve_state && checkpoint_size > 0,
checkpoint_size,
drain_duration,
checkpoint_duration,
compile_duration,
swap_duration,
restore_duration,
total_duration: start_time.elapsed(),
})
}
pub fn rollback(&self, kernel_id: &KernelId) -> Result<()> {
let fallback =
self.fallbacks.write().remove(kernel_id).ok_or_else(|| {
RingKernelError::ValidationError("No fallback available".to_string())
})?;
self.kernels.write().insert(kernel_id.clone(), fallback);
self.stats.rollbacks.fetch_add(1, Ordering::Relaxed);
if let Some(request) = self.active_requests.write().get_mut(kernel_id) {
request.state = HotReloadState::RollingBack;
}
Ok(())
}
fn validate_code(&self, code: &KernelCodeSource) -> Result<()> {
if code.code.is_empty() {
return Err(RingKernelError::ValidationError(
"Kernel code is empty".to_string(),
));
}
if code.entry_point.is_empty() {
return Err(RingKernelError::ValidationError(
"Entry point is empty".to_string(),
));
}
match code.format {
KernelCodeFormat::Ptx => {
if let Some(text) = code.as_str() {
if !text.contains(".version") && !text.contains(".target") {
return Err(RingKernelError::ValidationError(
"PTX code missing version/target directive".to_string(),
));
}
}
}
KernelCodeFormat::Wgsl => {
if let Some(text) = code.as_str() {
if !text.contains("@compute") && !text.contains("fn ") {
return Err(RingKernelError::ValidationError(
"WGSL code missing compute shader or function".to_string(),
));
}
}
}
KernelCodeFormat::Msl => {
if let Some(text) = code.as_str() {
if !text.contains("kernel ") {
return Err(RingKernelError::ValidationError(
"MSL code missing kernel function".to_string(),
));
}
}
}
_ => {}
}
Ok(())
}
pub fn stats(&self) -> HotReloadStatsSnapshot {
let successful = self.stats.successful_reloads.load(Ordering::Relaxed);
let failed = self.stats.failed_reloads.load(Ordering::Relaxed);
let total = successful.max(1);
HotReloadStatsSnapshot {
successful_reloads: successful,
failed_reloads: failed,
rollbacks: self.stats.rollbacks.load(Ordering::Relaxed),
avg_drain_time: Duration::from_micros(
self.stats.total_drain_time_us.load(Ordering::Relaxed) / total,
),
avg_compile_time: Duration::from_micros(
self.stats.total_compile_time_us.load(Ordering::Relaxed) / total,
),
avg_swap_time: Duration::from_micros(
self.stats.total_swap_time_us.load(Ordering::Relaxed) / total,
),
state_preserved_count: self.stats.state_preserved_count.load(Ordering::Relaxed),
}
}
pub fn list_kernels(&self) -> Vec<KernelId> {
self.kernels.read().keys().cloned().collect()
}
pub fn is_registered(&self, kernel_id: &KernelId) -> bool {
self.kernels.read().contains_key(kernel_id)
}
pub fn is_reload_in_progress(&self, kernel_id: &KernelId) -> bool {
self.active_requests
.read()
.get(kernel_id)
.map(|r| r.is_in_progress())
.unwrap_or(false)
}
pub fn config(&self) -> &HotReloadConfig {
&self.config
}
}
pub trait HotReloadableKernel: CheckpointableKernel {
fn prepare_for_reload(&mut self) -> Result<()>;
fn apply_code(&mut self, code: &KernelCodeSource) -> Result<()>;
fn resume_after_reload(&mut self) -> Result<()>;
fn is_ready_for_reload(&self) -> bool;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_device_info() {
let info = DeviceInfo::new(0, "Test GPU".to_string(), Backend::Cuda);
assert_eq!(info.index, 0);
assert_eq!(info.name, "Test GPU");
assert_eq!(info.memory_utilization(), 0.0);
}
#[test]
fn test_coordinator_registration() {
let coord = MultiGpuBuilder::new().build();
let device = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
coord.register_device(device);
assert_eq!(coord.device_count(), 1);
assert!(coord.device(0).is_some());
}
#[test]
fn test_kernel_assignment() {
let coord = MultiGpuBuilder::new().build();
let device = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
coord.register_device(device);
let kernel_id = KernelId::new("test_kernel");
coord.assign_kernel(kernel_id.clone(), 0);
assert_eq!(coord.get_kernel_device(&kernel_id), Some(0));
assert_eq!(coord.kernels_on_device(0).len(), 1);
}
#[test]
fn test_load_balancing_least_loaded() {
let coord = MultiGpuBuilder::new()
.load_balancing(LoadBalancingStrategy::LeastLoaded)
.build();
coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
coord.assign_kernel(KernelId::new("k1"), 0);
let selected = coord.select_device(&LaunchOptions::default()).unwrap();
assert_eq!(selected, 1);
}
#[test]
fn test_round_robin() {
let coord = MultiGpuBuilder::new()
.load_balancing(LoadBalancingStrategy::RoundRobin)
.build();
coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
let d1 = coord.select_device(&LaunchOptions::default()).unwrap();
let d2 = coord.select_device(&LaunchOptions::default()).unwrap();
let d3 = coord.select_device(&LaunchOptions::default()).unwrap();
assert_ne!(d1, d2);
assert_eq!(d1, d3);
}
#[test]
fn test_interconnect_bandwidth() {
assert!(
InterconnectType::NvLink.estimated_bandwidth_gbps()
> InterconnectType::Pcie.estimated_bandwidth_gbps()
);
assert!(
InterconnectType::Pcie.estimated_bandwidth_gbps()
> InterconnectType::None.estimated_bandwidth_gbps()
);
assert!(
InterconnectType::SameDevice.estimated_bandwidth_gbps()
> InterconnectType::NvLink.estimated_bandwidth_gbps()
);
}
#[test]
fn test_interconnect_p2p_support() {
assert!(!InterconnectType::None.supports_p2p());
assert!(InterconnectType::Pcie.supports_p2p());
assert!(InterconnectType::NvLink.supports_p2p());
assert!(InterconnectType::NvSwitch.supports_p2p());
}
#[test]
fn test_gpu_topology_creation() {
let topo = GpuTopology::new(4);
assert_eq!(topo.device_count, 4);
for i in 0..4 {
let conn = topo.get_connection(i, i);
assert!(conn.is_some());
assert_eq!(conn.unwrap().interconnect, InterconnectType::SameDevice);
}
}
#[test]
fn test_gpu_topology_set_connection() {
let mut topo = GpuTopology::new(4);
topo.set_connection(GpuConnection::new(0, 1, InterconnectType::NvLink));
let conn_01 = topo.get_connection(0, 1);
assert!(conn_01.is_some());
assert_eq!(conn_01.unwrap().interconnect, InterconnectType::NvLink);
let conn_10 = topo.get_connection(1, 0);
assert!(conn_10.is_some());
assert_eq!(conn_10.unwrap().interconnect, InterconnectType::NvLink);
}
#[test]
fn test_gpu_topology_neighbors() {
let mut topo = GpuTopology::new(4);
topo.set_connection(GpuConnection::new(0, 1, InterconnectType::NvLink));
topo.set_connection(GpuConnection::new(1, 2, InterconnectType::NvLink));
topo.set_connection(GpuConnection::new(2, 3, InterconnectType::NvLink));
topo.set_connection(GpuConnection::new(3, 0, InterconnectType::NvLink));
let neighbors_0 = topo.neighbors(0);
assert_eq!(neighbors_0.len(), 2);
assert!(neighbors_0.contains(&1));
assert!(neighbors_0.contains(&3));
}
#[test]
fn test_gpu_topology_best_path() {
let mut topo = GpuTopology::new(4);
topo.set_connection(GpuConnection::new(0, 1, InterconnectType::NvLink));
topo.set_connection(GpuConnection::new(1, 2, InterconnectType::NvLink));
topo.set_connection(GpuConnection::new(2, 3, InterconnectType::NvLink));
topo.set_connection(GpuConnection::new(0, 3, InterconnectType::None));
let path_01 = topo.best_path(0, 1);
assert_eq!(path_01, vec![0, 1]);
let path_00 = topo.best_path(0, 0);
assert_eq!(path_00, vec![0]);
}
#[test]
fn test_gpu_topology_fully_connected() {
let mut topo = GpuTopology::new(3);
assert!(!topo.is_fully_connected());
topo.set_connection(GpuConnection::new(0, 1, InterconnectType::NvLink));
topo.set_connection(GpuConnection::new(0, 2, InterconnectType::NvLink));
topo.set_connection(GpuConnection::new(1, 2, InterconnectType::NvLink));
assert!(topo.is_fully_connected());
}
#[test]
fn test_gpu_topology_numa() {
let mut topo = GpuTopology::new(4);
topo.set_numa_node(0, 0);
topo.set_numa_node(1, 0);
topo.set_numa_node(2, 1);
topo.set_numa_node(3, 1);
let numa_neighbors_0 = topo.numa_neighbors(0);
assert_eq!(numa_neighbors_0, vec![1]);
let numa_neighbors_2 = topo.numa_neighbors(2);
assert_eq!(numa_neighbors_2, vec![3]);
}
#[test]
fn test_coordinator_topology_discovery() {
let coord = MultiGpuBuilder::new().enable_p2p(true).build();
let mut dev0 = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
dev0.p2p_capable = true;
dev0.compute_capability = Some((8, 0));
let mut dev1 = DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda);
dev1.p2p_capable = true;
dev1.compute_capability = Some((8, 6));
coord.register_device(dev0);
coord.register_device(dev1);
let topo = coord.discover_topology();
assert_eq!(topo.device_count, 2);
let conn = topo.get_connection(0, 1);
assert!(conn.is_some());
assert_eq!(conn.unwrap().interconnect, InterconnectType::NvLink);
}
#[test]
fn test_migration_request() {
let coord = MultiGpuBuilder::new().build();
coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
let kernel_id = KernelId::new("migrating_kernel");
coord.assign_kernel(kernel_id.clone(), 0);
let request = coord.request_migration(&kernel_id, 1).unwrap();
assert_eq!(request.source_device, 0);
assert_eq!(request.target_device, 1);
assert_eq!(request.state, MigrationState::Pending);
}
#[test]
fn test_migration_same_device_error() {
let coord = MultiGpuBuilder::new().build();
coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
let kernel_id = KernelId::new("kernel");
coord.assign_kernel(kernel_id.clone(), 0);
let result = coord.request_migration(&kernel_id, 0);
assert!(result.is_err());
}
#[test]
fn test_migration_complete() {
let coord = MultiGpuBuilder::new().build();
coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
let kernel_id = KernelId::new("migrating_kernel");
coord.assign_kernel(kernel_id.clone(), 0);
assert_eq!(coord.get_kernel_device(&kernel_id), Some(0));
let request = coord.request_migration(&kernel_id, 1).unwrap();
coord.complete_migration(&request).unwrap();
assert_eq!(coord.get_kernel_device(&kernel_id), Some(1));
}
#[test]
fn test_migration_transfer_time_estimate() {
let request = MigrationRequest {
kernel_id: KernelId::new("test"),
source_device: 0,
target_device: 1,
path: vec![0, 1],
estimated_bandwidth_gbps: 300.0, estimated_latency_us: 1.0,
state: MigrationState::Pending,
started_at: None,
};
let time = request.estimate_transfer_time(1_000_000_000);
assert!(time.as_micros() > 3000);
assert!(time.as_micros() < 4000);
}
use crate::hlc::HlcTimestamp;
use crate::message::MessageEnvelope;
fn make_test_k2k_message(source: &KernelId, dest: &KernelId) -> K2KMessage {
let timestamp = HlcTimestamp::now(42);
let envelope = MessageEnvelope::empty(1, 2, timestamp);
K2KMessage::new(source.clone(), dest.clone(), envelope, timestamp)
}
#[test]
fn test_router_same_device() {
let coord = MultiGpuBuilder::new().build();
coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
let k1 = KernelId::new("k1");
let k2 = KernelId::new("k2");
coord.assign_kernel(k1.clone(), 0);
coord.assign_kernel(k2.clone(), 0);
let router = CrossGpuK2KRouter::new(coord);
let msg = make_test_k2k_message(&k1, &k2);
let decision = router.route_message(&k1, &k2, msg).unwrap();
matches!(decision, RoutingDecision::SameDevice);
}
#[test]
fn test_router_cross_device() {
let coord = MultiGpuBuilder::new().enable_p2p(true).build();
let mut dev0 = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
dev0.p2p_capable = true;
let mut dev1 = DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda);
dev1.p2p_capable = true;
coord.register_device(dev0);
coord.register_device(dev1);
let k1 = KernelId::new("k1");
let k2 = KernelId::new("k2");
coord.assign_kernel(k1.clone(), 0);
coord.assign_kernel(k2.clone(), 1);
let router = CrossGpuK2KRouter::new(coord);
let msg = make_test_k2k_message(&k1, &k2);
let decision = router.route_message(&k1, &k2, msg).unwrap();
match decision {
RoutingDecision::DirectP2P {
source_device,
dest_device,
..
} => {
assert_eq!(source_device, 0);
assert_eq!(dest_device, 1);
}
_ => panic!("Expected DirectP2P routing"),
}
}
#[test]
fn test_router_pending_messages() {
let coord = MultiGpuBuilder::new().enable_p2p(true).build();
let mut dev0 = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
dev0.p2p_capable = true;
let mut dev1 = DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda);
dev1.p2p_capable = true;
coord.register_device(dev0);
coord.register_device(dev1);
let k1 = KernelId::new("k1");
let k2 = KernelId::new("k2");
coord.assign_kernel(k1.clone(), 0);
coord.assign_kernel(k2.clone(), 1);
let router = CrossGpuK2KRouter::new(coord);
for _ in 0..3 {
let msg = make_test_k2k_message(&k1, &k2);
router.route_message(&k1, &k2, msg).unwrap();
}
assert_eq!(router.stats().messages_pending, 3);
let pending = router.drain_pending(0, 1);
assert_eq!(pending.len(), 3);
assert_eq!(router.stats().messages_pending, 0);
}
#[test]
fn test_router_stats() {
let coord = MultiGpuBuilder::new().build();
coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
let k1 = KernelId::new("k1");
let k2 = KernelId::new("k2");
coord.assign_kernel(k1.clone(), 0);
coord.assign_kernel(k2.clone(), 0);
let router = CrossGpuK2KRouter::new(coord);
let stats = router.stats();
assert_eq!(stats.messages_routed, 0);
assert_eq!(stats.bytes_transferred, 0);
assert_eq!(stats.routing_failures, 0);
}
use crate::checkpoint::{Checkpoint, CheckpointBuilder};
struct MockCheckpointableKernel {
kernel_id: String,
kernel_type: String,
state_data: Vec<u8>,
step: u64,
}
impl MockCheckpointableKernel {
fn new(kernel_id: &str, state_size: usize) -> Self {
Self {
kernel_id: kernel_id.to_string(),
kernel_type: "mock_kernel".to_string(),
state_data: vec![0xAB; state_size],
step: 1000,
}
}
}
impl CheckpointableKernel for MockCheckpointableKernel {
fn create_checkpoint(&self) -> Result<Checkpoint> {
let checkpoint = CheckpointBuilder::new(&self.kernel_id, &self.kernel_type)
.step(self.step)
.grid_size(64, 64, 64)
.control_block(vec![1, 2, 3, 4])
.device_memory("state", self.state_data.clone())
.build();
Ok(checkpoint)
}
fn restore_from_checkpoint(&mut self, checkpoint: &Checkpoint) -> Result<()> {
self.step = checkpoint.metadata.current_step;
Ok(())
}
fn checkpoint_kernel_id(&self) -> &str {
&self.kernel_id
}
fn checkpoint_kernel_type(&self) -> &str {
&self.kernel_type
}
}
#[test]
fn test_migrator_creation() {
let coord = MultiGpuBuilder::new().build();
let migrator = KernelMigrator::new(coord);
let stats = migrator.stats();
assert_eq!(stats.successful_migrations, 0);
assert_eq!(stats.failed_migrations, 0);
assert_eq!(stats.bytes_transferred, 0);
}
#[test]
fn test_migrator_with_custom_storage() {
let coord = MultiGpuBuilder::new().build();
let storage = Arc::new(MemoryStorage::new());
let migrator = KernelMigrator::with_storage(coord.clone(), storage);
assert!(Arc::ptr_eq(migrator.coordinator(), &coord));
}
#[test]
fn test_successful_migration() {
let coord = MultiGpuBuilder::new().build();
coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
let kernel_id = KernelId::new("migratable_kernel");
coord.assign_kernel(kernel_id.clone(), 0);
let migrator = KernelMigrator::new(coord.clone());
let kernel = MockCheckpointableKernel::new("migratable_kernel", 1024);
let mut request = coord.request_migration(&kernel_id, 1).unwrap();
assert_eq!(request.state, MigrationState::Pending);
let result = migrator
.migrate_with_checkpoint(&kernel, &mut request)
.unwrap();
assert_eq!(result.kernel_id.as_str(), "migratable_kernel");
assert_eq!(result.source_device, 0);
assert_eq!(result.target_device, 1);
assert!(result.checkpoint_size > 0);
assert!(result.total_duration > Duration::ZERO);
assert_eq!(coord.get_kernel_device(&kernel_id), Some(1));
let stats = migrator.stats();
assert_eq!(stats.successful_migrations, 1);
assert_eq!(stats.failed_migrations, 0);
assert!(stats.bytes_transferred > 0);
}
#[test]
fn test_migration_result_fields() {
let coord = MultiGpuBuilder::new().build();
coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
let kernel_id = KernelId::new("test_kernel");
coord.assign_kernel(kernel_id.clone(), 0);
let migrator = KernelMigrator::new(coord.clone());
let kernel = MockCheckpointableKernel::new("test_kernel", 4096);
let mut request = coord.request_migration(&kernel_id, 1).unwrap();
let result = migrator
.migrate_with_checkpoint(&kernel, &mut request)
.unwrap();
assert!(result.checkpoint_duration >= Duration::ZERO);
assert!(result.transfer_duration >= Duration::ZERO);
assert!(result.restore_duration >= Duration::ZERO);
assert!(result.total_duration >= result.checkpoint_duration);
}
#[test]
fn test_migration_stats_accumulate() {
let coord = MultiGpuBuilder::new().build();
coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
let migrator = KernelMigrator::new(coord.clone());
let k1 = KernelId::new("k1");
coord.assign_kernel(k1.clone(), 0);
let kernel1 = MockCheckpointableKernel::new("k1", 1000);
let mut req1 = coord.request_migration(&k1, 1).unwrap();
migrator
.migrate_with_checkpoint(&kernel1, &mut req1)
.unwrap();
let k2 = KernelId::new("k2");
coord.assign_kernel(k2.clone(), 0);
let kernel2 = MockCheckpointableKernel::new("k2", 2000);
let mut req2 = coord.request_migration(&k2, 1).unwrap();
migrator
.migrate_with_checkpoint(&kernel2, &mut req2)
.unwrap();
let stats = migrator.stats();
assert_eq!(stats.successful_migrations, 2);
assert_eq!(stats.failed_migrations, 0);
assert!(stats.bytes_transferred > 0);
}
#[test]
fn test_unregister_device_no_kernels() {
let coord = MultiGpuBuilder::new().build();
coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
let result = coord.unregister_device(0);
assert!(result.success);
assert_eq!(result.device_index, 0);
assert!(result.kernels_to_migrate.is_empty());
assert!(result.orphaned_kernels.is_empty());
}
#[test]
fn test_unregister_device_with_kernels() {
let coord = MultiGpuBuilder::new().build();
coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
let k1 = KernelId::new("k1");
let k2 = KernelId::new("k2");
coord.assign_kernel(k1.clone(), 0);
coord.assign_kernel(k2.clone(), 0);
let result = coord.unregister_device(0);
assert!(result.success);
assert_eq!(result.kernels_to_migrate.len(), 2);
assert!(result.orphaned_kernels.is_empty());
for plan in &result.kernels_to_migrate {
assert_eq!(plan.source_device, 0);
assert_eq!(plan.target_device, 1);
}
assert_eq!(coord.get_kernel_device(&k1), Some(1));
assert_eq!(coord.get_kernel_device(&k2), Some(1));
}
#[test]
fn test_unregister_single_device_orphans_kernels() {
let coord = MultiGpuBuilder::new().build();
coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
let k1 = KernelId::new("k1");
coord.assign_kernel(k1.clone(), 0);
let result = coord.unregister_device(0);
assert!(result.success);
assert!(result.kernels_to_migrate.is_empty());
assert_eq!(result.orphaned_kernels.len(), 1);
assert_eq!(result.orphaned_kernels[0], k1);
assert!(coord.get_kernel_device(&k1).is_none());
}
#[test]
fn test_unregister_nonexistent_device() {
let coord = MultiGpuBuilder::new().build();
coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
let result = coord.unregister_device(99);
assert!(!result.success);
assert_eq!(result.device_index, 99);
}
#[test]
fn test_unregister_distributes_to_least_loaded() {
let coord = MultiGpuBuilder::new().build();
coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
coord.register_device(DeviceInfo::new(2, "GPU 2".to_string(), Backend::Cuda));
coord.assign_kernel(KernelId::new("pre1"), 1);
coord.assign_kernel(KernelId::new("pre2"), 1);
coord.assign_kernel(KernelId::new("pre3"), 1);
let k1 = KernelId::new("migrate_me");
coord.assign_kernel(k1.clone(), 0);
let result = coord.unregister_device(0);
assert!(result.success);
assert_eq!(result.kernels_to_migrate.len(), 1);
let plan = &result.kernels_to_migrate[0];
assert_eq!(plan.target_device, 2);
}
#[test]
fn test_migration_priority_enum() {
let low = MigrationPriority::Low;
let normal = MigrationPriority::Normal;
let high = MigrationPriority::High;
let critical = MigrationPriority::Critical;
assert_ne!(low, normal);
assert_ne!(normal, high);
assert_ne!(high, critical);
assert_eq!(low, MigrationPriority::Low);
}
#[test]
fn test_hot_reload_config_default() {
let config = HotReloadConfig::default();
assert!(config.enabled);
assert!(config.preserve_state);
assert!(config.validate_before_swap);
assert!(config.keep_fallback);
assert_eq!(config.max_retries, 3);
}
#[test]
fn test_hot_reload_config_builder() {
let config = HotReloadConfig::new()
.with_enabled(false)
.with_preserve_state(false)
.with_max_retries(5)
.with_timeout(Duration::from_secs(60));
assert!(!config.enabled);
assert!(!config.preserve_state);
assert_eq!(config.max_retries, 5);
assert_eq!(config.reload_timeout, Duration::from_secs(60));
}
#[test]
fn test_kernel_code_source_ptx() {
let ptx = ".version 7.0\n.target sm_80\nkernel: ret;";
let code = KernelCodeSource::from_ptx(ptx, "kernel");
assert_eq!(code.format, KernelCodeFormat::Ptx);
assert_eq!(code.entry_point, "kernel");
assert_eq!(code.as_str(), Some(ptx));
assert_eq!(code.size(), ptx.len());
}
#[test]
fn test_kernel_code_source_wgsl() {
let wgsl = "@compute fn main() {}";
let code = KernelCodeSource::from_wgsl(wgsl, "main");
assert_eq!(code.format, KernelCodeFormat::Wgsl);
assert_eq!(code.entry_point, "main");
assert_eq!(code.as_str(), Some(wgsl));
}
#[test]
fn test_kernel_code_source_msl() {
let msl = "kernel void my_kernel() {}";
let code = KernelCodeSource::from_msl(msl, "my_kernel");
assert_eq!(code.format, KernelCodeFormat::Msl);
assert_eq!(code.entry_point, "my_kernel");
assert_eq!(code.as_str(), Some(msl));
}
#[test]
fn test_hot_reload_manager_creation() {
let manager = HotReloadManager::with_defaults();
assert!(manager.is_enabled());
assert!(manager.list_kernels().is_empty());
}
#[test]
fn test_hot_reload_manager_register_kernel() {
let manager = HotReloadManager::with_defaults();
let kernel_id = KernelId::new("test_kernel");
let code = KernelCodeSource::from_ptx(".version 7.0", "kernel");
manager.register_kernel(&kernel_id, code);
assert!(manager.is_registered(&kernel_id));
assert!(!manager.is_reload_in_progress(&kernel_id));
assert!(manager.get_current_version(&kernel_id).is_some());
}
#[test]
fn test_hot_reload_request_states() {
let kernel_id = KernelId::new("test");
let code = KernelCodeSource::from_ptx(".version 7.0", "kernel");
let request = HotReloadRequest::new(kernel_id, code);
assert_eq!(request.state, HotReloadState::Idle);
assert!(!request.is_in_progress());
assert!(!request.is_completed());
assert!(!request.is_failed());
}
#[test]
fn test_hot_reload_disabled() {
let config = HotReloadConfig::new().with_enabled(false);
let manager = HotReloadManager::new(config);
let kernel_id = KernelId::new("test");
let code = KernelCodeSource::from_ptx(".version 7.0", "kernel");
manager.register_kernel(&kernel_id, code.clone());
let result = manager.request_reload(&kernel_id, code);
assert!(result.is_err());
}
#[test]
fn test_hot_reload_stats() {
let manager = HotReloadManager::with_defaults();
let stats = manager.stats();
assert_eq!(stats.successful_reloads, 0);
assert_eq!(stats.failed_reloads, 0);
assert_eq!(stats.rollbacks, 0);
}
#[test]
fn test_hot_reload_code_formats() {
let formats = [
KernelCodeFormat::Ptx,
KernelCodeFormat::Cubin,
KernelCodeFormat::SpirV,
KernelCodeFormat::Wgsl,
KernelCodeFormat::Msl,
KernelCodeFormat::MetalLib,
KernelCodeFormat::Source,
];
for (i, f1) in formats.iter().enumerate() {
for (j, f2) in formats.iter().enumerate() {
if i != j {
assert_ne!(f1, f2);
}
}
}
}
#[test]
fn test_hot_reload_state_transitions() {
let states = [
HotReloadState::Idle,
HotReloadState::Draining,
HotReloadState::Checkpointing,
HotReloadState::Compiling,
HotReloadState::Validating,
HotReloadState::Swapping,
HotReloadState::Restoring,
HotReloadState::Completed,
HotReloadState::Failed,
HotReloadState::RollingBack,
];
for (i, s1) in states.iter().enumerate() {
for (j, s2) in states.iter().enumerate() {
if i != j {
assert_ne!(s1, s2);
}
}
}
}
#[test]
fn test_hot_reload_execute() {
let manager = HotReloadManager::with_defaults();
let kernel_id = KernelId::new("test_kernel");
let initial_code = KernelCodeSource::from_ptx(".version 7.0\n.target sm_80", "kernel");
manager.register_kernel(&kernel_id, initial_code);
let new_code = KernelCodeSource::from_ptx(".version 8.0\n.target sm_90", "kernel");
let mut request = manager.request_reload(&kernel_id, new_code).unwrap();
let mock_kernel = MockCheckpointableKernel::new("test_kernel", 512);
let result = manager.execute_reload(&mut request, &mock_kernel).unwrap();
assert!(request.is_completed());
assert_eq!(result.kernel_id.as_str(), "test_kernel");
assert!(result.state_preserved);
assert!(result.checkpoint_size > 0);
assert!(result.total_duration > Duration::ZERO);
let stats = manager.stats();
assert_eq!(stats.successful_reloads, 1);
}
#[test]
fn test_hot_reload_list_kernels() {
let manager = HotReloadManager::with_defaults();
let k1 = KernelId::new("kernel1");
let k2 = KernelId::new("kernel2");
let k3 = KernelId::new("kernel3");
manager.register_kernel(&k1, KernelCodeSource::from_ptx(".version 7.0", "k1"));
manager.register_kernel(&k2, KernelCodeSource::from_ptx(".version 7.0", "k2"));
manager.register_kernel(&k3, KernelCodeSource::from_ptx(".version 7.0", "k3"));
let kernels = manager.list_kernels();
assert_eq!(kernels.len(), 3);
assert!(kernels.contains(&k1));
assert!(kernels.contains(&k2));
assert!(kernels.contains(&k3));
}
}