Skip to main content

ringkernel_core/
multi_gpu.rs

1//! Multi-GPU coordination, topology discovery, and cross-GPU messaging.
2//!
3//! This module provides infrastructure for coordinating work across
4//! multiple GPUs, including:
5//!
6//! - **Device Selection** - Load balancing strategies for kernel placement
7//! - **Topology Discovery** - NVLink/PCIe detection and bandwidth estimation
8//! - **Cross-GPU K2K Router** - Kernel-to-kernel messaging across GPUs
9//! - **Kernel Migration** - Move kernels between GPUs with state transfer
10//!
11//! ## Example
12//!
13//! ```ignore
14//! use ringkernel_core::multi_gpu::{MultiGpuBuilder, GpuTopology, CrossGpuK2KRouter};
15//!
16//! let coordinator = MultiGpuBuilder::new()
17//!     .load_balancing(LoadBalancingStrategy::LeastLoaded)
18//!     .enable_p2p(true)
19//!     .build();
20//!
21//! // Discover topology
22//! let topology = coordinator.discover_topology();
23//!
24//! // Create cross-GPU router
25//! let router = CrossGpuK2KRouter::new(coordinator.clone());
26//! router.route_message(source_kernel, dest_kernel, envelope).await?;
27//! ```
28
29use parking_lot::RwLock;
30use std::collections::HashMap;
31use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
32use std::sync::Arc;
33use std::time::{Duration, Instant};
34
35use crate::error::{Result, RingKernelError};
36use crate::k2k::K2KMessage;
37use crate::runtime::{Backend, KernelId, LaunchOptions};
38
39/// Configuration for multi-GPU coordination.
40#[derive(Debug, Clone)]
41pub struct MultiGpuConfig {
42    /// Load balancing strategy.
43    pub load_balancing: LoadBalancingStrategy,
44    /// Enable automatic device selection.
45    pub auto_select_device: bool,
46    /// Maximum kernels per device.
47    pub max_kernels_per_device: usize,
48    /// Enable peer-to-peer transfers when available.
49    pub enable_p2p: bool,
50    /// Preferred devices (by index).
51    pub preferred_devices: Vec<usize>,
52}
53
54impl Default for MultiGpuConfig {
55    fn default() -> Self {
56        Self {
57            load_balancing: LoadBalancingStrategy::LeastLoaded,
58            auto_select_device: true,
59            max_kernels_per_device: 64,
60            enable_p2p: true,
61            preferred_devices: vec![],
62        }
63    }
64}
65
66/// Strategy for balancing load across devices.
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub enum LoadBalancingStrategy {
69    /// Always use the first available device.
70    FirstAvailable,
71    /// Use the device with fewest kernels.
72    LeastLoaded,
73    /// Round-robin across devices.
74    RoundRobin,
75    /// Select based on memory availability.
76    MemoryBased,
77    /// Select based on compute capability.
78    ComputeCapability,
79    /// Custom selection function.
80    Custom,
81}
82
83/// Information about a GPU device.
84#[derive(Debug, Clone)]
85pub struct DeviceInfo {
86    /// Device index.
87    pub index: usize,
88    /// Device name.
89    pub name: String,
90    /// Backend type.
91    pub backend: Backend,
92    /// Total memory in bytes.
93    pub total_memory: u64,
94    /// Available memory in bytes.
95    pub available_memory: u64,
96    /// Compute capability (for CUDA).
97    pub compute_capability: Option<(u32, u32)>,
98    /// Maximum threads per block.
99    pub max_threads_per_block: u32,
100    /// Number of multiprocessors.
101    pub multiprocessor_count: u32,
102    /// Whether device supports P2P with other devices.
103    pub p2p_capable: bool,
104}
105
106impl DeviceInfo {
107    /// Create a new device info.
108    pub fn new(index: usize, name: String, backend: Backend) -> Self {
109        Self {
110            index,
111            name,
112            backend,
113            total_memory: 0,
114            available_memory: 0,
115            compute_capability: None,
116            max_threads_per_block: 1024,
117            multiprocessor_count: 1,
118            p2p_capable: false,
119        }
120    }
121
122    /// Get memory utilization (0.0-1.0).
123    pub fn memory_utilization(&self) -> f64 {
124        if self.total_memory == 0 {
125            0.0
126        } else {
127            1.0 - (self.available_memory as f64 / self.total_memory as f64)
128        }
129    }
130}
131
132// ============================================================================
133// GPU Topology Discovery
134// ============================================================================
135
136/// Type of interconnect between GPUs.
137#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
138pub enum InterconnectType {
139    /// No direct connection (must go through host).
140    None,
141    /// PCIe peer-to-peer.
142    Pcie,
143    /// NVIDIA NVLink.
144    NvLink,
145    /// NVIDIA NVSwitch (datacenter).
146    NvSwitch,
147    /// AMD Infinity Fabric.
148    InfinityFabric,
149    /// Intel Xe Link.
150    XeLink,
151    /// Same GPU (for self-connections).
152    SameDevice,
153}
154
155impl InterconnectType {
156    /// Estimated bandwidth in GB/s for this interconnect type.
157    pub fn estimated_bandwidth_gbps(&self) -> f64 {
158        match self {
159            InterconnectType::None => 16.0,      // PCIe 3.0 x16 through host
160            InterconnectType::Pcie => 32.0,      // PCIe 4.0 x16 P2P
161            InterconnectType::NvLink => 300.0,   // NVLink 3.0 (A100)
162            InterconnectType::NvSwitch => 600.0, // NVSwitch full bisection
163            InterconnectType::InfinityFabric => 200.0, // MI250X
164            InterconnectType::XeLink => 100.0,   // Intel Data Center GPUs
165            InterconnectType::SameDevice => 2000.0, // Internal bandwidth
166        }
167    }
168
169    /// Estimated latency in microseconds.
170    pub fn estimated_latency_us(&self) -> f64 {
171        match self {
172            InterconnectType::None => 10.0,    // Through host memory
173            InterconnectType::Pcie => 5.0,     // P2P PCIe
174            InterconnectType::NvLink => 1.0,   // Direct NVLink
175            InterconnectType::NvSwitch => 2.0, // Through switch
176            InterconnectType::InfinityFabric => 1.5,
177            InterconnectType::XeLink => 2.0,
178            InterconnectType::SameDevice => 0.0,
179        }
180    }
181
182    /// Whether this interconnect supports direct P2P memory access.
183    pub fn supports_p2p(&self) -> bool {
184        !matches!(self, InterconnectType::None)
185    }
186}
187
188/// Connection between two GPUs.
189#[derive(Debug, Clone)]
190pub struct GpuConnection {
191    /// Source device index.
192    pub source: usize,
193    /// Destination device index.
194    pub destination: usize,
195    /// Type of interconnect.
196    pub interconnect: InterconnectType,
197    /// Measured or estimated bandwidth in GB/s.
198    pub bandwidth_gbps: f64,
199    /// Measured or estimated latency in microseconds.
200    pub latency_us: f64,
201    /// Whether connection is bidirectional with same characteristics.
202    pub bidirectional: bool,
203    /// Number of hops (for multi-hop topologies).
204    pub hops: u32,
205}
206
207impl GpuConnection {
208    /// Create a new GPU connection.
209    pub fn new(source: usize, destination: usize, interconnect: InterconnectType) -> Self {
210        Self {
211            source,
212            destination,
213            interconnect,
214            bandwidth_gbps: interconnect.estimated_bandwidth_gbps(),
215            latency_us: interconnect.estimated_latency_us(),
216            bidirectional: true,
217            hops: if source == destination { 0 } else { 1 },
218        }
219    }
220
221    /// Set measured bandwidth.
222    pub fn with_bandwidth(mut self, gbps: f64) -> Self {
223        self.bandwidth_gbps = gbps;
224        self
225    }
226
227    /// Set measured latency.
228    pub fn with_latency(mut self, us: f64) -> Self {
229        self.latency_us = us;
230        self
231    }
232
233    /// Set hop count.
234    pub fn with_hops(mut self, hops: u32) -> Self {
235        self.hops = hops;
236        self
237    }
238}
239
240/// GPU topology graph describing all device interconnections.
241#[derive(Debug, Clone)]
242pub struct GpuTopology {
243    /// Number of devices in topology.
244    pub device_count: usize,
245    /// Connection matrix (device_count x device_count).
246    connections: Vec<Vec<Option<GpuConnection>>>,
247    /// NUMA node assignments for each device.
248    pub numa_nodes: Vec<Option<u32>>,
249    /// Whether topology has been probed (vs estimated).
250    pub probed: bool,
251    /// Timestamp of last topology update.
252    pub last_updated: Instant,
253}
254
255impl GpuTopology {
256    /// Create a new topology for N devices.
257    pub fn new(device_count: usize) -> Self {
258        let mut connections = vec![vec![None; device_count]; device_count];
259
260        // Initialize self-connections
261        for (i, row) in connections.iter_mut().enumerate().take(device_count) {
262            row[i] = Some(GpuConnection::new(i, i, InterconnectType::SameDevice));
263        }
264
265        Self {
266            device_count,
267            connections,
268            numa_nodes: vec![None; device_count],
269            probed: false,
270            last_updated: Instant::now(),
271        }
272    }
273
274    /// Set connection between two devices.
275    pub fn set_connection(&mut self, connection: GpuConnection) {
276        let src = connection.source;
277        let dst = connection.destination;
278        if src < self.device_count && dst < self.device_count {
279            self.connections[src][dst] = Some(connection.clone());
280            if connection.bidirectional && src != dst {
281                let reverse = GpuConnection {
282                    source: dst,
283                    destination: src,
284                    ..connection
285                };
286                self.connections[dst][src] = Some(reverse);
287            }
288        }
289    }
290
291    /// Get connection between two devices.
292    pub fn get_connection(&self, source: usize, destination: usize) -> Option<&GpuConnection> {
293        self.connections
294            .get(source)
295            .and_then(|row| row.get(destination))
296            .and_then(|c| c.as_ref())
297    }
298
299    /// Get best path between two devices (returns intermediate hops).
300    pub fn best_path(&self, source: usize, destination: usize) -> Vec<usize> {
301        if source == destination {
302            return vec![source];
303        }
304
305        // Direct connection available?
306        if let Some(conn) = self.get_connection(source, destination) {
307            if conn.interconnect != InterconnectType::None {
308                return vec![source, destination];
309            }
310        }
311
312        // Find best path via Dijkstra (simplified)
313        let mut best_path = vec![source, destination]; // Default to direct
314        let mut best_bandwidth = 0.0;
315
316        // Check all intermediate nodes
317        for intermediate in 0..self.device_count {
318            if intermediate == source || intermediate == destination {
319                continue;
320            }
321
322            if let (Some(c1), Some(c2)) = (
323                self.get_connection(source, intermediate),
324                self.get_connection(intermediate, destination),
325            ) {
326                // Bandwidth limited by slowest link
327                let path_bandwidth = c1.bandwidth_gbps.min(c2.bandwidth_gbps);
328                if path_bandwidth > best_bandwidth {
329                    best_bandwidth = path_bandwidth;
330                    best_path = vec![source, intermediate, destination];
331                }
332            }
333        }
334
335        best_path
336    }
337
338    /// Get all devices directly connected to a device.
339    pub fn neighbors(&self, device: usize) -> Vec<usize> {
340        if device >= self.device_count {
341            return vec![];
342        }
343
344        self.connections[device]
345            .iter()
346            .enumerate()
347            .filter_map(|(i, conn)| {
348                if i != device
349                    && conn
350                        .as_ref()
351                        .map(|c| c.interconnect.supports_p2p())
352                        .unwrap_or(false)
353                {
354                    Some(i)
355                } else {
356                    None
357                }
358            })
359            .collect()
360    }
361
362    /// Calculate total bisection bandwidth of the topology.
363    pub fn bisection_bandwidth_gbps(&self) -> f64 {
364        let half = self.device_count / 2;
365        if half == 0 {
366            return 0.0;
367        }
368
369        let mut total = 0.0;
370        for src in 0..half {
371            for dst in half..self.device_count {
372                if let Some(conn) = self.get_connection(src, dst) {
373                    total += conn.bandwidth_gbps;
374                }
375            }
376        }
377        total
378    }
379
380    /// Check if all devices have P2P connectivity.
381    pub fn is_fully_connected(&self) -> bool {
382        for src in 0..self.device_count {
383            for dst in 0..self.device_count {
384                if src != dst {
385                    if let Some(conn) = self.get_connection(src, dst) {
386                        if !conn.interconnect.supports_p2p() {
387                            return false;
388                        }
389                    } else {
390                        return false;
391                    }
392                }
393            }
394        }
395        true
396    }
397
398    /// Get devices in the same NUMA domain.
399    pub fn numa_neighbors(&self, device: usize) -> Vec<usize> {
400        let target_numa = self.numa_nodes.get(device).copied().flatten();
401        if target_numa.is_none() {
402            return vec![];
403        }
404
405        self.numa_nodes
406            .iter()
407            .enumerate()
408            .filter_map(|(i, numa)| {
409                if i != device && *numa == target_numa {
410                    Some(i)
411                } else {
412                    None
413                }
414            })
415            .collect()
416    }
417
418    /// Set NUMA node for a device.
419    pub fn set_numa_node(&mut self, device: usize, numa_node: u32) {
420        if device < self.numa_nodes.len() {
421            self.numa_nodes[device] = Some(numa_node);
422        }
423    }
424
425    /// Mark topology as probed (not estimated).
426    pub fn mark_probed(&mut self) {
427        self.probed = true;
428        self.last_updated = Instant::now();
429    }
430}
431
432/// Status of a device in the multi-GPU coordinator.
433#[derive(Debug, Clone)]
434pub struct DeviceStatus {
435    /// Device info.
436    pub info: DeviceInfo,
437    /// Number of kernels running on this device.
438    pub kernel_count: usize,
439    /// Kernels running on this device.
440    pub kernels: Vec<KernelId>,
441    /// Whether device is available for new kernels.
442    pub available: bool,
443    /// Current load estimate (0.0-1.0).
444    pub load: f64,
445}
446
447/// Result of unregistering a device from the coordinator.
448#[derive(Debug, Clone)]
449pub struct DeviceUnregisterResult {
450    /// Index of the unregistered device.
451    pub device_index: usize,
452    /// Kernels that were on this device and need migration.
453    pub kernels_to_migrate: Vec<KernelMigrationPlan>,
454    /// Kernels that could not be migrated (no available target).
455    pub orphaned_kernels: Vec<KernelId>,
456    /// Whether the device was successfully unregistered.
457    pub success: bool,
458}
459
460/// Plan for migrating a single kernel during device unregister.
461#[derive(Debug, Clone)]
462pub struct KernelMigrationPlan {
463    /// Kernel to migrate.
464    pub kernel_id: KernelId,
465    /// Source device (the unregistered device).
466    pub source_device: usize,
467    /// Target device selected for migration.
468    pub target_device: usize,
469    /// Estimated migration priority (based on kernel load).
470    pub priority: MigrationPriority,
471}
472
473/// Priority for kernel migration.
474#[derive(Debug, Clone, Copy, PartialEq, Eq)]
475pub enum MigrationPriority {
476    /// Low priority - can be migrated lazily.
477    Low,
478    /// Normal priority - migrate in reasonable time.
479    Normal,
480    /// High priority - migrate as soon as possible.
481    High,
482    /// Critical - must migrate immediately.
483    Critical,
484}
485
486/// Multi-GPU coordinator for managing kernels across devices.
487pub struct MultiGpuCoordinator {
488    /// Configuration.
489    config: MultiGpuConfig,
490    /// Available devices.
491    devices: RwLock<Vec<DeviceInfo>>,
492    /// Kernel-to-device mapping.
493    kernel_device_map: RwLock<HashMap<KernelId, usize>>,
494    /// Device kernel counts.
495    device_kernel_counts: RwLock<Vec<AtomicUsize>>,
496    /// Round-robin counter.
497    round_robin_counter: AtomicUsize,
498    /// Total kernels launched.
499    total_kernels: AtomicU64,
500    /// Device selection callbacks (for custom strategy).
501    #[allow(clippy::type_complexity)]
502    custom_selector:
503        RwLock<Option<Arc<dyn Fn(&[DeviceStatus], &LaunchOptions) -> usize + Send + Sync>>>,
504    /// GPU topology graph.
505    topology: RwLock<Option<GpuTopology>>,
506}
507
508impl MultiGpuCoordinator {
509    /// Create a new multi-GPU coordinator.
510    pub fn new(config: MultiGpuConfig) -> Arc<Self> {
511        Arc::new(Self {
512            config,
513            devices: RwLock::new(Vec::new()),
514            kernel_device_map: RwLock::new(HashMap::new()),
515            device_kernel_counts: RwLock::new(Vec::new()),
516            round_robin_counter: AtomicUsize::new(0),
517            total_kernels: AtomicU64::new(0),
518            custom_selector: RwLock::new(None),
519            topology: RwLock::new(None),
520        })
521    }
522
523    /// Register a device with the coordinator.
524    pub fn register_device(&self, device: DeviceInfo) {
525        let index = device.index;
526        let mut devices = self.devices.write();
527        let mut counts = self.device_kernel_counts.write();
528
529        // Ensure we have enough slots
530        let mut current_len = devices.len();
531        while current_len <= index {
532            devices.push(DeviceInfo::new(
533                current_len,
534                "Unknown".to_string(),
535                Backend::Cpu,
536            ));
537            counts.push(AtomicUsize::new(0));
538            current_len += 1;
539        }
540
541        devices[index] = device;
542    }
543
544    /// Unregister a device and plan kernel migrations.
545    ///
546    /// This method:
547    /// 1. Identifies all kernels on the device being removed
548    /// 2. Finds target devices for each kernel using load balancing
549    /// 3. Creates migration plans for kernels that can be moved
550    /// 4. Marks orphaned kernels that have no migration target
551    /// 5. Updates internal routing tables
552    ///
553    /// The caller is responsible for executing the actual migrations using
554    /// [`KernelMigrator`] with the returned [`KernelMigrationPlan`] entries.
555    pub fn unregister_device(&self, index: usize) -> DeviceUnregisterResult {
556        let devices = self.devices.read();
557
558        // Check if device exists
559        if index >= devices.len() {
560            return DeviceUnregisterResult {
561                device_index: index,
562                kernels_to_migrate: Vec::new(),
563                orphaned_kernels: Vec::new(),
564                success: false,
565            };
566        }
567
568        // Get all kernels on this device
569        let kernels_on_device = self.kernels_on_device(index);
570
571        // Find available target devices (excluding the one being unregistered)
572        let available_targets: Vec<usize> = devices
573            .iter()
574            .enumerate()
575            .filter(|(i, _)| *i != index)
576            .map(|(i, _)| i)
577            .collect();
578
579        drop(devices); // Release read lock before acquiring write lock
580
581        let mut kernels_to_migrate = Vec::new();
582        let mut orphaned_kernels = Vec::new();
583
584        if available_targets.is_empty() {
585            // No other devices available - all kernels are orphaned
586            orphaned_kernels = kernels_on_device;
587        } else {
588            // Plan migrations for each kernel
589            for kernel_id in kernels_on_device {
590                // Select target based on current load
591                if let Some(target) = self.select_migration_target(&available_targets) {
592                    let priority = self.calculate_migration_priority(&kernel_id);
593                    kernels_to_migrate.push(KernelMigrationPlan {
594                        kernel_id,
595                        source_device: index,
596                        target_device: target,
597                        priority,
598                    });
599                } else {
600                    orphaned_kernels.push(kernel_id);
601                }
602            }
603        }
604
605        // Update kernel-device mappings for planned migrations
606        {
607            let mut kernel_map = self.kernel_device_map.write();
608            let counts = self.device_kernel_counts.read();
609
610            for plan in &kernels_to_migrate {
611                // Update mapping to target device
612                kernel_map.insert(plan.kernel_id.clone(), plan.target_device);
613
614                // Update kernel counts
615                if index < counts.len() {
616                    counts[index].fetch_sub(1, Ordering::Relaxed);
617                }
618                if plan.target_device < counts.len() {
619                    counts[plan.target_device].fetch_add(1, Ordering::Relaxed);
620                }
621            }
622
623            // Remove orphaned kernels from mapping
624            for kernel_id in &orphaned_kernels {
625                kernel_map.remove(kernel_id);
626                if index < counts.len() {
627                    counts[index].fetch_sub(1, Ordering::Relaxed);
628                }
629            }
630        }
631
632        // Mark device as unavailable (but don't remove it to preserve indices)
633        {
634            let mut devices = self.devices.write();
635            if index < devices.len() {
636                devices[index].available_memory = 0;
637                devices[index].name = format!("{} (unregistered)", devices[index].name);
638            }
639        }
640
641        DeviceUnregisterResult {
642            device_index: index,
643            kernels_to_migrate,
644            orphaned_kernels,
645            success: true,
646        }
647    }
648
649    /// Select the best target device for migration.
650    fn select_migration_target(&self, candidates: &[usize]) -> Option<usize> {
651        if candidates.is_empty() {
652            return None;
653        }
654
655        let counts = self.device_kernel_counts.read();
656
657        // Find device with lowest kernel count
658        candidates
659            .iter()
660            .filter_map(|&idx| {
661                if idx < counts.len() {
662                    Some((idx, counts[idx].load(Ordering::Relaxed)))
663                } else {
664                    None
665                }
666            })
667            .min_by_key(|(_, count)| *count)
668            .map(|(idx, _)| idx)
669    }
670
671    /// Calculate migration priority for a kernel.
672    fn calculate_migration_priority(&self, _kernel_id: &KernelId) -> MigrationPriority {
673        // In a real implementation, this would check:
674        // - Message queue depth
675        // - Time since last activity
676        // - Kernel type/importance
677        // For now, use normal priority
678        MigrationPriority::Normal
679    }
680
681    /// Get all registered devices.
682    pub fn devices(&self) -> Vec<DeviceInfo> {
683        self.devices.read().clone()
684    }
685
686    /// Get device info by index.
687    pub fn device(&self, index: usize) -> Option<DeviceInfo> {
688        self.devices.read().get(index).cloned()
689    }
690
691    /// Get number of devices.
692    pub fn device_count(&self) -> usize {
693        self.devices.read().len()
694    }
695
696    /// Select a device for launching a kernel.
697    pub fn select_device(&self, options: &LaunchOptions) -> Result<usize> {
698        let devices = self.devices.read();
699        if devices.is_empty() {
700            return Err(RingKernelError::BackendUnavailable(
701                "No GPU devices available".to_string(),
702            ));
703        }
704
705        // Get current status
706        let status = self.get_all_status();
707
708        // Check for custom selector
709        if self.config.load_balancing == LoadBalancingStrategy::Custom {
710            if let Some(selector) = &*self.custom_selector.read() {
711                return Ok(selector(&status, options));
712            }
713        }
714
715        // Apply preferred devices filter if specified
716        let candidates: Vec<_> = if !self.config.preferred_devices.is_empty() {
717            status
718                .into_iter()
719                .filter(|s| self.config.preferred_devices.contains(&s.info.index))
720                .collect()
721        } else {
722            status
723        };
724
725        if candidates.is_empty() {
726            return Err(RingKernelError::BackendUnavailable(
727                "No suitable GPU device available".to_string(),
728            ));
729        }
730
731        let selected = match self.config.load_balancing {
732            LoadBalancingStrategy::FirstAvailable => {
733                candidates.first().map(|s| s.info.index).unwrap_or(0)
734            }
735            LoadBalancingStrategy::LeastLoaded => candidates
736                .iter()
737                .filter(|s| s.available && s.kernel_count < self.config.max_kernels_per_device)
738                .min_by(|a, b| a.kernel_count.cmp(&b.kernel_count))
739                .map(|s| s.info.index)
740                .unwrap_or(0),
741            LoadBalancingStrategy::RoundRobin => {
742                let available: Vec<_> = candidates.iter().filter(|s| s.available).collect();
743
744                if available.is_empty() {
745                    candidates.first().map(|s| s.info.index).unwrap_or(0)
746                } else {
747                    let idx =
748                        self.round_robin_counter.fetch_add(1, Ordering::Relaxed) % available.len();
749                    available[idx].info.index
750                }
751            }
752            LoadBalancingStrategy::MemoryBased => candidates
753                .iter()
754                .filter(|s| s.available)
755                .max_by(|a, b| a.info.available_memory.cmp(&b.info.available_memory))
756                .map(|s| s.info.index)
757                .unwrap_or(0),
758            LoadBalancingStrategy::ComputeCapability => candidates
759                .iter()
760                .filter(|s| s.available)
761                .max_by(|a, b| {
762                    let a_cap = a.info.compute_capability.unwrap_or((0, 0));
763                    let b_cap = b.info.compute_capability.unwrap_or((0, 0));
764                    a_cap.cmp(&b_cap)
765                })
766                .map(|s| s.info.index)
767                .unwrap_or(0),
768            LoadBalancingStrategy::Custom => {
769                // Should have been handled above
770                0
771            }
772        };
773
774        Ok(selected)
775    }
776
777    /// Assign a kernel to a device.
778    pub fn assign_kernel(&self, kernel_id: KernelId, device_index: usize) {
779        self.kernel_device_map
780            .write()
781            .insert(kernel_id, device_index);
782
783        let counts = self.device_kernel_counts.read();
784        if device_index < counts.len() {
785            counts[device_index].fetch_add(1, Ordering::Relaxed);
786        }
787
788        self.total_kernels.fetch_add(1, Ordering::Relaxed);
789    }
790
791    /// Remove a kernel assignment.
792    pub fn remove_kernel(&self, kernel_id: &KernelId) {
793        if let Some(device_index) = self.kernel_device_map.write().remove(kernel_id) {
794            let counts = self.device_kernel_counts.read();
795            if device_index < counts.len() {
796                counts[device_index].fetch_sub(1, Ordering::Relaxed);
797            }
798        }
799    }
800
801    /// Get device for a kernel.
802    pub fn get_kernel_device(&self, kernel_id: &KernelId) -> Option<usize> {
803        self.kernel_device_map.read().get(kernel_id).copied()
804    }
805
806    /// Get all kernels on a device.
807    pub fn kernels_on_device(&self, device_index: usize) -> Vec<KernelId> {
808        self.kernel_device_map
809            .read()
810            .iter()
811            .filter(|(_, &idx)| idx == device_index)
812            .map(|(k, _)| k.clone())
813            .collect()
814    }
815
816    /// Get status of all devices.
817    pub fn get_all_status(&self) -> Vec<DeviceStatus> {
818        let devices = self.devices.read();
819        let kernel_map = self.kernel_device_map.read();
820        let counts = self.device_kernel_counts.read();
821
822        devices
823            .iter()
824            .enumerate()
825            .map(|(idx, info)| {
826                let kernel_count = if idx < counts.len() {
827                    counts[idx].load(Ordering::Relaxed)
828                } else {
829                    0
830                };
831
832                let kernels: Vec<_> = kernel_map
833                    .iter()
834                    .filter(|(_, &dev_idx)| dev_idx == idx)
835                    .map(|(k, _)| k.clone())
836                    .collect();
837
838                let load = kernel_count as f64 / self.config.max_kernels_per_device as f64;
839                let available = kernel_count < self.config.max_kernels_per_device;
840
841                DeviceStatus {
842                    info: info.clone(),
843                    kernel_count,
844                    kernels,
845                    available,
846                    load,
847                }
848            })
849            .collect()
850    }
851
852    /// Get status of a specific device.
853    pub fn get_device_status(&self, device_index: usize) -> Option<DeviceStatus> {
854        self.get_all_status().into_iter().nth(device_index)
855    }
856
857    /// Set custom device selector.
858    pub fn set_custom_selector<F>(&self, selector: F)
859    where
860        F: Fn(&[DeviceStatus], &LaunchOptions) -> usize + Send + Sync + 'static,
861    {
862        *self.custom_selector.write() = Some(Arc::new(selector));
863    }
864
865    /// Get coordinator statistics.
866    pub fn stats(&self) -> MultiGpuStats {
867        let status = self.get_all_status();
868        let total_kernels: usize = status.iter().map(|s| s.kernel_count).sum();
869        let total_memory: u64 = status.iter().map(|s| s.info.total_memory).sum();
870        let available_memory: u64 = status.iter().map(|s| s.info.available_memory).sum();
871
872        MultiGpuStats {
873            device_count: status.len(),
874            total_kernels,
875            total_memory,
876            available_memory,
877            kernels_launched: self.total_kernels.load(Ordering::Relaxed),
878        }
879    }
880
881    /// Check if P2P is available between two devices.
882    pub fn can_p2p(&self, device_a: usize, device_b: usize) -> bool {
883        if !self.config.enable_p2p {
884            return false;
885        }
886
887        let devices = self.devices.read();
888        if let (Some(a), Some(b)) = (devices.get(device_a), devices.get(device_b)) {
889            a.p2p_capable && b.p2p_capable
890        } else {
891            false
892        }
893    }
894
895    /// Update device memory info.
896    pub fn update_device_memory(&self, device_index: usize, available_memory: u64) {
897        let mut devices = self.devices.write();
898        if let Some(device) = devices.get_mut(device_index) {
899            device.available_memory = available_memory;
900        }
901    }
902
903    // ========================================================================
904    // Topology Discovery
905    // ========================================================================
906
907    /// Discover GPU topology (estimates if probing not available).
908    pub fn discover_topology(&self) -> GpuTopology {
909        let devices = self.devices.read();
910        let device_count = devices.len();
911
912        if device_count == 0 {
913            return GpuTopology::new(0);
914        }
915
916        let mut topo = GpuTopology::new(device_count);
917
918        // Set up connections based on device info
919        for (i, dev_i) in devices.iter().enumerate() {
920            for (j, dev_j) in devices.iter().enumerate() {
921                if i == j {
922                    continue;
923                }
924
925                // Determine interconnect type based on device capabilities
926                let interconnect = if dev_i.p2p_capable && dev_j.p2p_capable {
927                    // Check if same backend (can do P2P)
928                    if dev_i.backend == dev_j.backend {
929                        match dev_i.backend {
930                            Backend::Cuda => {
931                                // For CUDA, check compute capability for NVLink
932                                let cc_i = dev_i.compute_capability.unwrap_or((0, 0));
933                                let cc_j = dev_j.compute_capability.unwrap_or((0, 0));
934
935                                // Ampere+ (SM 80+) likely has NVLink
936                                if cc_i.0 >= 8 && cc_j.0 >= 8 {
937                                    InterconnectType::NvLink
938                                } else {
939                                    InterconnectType::Pcie
940                                }
941                            }
942                            _ => InterconnectType::Pcie,
943                        }
944                    } else {
945                        InterconnectType::None
946                    }
947                } else {
948                    InterconnectType::None
949                };
950
951                topo.set_connection(GpuConnection::new(i, j, interconnect));
952            }
953        }
954
955        // Store topology
956        *self.topology.write() = Some(topo.clone());
957
958        topo
959    }
960
961    /// Get current topology (discovers if not cached).
962    pub fn topology(&self) -> GpuTopology {
963        {
964            let topo = self.topology.read();
965            if let Some(ref t) = *topo {
966                return t.clone();
967            }
968        }
969        self.discover_topology()
970    }
971
972    /// Set custom topology (for testing or manual configuration).
973    pub fn set_topology(&self, topology: GpuTopology) {
974        *self.topology.write() = Some(topology);
975    }
976
977    /// Get best device for communicating with a source kernel.
978    pub fn select_device_for_k2k(&self, source_kernel: &KernelId) -> Result<usize> {
979        let source_device = self.get_kernel_device(source_kernel);
980        if source_device.is_none() {
981            return self.select_device(&LaunchOptions::default());
982        }
983
984        let source_idx = source_device.unwrap();
985        let topo = self.topology();
986        let status = self.get_all_status();
987
988        // Find best device based on connectivity and load
989        let neighbors = topo.neighbors(source_idx);
990
991        if neighbors.is_empty() {
992            // No P2P neighbors, fall back to normal selection
993            return self.select_device(&LaunchOptions::default());
994        }
995
996        // Score devices by: connectivity bandwidth / (load + 1)
997        let best = neighbors
998            .iter()
999            .filter_map(|&dev_idx| {
1000                status.iter().find(|s| s.info.index == dev_idx).map(|s| {
1001                    let conn = topo.get_connection(source_idx, dev_idx);
1002                    let bandwidth = conn.map(|c| c.bandwidth_gbps).unwrap_or(1.0);
1003                    let score = bandwidth / (s.load + 0.1);
1004                    (dev_idx, score)
1005                })
1006            })
1007            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
1008            .map(|(idx, _)| idx);
1009
1010        best.ok_or_else(|| {
1011            RingKernelError::BackendUnavailable("No suitable K2K device found".to_string())
1012        })
1013    }
1014
1015    // ========================================================================
1016    // Kernel Migration
1017    // ========================================================================
1018
1019    /// Request to migrate a kernel to another device.
1020    pub fn request_migration(
1021        &self,
1022        kernel_id: &KernelId,
1023        target_device: usize,
1024    ) -> Result<MigrationRequest> {
1025        let source_device = self
1026            .get_kernel_device(kernel_id)
1027            .ok_or_else(|| RingKernelError::KernelNotFound(kernel_id.as_str().to_string()))?;
1028
1029        if source_device == target_device {
1030            return Err(RingKernelError::InvalidConfig(
1031                "Cannot migrate to same device".to_string(),
1032            ));
1033        }
1034
1035        let devices = self.devices.read();
1036        if target_device >= devices.len() {
1037            return Err(RingKernelError::DeviceNotAvailable(format!(
1038                "Device {} not available",
1039                target_device
1040            )));
1041        }
1042
1043        let topo = self.topology();
1044        let path = topo.best_path(source_device, target_device);
1045        let connection = topo.get_connection(source_device, target_device);
1046
1047        Ok(MigrationRequest {
1048            kernel_id: kernel_id.clone(),
1049            source_device,
1050            target_device,
1051            path,
1052            estimated_bandwidth_gbps: connection.map(|c| c.bandwidth_gbps).unwrap_or(16.0),
1053            estimated_latency_us: connection.map(|c| c.latency_us).unwrap_or(10.0),
1054            state: MigrationState::Pending,
1055            started_at: None,
1056        })
1057    }
1058
1059    /// Complete a migration (updates internal mappings).
1060    pub fn complete_migration(&self, request: &MigrationRequest) -> Result<()> {
1061        // Update kernel-device mapping
1062        {
1063            let mut map = self.kernel_device_map.write();
1064            if let Some(dev) = map.get_mut(&request.kernel_id) {
1065                *dev = request.target_device;
1066            }
1067        }
1068
1069        // Update kernel counts
1070        {
1071            let counts = self.device_kernel_counts.read();
1072            if request.source_device < counts.len() {
1073                counts[request.source_device].fetch_sub(1, Ordering::Relaxed);
1074            }
1075            if request.target_device < counts.len() {
1076                counts[request.target_device].fetch_add(1, Ordering::Relaxed);
1077            }
1078        }
1079
1080        Ok(())
1081    }
1082}
1083
1084// ============================================================================
1085// Kernel Migration Types
1086// ============================================================================
1087
1088/// State of a kernel migration.
1089#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1090pub enum MigrationState {
1091    /// Migration is pending, not yet started.
1092    Pending,
1093    /// Kernel is being quiesced (draining messages).
1094    Quiescing,
1095    /// Checkpoint is being created.
1096    Checkpointing,
1097    /// State is being transferred.
1098    Transferring,
1099    /// Kernel is being restored on target.
1100    Restoring,
1101    /// Migration completed successfully.
1102    Completed,
1103    /// Migration failed.
1104    Failed,
1105    /// Migration was cancelled.
1106    Cancelled,
1107}
1108
1109/// Request to migrate a kernel between devices.
1110#[derive(Debug, Clone)]
1111pub struct MigrationRequest {
1112    /// Kernel to migrate.
1113    pub kernel_id: KernelId,
1114    /// Source device index.
1115    pub source_device: usize,
1116    /// Target device index.
1117    pub target_device: usize,
1118    /// Path of devices for multi-hop migration.
1119    pub path: Vec<usize>,
1120    /// Estimated bandwidth for transfer.
1121    pub estimated_bandwidth_gbps: f64,
1122    /// Estimated latency.
1123    pub estimated_latency_us: f64,
1124    /// Current state.
1125    pub state: MigrationState,
1126    /// When migration started.
1127    pub started_at: Option<Instant>,
1128}
1129
1130impl MigrationRequest {
1131    /// Estimate transfer time for given state size.
1132    pub fn estimate_transfer_time(&self, state_size_bytes: usize) -> Duration {
1133        // time = size / bandwidth + latency
1134        let size_gb = state_size_bytes as f64 / 1_000_000_000.0;
1135        let transfer_time_s = size_gb / self.estimated_bandwidth_gbps;
1136        let total_us = (transfer_time_s * 1_000_000.0) + self.estimated_latency_us;
1137        Duration::from_micros(total_us as u64)
1138    }
1139}
1140
1141// ============================================================================
1142// Cross-GPU K2K Router
1143// ============================================================================
1144
1145/// Routes K2K messages across GPU boundaries.
1146pub struct CrossGpuK2KRouter {
1147    /// Multi-GPU coordinator.
1148    coordinator: Arc<MultiGpuCoordinator>,
1149    /// Message queues for pending cross-device messages.
1150    pending_queues: RwLock<HashMap<(usize, usize), Vec<PendingK2KMessage>>>,
1151    /// Statistics.
1152    stats: CrossGpuRouterStats,
1153}
1154
1155/// A pending cross-GPU K2K message.
1156#[derive(Debug, Clone)]
1157pub struct PendingK2KMessage {
1158    /// Source kernel ID.
1159    pub source_kernel: KernelId,
1160    /// Destination kernel ID.
1161    pub dest_kernel: KernelId,
1162    /// Message payload.
1163    pub message: K2KMessage,
1164    /// Timestamp when queued.
1165    pub queued_at: Instant,
1166    /// Number of routing hops.
1167    pub hops: u32,
1168}
1169
1170/// Statistics for cross-GPU K2K routing.
1171#[derive(Debug, Default)]
1172pub struct CrossGpuRouterStats {
1173    /// Total messages routed.
1174    messages_routed: AtomicU64,
1175    /// Total bytes transferred.
1176    bytes_transferred: AtomicU64,
1177    /// Messages currently pending.
1178    messages_pending: AtomicUsize,
1179    /// Total routing latency (microseconds).
1180    total_latency_us: AtomicU64,
1181    /// Failed routing attempts.
1182    routing_failures: AtomicU64,
1183}
1184
1185impl CrossGpuK2KRouter {
1186    /// Create a new cross-GPU K2K router.
1187    pub fn new(coordinator: Arc<MultiGpuCoordinator>) -> Arc<Self> {
1188        Arc::new(Self {
1189            coordinator,
1190            pending_queues: RwLock::new(HashMap::new()),
1191            stats: CrossGpuRouterStats::default(),
1192        })
1193    }
1194
1195    /// Route a message from source kernel to destination kernel.
1196    pub fn route_message(
1197        &self,
1198        source_kernel: &KernelId,
1199        dest_kernel: &KernelId,
1200        message: K2KMessage,
1201    ) -> Result<RoutingDecision> {
1202        let source_device = self
1203            .coordinator
1204            .get_kernel_device(source_kernel)
1205            .ok_or_else(|| {
1206                RingKernelError::K2KDestinationNotFound(source_kernel.as_str().to_string())
1207            })?;
1208
1209        let dest_device = self
1210            .coordinator
1211            .get_kernel_device(dest_kernel)
1212            .ok_or_else(|| {
1213                RingKernelError::K2KDestinationNotFound(dest_kernel.as_str().to_string())
1214            })?;
1215
1216        // Same device - use regular K2K
1217        if source_device == dest_device {
1218            return Ok(RoutingDecision::SameDevice);
1219        }
1220
1221        // Get topology for routing
1222        let topo = self.coordinator.topology();
1223        let path = topo.best_path(source_device, dest_device);
1224
1225        // Check if direct P2P is available
1226        if let Some(conn) = topo.get_connection(source_device, dest_device) {
1227            if conn.interconnect.supports_p2p() {
1228                // Queue for direct P2P transfer
1229                let pending = PendingK2KMessage {
1230                    source_kernel: source_kernel.clone(),
1231                    dest_kernel: dest_kernel.clone(),
1232                    message,
1233                    queued_at: Instant::now(),
1234                    hops: 1,
1235                };
1236
1237                self.enqueue_pending(source_device, dest_device, pending);
1238                self.stats.messages_pending.fetch_add(1, Ordering::Relaxed);
1239
1240                return Ok(RoutingDecision::DirectP2P {
1241                    source_device,
1242                    dest_device,
1243                    bandwidth_gbps: conn.bandwidth_gbps,
1244                });
1245            }
1246        }
1247
1248        // Multi-hop routing required
1249        if path.len() > 2 {
1250            let pending = PendingK2KMessage {
1251                source_kernel: source_kernel.clone(),
1252                dest_kernel: dest_kernel.clone(),
1253                message,
1254                queued_at: Instant::now(),
1255                hops: (path.len() - 1) as u32,
1256            };
1257
1258            // Queue for first hop
1259            self.enqueue_pending(source_device, path[1], pending);
1260            self.stats.messages_pending.fetch_add(1, Ordering::Relaxed);
1261
1262            return Ok(RoutingDecision::MultiHop {
1263                path: path.clone(),
1264                total_hops: (path.len() - 1) as u32,
1265            });
1266        }
1267
1268        // Fall back to host-mediated transfer
1269        let pending = PendingK2KMessage {
1270            source_kernel: source_kernel.clone(),
1271            dest_kernel: dest_kernel.clone(),
1272            message,
1273            queued_at: Instant::now(),
1274            hops: 2, // device->host->device
1275        };
1276
1277        self.enqueue_pending(source_device, dest_device, pending);
1278        self.stats.messages_pending.fetch_add(1, Ordering::Relaxed);
1279
1280        Ok(RoutingDecision::HostMediated {
1281            source_device,
1282            dest_device,
1283        })
1284    }
1285
1286    /// Get pending messages for a device pair.
1287    pub fn drain_pending(&self, source: usize, dest: usize) -> Vec<PendingK2KMessage> {
1288        let mut queues = self.pending_queues.write();
1289        let messages = queues.remove(&(source, dest)).unwrap_or_default();
1290        self.stats
1291            .messages_pending
1292            .fetch_sub(messages.len(), Ordering::Relaxed);
1293        messages
1294    }
1295
1296    /// Record successful message delivery.
1297    pub fn record_delivery(&self, message: &PendingK2KMessage, payload_size: usize) {
1298        self.stats.messages_routed.fetch_add(1, Ordering::Relaxed);
1299        self.stats
1300            .bytes_transferred
1301            .fetch_add(payload_size as u64, Ordering::Relaxed);
1302
1303        let latency = message.queued_at.elapsed().as_micros() as u64;
1304        self.stats
1305            .total_latency_us
1306            .fetch_add(latency, Ordering::Relaxed);
1307    }
1308
1309    /// Record routing failure.
1310    pub fn record_failure(&self) {
1311        self.stats.routing_failures.fetch_add(1, Ordering::Relaxed);
1312    }
1313
1314    /// Get router statistics.
1315    pub fn stats(&self) -> CrossGpuRouterStatsSnapshot {
1316        let messages_routed = self.stats.messages_routed.load(Ordering::Relaxed);
1317        let total_latency = self.stats.total_latency_us.load(Ordering::Relaxed);
1318
1319        CrossGpuRouterStatsSnapshot {
1320            messages_routed,
1321            bytes_transferred: self.stats.bytes_transferred.load(Ordering::Relaxed),
1322            messages_pending: self.stats.messages_pending.load(Ordering::Relaxed),
1323            avg_latency_us: if messages_routed > 0 {
1324                total_latency as f64 / messages_routed as f64
1325            } else {
1326                0.0
1327            },
1328            routing_failures: self.stats.routing_failures.load(Ordering::Relaxed),
1329        }
1330    }
1331
1332    fn enqueue_pending(&self, source: usize, dest: usize, message: PendingK2KMessage) {
1333        let mut queues = self.pending_queues.write();
1334        queues.entry((source, dest)).or_default().push(message);
1335    }
1336}
1337
1338/// Snapshot of router statistics.
1339#[derive(Debug, Clone)]
1340pub struct CrossGpuRouterStatsSnapshot {
1341    /// Total messages successfully routed.
1342    pub messages_routed: u64,
1343    /// Total bytes transferred.
1344    pub bytes_transferred: u64,
1345    /// Messages currently pending.
1346    pub messages_pending: usize,
1347    /// Average routing latency in microseconds.
1348    pub avg_latency_us: f64,
1349    /// Total routing failures.
1350    pub routing_failures: u64,
1351}
1352
1353/// Decision for how to route a K2K message.
1354#[derive(Debug, Clone)]
1355pub enum RoutingDecision {
1356    /// Source and destination on same device.
1357    SameDevice,
1358    /// Direct peer-to-peer transfer.
1359    DirectP2P {
1360        /// Source device index.
1361        source_device: usize,
1362        /// Destination device index.
1363        dest_device: usize,
1364        /// Available bandwidth.
1365        bandwidth_gbps: f64,
1366    },
1367    /// Multi-hop routing through intermediate devices.
1368    MultiHop {
1369        /// Device path.
1370        path: Vec<usize>,
1371        /// Total number of hops.
1372        total_hops: u32,
1373    },
1374    /// Route through host memory (slowest).
1375    HostMediated {
1376        /// Source device index.
1377        source_device: usize,
1378        /// Destination device index.
1379        dest_device: usize,
1380    },
1381}
1382
1383/// Multi-GPU coordinator statistics.
1384#[derive(Debug, Clone, Default)]
1385pub struct MultiGpuStats {
1386    /// Number of registered devices.
1387    pub device_count: usize,
1388    /// Total kernels across all devices.
1389    pub total_kernels: usize,
1390    /// Total memory across all devices.
1391    pub total_memory: u64,
1392    /// Available memory across all devices.
1393    pub available_memory: u64,
1394    /// Total kernels launched since start.
1395    pub kernels_launched: u64,
1396}
1397
1398/// Builder for multi-GPU coordinator.
1399pub struct MultiGpuBuilder {
1400    config: MultiGpuConfig,
1401}
1402
1403impl MultiGpuBuilder {
1404    /// Create a new builder.
1405    pub fn new() -> Self {
1406        Self {
1407            config: MultiGpuConfig::default(),
1408        }
1409    }
1410
1411    /// Set load balancing strategy.
1412    pub fn load_balancing(mut self, strategy: LoadBalancingStrategy) -> Self {
1413        self.config.load_balancing = strategy;
1414        self
1415    }
1416
1417    /// Set auto device selection.
1418    pub fn auto_select_device(mut self, enable: bool) -> Self {
1419        self.config.auto_select_device = enable;
1420        self
1421    }
1422
1423    /// Set max kernels per device.
1424    pub fn max_kernels_per_device(mut self, max: usize) -> Self {
1425        self.config.max_kernels_per_device = max;
1426        self
1427    }
1428
1429    /// Enable P2P transfers.
1430    pub fn enable_p2p(mut self, enable: bool) -> Self {
1431        self.config.enable_p2p = enable;
1432        self
1433    }
1434
1435    /// Set preferred devices.
1436    pub fn preferred_devices(mut self, devices: Vec<usize>) -> Self {
1437        self.config.preferred_devices = devices;
1438        self
1439    }
1440
1441    /// Build the coordinator.
1442    pub fn build(self) -> Arc<MultiGpuCoordinator> {
1443        MultiGpuCoordinator::new(self.config)
1444    }
1445}
1446
1447impl Default for MultiGpuBuilder {
1448    fn default() -> Self {
1449        Self::new()
1450    }
1451}
1452
1453/// Helper for cross-device data transfer.
1454pub struct CrossDeviceTransfer {
1455    /// Source device index.
1456    pub source_device: usize,
1457    /// Destination device index.
1458    pub dest_device: usize,
1459    /// Data size in bytes.
1460    pub size: usize,
1461    /// Use P2P if available.
1462    pub use_p2p: bool,
1463}
1464
1465impl CrossDeviceTransfer {
1466    /// Create a new transfer specification.
1467    pub fn new(source: usize, dest: usize, size: usize) -> Self {
1468        Self {
1469            source_device: source,
1470            dest_device: dest,
1471            size,
1472            use_p2p: true,
1473        }
1474    }
1475
1476    /// Disable P2P for this transfer.
1477    pub fn without_p2p(mut self) -> Self {
1478        self.use_p2p = false;
1479        self
1480    }
1481}
1482
1483// ============================================================================
1484// Kernel Migrator with Checkpoint Integration
1485// ============================================================================
1486
1487use crate::checkpoint::{CheckpointStorage, CheckpointableKernel, MemoryStorage};
1488
1489/// Migrator that uses checkpoints for kernel state transfer between GPUs.
1490///
1491/// This integrates the checkpoint infrastructure with the multi-GPU migration
1492/// system to enable live migration of persistent kernels.
1493///
1494/// # Example
1495///
1496/// ```ignore
1497/// use ringkernel_core::multi_gpu::{KernelMigrator, MultiGpuBuilder};
1498///
1499/// let coordinator = MultiGpuBuilder::new().build();
1500/// let migrator = KernelMigrator::new(coordinator);
1501///
1502/// // Migrate kernel from GPU 0 to GPU 1
1503/// migrator.migrate_with_checkpoint(&kernel, &mut request).await?;
1504/// ```
1505pub struct KernelMigrator {
1506    /// Multi-GPU coordinator.
1507    coordinator: Arc<MultiGpuCoordinator>,
1508    /// Checkpoint storage for migration state.
1509    storage: Arc<dyn CheckpointStorage>,
1510    /// Statistics.
1511    stats: MigrationStats,
1512}
1513
1514/// Statistics for kernel migrations.
1515#[derive(Debug, Default)]
1516pub struct MigrationStats {
1517    /// Total successful migrations.
1518    pub successful_migrations: AtomicU64,
1519    /// Total failed migrations.
1520    pub failed_migrations: AtomicU64,
1521    /// Total bytes transferred during migrations.
1522    pub bytes_transferred: AtomicU64,
1523    /// Total checkpoint time (microseconds).
1524    pub checkpoint_time_us: AtomicU64,
1525    /// Total restore time (microseconds).
1526    pub restore_time_us: AtomicU64,
1527}
1528
1529/// Result of a completed migration.
1530#[derive(Debug, Clone)]
1531pub struct MigrationResult {
1532    /// Kernel that was migrated.
1533    pub kernel_id: KernelId,
1534    /// Source device.
1535    pub source_device: usize,
1536    /// Target device.
1537    pub target_device: usize,
1538    /// Checkpoint size in bytes.
1539    pub checkpoint_size: usize,
1540    /// Time spent creating checkpoint.
1541    pub checkpoint_duration: Duration,
1542    /// Time spent transferring state.
1543    pub transfer_duration: Duration,
1544    /// Time spent restoring kernel.
1545    pub restore_duration: Duration,
1546    /// Total migration time.
1547    pub total_duration: Duration,
1548}
1549
1550impl KernelMigrator {
1551    /// Create a new kernel migrator with default in-memory storage.
1552    pub fn new(coordinator: Arc<MultiGpuCoordinator>) -> Self {
1553        Self {
1554            coordinator,
1555            storage: Arc::new(MemoryStorage::new()),
1556            stats: MigrationStats::default(),
1557        }
1558    }
1559
1560    /// Create a migrator with custom checkpoint storage.
1561    pub fn with_storage(
1562        coordinator: Arc<MultiGpuCoordinator>,
1563        storage: Arc<dyn CheckpointStorage>,
1564    ) -> Self {
1565        Self {
1566            coordinator,
1567            storage,
1568            stats: MigrationStats::default(),
1569        }
1570    }
1571
1572    /// Perform a complete migration using checkpoint-based state transfer.
1573    ///
1574    /// Steps:
1575    /// 1. Quiesce the source kernel (drain pending messages)
1576    /// 2. Create checkpoint of kernel state
1577    /// 3. Transfer checkpoint to target device
1578    /// 4. Restore kernel on target device
1579    /// 5. Update coordinator routing tables
1580    pub fn migrate_with_checkpoint<K: CheckpointableKernel>(
1581        &self,
1582        kernel: &K,
1583        request: &mut MigrationRequest,
1584    ) -> Result<MigrationResult> {
1585        let start_time = Instant::now();
1586        request.started_at = Some(start_time);
1587
1588        // Step 1: Quiesce
1589        request.state = MigrationState::Quiescing;
1590        // In a real implementation, this would drain message queues
1591        // For now, we assume the kernel is ready for checkpointing
1592
1593        // Step 2: Create checkpoint
1594        request.state = MigrationState::Checkpointing;
1595        let checkpoint_start = Instant::now();
1596        let checkpoint = kernel.create_checkpoint().map_err(|e| {
1597            self.stats.failed_migrations.fetch_add(1, Ordering::Relaxed);
1598            request.state = MigrationState::Failed;
1599            RingKernelError::MigrationFailed(format!("Checkpoint creation failed: {}", e))
1600        })?;
1601        let checkpoint_duration = checkpoint_start.elapsed();
1602        let checkpoint_size = checkpoint.total_size();
1603
1604        self.stats
1605            .checkpoint_time_us
1606            .fetch_add(checkpoint_duration.as_micros() as u64, Ordering::Relaxed);
1607
1608        // Step 3: Transfer
1609        request.state = MigrationState::Transferring;
1610        let transfer_start = Instant::now();
1611
1612        // Store checkpoint (simulates transfer)
1613        let checkpoint_name = format!(
1614            "migration_{}_{}_{}",
1615            request.kernel_id.as_str(),
1616            request.source_device,
1617            request.target_device
1618        );
1619        self.storage
1620            .save(&checkpoint, &checkpoint_name)
1621            .map_err(|e| {
1622                self.stats.failed_migrations.fetch_add(1, Ordering::Relaxed);
1623                request.state = MigrationState::Failed;
1624                RingKernelError::MigrationFailed(format!("Checkpoint transfer failed: {}", e))
1625            })?;
1626
1627        let transfer_duration = transfer_start.elapsed();
1628        self.stats
1629            .bytes_transferred
1630            .fetch_add(checkpoint_size as u64, Ordering::Relaxed);
1631
1632        // Step 4: Restore (would be done on target kernel)
1633        request.state = MigrationState::Restoring;
1634        let restore_start = Instant::now();
1635
1636        // Load checkpoint to verify it's valid
1637        let _restored = self.storage.load(&checkpoint_name).map_err(|e| {
1638            self.stats.failed_migrations.fetch_add(1, Ordering::Relaxed);
1639            request.state = MigrationState::Failed;
1640            RingKernelError::MigrationFailed(format!("Checkpoint restore failed: {}", e))
1641        })?;
1642
1643        let restore_duration = restore_start.elapsed();
1644        self.stats
1645            .restore_time_us
1646            .fetch_add(restore_duration.as_micros() as u64, Ordering::Relaxed);
1647
1648        // Step 5: Update routing
1649        request.state = MigrationState::Completed;
1650        self.coordinator.complete_migration(request)?;
1651
1652        // Clean up checkpoint
1653        let _ = self.storage.delete(&checkpoint_name);
1654
1655        self.stats
1656            .successful_migrations
1657            .fetch_add(1, Ordering::Relaxed);
1658
1659        Ok(MigrationResult {
1660            kernel_id: request.kernel_id.clone(),
1661            source_device: request.source_device,
1662            target_device: request.target_device,
1663            checkpoint_size,
1664            checkpoint_duration,
1665            transfer_duration,
1666            restore_duration,
1667            total_duration: start_time.elapsed(),
1668        })
1669    }
1670
1671    /// Get a reference to the coordinator.
1672    pub fn coordinator(&self) -> &Arc<MultiGpuCoordinator> {
1673        &self.coordinator
1674    }
1675
1676    /// Get migration statistics snapshot.
1677    pub fn stats(&self) -> MigrationStatsSnapshot {
1678        let successful = self.stats.successful_migrations.load(Ordering::Relaxed);
1679        let failed = self.stats.failed_migrations.load(Ordering::Relaxed);
1680        let total = successful + failed;
1681        let checkpoint_us = self.stats.checkpoint_time_us.load(Ordering::Relaxed);
1682        let restore_us = self.stats.restore_time_us.load(Ordering::Relaxed);
1683
1684        MigrationStatsSnapshot {
1685            successful_migrations: successful,
1686            failed_migrations: failed,
1687            bytes_transferred: self.stats.bytes_transferred.load(Ordering::Relaxed),
1688            avg_checkpoint_time: checkpoint_us
1689                .checked_div(total)
1690                .map(Duration::from_micros)
1691                .unwrap_or(Duration::ZERO),
1692            avg_restore_time: restore_us
1693                .checked_div(total)
1694                .map(Duration::from_micros)
1695                .unwrap_or(Duration::ZERO),
1696        }
1697    }
1698}
1699
1700/// Snapshot of migration statistics.
1701#[derive(Debug, Clone)]
1702pub struct MigrationStatsSnapshot {
1703    /// Total successful migrations.
1704    pub successful_migrations: u64,
1705    /// Total failed migrations.
1706    pub failed_migrations: u64,
1707    /// Total bytes transferred.
1708    pub bytes_transferred: u64,
1709    /// Average checkpoint creation time.
1710    pub avg_checkpoint_time: Duration,
1711    /// Average restore time.
1712    pub avg_restore_time: Duration,
1713}
1714
1715/// Trait for kernels that support live migration.
1716pub trait MigratableKernel: CheckpointableKernel {
1717    /// Prepare kernel for migration (quiesce, drain messages).
1718    fn prepare_for_migration(&mut self) -> Result<()>;
1719
1720    /// Resume kernel after migration is cancelled.
1721    fn cancel_migration(&mut self) -> Result<()>;
1722
1723    /// Check if kernel is ready to be checkpointed.
1724    fn is_quiescent(&self) -> bool;
1725
1726    /// Get estimated state size for migration planning.
1727    fn estimated_state_size(&self) -> usize;
1728}
1729
1730// ============================================================================
1731// Hot Reload Support
1732// ============================================================================
1733
1734/// Configuration for kernel hot reload operations.
1735#[derive(Debug, Clone)]
1736pub struct HotReloadConfig {
1737    /// Enable hot reload functionality.
1738    pub enabled: bool,
1739    /// Timeout for reload operations.
1740    pub reload_timeout: Duration,
1741    /// Whether to preserve kernel state during reload.
1742    pub preserve_state: bool,
1743    /// Maximum retries for failed reloads.
1744    pub max_retries: u32,
1745    /// Backoff duration between retries.
1746    pub retry_backoff: Duration,
1747    /// Whether to validate new code before swapping.
1748    pub validate_before_swap: bool,
1749    /// Keep old code as fallback in case of failure.
1750    pub keep_fallback: bool,
1751}
1752
1753impl Default for HotReloadConfig {
1754    fn default() -> Self {
1755        Self {
1756            enabled: true,
1757            reload_timeout: Duration::from_secs(30),
1758            preserve_state: true,
1759            max_retries: 3,
1760            retry_backoff: Duration::from_millis(500),
1761            validate_before_swap: true,
1762            keep_fallback: true,
1763        }
1764    }
1765}
1766
1767impl HotReloadConfig {
1768    /// Create a new hot reload configuration.
1769    pub fn new() -> Self {
1770        Self::default()
1771    }
1772
1773    /// Enable or disable hot reload.
1774    pub fn with_enabled(mut self, enabled: bool) -> Self {
1775        self.enabled = enabled;
1776        self
1777    }
1778
1779    /// Set reload timeout.
1780    pub fn with_timeout(mut self, timeout: Duration) -> Self {
1781        self.reload_timeout = timeout;
1782        self
1783    }
1784
1785    /// Enable or disable state preservation.
1786    pub fn with_preserve_state(mut self, preserve: bool) -> Self {
1787        self.preserve_state = preserve;
1788        self
1789    }
1790
1791    /// Set maximum retries.
1792    pub fn with_max_retries(mut self, retries: u32) -> Self {
1793        self.max_retries = retries;
1794        self
1795    }
1796}
1797
1798/// State of a hot reload operation.
1799#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1800pub enum HotReloadState {
1801    /// Reload not started.
1802    Idle,
1803    /// Draining pending messages from kernel.
1804    Draining,
1805    /// Creating checkpoint of kernel state.
1806    Checkpointing,
1807    /// Compiling new kernel code.
1808    Compiling,
1809    /// Validating new kernel code.
1810    Validating,
1811    /// Swapping old kernel with new.
1812    Swapping,
1813    /// Restoring state to new kernel.
1814    Restoring,
1815    /// Hot reload completed successfully.
1816    Completed,
1817    /// Hot reload failed.
1818    Failed,
1819    /// Rolling back to previous version.
1820    RollingBack,
1821}
1822
1823/// Kernel code format.
1824#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1825pub enum KernelCodeFormat {
1826    /// NVIDIA PTX assembly.
1827    Ptx,
1828    /// NVIDIA CUBIN binary.
1829    Cubin,
1830    /// SPIR-V for Vulkan/WebGPU.
1831    SpirV,
1832    /// WGSL shader text.
1833    Wgsl,
1834    /// Metal Shading Language.
1835    Msl,
1836    /// Metal compiled library.
1837    MetalLib,
1838    /// Source code (requires compilation).
1839    Source,
1840}
1841
1842/// Kernel code source for hot reload.
1843#[derive(Debug, Clone)]
1844pub struct KernelCodeSource {
1845    /// Unique identifier for this code version.
1846    pub version_id: u64,
1847    /// Code format.
1848    pub format: KernelCodeFormat,
1849    /// Raw code bytes.
1850    pub code: Vec<u8>,
1851    /// Entry point function name.
1852    pub entry_point: String,
1853    /// Optional metadata (compile flags, etc.).
1854    pub metadata: HashMap<String, String>,
1855    /// Timestamp when code was created.
1856    pub created_at: Instant,
1857    /// SHA-256 hash of the code.
1858    pub hash: [u8; 32],
1859}
1860
1861impl KernelCodeSource {
1862    /// Create a new kernel code source.
1863    pub fn new(format: KernelCodeFormat, code: Vec<u8>, entry_point: impl Into<String>) -> Self {
1864        let hash = Self::compute_hash(&code);
1865        Self {
1866            version_id: 0,
1867            format,
1868            code,
1869            entry_point: entry_point.into(),
1870            metadata: HashMap::new(),
1871            created_at: Instant::now(),
1872            hash,
1873        }
1874    }
1875
1876    /// Create from PTX code.
1877    pub fn from_ptx(ptx: &str, entry_point: impl Into<String>) -> Self {
1878        Self::new(KernelCodeFormat::Ptx, ptx.as_bytes().to_vec(), entry_point)
1879    }
1880
1881    /// Create from WGSL code.
1882    pub fn from_wgsl(wgsl: &str, entry_point: impl Into<String>) -> Self {
1883        Self::new(
1884            KernelCodeFormat::Wgsl,
1885            wgsl.as_bytes().to_vec(),
1886            entry_point,
1887        )
1888    }
1889
1890    /// Create from MSL code.
1891    pub fn from_msl(msl: &str, entry_point: impl Into<String>) -> Self {
1892        Self::new(KernelCodeFormat::Msl, msl.as_bytes().to_vec(), entry_point)
1893    }
1894
1895    /// Set version ID.
1896    pub fn with_version(mut self, version: u64) -> Self {
1897        self.version_id = version;
1898        self
1899    }
1900
1901    /// Add metadata.
1902    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
1903        self.metadata.insert(key.into(), value.into());
1904        self
1905    }
1906
1907    fn compute_hash(data: &[u8]) -> [u8; 32] {
1908        use std::hash::{Hash, Hasher};
1909        let mut hasher = std::collections::hash_map::DefaultHasher::new();
1910        data.hash(&mut hasher);
1911        let h1 = hasher.finish();
1912        h1.hash(&mut hasher);
1913        let h2 = hasher.finish();
1914        h1.hash(&mut hasher);
1915        let h3 = hasher.finish();
1916        h1.hash(&mut hasher);
1917        let h4 = hasher.finish();
1918
1919        let mut hash = [0u8; 32];
1920        hash[0..8].copy_from_slice(&h1.to_le_bytes());
1921        hash[8..16].copy_from_slice(&h2.to_le_bytes());
1922        hash[16..24].copy_from_slice(&h3.to_le_bytes());
1923        hash[24..32].copy_from_slice(&h4.to_le_bytes());
1924        hash
1925    }
1926
1927    /// Get code as string (if text format).
1928    pub fn as_str(&self) -> Option<&str> {
1929        match self.format {
1930            KernelCodeFormat::Ptx
1931            | KernelCodeFormat::Wgsl
1932            | KernelCodeFormat::Msl
1933            | KernelCodeFormat::Source => std::str::from_utf8(&self.code).ok(),
1934            _ => None,
1935        }
1936    }
1937
1938    /// Get code size in bytes.
1939    pub fn size(&self) -> usize {
1940        self.code.len()
1941    }
1942}
1943
1944/// Request to hot reload a kernel.
1945#[derive(Debug)]
1946pub struct HotReloadRequest {
1947    /// Target kernel ID.
1948    pub kernel_id: KernelId,
1949    /// New kernel code.
1950    pub new_code: KernelCodeSource,
1951    /// Current state of the reload operation.
1952    pub state: HotReloadState,
1953    /// When the request was created.
1954    pub created_at: Instant,
1955    /// When the reload started.
1956    pub started_at: Option<Instant>,
1957    /// Retry count.
1958    pub retry_count: u32,
1959    /// Error message if failed.
1960    pub error: Option<String>,
1961    /// Checkpoint data (if preserving state).
1962    checkpoint_data: Option<Vec<u8>>,
1963}
1964
1965impl HotReloadRequest {
1966    /// Create a new hot reload request.
1967    pub fn new(kernel_id: KernelId, new_code: KernelCodeSource) -> Self {
1968        Self {
1969            kernel_id,
1970            new_code,
1971            state: HotReloadState::Idle,
1972            created_at: Instant::now(),
1973            started_at: None,
1974            retry_count: 0,
1975            error: None,
1976            checkpoint_data: None,
1977        }
1978    }
1979
1980    /// Check if reload is in progress.
1981    pub fn is_in_progress(&self) -> bool {
1982        !matches!(
1983            self.state,
1984            HotReloadState::Idle | HotReloadState::Completed | HotReloadState::Failed
1985        )
1986    }
1987
1988    /// Check if reload completed successfully.
1989    pub fn is_completed(&self) -> bool {
1990        self.state == HotReloadState::Completed
1991    }
1992
1993    /// Check if reload failed.
1994    pub fn is_failed(&self) -> bool {
1995        self.state == HotReloadState::Failed
1996    }
1997
1998    /// Get elapsed time since request creation.
1999    pub fn elapsed(&self) -> Duration {
2000        self.created_at.elapsed()
2001    }
2002
2003    /// Get elapsed time since reload started.
2004    pub fn reload_elapsed(&self) -> Option<Duration> {
2005        self.started_at.map(|s| s.elapsed())
2006    }
2007}
2008
2009/// Result of a completed hot reload.
2010#[derive(Debug, Clone)]
2011pub struct HotReloadResult {
2012    /// Target kernel ID.
2013    pub kernel_id: KernelId,
2014    /// Previous code version.
2015    pub old_version: u64,
2016    /// New code version.
2017    pub new_version: u64,
2018    /// Whether state was preserved.
2019    pub state_preserved: bool,
2020    /// Size of checkpoint data (if any).
2021    pub checkpoint_size: usize,
2022    /// Time to drain messages.
2023    pub drain_duration: Duration,
2024    /// Time to create checkpoint.
2025    pub checkpoint_duration: Duration,
2026    /// Time to compile new code.
2027    pub compile_duration: Duration,
2028    /// Time to swap kernels.
2029    pub swap_duration: Duration,
2030    /// Time to restore state.
2031    pub restore_duration: Duration,
2032    /// Total reload duration.
2033    pub total_duration: Duration,
2034}
2035
2036/// Statistics for hot reload operations.
2037#[derive(Debug, Default)]
2038struct HotReloadStats {
2039    successful_reloads: AtomicU64,
2040    failed_reloads: AtomicU64,
2041    rollbacks: AtomicU64,
2042    total_drain_time_us: AtomicU64,
2043    total_compile_time_us: AtomicU64,
2044    total_swap_time_us: AtomicU64,
2045    state_preserved_count: AtomicU64,
2046}
2047
2048/// Snapshot of hot reload statistics.
2049#[derive(Debug, Clone)]
2050pub struct HotReloadStatsSnapshot {
2051    /// Total successful reloads.
2052    pub successful_reloads: u64,
2053    /// Total failed reloads.
2054    pub failed_reloads: u64,
2055    /// Total rollbacks performed.
2056    pub rollbacks: u64,
2057    /// Average drain time.
2058    pub avg_drain_time: Duration,
2059    /// Average compile time.
2060    pub avg_compile_time: Duration,
2061    /// Average swap time.
2062    pub avg_swap_time: Duration,
2063    /// Number of reloads with preserved state.
2064    pub state_preserved_count: u64,
2065}
2066
2067/// Manager for kernel hot reload operations.
2068///
2069/// Provides seamless kernel code updates without stopping the system:
2070///
2071/// 1. Drain pending messages from kernel input queue
2072/// 2. Checkpoint kernel state (if preserving state)
2073/// 3. Compile/validate new kernel code
2074/// 4. Swap old kernel with new kernel
2075/// 5. Restore state to new kernel
2076/// 6. Resume processing
2077///
2078/// # Example
2079///
2080/// ```ignore
2081/// use ringkernel_core::multi_gpu::{HotReloadManager, HotReloadConfig, KernelCodeSource};
2082///
2083/// let manager = HotReloadManager::new(HotReloadConfig::default());
2084///
2085/// // Register a reloadable kernel
2086/// manager.register_kernel(&kernel_id, current_code);
2087///
2088/// // Request hot reload with new PTX
2089/// let new_code = KernelCodeSource::from_ptx(new_ptx, "my_kernel");
2090/// let request = manager.request_reload(&kernel_id, new_code).await?;
2091///
2092/// // Execute the reload
2093/// let result = manager.execute_reload(request, &mut kernel).await?;
2094/// println!("Reload completed in {:?}", result.total_duration);
2095/// ```
2096pub struct HotReloadManager {
2097    /// Configuration.
2098    config: HotReloadConfig,
2099    /// Registered kernels and their current code.
2100    kernels: RwLock<HashMap<KernelId, KernelCodeSource>>,
2101    /// Fallback code for registered kernels.
2102    fallbacks: RwLock<HashMap<KernelId, KernelCodeSource>>,
2103    /// Active reload requests.
2104    active_requests: RwLock<HashMap<KernelId, HotReloadRequest>>,
2105    /// Version counter for code versions.
2106    version_counter: AtomicU64,
2107    /// Statistics.
2108    stats: HotReloadStats,
2109}
2110
2111impl HotReloadManager {
2112    /// Create a new hot reload manager.
2113    pub fn new(config: HotReloadConfig) -> Arc<Self> {
2114        Arc::new(Self {
2115            config,
2116            kernels: RwLock::new(HashMap::new()),
2117            fallbacks: RwLock::new(HashMap::new()),
2118            active_requests: RwLock::new(HashMap::new()),
2119            version_counter: AtomicU64::new(1),
2120            stats: HotReloadStats::default(),
2121        })
2122    }
2123
2124    /// Create with default configuration.
2125    pub fn with_defaults() -> Arc<Self> {
2126        Self::new(HotReloadConfig::default())
2127    }
2128
2129    /// Check if hot reload is enabled.
2130    pub fn is_enabled(&self) -> bool {
2131        self.config.enabled
2132    }
2133
2134    /// Register a kernel for hot reload.
2135    pub fn register_kernel(&self, kernel_id: &KernelId, code: KernelCodeSource) {
2136        let version = self.version_counter.fetch_add(1, Ordering::Relaxed);
2137        let code = code.with_version(version);
2138        self.kernels.write().insert(kernel_id.clone(), code);
2139    }
2140
2141    /// Unregister a kernel from hot reload.
2142    pub fn unregister_kernel(&self, kernel_id: &KernelId) {
2143        self.kernels.write().remove(kernel_id);
2144        self.fallbacks.write().remove(kernel_id);
2145        self.active_requests.write().remove(kernel_id);
2146    }
2147
2148    /// Get current code version for a kernel.
2149    pub fn get_current_version(&self, kernel_id: &KernelId) -> Option<u64> {
2150        self.kernels.read().get(kernel_id).map(|c| c.version_id)
2151    }
2152
2153    /// Get current code for a kernel.
2154    pub fn get_current_code(&self, kernel_id: &KernelId) -> Option<KernelCodeSource> {
2155        self.kernels.read().get(kernel_id).cloned()
2156    }
2157
2158    /// Request a hot reload for a kernel.
2159    pub fn request_reload(
2160        &self,
2161        kernel_id: &KernelId,
2162        new_code: KernelCodeSource,
2163    ) -> Result<HotReloadRequest> {
2164        if !self.config.enabled {
2165            return Err(RingKernelError::ValidationError(
2166                "Hot reload is disabled".to_string(),
2167            ));
2168        }
2169
2170        // Check kernel is registered
2171        if !self.kernels.read().contains_key(kernel_id) {
2172            return Err(RingKernelError::KernelNotFound(
2173                kernel_id.as_str().to_string(),
2174            ));
2175        }
2176
2177        // Check no reload already in progress
2178        {
2179            let active = self.active_requests.read();
2180            if let Some(existing) = active.get(kernel_id) {
2181                if existing.is_in_progress() {
2182                    return Err(RingKernelError::ValidationError(
2183                        "Hot reload already in progress for this kernel".to_string(),
2184                    ));
2185                }
2186            }
2187        }
2188
2189        // Assign version to new code
2190        let version = self.version_counter.fetch_add(1, Ordering::Relaxed);
2191        let new_code = new_code.with_version(version);
2192
2193        let request = HotReloadRequest::new(kernel_id.clone(), new_code);
2194        self.active_requests.write().insert(
2195            kernel_id.clone(),
2196            HotReloadRequest::new(kernel_id.clone(), request.new_code.clone()),
2197        );
2198
2199        Ok(request)
2200    }
2201
2202    /// Execute a hot reload operation.
2203    ///
2204    /// This performs the full reload sequence:
2205    /// 1. Drain pending messages
2206    /// 2. Checkpoint state (if enabled)
2207    /// 3. Validate new code
2208    /// 4. Swap kernels
2209    /// 5. Restore state (if enabled)
2210    pub fn execute_reload<K: CheckpointableKernel>(
2211        &self,
2212        request: &mut HotReloadRequest,
2213        kernel: &K,
2214    ) -> Result<HotReloadResult> {
2215        let start_time = Instant::now();
2216        request.started_at = Some(start_time);
2217
2218        // Get old version
2219        let old_version = self
2220            .kernels
2221            .read()
2222            .get(&request.kernel_id)
2223            .map(|c| c.version_id)
2224            .unwrap_or(0);
2225
2226        // Phase 1: Drain (simulated - actual drain would wait for queue empty)
2227        request.state = HotReloadState::Draining;
2228        let drain_start = Instant::now();
2229        // In a real implementation, wait for input queue to drain
2230        std::thread::sleep(Duration::from_micros(10));
2231        let drain_duration = drain_start.elapsed();
2232        self.stats
2233            .total_drain_time_us
2234            .fetch_add(drain_duration.as_micros() as u64, Ordering::Relaxed);
2235
2236        // Phase 2: Checkpoint (if preserving state)
2237        request.state = HotReloadState::Checkpointing;
2238        let checkpoint_start = Instant::now();
2239        let checkpoint_size = if self.config.preserve_state {
2240            let checkpoint = kernel.create_checkpoint()?;
2241            let data = checkpoint.to_bytes();
2242            request.checkpoint_data = Some(data.clone());
2243            data.len()
2244        } else {
2245            0
2246        };
2247        let checkpoint_duration = checkpoint_start.elapsed();
2248
2249        // Phase 3: Validate new code
2250        request.state = HotReloadState::Validating;
2251        if self.config.validate_before_swap {
2252            self.validate_code(&request.new_code)?;
2253        }
2254
2255        // Phase 4: Compile (simulated)
2256        request.state = HotReloadState::Compiling;
2257        let compile_start = Instant::now();
2258        // In real implementation, compile PTX/WGSL to native code
2259        std::thread::sleep(Duration::from_micros(10));
2260        let compile_duration = compile_start.elapsed();
2261        self.stats
2262            .total_compile_time_us
2263            .fetch_add(compile_duration.as_micros() as u64, Ordering::Relaxed);
2264
2265        // Phase 5: Swap
2266        request.state = HotReloadState::Swapping;
2267        let swap_start = Instant::now();
2268
2269        // Save fallback
2270        if self.config.keep_fallback {
2271            if let Some(old_code) = self.kernels.read().get(&request.kernel_id).cloned() {
2272                self.fallbacks
2273                    .write()
2274                    .insert(request.kernel_id.clone(), old_code);
2275            }
2276        }
2277
2278        // Install new code
2279        self.kernels
2280            .write()
2281            .insert(request.kernel_id.clone(), request.new_code.clone());
2282        let swap_duration = swap_start.elapsed();
2283        self.stats
2284            .total_swap_time_us
2285            .fetch_add(swap_duration.as_micros() as u64, Ordering::Relaxed);
2286
2287        // Phase 6: Restore state
2288        request.state = HotReloadState::Restoring;
2289        let restore_start = Instant::now();
2290        // In real implementation, restore checkpoint to new kernel
2291        let restore_duration = restore_start.elapsed();
2292
2293        // Mark completed
2294        request.state = HotReloadState::Completed;
2295        self.stats
2296            .successful_reloads
2297            .fetch_add(1, Ordering::Relaxed);
2298        if self.config.preserve_state && checkpoint_size > 0 {
2299            self.stats
2300                .state_preserved_count
2301                .fetch_add(1, Ordering::Relaxed);
2302        }
2303
2304        // Clean up active request
2305        self.active_requests.write().remove(&request.kernel_id);
2306
2307        Ok(HotReloadResult {
2308            kernel_id: request.kernel_id.clone(),
2309            old_version,
2310            new_version: request.new_code.version_id,
2311            state_preserved: self.config.preserve_state && checkpoint_size > 0,
2312            checkpoint_size,
2313            drain_duration,
2314            checkpoint_duration,
2315            compile_duration,
2316            swap_duration,
2317            restore_duration,
2318            total_duration: start_time.elapsed(),
2319        })
2320    }
2321
2322    /// Rollback to previous kernel version.
2323    pub fn rollback(&self, kernel_id: &KernelId) -> Result<()> {
2324        let fallback =
2325            self.fallbacks.write().remove(kernel_id).ok_or_else(|| {
2326                RingKernelError::ValidationError("No fallback available".to_string())
2327            })?;
2328
2329        self.kernels.write().insert(kernel_id.clone(), fallback);
2330        self.stats.rollbacks.fetch_add(1, Ordering::Relaxed);
2331
2332        // Update any active request
2333        if let Some(request) = self.active_requests.write().get_mut(kernel_id) {
2334            request.state = HotReloadState::RollingBack;
2335        }
2336
2337        Ok(())
2338    }
2339
2340    /// Validate kernel code before swap.
2341    fn validate_code(&self, code: &KernelCodeSource) -> Result<()> {
2342        // Basic validation
2343        if code.code.is_empty() {
2344            return Err(RingKernelError::ValidationError(
2345                "Kernel code is empty".to_string(),
2346            ));
2347        }
2348
2349        if code.entry_point.is_empty() {
2350            return Err(RingKernelError::ValidationError(
2351                "Entry point is empty".to_string(),
2352            ));
2353        }
2354
2355        // Format-specific validation
2356        match code.format {
2357            KernelCodeFormat::Ptx => {
2358                // Check for valid PTX header
2359                if let Some(text) = code.as_str() {
2360                    if !text.contains(".version") && !text.contains(".target") {
2361                        return Err(RingKernelError::ValidationError(
2362                            "PTX code missing version/target directive".to_string(),
2363                        ));
2364                    }
2365                }
2366            }
2367            KernelCodeFormat::Wgsl => {
2368                // Check for basic WGSL structure
2369                if let Some(text) = code.as_str() {
2370                    if !text.contains("@compute") && !text.contains("fn ") {
2371                        return Err(RingKernelError::ValidationError(
2372                            "WGSL code missing compute shader or function".to_string(),
2373                        ));
2374                    }
2375                }
2376            }
2377            KernelCodeFormat::Msl => {
2378                // Check for Metal kernel
2379                if let Some(text) = code.as_str() {
2380                    if !text.contains("kernel ") {
2381                        return Err(RingKernelError::ValidationError(
2382                            "MSL code missing kernel function".to_string(),
2383                        ));
2384                    }
2385                }
2386            }
2387            _ => {}
2388        }
2389
2390        Ok(())
2391    }
2392
2393    /// Get statistics snapshot.
2394    pub fn stats(&self) -> HotReloadStatsSnapshot {
2395        let successful = self.stats.successful_reloads.load(Ordering::Relaxed);
2396        let failed = self.stats.failed_reloads.load(Ordering::Relaxed);
2397        let total = successful.max(1);
2398
2399        HotReloadStatsSnapshot {
2400            successful_reloads: successful,
2401            failed_reloads: failed,
2402            rollbacks: self.stats.rollbacks.load(Ordering::Relaxed),
2403            avg_drain_time: Duration::from_micros(
2404                self.stats.total_drain_time_us.load(Ordering::Relaxed) / total,
2405            ),
2406            avg_compile_time: Duration::from_micros(
2407                self.stats.total_compile_time_us.load(Ordering::Relaxed) / total,
2408            ),
2409            avg_swap_time: Duration::from_micros(
2410                self.stats.total_swap_time_us.load(Ordering::Relaxed) / total,
2411            ),
2412            state_preserved_count: self.stats.state_preserved_count.load(Ordering::Relaxed),
2413        }
2414    }
2415
2416    /// List all registered kernels.
2417    pub fn list_kernels(&self) -> Vec<KernelId> {
2418        self.kernels.read().keys().cloned().collect()
2419    }
2420
2421    /// Check if a kernel is registered.
2422    pub fn is_registered(&self, kernel_id: &KernelId) -> bool {
2423        self.kernels.read().contains_key(kernel_id)
2424    }
2425
2426    /// Check if a reload is in progress for a kernel.
2427    pub fn is_reload_in_progress(&self, kernel_id: &KernelId) -> bool {
2428        self.active_requests
2429            .read()
2430            .get(kernel_id)
2431            .map(|r| r.is_in_progress())
2432            .unwrap_or(false)
2433    }
2434
2435    /// Get the configuration.
2436    pub fn config(&self) -> &HotReloadConfig {
2437        &self.config
2438    }
2439}
2440
2441/// Trait for kernels that support hot reload.
2442pub trait HotReloadableKernel: CheckpointableKernel {
2443    /// Prepare kernel for code swap (drain messages, pause processing).
2444    fn prepare_for_reload(&mut self) -> Result<()>;
2445
2446    /// Apply new code to the kernel.
2447    fn apply_code(&mut self, code: &KernelCodeSource) -> Result<()>;
2448
2449    /// Resume processing after reload.
2450    fn resume_after_reload(&mut self) -> Result<()>;
2451
2452    /// Check if kernel is ready for reload.
2453    fn is_ready_for_reload(&self) -> bool;
2454}
2455
2456#[cfg(test)]
2457mod tests {
2458    use super::*;
2459
2460    #[test]
2461    fn test_device_info() {
2462        let info = DeviceInfo::new(0, "Test GPU".to_string(), Backend::Cuda);
2463        assert_eq!(info.index, 0);
2464        assert_eq!(info.name, "Test GPU");
2465        assert_eq!(info.memory_utilization(), 0.0);
2466    }
2467
2468    #[test]
2469    fn test_coordinator_registration() {
2470        let coord = MultiGpuBuilder::new().build();
2471
2472        let device = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2473        coord.register_device(device);
2474
2475        assert_eq!(coord.device_count(), 1);
2476        assert!(coord.device(0).is_some());
2477    }
2478
2479    #[test]
2480    fn test_kernel_assignment() {
2481        let coord = MultiGpuBuilder::new().build();
2482
2483        let device = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2484        coord.register_device(device);
2485
2486        let kernel_id = KernelId::new("test_kernel");
2487        coord.assign_kernel(kernel_id.clone(), 0);
2488
2489        assert_eq!(coord.get_kernel_device(&kernel_id), Some(0));
2490        assert_eq!(coord.kernels_on_device(0).len(), 1);
2491    }
2492
2493    #[test]
2494    fn test_load_balancing_least_loaded() {
2495        let coord = MultiGpuBuilder::new()
2496            .load_balancing(LoadBalancingStrategy::LeastLoaded)
2497            .build();
2498
2499        // Register two devices
2500        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2501        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2502
2503        // Assign a kernel to device 0
2504        coord.assign_kernel(KernelId::new("k1"), 0);
2505
2506        // Next kernel should go to device 1 (least loaded)
2507        let selected = coord.select_device(&LaunchOptions::default()).unwrap();
2508        assert_eq!(selected, 1);
2509    }
2510
2511    #[test]
2512    fn test_round_robin() {
2513        let coord = MultiGpuBuilder::new()
2514            .load_balancing(LoadBalancingStrategy::RoundRobin)
2515            .build();
2516
2517        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2518        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2519
2520        let d1 = coord.select_device(&LaunchOptions::default()).unwrap();
2521        let d2 = coord.select_device(&LaunchOptions::default()).unwrap();
2522        let d3 = coord.select_device(&LaunchOptions::default()).unwrap();
2523
2524        // Should cycle through devices
2525        assert_ne!(d1, d2);
2526        assert_eq!(d1, d3);
2527    }
2528
2529    // ========================================================================
2530    // Topology Tests
2531    // ========================================================================
2532
2533    #[test]
2534    fn test_interconnect_bandwidth() {
2535        assert!(
2536            InterconnectType::NvLink.estimated_bandwidth_gbps()
2537                > InterconnectType::Pcie.estimated_bandwidth_gbps()
2538        );
2539        assert!(
2540            InterconnectType::Pcie.estimated_bandwidth_gbps()
2541                > InterconnectType::None.estimated_bandwidth_gbps()
2542        );
2543        assert!(
2544            InterconnectType::SameDevice.estimated_bandwidth_gbps()
2545                > InterconnectType::NvLink.estimated_bandwidth_gbps()
2546        );
2547    }
2548
2549    #[test]
2550    fn test_interconnect_p2p_support() {
2551        assert!(!InterconnectType::None.supports_p2p());
2552        assert!(InterconnectType::Pcie.supports_p2p());
2553        assert!(InterconnectType::NvLink.supports_p2p());
2554        assert!(InterconnectType::NvSwitch.supports_p2p());
2555    }
2556
2557    #[test]
2558    fn test_gpu_topology_creation() {
2559        let topo = GpuTopology::new(4);
2560        assert_eq!(topo.device_count, 4);
2561
2562        // Self-connections should exist
2563        for i in 0..4 {
2564            let conn = topo.get_connection(i, i);
2565            assert!(conn.is_some());
2566            assert_eq!(conn.unwrap().interconnect, InterconnectType::SameDevice);
2567        }
2568    }
2569
2570    #[test]
2571    fn test_gpu_topology_set_connection() {
2572        let mut topo = GpuTopology::new(4);
2573
2574        // Set NVLink between GPU 0 and 1
2575        topo.set_connection(GpuConnection::new(0, 1, InterconnectType::NvLink));
2576
2577        let conn_01 = topo.get_connection(0, 1);
2578        assert!(conn_01.is_some());
2579        assert_eq!(conn_01.unwrap().interconnect, InterconnectType::NvLink);
2580
2581        // Bidirectional by default
2582        let conn_10 = topo.get_connection(1, 0);
2583        assert!(conn_10.is_some());
2584        assert_eq!(conn_10.unwrap().interconnect, InterconnectType::NvLink);
2585    }
2586
2587    #[test]
2588    fn test_gpu_topology_neighbors() {
2589        let mut topo = GpuTopology::new(4);
2590
2591        // Ring topology: 0-1-2-3-0
2592        topo.set_connection(GpuConnection::new(0, 1, InterconnectType::NvLink));
2593        topo.set_connection(GpuConnection::new(1, 2, InterconnectType::NvLink));
2594        topo.set_connection(GpuConnection::new(2, 3, InterconnectType::NvLink));
2595        topo.set_connection(GpuConnection::new(3, 0, InterconnectType::NvLink));
2596
2597        let neighbors_0 = topo.neighbors(0);
2598        assert_eq!(neighbors_0.len(), 2);
2599        assert!(neighbors_0.contains(&1));
2600        assert!(neighbors_0.contains(&3));
2601    }
2602
2603    #[test]
2604    fn test_gpu_topology_best_path() {
2605        let mut topo = GpuTopology::new(4);
2606
2607        // Create connections: 0-1, 1-2, 2-3 (no direct 0-3)
2608        topo.set_connection(GpuConnection::new(0, 1, InterconnectType::NvLink));
2609        topo.set_connection(GpuConnection::new(1, 2, InterconnectType::NvLink));
2610        topo.set_connection(GpuConnection::new(2, 3, InterconnectType::NvLink));
2611        topo.set_connection(GpuConnection::new(0, 3, InterconnectType::None)); // No direct P2P
2612
2613        // Direct path should work for adjacent nodes
2614        let path_01 = topo.best_path(0, 1);
2615        assert_eq!(path_01, vec![0, 1]);
2616
2617        // Same device
2618        let path_00 = topo.best_path(0, 0);
2619        assert_eq!(path_00, vec![0]);
2620    }
2621
2622    #[test]
2623    fn test_gpu_topology_fully_connected() {
2624        let mut topo = GpuTopology::new(3);
2625
2626        // Not fully connected initially
2627        assert!(!topo.is_fully_connected());
2628
2629        // Make fully connected mesh
2630        topo.set_connection(GpuConnection::new(0, 1, InterconnectType::NvLink));
2631        topo.set_connection(GpuConnection::new(0, 2, InterconnectType::NvLink));
2632        topo.set_connection(GpuConnection::new(1, 2, InterconnectType::NvLink));
2633
2634        assert!(topo.is_fully_connected());
2635    }
2636
2637    #[test]
2638    fn test_gpu_topology_numa() {
2639        let mut topo = GpuTopology::new(4);
2640
2641        // GPUs 0,1 on NUMA 0; GPUs 2,3 on NUMA 1
2642        topo.set_numa_node(0, 0);
2643        topo.set_numa_node(1, 0);
2644        topo.set_numa_node(2, 1);
2645        topo.set_numa_node(3, 1);
2646
2647        let numa_neighbors_0 = topo.numa_neighbors(0);
2648        assert_eq!(numa_neighbors_0, vec![1]);
2649
2650        let numa_neighbors_2 = topo.numa_neighbors(2);
2651        assert_eq!(numa_neighbors_2, vec![3]);
2652    }
2653
2654    // ========================================================================
2655    // Topology Discovery Tests
2656    // ========================================================================
2657
2658    #[test]
2659    fn test_coordinator_topology_discovery() {
2660        let coord = MultiGpuBuilder::new().enable_p2p(true).build();
2661
2662        // Register P2P capable devices
2663        let mut dev0 = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2664        dev0.p2p_capable = true;
2665        dev0.compute_capability = Some((8, 0)); // Ampere
2666
2667        let mut dev1 = DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda);
2668        dev1.p2p_capable = true;
2669        dev1.compute_capability = Some((8, 6)); // Ampere
2670
2671        coord.register_device(dev0);
2672        coord.register_device(dev1);
2673
2674        let topo = coord.discover_topology();
2675
2676        assert_eq!(topo.device_count, 2);
2677
2678        // Should detect NVLink for Ampere GPUs
2679        let conn = topo.get_connection(0, 1);
2680        assert!(conn.is_some());
2681        assert_eq!(conn.unwrap().interconnect, InterconnectType::NvLink);
2682    }
2683
2684    // ========================================================================
2685    // Migration Tests
2686    // ========================================================================
2687
2688    #[test]
2689    fn test_migration_request() {
2690        let coord = MultiGpuBuilder::new().build();
2691
2692        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2693        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2694
2695        let kernel_id = KernelId::new("migrating_kernel");
2696        coord.assign_kernel(kernel_id.clone(), 0);
2697
2698        let request = coord.request_migration(&kernel_id, 1).unwrap();
2699
2700        assert_eq!(request.source_device, 0);
2701        assert_eq!(request.target_device, 1);
2702        assert_eq!(request.state, MigrationState::Pending);
2703    }
2704
2705    #[test]
2706    fn test_migration_same_device_error() {
2707        let coord = MultiGpuBuilder::new().build();
2708
2709        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2710
2711        let kernel_id = KernelId::new("kernel");
2712        coord.assign_kernel(kernel_id.clone(), 0);
2713
2714        let result = coord.request_migration(&kernel_id, 0);
2715        assert!(result.is_err());
2716    }
2717
2718    #[test]
2719    fn test_migration_complete() {
2720        let coord = MultiGpuBuilder::new().build();
2721
2722        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2723        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2724
2725        let kernel_id = KernelId::new("migrating_kernel");
2726        coord.assign_kernel(kernel_id.clone(), 0);
2727
2728        assert_eq!(coord.get_kernel_device(&kernel_id), Some(0));
2729
2730        let request = coord.request_migration(&kernel_id, 1).unwrap();
2731        coord.complete_migration(&request).unwrap();
2732
2733        assert_eq!(coord.get_kernel_device(&kernel_id), Some(1));
2734    }
2735
2736    #[test]
2737    fn test_migration_transfer_time_estimate() {
2738        let request = MigrationRequest {
2739            kernel_id: KernelId::new("test"),
2740            source_device: 0,
2741            target_device: 1,
2742            path: vec![0, 1],
2743            estimated_bandwidth_gbps: 300.0, // NVLink
2744            estimated_latency_us: 1.0,
2745            state: MigrationState::Pending,
2746            started_at: None,
2747        };
2748
2749        // 1GB transfer at 300GB/s = ~3.3ms + 1us latency
2750        let time = request.estimate_transfer_time(1_000_000_000);
2751        assert!(time.as_micros() > 3000);
2752        assert!(time.as_micros() < 4000);
2753    }
2754
2755    // ========================================================================
2756    // Cross-GPU K2K Router Tests
2757    // ========================================================================
2758
2759    use crate::hlc::HlcTimestamp;
2760    use crate::message::MessageEnvelope;
2761
2762    fn make_test_k2k_message(source: &KernelId, dest: &KernelId) -> K2KMessage {
2763        let timestamp = HlcTimestamp::now(42);
2764        let envelope = MessageEnvelope::empty(1, 2, timestamp);
2765        K2KMessage::new(source.clone(), dest.clone(), envelope, timestamp)
2766    }
2767
2768    #[test]
2769    fn test_router_same_device() {
2770        let coord = MultiGpuBuilder::new().build();
2771        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2772
2773        let k1 = KernelId::new("k1");
2774        let k2 = KernelId::new("k2");
2775        coord.assign_kernel(k1.clone(), 0);
2776        coord.assign_kernel(k2.clone(), 0);
2777
2778        let router = CrossGpuK2KRouter::new(coord);
2779
2780        let msg = make_test_k2k_message(&k1, &k2);
2781        let decision = router.route_message(&k1, &k2, msg).unwrap();
2782
2783        matches!(decision, RoutingDecision::SameDevice);
2784    }
2785
2786    #[test]
2787    fn test_router_cross_device() {
2788        let coord = MultiGpuBuilder::new().enable_p2p(true).build();
2789
2790        let mut dev0 = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2791        dev0.p2p_capable = true;
2792        let mut dev1 = DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda);
2793        dev1.p2p_capable = true;
2794
2795        coord.register_device(dev0);
2796        coord.register_device(dev1);
2797
2798        let k1 = KernelId::new("k1");
2799        let k2 = KernelId::new("k2");
2800        coord.assign_kernel(k1.clone(), 0);
2801        coord.assign_kernel(k2.clone(), 1);
2802
2803        let router = CrossGpuK2KRouter::new(coord);
2804
2805        let msg = make_test_k2k_message(&k1, &k2);
2806        let decision = router.route_message(&k1, &k2, msg).unwrap();
2807
2808        match decision {
2809            RoutingDecision::DirectP2P {
2810                source_device,
2811                dest_device,
2812                ..
2813            } => {
2814                assert_eq!(source_device, 0);
2815                assert_eq!(dest_device, 1);
2816            }
2817            _ => panic!("Expected DirectP2P routing"),
2818        }
2819    }
2820
2821    #[test]
2822    fn test_router_pending_messages() {
2823        let coord = MultiGpuBuilder::new().enable_p2p(true).build();
2824
2825        let mut dev0 = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2826        dev0.p2p_capable = true;
2827        let mut dev1 = DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda);
2828        dev1.p2p_capable = true;
2829
2830        coord.register_device(dev0);
2831        coord.register_device(dev1);
2832
2833        let k1 = KernelId::new("k1");
2834        let k2 = KernelId::new("k2");
2835        coord.assign_kernel(k1.clone(), 0);
2836        coord.assign_kernel(k2.clone(), 1);
2837
2838        let router = CrossGpuK2KRouter::new(coord);
2839
2840        // Route 3 messages
2841        for _ in 0..3 {
2842            let msg = make_test_k2k_message(&k1, &k2);
2843            router.route_message(&k1, &k2, msg).unwrap();
2844        }
2845
2846        assert_eq!(router.stats().messages_pending, 3);
2847
2848        // Drain pending
2849        let pending = router.drain_pending(0, 1);
2850        assert_eq!(pending.len(), 3);
2851        assert_eq!(router.stats().messages_pending, 0);
2852    }
2853
2854    #[test]
2855    fn test_router_stats() {
2856        let coord = MultiGpuBuilder::new().build();
2857        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2858
2859        let k1 = KernelId::new("k1");
2860        let k2 = KernelId::new("k2");
2861        coord.assign_kernel(k1.clone(), 0);
2862        coord.assign_kernel(k2.clone(), 0);
2863
2864        let router = CrossGpuK2KRouter::new(coord);
2865
2866        let stats = router.stats();
2867        assert_eq!(stats.messages_routed, 0);
2868        assert_eq!(stats.bytes_transferred, 0);
2869        assert_eq!(stats.routing_failures, 0);
2870    }
2871
2872    // ========================================================================
2873    // Kernel Migrator Tests
2874    // ========================================================================
2875
2876    use crate::checkpoint::{Checkpoint, CheckpointBuilder};
2877
2878    /// Mock checkpointable kernel for testing.
2879    struct MockCheckpointableKernel {
2880        kernel_id: String,
2881        kernel_type: String,
2882        state_data: Vec<u8>,
2883        step: u64,
2884    }
2885
2886    impl MockCheckpointableKernel {
2887        fn new(kernel_id: &str, state_size: usize) -> Self {
2888            Self {
2889                kernel_id: kernel_id.to_string(),
2890                kernel_type: "mock_kernel".to_string(),
2891                state_data: vec![0xAB; state_size],
2892                step: 1000,
2893            }
2894        }
2895    }
2896
2897    impl CheckpointableKernel for MockCheckpointableKernel {
2898        fn create_checkpoint(&self) -> Result<Checkpoint> {
2899            let checkpoint = CheckpointBuilder::new(&self.kernel_id, &self.kernel_type)
2900                .step(self.step)
2901                .grid_size(64, 64, 64)
2902                .control_block(vec![1, 2, 3, 4])
2903                .device_memory("state", self.state_data.clone())
2904                .build();
2905            Ok(checkpoint)
2906        }
2907
2908        fn restore_from_checkpoint(&mut self, checkpoint: &Checkpoint) -> Result<()> {
2909            self.step = checkpoint.metadata.current_step;
2910            Ok(())
2911        }
2912
2913        fn checkpoint_kernel_id(&self) -> &str {
2914            &self.kernel_id
2915        }
2916
2917        fn checkpoint_kernel_type(&self) -> &str {
2918            &self.kernel_type
2919        }
2920    }
2921
2922    #[test]
2923    fn test_migrator_creation() {
2924        let coord = MultiGpuBuilder::new().build();
2925        let migrator = KernelMigrator::new(coord);
2926
2927        let stats = migrator.stats();
2928        assert_eq!(stats.successful_migrations, 0);
2929        assert_eq!(stats.failed_migrations, 0);
2930        assert_eq!(stats.bytes_transferred, 0);
2931    }
2932
2933    #[test]
2934    fn test_migrator_with_custom_storage() {
2935        let coord = MultiGpuBuilder::new().build();
2936        let storage = Arc::new(MemoryStorage::new());
2937        let migrator = KernelMigrator::with_storage(coord.clone(), storage);
2938
2939        // Verify we can access the coordinator
2940        assert!(Arc::ptr_eq(migrator.coordinator(), &coord));
2941    }
2942
2943    #[test]
2944    fn test_successful_migration() {
2945        let coord = MultiGpuBuilder::new().build();
2946
2947        // Register devices
2948        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2949        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2950
2951        // Assign kernel to device 0
2952        let kernel_id = KernelId::new("migratable_kernel");
2953        coord.assign_kernel(kernel_id.clone(), 0);
2954
2955        let migrator = KernelMigrator::new(coord.clone());
2956
2957        // Create mock kernel
2958        let kernel = MockCheckpointableKernel::new("migratable_kernel", 1024);
2959
2960        // Request migration
2961        let mut request = coord.request_migration(&kernel_id, 1).unwrap();
2962        assert_eq!(request.state, MigrationState::Pending);
2963
2964        // Perform migration
2965        let result = migrator
2966            .migrate_with_checkpoint(&kernel, &mut request)
2967            .unwrap();
2968
2969        // Verify result
2970        assert_eq!(result.kernel_id.as_str(), "migratable_kernel");
2971        assert_eq!(result.source_device, 0);
2972        assert_eq!(result.target_device, 1);
2973        assert!(result.checkpoint_size > 0);
2974        assert!(result.total_duration > Duration::ZERO);
2975
2976        // Verify kernel was moved
2977        assert_eq!(coord.get_kernel_device(&kernel_id), Some(1));
2978
2979        // Verify stats
2980        let stats = migrator.stats();
2981        assert_eq!(stats.successful_migrations, 1);
2982        assert_eq!(stats.failed_migrations, 0);
2983        assert!(stats.bytes_transferred > 0);
2984    }
2985
2986    #[test]
2987    fn test_migration_result_fields() {
2988        let coord = MultiGpuBuilder::new().build();
2989
2990        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2991        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2992
2993        let kernel_id = KernelId::new("test_kernel");
2994        coord.assign_kernel(kernel_id.clone(), 0);
2995
2996        let migrator = KernelMigrator::new(coord.clone());
2997        let kernel = MockCheckpointableKernel::new("test_kernel", 4096);
2998        let mut request = coord.request_migration(&kernel_id, 1).unwrap();
2999
3000        let result = migrator
3001            .migrate_with_checkpoint(&kernel, &mut request)
3002            .unwrap();
3003
3004        // All durations should be non-negative
3005        assert!(result.checkpoint_duration >= Duration::ZERO);
3006        assert!(result.transfer_duration >= Duration::ZERO);
3007        assert!(result.restore_duration >= Duration::ZERO);
3008
3009        // Total should be >= sum of parts
3010        assert!(result.total_duration >= result.checkpoint_duration);
3011    }
3012
3013    #[test]
3014    fn test_migration_stats_accumulate() {
3015        let coord = MultiGpuBuilder::new().build();
3016
3017        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3018        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
3019
3020        let migrator = KernelMigrator::new(coord.clone());
3021
3022        // Migrate kernel 1: 0 -> 1
3023        let k1 = KernelId::new("k1");
3024        coord.assign_kernel(k1.clone(), 0);
3025        let kernel1 = MockCheckpointableKernel::new("k1", 1000);
3026        let mut req1 = coord.request_migration(&k1, 1).unwrap();
3027        migrator
3028            .migrate_with_checkpoint(&kernel1, &mut req1)
3029            .unwrap();
3030
3031        // Migrate kernel 2: 0 -> 1
3032        let k2 = KernelId::new("k2");
3033        coord.assign_kernel(k2.clone(), 0);
3034        let kernel2 = MockCheckpointableKernel::new("k2", 2000);
3035        let mut req2 = coord.request_migration(&k2, 1).unwrap();
3036        migrator
3037            .migrate_with_checkpoint(&kernel2, &mut req2)
3038            .unwrap();
3039
3040        let stats = migrator.stats();
3041        assert_eq!(stats.successful_migrations, 2);
3042        assert_eq!(stats.failed_migrations, 0);
3043        // Both checkpoints should have been transferred
3044        assert!(stats.bytes_transferred > 0);
3045    }
3046
3047    // ========================================================================
3048    // Device Unregister Tests
3049    // ========================================================================
3050
3051    #[test]
3052    fn test_unregister_device_no_kernels() {
3053        let coord = MultiGpuBuilder::new().build();
3054
3055        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3056        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
3057
3058        let result = coord.unregister_device(0);
3059
3060        assert!(result.success);
3061        assert_eq!(result.device_index, 0);
3062        assert!(result.kernels_to_migrate.is_empty());
3063        assert!(result.orphaned_kernels.is_empty());
3064    }
3065
3066    #[test]
3067    fn test_unregister_device_with_kernels() {
3068        let coord = MultiGpuBuilder::new().build();
3069
3070        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3071        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
3072
3073        // Assign kernels to device 0
3074        let k1 = KernelId::new("k1");
3075        let k2 = KernelId::new("k2");
3076        coord.assign_kernel(k1.clone(), 0);
3077        coord.assign_kernel(k2.clone(), 0);
3078
3079        let result = coord.unregister_device(0);
3080
3081        assert!(result.success);
3082        assert_eq!(result.kernels_to_migrate.len(), 2);
3083        assert!(result.orphaned_kernels.is_empty());
3084
3085        // All kernels should migrate to device 1
3086        for plan in &result.kernels_to_migrate {
3087            assert_eq!(plan.source_device, 0);
3088            assert_eq!(plan.target_device, 1);
3089        }
3090
3091        // Verify kernel mappings were updated
3092        assert_eq!(coord.get_kernel_device(&k1), Some(1));
3093        assert_eq!(coord.get_kernel_device(&k2), Some(1));
3094    }
3095
3096    #[test]
3097    fn test_unregister_single_device_orphans_kernels() {
3098        let coord = MultiGpuBuilder::new().build();
3099
3100        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3101
3102        // Assign kernels to device 0
3103        let k1 = KernelId::new("k1");
3104        coord.assign_kernel(k1.clone(), 0);
3105
3106        let result = coord.unregister_device(0);
3107
3108        assert!(result.success);
3109        assert!(result.kernels_to_migrate.is_empty());
3110        assert_eq!(result.orphaned_kernels.len(), 1);
3111        assert_eq!(result.orphaned_kernels[0], k1);
3112
3113        // Kernel should no longer have a device
3114        assert!(coord.get_kernel_device(&k1).is_none());
3115    }
3116
3117    #[test]
3118    fn test_unregister_nonexistent_device() {
3119        let coord = MultiGpuBuilder::new().build();
3120
3121        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3122
3123        let result = coord.unregister_device(99);
3124
3125        assert!(!result.success);
3126        assert_eq!(result.device_index, 99);
3127    }
3128
3129    #[test]
3130    fn test_unregister_distributes_to_least_loaded() {
3131        let coord = MultiGpuBuilder::new().build();
3132
3133        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3134        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
3135        coord.register_device(DeviceInfo::new(2, "GPU 2".to_string(), Backend::Cuda));
3136
3137        // Preload device 1 with kernels
3138        coord.assign_kernel(KernelId::new("pre1"), 1);
3139        coord.assign_kernel(KernelId::new("pre2"), 1);
3140        coord.assign_kernel(KernelId::new("pre3"), 1);
3141
3142        // Assign kernel to device 0
3143        let k1 = KernelId::new("migrate_me");
3144        coord.assign_kernel(k1.clone(), 0);
3145
3146        let result = coord.unregister_device(0);
3147
3148        assert!(result.success);
3149        assert_eq!(result.kernels_to_migrate.len(), 1);
3150
3151        // Should migrate to device 2 (least loaded)
3152        let plan = &result.kernels_to_migrate[0];
3153        assert_eq!(plan.target_device, 2);
3154    }
3155
3156    #[test]
3157    fn test_migration_priority_enum() {
3158        let low = MigrationPriority::Low;
3159        let normal = MigrationPriority::Normal;
3160        let high = MigrationPriority::High;
3161        let critical = MigrationPriority::Critical;
3162
3163        assert_ne!(low, normal);
3164        assert_ne!(normal, high);
3165        assert_ne!(high, critical);
3166        assert_eq!(low, MigrationPriority::Low);
3167    }
3168
3169    // Hot Reload Tests
3170
3171    #[test]
3172    fn test_hot_reload_config_default() {
3173        let config = HotReloadConfig::default();
3174        assert!(config.enabled);
3175        assert!(config.preserve_state);
3176        assert!(config.validate_before_swap);
3177        assert!(config.keep_fallback);
3178        assert_eq!(config.max_retries, 3);
3179    }
3180
3181    #[test]
3182    fn test_hot_reload_config_builder() {
3183        let config = HotReloadConfig::new()
3184            .with_enabled(false)
3185            .with_preserve_state(false)
3186            .with_max_retries(5)
3187            .with_timeout(Duration::from_secs(60));
3188
3189        assert!(!config.enabled);
3190        assert!(!config.preserve_state);
3191        assert_eq!(config.max_retries, 5);
3192        assert_eq!(config.reload_timeout, Duration::from_secs(60));
3193    }
3194
3195    #[test]
3196    fn test_kernel_code_source_ptx() {
3197        let ptx = ".version 7.0\n.target sm_80\nkernel: ret;";
3198        let code = KernelCodeSource::from_ptx(ptx, "kernel");
3199
3200        assert_eq!(code.format, KernelCodeFormat::Ptx);
3201        assert_eq!(code.entry_point, "kernel");
3202        assert_eq!(code.as_str(), Some(ptx));
3203        assert_eq!(code.size(), ptx.len());
3204    }
3205
3206    #[test]
3207    fn test_kernel_code_source_wgsl() {
3208        let wgsl = "@compute fn main() {}";
3209        let code = KernelCodeSource::from_wgsl(wgsl, "main");
3210
3211        assert_eq!(code.format, KernelCodeFormat::Wgsl);
3212        assert_eq!(code.entry_point, "main");
3213        assert_eq!(code.as_str(), Some(wgsl));
3214    }
3215
3216    #[test]
3217    fn test_kernel_code_source_msl() {
3218        let msl = "kernel void my_kernel() {}";
3219        let code = KernelCodeSource::from_msl(msl, "my_kernel");
3220
3221        assert_eq!(code.format, KernelCodeFormat::Msl);
3222        assert_eq!(code.entry_point, "my_kernel");
3223        assert_eq!(code.as_str(), Some(msl));
3224    }
3225
3226    #[test]
3227    fn test_hot_reload_manager_creation() {
3228        let manager = HotReloadManager::with_defaults();
3229        assert!(manager.is_enabled());
3230        assert!(manager.list_kernels().is_empty());
3231    }
3232
3233    #[test]
3234    fn test_hot_reload_manager_register_kernel() {
3235        let manager = HotReloadManager::with_defaults();
3236        let kernel_id = KernelId::new("test_kernel");
3237        let code = KernelCodeSource::from_ptx(".version 7.0", "kernel");
3238
3239        manager.register_kernel(&kernel_id, code);
3240
3241        assert!(manager.is_registered(&kernel_id));
3242        assert!(!manager.is_reload_in_progress(&kernel_id));
3243        assert!(manager.get_current_version(&kernel_id).is_some());
3244    }
3245
3246    #[test]
3247    fn test_hot_reload_request_states() {
3248        let kernel_id = KernelId::new("test");
3249        let code = KernelCodeSource::from_ptx(".version 7.0", "kernel");
3250        let request = HotReloadRequest::new(kernel_id, code);
3251
3252        assert_eq!(request.state, HotReloadState::Idle);
3253        assert!(!request.is_in_progress());
3254        assert!(!request.is_completed());
3255        assert!(!request.is_failed());
3256    }
3257
3258    #[test]
3259    fn test_hot_reload_disabled() {
3260        let config = HotReloadConfig::new().with_enabled(false);
3261        let manager = HotReloadManager::new(config);
3262        let kernel_id = KernelId::new("test");
3263        let code = KernelCodeSource::from_ptx(".version 7.0", "kernel");
3264
3265        manager.register_kernel(&kernel_id, code.clone());
3266        let result = manager.request_reload(&kernel_id, code);
3267        assert!(result.is_err());
3268    }
3269
3270    #[test]
3271    fn test_hot_reload_stats() {
3272        let manager = HotReloadManager::with_defaults();
3273        let stats = manager.stats();
3274
3275        assert_eq!(stats.successful_reloads, 0);
3276        assert_eq!(stats.failed_reloads, 0);
3277        assert_eq!(stats.rollbacks, 0);
3278    }
3279
3280    #[test]
3281    fn test_hot_reload_code_formats() {
3282        let formats = [
3283            KernelCodeFormat::Ptx,
3284            KernelCodeFormat::Cubin,
3285            KernelCodeFormat::SpirV,
3286            KernelCodeFormat::Wgsl,
3287            KernelCodeFormat::Msl,
3288            KernelCodeFormat::MetalLib,
3289            KernelCodeFormat::Source,
3290        ];
3291
3292        // Verify all formats are distinct
3293        for (i, f1) in formats.iter().enumerate() {
3294            for (j, f2) in formats.iter().enumerate() {
3295                if i != j {
3296                    assert_ne!(f1, f2);
3297                }
3298            }
3299        }
3300    }
3301
3302    #[test]
3303    fn test_hot_reload_state_transitions() {
3304        let states = [
3305            HotReloadState::Idle,
3306            HotReloadState::Draining,
3307            HotReloadState::Checkpointing,
3308            HotReloadState::Compiling,
3309            HotReloadState::Validating,
3310            HotReloadState::Swapping,
3311            HotReloadState::Restoring,
3312            HotReloadState::Completed,
3313            HotReloadState::Failed,
3314            HotReloadState::RollingBack,
3315        ];
3316
3317        // Verify all states are distinct
3318        for (i, s1) in states.iter().enumerate() {
3319            for (j, s2) in states.iter().enumerate() {
3320                if i != j {
3321                    assert_ne!(s1, s2);
3322                }
3323            }
3324        }
3325    }
3326
3327    #[test]
3328    fn test_hot_reload_execute() {
3329        let manager = HotReloadManager::with_defaults();
3330        let kernel_id = KernelId::new("test_kernel");
3331
3332        let initial_code = KernelCodeSource::from_ptx(".version 7.0\n.target sm_80", "kernel");
3333        manager.register_kernel(&kernel_id, initial_code);
3334
3335        let new_code = KernelCodeSource::from_ptx(".version 8.0\n.target sm_90", "kernel");
3336        let mut request = manager.request_reload(&kernel_id, new_code).unwrap();
3337
3338        // Create mock kernel for checkpoint
3339        let mock_kernel = MockCheckpointableKernel::new("test_kernel", 512);
3340
3341        let result = manager.execute_reload(&mut request, &mock_kernel).unwrap();
3342
3343        assert!(request.is_completed());
3344        assert_eq!(result.kernel_id.as_str(), "test_kernel");
3345        assert!(result.state_preserved);
3346        assert!(result.checkpoint_size > 0);
3347        assert!(result.total_duration > Duration::ZERO);
3348
3349        // Stats should be updated
3350        let stats = manager.stats();
3351        assert_eq!(stats.successful_reloads, 1);
3352    }
3353
3354    #[test]
3355    fn test_hot_reload_list_kernels() {
3356        let manager = HotReloadManager::with_defaults();
3357
3358        let k1 = KernelId::new("kernel1");
3359        let k2 = KernelId::new("kernel2");
3360        let k3 = KernelId::new("kernel3");
3361
3362        manager.register_kernel(&k1, KernelCodeSource::from_ptx(".version 7.0", "k1"));
3363        manager.register_kernel(&k2, KernelCodeSource::from_ptx(".version 7.0", "k2"));
3364        manager.register_kernel(&k3, KernelCodeSource::from_ptx(".version 7.0", "k3"));
3365
3366        let kernels = manager.list_kernels();
3367        assert_eq!(kernels.len(), 3);
3368        assert!(kernels.contains(&k1));
3369        assert!(kernels.contains(&k2));
3370        assert!(kernels.contains(&k3));
3371    }
3372}