ringkernel_core/
multi_gpu.rs

1//! Multi-GPU coordination, topology discovery, and cross-GPU messaging.
2//!
3//! This module provides infrastructure for coordinating work across
4//! multiple GPUs, including:
5//!
6//! - **Device Selection** - Load balancing strategies for kernel placement
7//! - **Topology Discovery** - NVLink/PCIe detection and bandwidth estimation
8//! - **Cross-GPU K2K Router** - Kernel-to-kernel messaging across GPUs
9//! - **Kernel Migration** - Move kernels between GPUs with state transfer
10//!
11//! ## Example
12//!
13//! ```ignore
14//! use ringkernel_core::multi_gpu::{MultiGpuBuilder, GpuTopology, CrossGpuK2KRouter};
15//!
16//! let coordinator = MultiGpuBuilder::new()
17//!     .load_balancing(LoadBalancingStrategy::LeastLoaded)
18//!     .enable_p2p(true)
19//!     .build();
20//!
21//! // Discover topology
22//! let topology = coordinator.discover_topology();
23//!
24//! // Create cross-GPU router
25//! let router = CrossGpuK2KRouter::new(coordinator.clone());
26//! router.route_message(source_kernel, dest_kernel, envelope).await?;
27//! ```
28
29use parking_lot::RwLock;
30use std::collections::HashMap;
31use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
32use std::sync::Arc;
33use std::time::{Duration, Instant};
34
35use crate::error::{Result, RingKernelError};
36use crate::k2k::K2KMessage;
37use crate::runtime::{Backend, KernelId, LaunchOptions};
38
39/// Configuration for multi-GPU coordination.
40#[derive(Debug, Clone)]
41pub struct MultiGpuConfig {
42    /// Load balancing strategy.
43    pub load_balancing: LoadBalancingStrategy,
44    /// Enable automatic device selection.
45    pub auto_select_device: bool,
46    /// Maximum kernels per device.
47    pub max_kernels_per_device: usize,
48    /// Enable peer-to-peer transfers when available.
49    pub enable_p2p: bool,
50    /// Preferred devices (by index).
51    pub preferred_devices: Vec<usize>,
52}
53
54impl Default for MultiGpuConfig {
55    fn default() -> Self {
56        Self {
57            load_balancing: LoadBalancingStrategy::LeastLoaded,
58            auto_select_device: true,
59            max_kernels_per_device: 64,
60            enable_p2p: true,
61            preferred_devices: vec![],
62        }
63    }
64}
65
66/// Strategy for balancing load across devices.
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub enum LoadBalancingStrategy {
69    /// Always use the first available device.
70    FirstAvailable,
71    /// Use the device with fewest kernels.
72    LeastLoaded,
73    /// Round-robin across devices.
74    RoundRobin,
75    /// Select based on memory availability.
76    MemoryBased,
77    /// Select based on compute capability.
78    ComputeCapability,
79    /// Custom selection function.
80    Custom,
81}
82
83/// Information about a GPU device.
84#[derive(Debug, Clone)]
85pub struct DeviceInfo {
86    /// Device index.
87    pub index: usize,
88    /// Device name.
89    pub name: String,
90    /// Backend type.
91    pub backend: Backend,
92    /// Total memory in bytes.
93    pub total_memory: u64,
94    /// Available memory in bytes.
95    pub available_memory: u64,
96    /// Compute capability (for CUDA).
97    pub compute_capability: Option<(u32, u32)>,
98    /// Maximum threads per block.
99    pub max_threads_per_block: u32,
100    /// Number of multiprocessors.
101    pub multiprocessor_count: u32,
102    /// Whether device supports P2P with other devices.
103    pub p2p_capable: bool,
104}
105
106impl DeviceInfo {
107    /// Create a new device info.
108    pub fn new(index: usize, name: String, backend: Backend) -> Self {
109        Self {
110            index,
111            name,
112            backend,
113            total_memory: 0,
114            available_memory: 0,
115            compute_capability: None,
116            max_threads_per_block: 1024,
117            multiprocessor_count: 1,
118            p2p_capable: false,
119        }
120    }
121
122    /// Get memory utilization (0.0-1.0).
123    pub fn memory_utilization(&self) -> f64 {
124        if self.total_memory == 0 {
125            0.0
126        } else {
127            1.0 - (self.available_memory as f64 / self.total_memory as f64)
128        }
129    }
130}
131
132// ============================================================================
133// GPU Topology Discovery
134// ============================================================================
135
136/// Type of interconnect between GPUs.
137#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
138pub enum InterconnectType {
139    /// No direct connection (must go through host).
140    None,
141    /// PCIe peer-to-peer.
142    Pcie,
143    /// NVIDIA NVLink.
144    NvLink,
145    /// NVIDIA NVSwitch (datacenter).
146    NvSwitch,
147    /// AMD Infinity Fabric.
148    InfinityFabric,
149    /// Intel Xe Link.
150    XeLink,
151    /// Same GPU (for self-connections).
152    SameDevice,
153}
154
155impl InterconnectType {
156    /// Estimated bandwidth in GB/s for this interconnect type.
157    pub fn estimated_bandwidth_gbps(&self) -> f64 {
158        match self {
159            InterconnectType::None => 16.0,      // PCIe 3.0 x16 through host
160            InterconnectType::Pcie => 32.0,      // PCIe 4.0 x16 P2P
161            InterconnectType::NvLink => 300.0,   // NVLink 3.0 (A100)
162            InterconnectType::NvSwitch => 600.0, // NVSwitch full bisection
163            InterconnectType::InfinityFabric => 200.0, // MI250X
164            InterconnectType::XeLink => 100.0,   // Intel Data Center GPUs
165            InterconnectType::SameDevice => 2000.0, // Internal bandwidth
166        }
167    }
168
169    /// Estimated latency in microseconds.
170    pub fn estimated_latency_us(&self) -> f64 {
171        match self {
172            InterconnectType::None => 10.0,    // Through host memory
173            InterconnectType::Pcie => 5.0,     // P2P PCIe
174            InterconnectType::NvLink => 1.0,   // Direct NVLink
175            InterconnectType::NvSwitch => 2.0, // Through switch
176            InterconnectType::InfinityFabric => 1.5,
177            InterconnectType::XeLink => 2.0,
178            InterconnectType::SameDevice => 0.0,
179        }
180    }
181
182    /// Whether this interconnect supports direct P2P memory access.
183    pub fn supports_p2p(&self) -> bool {
184        !matches!(self, InterconnectType::None)
185    }
186}
187
188/// Connection between two GPUs.
189#[derive(Debug, Clone)]
190pub struct GpuConnection {
191    /// Source device index.
192    pub source: usize,
193    /// Destination device index.
194    pub destination: usize,
195    /// Type of interconnect.
196    pub interconnect: InterconnectType,
197    /// Measured or estimated bandwidth in GB/s.
198    pub bandwidth_gbps: f64,
199    /// Measured or estimated latency in microseconds.
200    pub latency_us: f64,
201    /// Whether connection is bidirectional with same characteristics.
202    pub bidirectional: bool,
203    /// Number of hops (for multi-hop topologies).
204    pub hops: u32,
205}
206
207impl GpuConnection {
208    /// Create a new GPU connection.
209    pub fn new(source: usize, destination: usize, interconnect: InterconnectType) -> Self {
210        Self {
211            source,
212            destination,
213            interconnect,
214            bandwidth_gbps: interconnect.estimated_bandwidth_gbps(),
215            latency_us: interconnect.estimated_latency_us(),
216            bidirectional: true,
217            hops: if source == destination { 0 } else { 1 },
218        }
219    }
220
221    /// Set measured bandwidth.
222    pub fn with_bandwidth(mut self, gbps: f64) -> Self {
223        self.bandwidth_gbps = gbps;
224        self
225    }
226
227    /// Set measured latency.
228    pub fn with_latency(mut self, us: f64) -> Self {
229        self.latency_us = us;
230        self
231    }
232
233    /// Set hop count.
234    pub fn with_hops(mut self, hops: u32) -> Self {
235        self.hops = hops;
236        self
237    }
238}
239
240/// GPU topology graph describing all device interconnections.
241#[derive(Debug, Clone)]
242pub struct GpuTopology {
243    /// Number of devices in topology.
244    pub device_count: usize,
245    /// Connection matrix (device_count x device_count).
246    connections: Vec<Vec<Option<GpuConnection>>>,
247    /// NUMA node assignments for each device.
248    pub numa_nodes: Vec<Option<u32>>,
249    /// Whether topology has been probed (vs estimated).
250    pub probed: bool,
251    /// Timestamp of last topology update.
252    pub last_updated: Instant,
253}
254
255impl GpuTopology {
256    /// Create a new topology for N devices.
257    pub fn new(device_count: usize) -> Self {
258        let mut connections = vec![vec![None; device_count]; device_count];
259
260        // Initialize self-connections
261        for (i, row) in connections.iter_mut().enumerate().take(device_count) {
262            row[i] = Some(GpuConnection::new(i, i, InterconnectType::SameDevice));
263        }
264
265        Self {
266            device_count,
267            connections,
268            numa_nodes: vec![None; device_count],
269            probed: false,
270            last_updated: Instant::now(),
271        }
272    }
273
274    /// Set connection between two devices.
275    pub fn set_connection(&mut self, connection: GpuConnection) {
276        let src = connection.source;
277        let dst = connection.destination;
278        if src < self.device_count && dst < self.device_count {
279            self.connections[src][dst] = Some(connection.clone());
280            if connection.bidirectional && src != dst {
281                let reverse = GpuConnection {
282                    source: dst,
283                    destination: src,
284                    ..connection
285                };
286                self.connections[dst][src] = Some(reverse);
287            }
288        }
289    }
290
291    /// Get connection between two devices.
292    pub fn get_connection(&self, source: usize, destination: usize) -> Option<&GpuConnection> {
293        self.connections
294            .get(source)
295            .and_then(|row| row.get(destination))
296            .and_then(|c| c.as_ref())
297    }
298
299    /// Get best path between two devices (returns intermediate hops).
300    pub fn best_path(&self, source: usize, destination: usize) -> Vec<usize> {
301        if source == destination {
302            return vec![source];
303        }
304
305        // Direct connection available?
306        if let Some(conn) = self.get_connection(source, destination) {
307            if conn.interconnect != InterconnectType::None {
308                return vec![source, destination];
309            }
310        }
311
312        // Find best path via Dijkstra (simplified)
313        let mut best_path = vec![source, destination]; // Default to direct
314        let mut best_bandwidth = 0.0;
315
316        // Check all intermediate nodes
317        for intermediate in 0..self.device_count {
318            if intermediate == source || intermediate == destination {
319                continue;
320            }
321
322            if let (Some(c1), Some(c2)) = (
323                self.get_connection(source, intermediate),
324                self.get_connection(intermediate, destination),
325            ) {
326                // Bandwidth limited by slowest link
327                let path_bandwidth = c1.bandwidth_gbps.min(c2.bandwidth_gbps);
328                if path_bandwidth > best_bandwidth {
329                    best_bandwidth = path_bandwidth;
330                    best_path = vec![source, intermediate, destination];
331                }
332            }
333        }
334
335        best_path
336    }
337
338    /// Get all devices directly connected to a device.
339    pub fn neighbors(&self, device: usize) -> Vec<usize> {
340        if device >= self.device_count {
341            return vec![];
342        }
343
344        self.connections[device]
345            .iter()
346            .enumerate()
347            .filter_map(|(i, conn)| {
348                if i != device
349                    && conn
350                        .as_ref()
351                        .map(|c| c.interconnect.supports_p2p())
352                        .unwrap_or(false)
353                {
354                    Some(i)
355                } else {
356                    None
357                }
358            })
359            .collect()
360    }
361
362    /// Calculate total bisection bandwidth of the topology.
363    pub fn bisection_bandwidth_gbps(&self) -> f64 {
364        let half = self.device_count / 2;
365        if half == 0 {
366            return 0.0;
367        }
368
369        let mut total = 0.0;
370        for src in 0..half {
371            for dst in half..self.device_count {
372                if let Some(conn) = self.get_connection(src, dst) {
373                    total += conn.bandwidth_gbps;
374                }
375            }
376        }
377        total
378    }
379
380    /// Check if all devices have P2P connectivity.
381    pub fn is_fully_connected(&self) -> bool {
382        for src in 0..self.device_count {
383            for dst in 0..self.device_count {
384                if src != dst {
385                    if let Some(conn) = self.get_connection(src, dst) {
386                        if !conn.interconnect.supports_p2p() {
387                            return false;
388                        }
389                    } else {
390                        return false;
391                    }
392                }
393            }
394        }
395        true
396    }
397
398    /// Get devices in the same NUMA domain.
399    pub fn numa_neighbors(&self, device: usize) -> Vec<usize> {
400        let target_numa = self.numa_nodes.get(device).copied().flatten();
401        if target_numa.is_none() {
402            return vec![];
403        }
404
405        self.numa_nodes
406            .iter()
407            .enumerate()
408            .filter_map(|(i, numa)| {
409                if i != device && *numa == target_numa {
410                    Some(i)
411                } else {
412                    None
413                }
414            })
415            .collect()
416    }
417
418    /// Set NUMA node for a device.
419    pub fn set_numa_node(&mut self, device: usize, numa_node: u32) {
420        if device < self.numa_nodes.len() {
421            self.numa_nodes[device] = Some(numa_node);
422        }
423    }
424
425    /// Mark topology as probed (not estimated).
426    pub fn mark_probed(&mut self) {
427        self.probed = true;
428        self.last_updated = Instant::now();
429    }
430}
431
432/// Status of a device in the multi-GPU coordinator.
433#[derive(Debug, Clone)]
434pub struct DeviceStatus {
435    /// Device info.
436    pub info: DeviceInfo,
437    /// Number of kernels running on this device.
438    pub kernel_count: usize,
439    /// Kernels running on this device.
440    pub kernels: Vec<KernelId>,
441    /// Whether device is available for new kernels.
442    pub available: bool,
443    /// Current load estimate (0.0-1.0).
444    pub load: f64,
445}
446
447/// Result of unregistering a device from the coordinator.
448#[derive(Debug, Clone)]
449pub struct DeviceUnregisterResult {
450    /// Index of the unregistered device.
451    pub device_index: usize,
452    /// Kernels that were on this device and need migration.
453    pub kernels_to_migrate: Vec<KernelMigrationPlan>,
454    /// Kernels that could not be migrated (no available target).
455    pub orphaned_kernels: Vec<KernelId>,
456    /// Whether the device was successfully unregistered.
457    pub success: bool,
458}
459
460/// Plan for migrating a single kernel during device unregister.
461#[derive(Debug, Clone)]
462pub struct KernelMigrationPlan {
463    /// Kernel to migrate.
464    pub kernel_id: KernelId,
465    /// Source device (the unregistered device).
466    pub source_device: usize,
467    /// Target device selected for migration.
468    pub target_device: usize,
469    /// Estimated migration priority (based on kernel load).
470    pub priority: MigrationPriority,
471}
472
473/// Priority for kernel migration.
474#[derive(Debug, Clone, Copy, PartialEq, Eq)]
475pub enum MigrationPriority {
476    /// Low priority - can be migrated lazily.
477    Low,
478    /// Normal priority - migrate in reasonable time.
479    Normal,
480    /// High priority - migrate as soon as possible.
481    High,
482    /// Critical - must migrate immediately.
483    Critical,
484}
485
486/// Multi-GPU coordinator for managing kernels across devices.
487pub struct MultiGpuCoordinator {
488    /// Configuration.
489    config: MultiGpuConfig,
490    /// Available devices.
491    devices: RwLock<Vec<DeviceInfo>>,
492    /// Kernel-to-device mapping.
493    kernel_device_map: RwLock<HashMap<KernelId, usize>>,
494    /// Device kernel counts.
495    device_kernel_counts: RwLock<Vec<AtomicUsize>>,
496    /// Round-robin counter.
497    round_robin_counter: AtomicUsize,
498    /// Total kernels launched.
499    total_kernels: AtomicU64,
500    /// Device selection callbacks (for custom strategy).
501    #[allow(clippy::type_complexity)]
502    custom_selector:
503        RwLock<Option<Arc<dyn Fn(&[DeviceStatus], &LaunchOptions) -> usize + Send + Sync>>>,
504    /// GPU topology graph.
505    topology: RwLock<Option<GpuTopology>>,
506}
507
508impl MultiGpuCoordinator {
509    /// Create a new multi-GPU coordinator.
510    pub fn new(config: MultiGpuConfig) -> Arc<Self> {
511        Arc::new(Self {
512            config,
513            devices: RwLock::new(Vec::new()),
514            kernel_device_map: RwLock::new(HashMap::new()),
515            device_kernel_counts: RwLock::new(Vec::new()),
516            round_robin_counter: AtomicUsize::new(0),
517            total_kernels: AtomicU64::new(0),
518            custom_selector: RwLock::new(None),
519            topology: RwLock::new(None),
520        })
521    }
522
523    /// Register a device with the coordinator.
524    pub fn register_device(&self, device: DeviceInfo) {
525        let index = device.index;
526        let mut devices = self.devices.write();
527        let mut counts = self.device_kernel_counts.write();
528
529        // Ensure we have enough slots
530        let mut current_len = devices.len();
531        while current_len <= index {
532            devices.push(DeviceInfo::new(
533                current_len,
534                "Unknown".to_string(),
535                Backend::Cpu,
536            ));
537            counts.push(AtomicUsize::new(0));
538            current_len += 1;
539        }
540
541        devices[index] = device;
542    }
543
544    /// Unregister a device and plan kernel migrations.
545    ///
546    /// This method:
547    /// 1. Identifies all kernels on the device being removed
548    /// 2. Finds target devices for each kernel using load balancing
549    /// 3. Creates migration plans for kernels that can be moved
550    /// 4. Marks orphaned kernels that have no migration target
551    /// 5. Updates internal routing tables
552    ///
553    /// The caller is responsible for executing the actual migrations using
554    /// [`KernelMigrator`] with the returned [`KernelMigrationPlan`] entries.
555    pub fn unregister_device(&self, index: usize) -> DeviceUnregisterResult {
556        let devices = self.devices.read();
557
558        // Check if device exists
559        if index >= devices.len() {
560            return DeviceUnregisterResult {
561                device_index: index,
562                kernels_to_migrate: Vec::new(),
563                orphaned_kernels: Vec::new(),
564                success: false,
565            };
566        }
567
568        // Get all kernels on this device
569        let kernels_on_device = self.kernels_on_device(index);
570
571        // Find available target devices (excluding the one being unregistered)
572        let available_targets: Vec<usize> = devices
573            .iter()
574            .enumerate()
575            .filter(|(i, _)| *i != index)
576            .map(|(i, _)| i)
577            .collect();
578
579        drop(devices); // Release read lock before acquiring write lock
580
581        let mut kernels_to_migrate = Vec::new();
582        let mut orphaned_kernels = Vec::new();
583
584        if available_targets.is_empty() {
585            // No other devices available - all kernels are orphaned
586            orphaned_kernels = kernels_on_device;
587        } else {
588            // Plan migrations for each kernel
589            for kernel_id in kernels_on_device {
590                // Select target based on current load
591                if let Some(target) = self.select_migration_target(&available_targets) {
592                    let priority = self.calculate_migration_priority(&kernel_id);
593                    kernels_to_migrate.push(KernelMigrationPlan {
594                        kernel_id,
595                        source_device: index,
596                        target_device: target,
597                        priority,
598                    });
599                } else {
600                    orphaned_kernels.push(kernel_id);
601                }
602            }
603        }
604
605        // Update kernel-device mappings for planned migrations
606        {
607            let mut kernel_map = self.kernel_device_map.write();
608            let counts = self.device_kernel_counts.read();
609
610            for plan in &kernels_to_migrate {
611                // Update mapping to target device
612                kernel_map.insert(plan.kernel_id.clone(), plan.target_device);
613
614                // Update kernel counts
615                if index < counts.len() {
616                    counts[index].fetch_sub(1, Ordering::Relaxed);
617                }
618                if plan.target_device < counts.len() {
619                    counts[plan.target_device].fetch_add(1, Ordering::Relaxed);
620                }
621            }
622
623            // Remove orphaned kernels from mapping
624            for kernel_id in &orphaned_kernels {
625                kernel_map.remove(kernel_id);
626                if index < counts.len() {
627                    counts[index].fetch_sub(1, Ordering::Relaxed);
628                }
629            }
630        }
631
632        // Mark device as unavailable (but don't remove it to preserve indices)
633        {
634            let mut devices = self.devices.write();
635            if index < devices.len() {
636                devices[index].available_memory = 0;
637                devices[index].name = format!("{} (unregistered)", devices[index].name);
638            }
639        }
640
641        DeviceUnregisterResult {
642            device_index: index,
643            kernels_to_migrate,
644            orphaned_kernels,
645            success: true,
646        }
647    }
648
649    /// Select the best target device for migration.
650    fn select_migration_target(&self, candidates: &[usize]) -> Option<usize> {
651        if candidates.is_empty() {
652            return None;
653        }
654
655        let counts = self.device_kernel_counts.read();
656
657        // Find device with lowest kernel count
658        candidates
659            .iter()
660            .filter_map(|&idx| {
661                if idx < counts.len() {
662                    Some((idx, counts[idx].load(Ordering::Relaxed)))
663                } else {
664                    None
665                }
666            })
667            .min_by_key(|(_, count)| *count)
668            .map(|(idx, _)| idx)
669    }
670
671    /// Calculate migration priority for a kernel.
672    fn calculate_migration_priority(&self, _kernel_id: &KernelId) -> MigrationPriority {
673        // In a real implementation, this would check:
674        // - Message queue depth
675        // - Time since last activity
676        // - Kernel type/importance
677        // For now, use normal priority
678        MigrationPriority::Normal
679    }
680
681    /// Get all registered devices.
682    pub fn devices(&self) -> Vec<DeviceInfo> {
683        self.devices.read().clone()
684    }
685
686    /// Get device info by index.
687    pub fn device(&self, index: usize) -> Option<DeviceInfo> {
688        self.devices.read().get(index).cloned()
689    }
690
691    /// Get number of devices.
692    pub fn device_count(&self) -> usize {
693        self.devices.read().len()
694    }
695
696    /// Select a device for launching a kernel.
697    pub fn select_device(&self, options: &LaunchOptions) -> Result<usize> {
698        let devices = self.devices.read();
699        if devices.is_empty() {
700            return Err(RingKernelError::BackendUnavailable(
701                "No GPU devices available".to_string(),
702            ));
703        }
704
705        // Get current status
706        let status = self.get_all_status();
707
708        // Check for custom selector
709        if self.config.load_balancing == LoadBalancingStrategy::Custom {
710            if let Some(selector) = &*self.custom_selector.read() {
711                return Ok(selector(&status, options));
712            }
713        }
714
715        // Apply preferred devices filter if specified
716        let candidates: Vec<_> = if !self.config.preferred_devices.is_empty() {
717            status
718                .into_iter()
719                .filter(|s| self.config.preferred_devices.contains(&s.info.index))
720                .collect()
721        } else {
722            status
723        };
724
725        if candidates.is_empty() {
726            return Err(RingKernelError::BackendUnavailable(
727                "No suitable GPU device available".to_string(),
728            ));
729        }
730
731        let selected = match self.config.load_balancing {
732            LoadBalancingStrategy::FirstAvailable => {
733                candidates.first().map(|s| s.info.index).unwrap_or(0)
734            }
735            LoadBalancingStrategy::LeastLoaded => candidates
736                .iter()
737                .filter(|s| s.available && s.kernel_count < self.config.max_kernels_per_device)
738                .min_by(|a, b| a.kernel_count.cmp(&b.kernel_count))
739                .map(|s| s.info.index)
740                .unwrap_or(0),
741            LoadBalancingStrategy::RoundRobin => {
742                let available: Vec<_> = candidates.iter().filter(|s| s.available).collect();
743
744                if available.is_empty() {
745                    candidates.first().map(|s| s.info.index).unwrap_or(0)
746                } else {
747                    let idx =
748                        self.round_robin_counter.fetch_add(1, Ordering::Relaxed) % available.len();
749                    available[idx].info.index
750                }
751            }
752            LoadBalancingStrategy::MemoryBased => candidates
753                .iter()
754                .filter(|s| s.available)
755                .max_by(|a, b| a.info.available_memory.cmp(&b.info.available_memory))
756                .map(|s| s.info.index)
757                .unwrap_or(0),
758            LoadBalancingStrategy::ComputeCapability => candidates
759                .iter()
760                .filter(|s| s.available)
761                .max_by(|a, b| {
762                    let a_cap = a.info.compute_capability.unwrap_or((0, 0));
763                    let b_cap = b.info.compute_capability.unwrap_or((0, 0));
764                    a_cap.cmp(&b_cap)
765                })
766                .map(|s| s.info.index)
767                .unwrap_or(0),
768            LoadBalancingStrategy::Custom => {
769                // Should have been handled above
770                0
771            }
772        };
773
774        Ok(selected)
775    }
776
777    /// Assign a kernel to a device.
778    pub fn assign_kernel(&self, kernel_id: KernelId, device_index: usize) {
779        self.kernel_device_map
780            .write()
781            .insert(kernel_id, device_index);
782
783        let counts = self.device_kernel_counts.read();
784        if device_index < counts.len() {
785            counts[device_index].fetch_add(1, Ordering::Relaxed);
786        }
787
788        self.total_kernels.fetch_add(1, Ordering::Relaxed);
789    }
790
791    /// Remove a kernel assignment.
792    pub fn remove_kernel(&self, kernel_id: &KernelId) {
793        if let Some(device_index) = self.kernel_device_map.write().remove(kernel_id) {
794            let counts = self.device_kernel_counts.read();
795            if device_index < counts.len() {
796                counts[device_index].fetch_sub(1, Ordering::Relaxed);
797            }
798        }
799    }
800
801    /// Get device for a kernel.
802    pub fn get_kernel_device(&self, kernel_id: &KernelId) -> Option<usize> {
803        self.kernel_device_map.read().get(kernel_id).copied()
804    }
805
806    /// Get all kernels on a device.
807    pub fn kernels_on_device(&self, device_index: usize) -> Vec<KernelId> {
808        self.kernel_device_map
809            .read()
810            .iter()
811            .filter(|(_, &idx)| idx == device_index)
812            .map(|(k, _)| k.clone())
813            .collect()
814    }
815
816    /// Get status of all devices.
817    pub fn get_all_status(&self) -> Vec<DeviceStatus> {
818        let devices = self.devices.read();
819        let kernel_map = self.kernel_device_map.read();
820        let counts = self.device_kernel_counts.read();
821
822        devices
823            .iter()
824            .enumerate()
825            .map(|(idx, info)| {
826                let kernel_count = if idx < counts.len() {
827                    counts[idx].load(Ordering::Relaxed)
828                } else {
829                    0
830                };
831
832                let kernels: Vec<_> = kernel_map
833                    .iter()
834                    .filter(|(_, &dev_idx)| dev_idx == idx)
835                    .map(|(k, _)| k.clone())
836                    .collect();
837
838                let load = kernel_count as f64 / self.config.max_kernels_per_device as f64;
839                let available = kernel_count < self.config.max_kernels_per_device;
840
841                DeviceStatus {
842                    info: info.clone(),
843                    kernel_count,
844                    kernels,
845                    available,
846                    load,
847                }
848            })
849            .collect()
850    }
851
852    /// Get status of a specific device.
853    pub fn get_device_status(&self, device_index: usize) -> Option<DeviceStatus> {
854        self.get_all_status().into_iter().nth(device_index)
855    }
856
857    /// Set custom device selector.
858    pub fn set_custom_selector<F>(&self, selector: F)
859    where
860        F: Fn(&[DeviceStatus], &LaunchOptions) -> usize + Send + Sync + 'static,
861    {
862        *self.custom_selector.write() = Some(Arc::new(selector));
863    }
864
865    /// Get coordinator statistics.
866    pub fn stats(&self) -> MultiGpuStats {
867        let status = self.get_all_status();
868        let total_kernels: usize = status.iter().map(|s| s.kernel_count).sum();
869        let total_memory: u64 = status.iter().map(|s| s.info.total_memory).sum();
870        let available_memory: u64 = status.iter().map(|s| s.info.available_memory).sum();
871
872        MultiGpuStats {
873            device_count: status.len(),
874            total_kernels,
875            total_memory,
876            available_memory,
877            kernels_launched: self.total_kernels.load(Ordering::Relaxed),
878        }
879    }
880
881    /// Check if P2P is available between two devices.
882    pub fn can_p2p(&self, device_a: usize, device_b: usize) -> bool {
883        if !self.config.enable_p2p {
884            return false;
885        }
886
887        let devices = self.devices.read();
888        if let (Some(a), Some(b)) = (devices.get(device_a), devices.get(device_b)) {
889            a.p2p_capable && b.p2p_capable
890        } else {
891            false
892        }
893    }
894
895    /// Update device memory info.
896    pub fn update_device_memory(&self, device_index: usize, available_memory: u64) {
897        let mut devices = self.devices.write();
898        if let Some(device) = devices.get_mut(device_index) {
899            device.available_memory = available_memory;
900        }
901    }
902
903    // ========================================================================
904    // Topology Discovery
905    // ========================================================================
906
907    /// Discover GPU topology (estimates if probing not available).
908    pub fn discover_topology(&self) -> GpuTopology {
909        let devices = self.devices.read();
910        let device_count = devices.len();
911
912        if device_count == 0 {
913            return GpuTopology::new(0);
914        }
915
916        let mut topo = GpuTopology::new(device_count);
917
918        // Set up connections based on device info
919        for (i, dev_i) in devices.iter().enumerate() {
920            for (j, dev_j) in devices.iter().enumerate() {
921                if i == j {
922                    continue;
923                }
924
925                // Determine interconnect type based on device capabilities
926                let interconnect = if dev_i.p2p_capable && dev_j.p2p_capable {
927                    // Check if same backend (can do P2P)
928                    if dev_i.backend == dev_j.backend {
929                        match dev_i.backend {
930                            Backend::Cuda => {
931                                // For CUDA, check compute capability for NVLink
932                                let cc_i = dev_i.compute_capability.unwrap_or((0, 0));
933                                let cc_j = dev_j.compute_capability.unwrap_or((0, 0));
934
935                                // Ampere+ (SM 80+) likely has NVLink
936                                if cc_i.0 >= 8 && cc_j.0 >= 8 {
937                                    InterconnectType::NvLink
938                                } else {
939                                    InterconnectType::Pcie
940                                }
941                            }
942                            _ => InterconnectType::Pcie,
943                        }
944                    } else {
945                        InterconnectType::None
946                    }
947                } else {
948                    InterconnectType::None
949                };
950
951                topo.set_connection(GpuConnection::new(i, j, interconnect));
952            }
953        }
954
955        // Store topology
956        *self.topology.write() = Some(topo.clone());
957
958        topo
959    }
960
961    /// Get current topology (discovers if not cached).
962    pub fn topology(&self) -> GpuTopology {
963        {
964            let topo = self.topology.read();
965            if let Some(ref t) = *topo {
966                return t.clone();
967            }
968        }
969        self.discover_topology()
970    }
971
972    /// Set custom topology (for testing or manual configuration).
973    pub fn set_topology(&self, topology: GpuTopology) {
974        *self.topology.write() = Some(topology);
975    }
976
977    /// Get best device for communicating with a source kernel.
978    pub fn select_device_for_k2k(&self, source_kernel: &KernelId) -> Result<usize> {
979        let source_device = self.get_kernel_device(source_kernel);
980        if source_device.is_none() {
981            return self.select_device(&LaunchOptions::default());
982        }
983
984        let source_idx = source_device.unwrap();
985        let topo = self.topology();
986        let status = self.get_all_status();
987
988        // Find best device based on connectivity and load
989        let neighbors = topo.neighbors(source_idx);
990
991        if neighbors.is_empty() {
992            // No P2P neighbors, fall back to normal selection
993            return self.select_device(&LaunchOptions::default());
994        }
995
996        // Score devices by: connectivity bandwidth / (load + 1)
997        let best = neighbors
998            .iter()
999            .filter_map(|&dev_idx| {
1000                status.iter().find(|s| s.info.index == dev_idx).map(|s| {
1001                    let conn = topo.get_connection(source_idx, dev_idx);
1002                    let bandwidth = conn.map(|c| c.bandwidth_gbps).unwrap_or(1.0);
1003                    let score = bandwidth / (s.load + 0.1);
1004                    (dev_idx, score)
1005                })
1006            })
1007            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
1008            .map(|(idx, _)| idx);
1009
1010        best.ok_or_else(|| {
1011            RingKernelError::BackendUnavailable("No suitable K2K device found".to_string())
1012        })
1013    }
1014
1015    // ========================================================================
1016    // Kernel Migration
1017    // ========================================================================
1018
1019    /// Request to migrate a kernel to another device.
1020    pub fn request_migration(
1021        &self,
1022        kernel_id: &KernelId,
1023        target_device: usize,
1024    ) -> Result<MigrationRequest> {
1025        let source_device = self
1026            .get_kernel_device(kernel_id)
1027            .ok_or_else(|| RingKernelError::KernelNotFound(kernel_id.as_str().to_string()))?;
1028
1029        if source_device == target_device {
1030            return Err(RingKernelError::InvalidConfig(
1031                "Cannot migrate to same device".to_string(),
1032            ));
1033        }
1034
1035        let devices = self.devices.read();
1036        if target_device >= devices.len() {
1037            return Err(RingKernelError::DeviceNotAvailable(format!(
1038                "Device {} not available",
1039                target_device
1040            )));
1041        }
1042
1043        let topo = self.topology();
1044        let path = topo.best_path(source_device, target_device);
1045        let connection = topo.get_connection(source_device, target_device);
1046
1047        Ok(MigrationRequest {
1048            kernel_id: kernel_id.clone(),
1049            source_device,
1050            target_device,
1051            path,
1052            estimated_bandwidth_gbps: connection.map(|c| c.bandwidth_gbps).unwrap_or(16.0),
1053            estimated_latency_us: connection.map(|c| c.latency_us).unwrap_or(10.0),
1054            state: MigrationState::Pending,
1055            started_at: None,
1056        })
1057    }
1058
1059    /// Complete a migration (updates internal mappings).
1060    pub fn complete_migration(&self, request: &MigrationRequest) -> Result<()> {
1061        // Update kernel-device mapping
1062        {
1063            let mut map = self.kernel_device_map.write();
1064            if let Some(dev) = map.get_mut(&request.kernel_id) {
1065                *dev = request.target_device;
1066            }
1067        }
1068
1069        // Update kernel counts
1070        {
1071            let counts = self.device_kernel_counts.read();
1072            if request.source_device < counts.len() {
1073                counts[request.source_device].fetch_sub(1, Ordering::Relaxed);
1074            }
1075            if request.target_device < counts.len() {
1076                counts[request.target_device].fetch_add(1, Ordering::Relaxed);
1077            }
1078        }
1079
1080        Ok(())
1081    }
1082}
1083
1084// ============================================================================
1085// Kernel Migration Types
1086// ============================================================================
1087
1088/// State of a kernel migration.
1089#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1090pub enum MigrationState {
1091    /// Migration is pending, not yet started.
1092    Pending,
1093    /// Kernel is being quiesced (draining messages).
1094    Quiescing,
1095    /// Checkpoint is being created.
1096    Checkpointing,
1097    /// State is being transferred.
1098    Transferring,
1099    /// Kernel is being restored on target.
1100    Restoring,
1101    /// Migration completed successfully.
1102    Completed,
1103    /// Migration failed.
1104    Failed,
1105    /// Migration was cancelled.
1106    Cancelled,
1107}
1108
1109/// Request to migrate a kernel between devices.
1110#[derive(Debug, Clone)]
1111pub struct MigrationRequest {
1112    /// Kernel to migrate.
1113    pub kernel_id: KernelId,
1114    /// Source device index.
1115    pub source_device: usize,
1116    /// Target device index.
1117    pub target_device: usize,
1118    /// Path of devices for multi-hop migration.
1119    pub path: Vec<usize>,
1120    /// Estimated bandwidth for transfer.
1121    pub estimated_bandwidth_gbps: f64,
1122    /// Estimated latency.
1123    pub estimated_latency_us: f64,
1124    /// Current state.
1125    pub state: MigrationState,
1126    /// When migration started.
1127    pub started_at: Option<Instant>,
1128}
1129
1130impl MigrationRequest {
1131    /// Estimate transfer time for given state size.
1132    pub fn estimate_transfer_time(&self, state_size_bytes: usize) -> Duration {
1133        // time = size / bandwidth + latency
1134        let size_gb = state_size_bytes as f64 / 1_000_000_000.0;
1135        let transfer_time_s = size_gb / self.estimated_bandwidth_gbps;
1136        let total_us = (transfer_time_s * 1_000_000.0) + self.estimated_latency_us;
1137        Duration::from_micros(total_us as u64)
1138    }
1139}
1140
1141// ============================================================================
1142// Cross-GPU K2K Router
1143// ============================================================================
1144
1145/// Routes K2K messages across GPU boundaries.
1146pub struct CrossGpuK2KRouter {
1147    /// Multi-GPU coordinator.
1148    coordinator: Arc<MultiGpuCoordinator>,
1149    /// Message queues for pending cross-device messages.
1150    pending_queues: RwLock<HashMap<(usize, usize), Vec<PendingK2KMessage>>>,
1151    /// Statistics.
1152    stats: CrossGpuRouterStats,
1153}
1154
1155/// A pending cross-GPU K2K message.
1156#[derive(Debug, Clone)]
1157pub struct PendingK2KMessage {
1158    /// Source kernel ID.
1159    pub source_kernel: KernelId,
1160    /// Destination kernel ID.
1161    pub dest_kernel: KernelId,
1162    /// Message payload.
1163    pub message: K2KMessage,
1164    /// Timestamp when queued.
1165    pub queued_at: Instant,
1166    /// Number of routing hops.
1167    pub hops: u32,
1168}
1169
1170/// Statistics for cross-GPU K2K routing.
1171#[derive(Debug, Default)]
1172pub struct CrossGpuRouterStats {
1173    /// Total messages routed.
1174    messages_routed: AtomicU64,
1175    /// Total bytes transferred.
1176    bytes_transferred: AtomicU64,
1177    /// Messages currently pending.
1178    messages_pending: AtomicUsize,
1179    /// Total routing latency (microseconds).
1180    total_latency_us: AtomicU64,
1181    /// Failed routing attempts.
1182    routing_failures: AtomicU64,
1183}
1184
1185impl CrossGpuK2KRouter {
1186    /// Create a new cross-GPU K2K router.
1187    pub fn new(coordinator: Arc<MultiGpuCoordinator>) -> Arc<Self> {
1188        Arc::new(Self {
1189            coordinator,
1190            pending_queues: RwLock::new(HashMap::new()),
1191            stats: CrossGpuRouterStats::default(),
1192        })
1193    }
1194
1195    /// Route a message from source kernel to destination kernel.
1196    pub fn route_message(
1197        &self,
1198        source_kernel: &KernelId,
1199        dest_kernel: &KernelId,
1200        message: K2KMessage,
1201    ) -> Result<RoutingDecision> {
1202        let source_device = self
1203            .coordinator
1204            .get_kernel_device(source_kernel)
1205            .ok_or_else(|| {
1206                RingKernelError::K2KDestinationNotFound(source_kernel.as_str().to_string())
1207            })?;
1208
1209        let dest_device = self
1210            .coordinator
1211            .get_kernel_device(dest_kernel)
1212            .ok_or_else(|| {
1213                RingKernelError::K2KDestinationNotFound(dest_kernel.as_str().to_string())
1214            })?;
1215
1216        // Same device - use regular K2K
1217        if source_device == dest_device {
1218            return Ok(RoutingDecision::SameDevice);
1219        }
1220
1221        // Get topology for routing
1222        let topo = self.coordinator.topology();
1223        let path = topo.best_path(source_device, dest_device);
1224
1225        // Check if direct P2P is available
1226        if let Some(conn) = topo.get_connection(source_device, dest_device) {
1227            if conn.interconnect.supports_p2p() {
1228                // Queue for direct P2P transfer
1229                let pending = PendingK2KMessage {
1230                    source_kernel: source_kernel.clone(),
1231                    dest_kernel: dest_kernel.clone(),
1232                    message,
1233                    queued_at: Instant::now(),
1234                    hops: 1,
1235                };
1236
1237                self.enqueue_pending(source_device, dest_device, pending);
1238                self.stats.messages_pending.fetch_add(1, Ordering::Relaxed);
1239
1240                return Ok(RoutingDecision::DirectP2P {
1241                    source_device,
1242                    dest_device,
1243                    bandwidth_gbps: conn.bandwidth_gbps,
1244                });
1245            }
1246        }
1247
1248        // Multi-hop routing required
1249        if path.len() > 2 {
1250            let pending = PendingK2KMessage {
1251                source_kernel: source_kernel.clone(),
1252                dest_kernel: dest_kernel.clone(),
1253                message,
1254                queued_at: Instant::now(),
1255                hops: (path.len() - 1) as u32,
1256            };
1257
1258            // Queue for first hop
1259            self.enqueue_pending(source_device, path[1], pending);
1260            self.stats.messages_pending.fetch_add(1, Ordering::Relaxed);
1261
1262            return Ok(RoutingDecision::MultiHop {
1263                path: path.clone(),
1264                total_hops: (path.len() - 1) as u32,
1265            });
1266        }
1267
1268        // Fall back to host-mediated transfer
1269        let pending = PendingK2KMessage {
1270            source_kernel: source_kernel.clone(),
1271            dest_kernel: dest_kernel.clone(),
1272            message,
1273            queued_at: Instant::now(),
1274            hops: 2, // device->host->device
1275        };
1276
1277        self.enqueue_pending(source_device, dest_device, pending);
1278        self.stats.messages_pending.fetch_add(1, Ordering::Relaxed);
1279
1280        Ok(RoutingDecision::HostMediated {
1281            source_device,
1282            dest_device,
1283        })
1284    }
1285
1286    /// Get pending messages for a device pair.
1287    pub fn drain_pending(&self, source: usize, dest: usize) -> Vec<PendingK2KMessage> {
1288        let mut queues = self.pending_queues.write();
1289        let messages = queues.remove(&(source, dest)).unwrap_or_default();
1290        self.stats
1291            .messages_pending
1292            .fetch_sub(messages.len(), Ordering::Relaxed);
1293        messages
1294    }
1295
1296    /// Record successful message delivery.
1297    pub fn record_delivery(&self, message: &PendingK2KMessage, payload_size: usize) {
1298        self.stats.messages_routed.fetch_add(1, Ordering::Relaxed);
1299        self.stats
1300            .bytes_transferred
1301            .fetch_add(payload_size as u64, Ordering::Relaxed);
1302
1303        let latency = message.queued_at.elapsed().as_micros() as u64;
1304        self.stats
1305            .total_latency_us
1306            .fetch_add(latency, Ordering::Relaxed);
1307    }
1308
1309    /// Record routing failure.
1310    pub fn record_failure(&self) {
1311        self.stats.routing_failures.fetch_add(1, Ordering::Relaxed);
1312    }
1313
1314    /// Get router statistics.
1315    pub fn stats(&self) -> CrossGpuRouterStatsSnapshot {
1316        let messages_routed = self.stats.messages_routed.load(Ordering::Relaxed);
1317        let total_latency = self.stats.total_latency_us.load(Ordering::Relaxed);
1318
1319        CrossGpuRouterStatsSnapshot {
1320            messages_routed,
1321            bytes_transferred: self.stats.bytes_transferred.load(Ordering::Relaxed),
1322            messages_pending: self.stats.messages_pending.load(Ordering::Relaxed),
1323            avg_latency_us: if messages_routed > 0 {
1324                total_latency as f64 / messages_routed as f64
1325            } else {
1326                0.0
1327            },
1328            routing_failures: self.stats.routing_failures.load(Ordering::Relaxed),
1329        }
1330    }
1331
1332    fn enqueue_pending(&self, source: usize, dest: usize, message: PendingK2KMessage) {
1333        let mut queues = self.pending_queues.write();
1334        queues.entry((source, dest)).or_default().push(message);
1335    }
1336}
1337
1338/// Snapshot of router statistics.
1339#[derive(Debug, Clone)]
1340pub struct CrossGpuRouterStatsSnapshot {
1341    /// Total messages successfully routed.
1342    pub messages_routed: u64,
1343    /// Total bytes transferred.
1344    pub bytes_transferred: u64,
1345    /// Messages currently pending.
1346    pub messages_pending: usize,
1347    /// Average routing latency in microseconds.
1348    pub avg_latency_us: f64,
1349    /// Total routing failures.
1350    pub routing_failures: u64,
1351}
1352
1353/// Decision for how to route a K2K message.
1354#[derive(Debug, Clone)]
1355pub enum RoutingDecision {
1356    /// Source and destination on same device.
1357    SameDevice,
1358    /// Direct peer-to-peer transfer.
1359    DirectP2P {
1360        /// Source device index.
1361        source_device: usize,
1362        /// Destination device index.
1363        dest_device: usize,
1364        /// Available bandwidth.
1365        bandwidth_gbps: f64,
1366    },
1367    /// Multi-hop routing through intermediate devices.
1368    MultiHop {
1369        /// Device path.
1370        path: Vec<usize>,
1371        /// Total number of hops.
1372        total_hops: u32,
1373    },
1374    /// Route through host memory (slowest).
1375    HostMediated {
1376        /// Source device index.
1377        source_device: usize,
1378        /// Destination device index.
1379        dest_device: usize,
1380    },
1381}
1382
1383/// Multi-GPU coordinator statistics.
1384#[derive(Debug, Clone, Default)]
1385pub struct MultiGpuStats {
1386    /// Number of registered devices.
1387    pub device_count: usize,
1388    /// Total kernels across all devices.
1389    pub total_kernels: usize,
1390    /// Total memory across all devices.
1391    pub total_memory: u64,
1392    /// Available memory across all devices.
1393    pub available_memory: u64,
1394    /// Total kernels launched since start.
1395    pub kernels_launched: u64,
1396}
1397
1398/// Builder for multi-GPU coordinator.
1399pub struct MultiGpuBuilder {
1400    config: MultiGpuConfig,
1401}
1402
1403impl MultiGpuBuilder {
1404    /// Create a new builder.
1405    pub fn new() -> Self {
1406        Self {
1407            config: MultiGpuConfig::default(),
1408        }
1409    }
1410
1411    /// Set load balancing strategy.
1412    pub fn load_balancing(mut self, strategy: LoadBalancingStrategy) -> Self {
1413        self.config.load_balancing = strategy;
1414        self
1415    }
1416
1417    /// Set auto device selection.
1418    pub fn auto_select_device(mut self, enable: bool) -> Self {
1419        self.config.auto_select_device = enable;
1420        self
1421    }
1422
1423    /// Set max kernels per device.
1424    pub fn max_kernels_per_device(mut self, max: usize) -> Self {
1425        self.config.max_kernels_per_device = max;
1426        self
1427    }
1428
1429    /// Enable P2P transfers.
1430    pub fn enable_p2p(mut self, enable: bool) -> Self {
1431        self.config.enable_p2p = enable;
1432        self
1433    }
1434
1435    /// Set preferred devices.
1436    pub fn preferred_devices(mut self, devices: Vec<usize>) -> Self {
1437        self.config.preferred_devices = devices;
1438        self
1439    }
1440
1441    /// Build the coordinator.
1442    pub fn build(self) -> Arc<MultiGpuCoordinator> {
1443        MultiGpuCoordinator::new(self.config)
1444    }
1445}
1446
1447impl Default for MultiGpuBuilder {
1448    fn default() -> Self {
1449        Self::new()
1450    }
1451}
1452
1453/// Helper for cross-device data transfer.
1454pub struct CrossDeviceTransfer {
1455    /// Source device index.
1456    pub source_device: usize,
1457    /// Destination device index.
1458    pub dest_device: usize,
1459    /// Data size in bytes.
1460    pub size: usize,
1461    /// Use P2P if available.
1462    pub use_p2p: bool,
1463}
1464
1465impl CrossDeviceTransfer {
1466    /// Create a new transfer specification.
1467    pub fn new(source: usize, dest: usize, size: usize) -> Self {
1468        Self {
1469            source_device: source,
1470            dest_device: dest,
1471            size,
1472            use_p2p: true,
1473        }
1474    }
1475
1476    /// Disable P2P for this transfer.
1477    pub fn without_p2p(mut self) -> Self {
1478        self.use_p2p = false;
1479        self
1480    }
1481}
1482
1483// ============================================================================
1484// Kernel Migrator with Checkpoint Integration
1485// ============================================================================
1486
1487use crate::checkpoint::{CheckpointStorage, CheckpointableKernel, MemoryStorage};
1488
1489/// Migrator that uses checkpoints for kernel state transfer between GPUs.
1490///
1491/// This integrates the checkpoint infrastructure with the multi-GPU migration
1492/// system to enable live migration of persistent kernels.
1493///
1494/// # Example
1495///
1496/// ```ignore
1497/// use ringkernel_core::multi_gpu::{KernelMigrator, MultiGpuBuilder};
1498///
1499/// let coordinator = MultiGpuBuilder::new().build();
1500/// let migrator = KernelMigrator::new(coordinator);
1501///
1502/// // Migrate kernel from GPU 0 to GPU 1
1503/// migrator.migrate_with_checkpoint(&kernel, &mut request).await?;
1504/// ```
1505pub struct KernelMigrator {
1506    /// Multi-GPU coordinator.
1507    coordinator: Arc<MultiGpuCoordinator>,
1508    /// Checkpoint storage for migration state.
1509    storage: Arc<dyn CheckpointStorage>,
1510    /// Statistics.
1511    stats: MigrationStats,
1512}
1513
1514/// Statistics for kernel migrations.
1515#[derive(Debug, Default)]
1516pub struct MigrationStats {
1517    /// Total successful migrations.
1518    pub successful_migrations: AtomicU64,
1519    /// Total failed migrations.
1520    pub failed_migrations: AtomicU64,
1521    /// Total bytes transferred during migrations.
1522    pub bytes_transferred: AtomicU64,
1523    /// Total checkpoint time (microseconds).
1524    pub checkpoint_time_us: AtomicU64,
1525    /// Total restore time (microseconds).
1526    pub restore_time_us: AtomicU64,
1527}
1528
1529/// Result of a completed migration.
1530#[derive(Debug, Clone)]
1531pub struct MigrationResult {
1532    /// Kernel that was migrated.
1533    pub kernel_id: KernelId,
1534    /// Source device.
1535    pub source_device: usize,
1536    /// Target device.
1537    pub target_device: usize,
1538    /// Checkpoint size in bytes.
1539    pub checkpoint_size: usize,
1540    /// Time spent creating checkpoint.
1541    pub checkpoint_duration: Duration,
1542    /// Time spent transferring state.
1543    pub transfer_duration: Duration,
1544    /// Time spent restoring kernel.
1545    pub restore_duration: Duration,
1546    /// Total migration time.
1547    pub total_duration: Duration,
1548}
1549
1550impl KernelMigrator {
1551    /// Create a new kernel migrator with default in-memory storage.
1552    pub fn new(coordinator: Arc<MultiGpuCoordinator>) -> Self {
1553        Self {
1554            coordinator,
1555            storage: Arc::new(MemoryStorage::new()),
1556            stats: MigrationStats::default(),
1557        }
1558    }
1559
1560    /// Create a migrator with custom checkpoint storage.
1561    pub fn with_storage(
1562        coordinator: Arc<MultiGpuCoordinator>,
1563        storage: Arc<dyn CheckpointStorage>,
1564    ) -> Self {
1565        Self {
1566            coordinator,
1567            storage,
1568            stats: MigrationStats::default(),
1569        }
1570    }
1571
1572    /// Perform a complete migration using checkpoint-based state transfer.
1573    ///
1574    /// Steps:
1575    /// 1. Quiesce the source kernel (drain pending messages)
1576    /// 2. Create checkpoint of kernel state
1577    /// 3. Transfer checkpoint to target device
1578    /// 4. Restore kernel on target device
1579    /// 5. Update coordinator routing tables
1580    pub fn migrate_with_checkpoint<K: CheckpointableKernel>(
1581        &self,
1582        kernel: &K,
1583        request: &mut MigrationRequest,
1584    ) -> Result<MigrationResult> {
1585        let start_time = Instant::now();
1586        request.started_at = Some(start_time);
1587
1588        // Step 1: Quiesce
1589        request.state = MigrationState::Quiescing;
1590        // In a real implementation, this would drain message queues
1591        // For now, we assume the kernel is ready for checkpointing
1592
1593        // Step 2: Create checkpoint
1594        request.state = MigrationState::Checkpointing;
1595        let checkpoint_start = Instant::now();
1596        let checkpoint = kernel.create_checkpoint().map_err(|e| {
1597            self.stats.failed_migrations.fetch_add(1, Ordering::Relaxed);
1598            request.state = MigrationState::Failed;
1599            RingKernelError::MigrationFailed(format!("Checkpoint creation failed: {}", e))
1600        })?;
1601        let checkpoint_duration = checkpoint_start.elapsed();
1602        let checkpoint_size = checkpoint.total_size();
1603
1604        self.stats
1605            .checkpoint_time_us
1606            .fetch_add(checkpoint_duration.as_micros() as u64, Ordering::Relaxed);
1607
1608        // Step 3: Transfer
1609        request.state = MigrationState::Transferring;
1610        let transfer_start = Instant::now();
1611
1612        // Store checkpoint (simulates transfer)
1613        let checkpoint_name = format!(
1614            "migration_{}_{}_{}",
1615            request.kernel_id.as_str(),
1616            request.source_device,
1617            request.target_device
1618        );
1619        self.storage
1620            .save(&checkpoint, &checkpoint_name)
1621            .map_err(|e| {
1622                self.stats.failed_migrations.fetch_add(1, Ordering::Relaxed);
1623                request.state = MigrationState::Failed;
1624                RingKernelError::MigrationFailed(format!("Checkpoint transfer failed: {}", e))
1625            })?;
1626
1627        let transfer_duration = transfer_start.elapsed();
1628        self.stats
1629            .bytes_transferred
1630            .fetch_add(checkpoint_size as u64, Ordering::Relaxed);
1631
1632        // Step 4: Restore (would be done on target kernel)
1633        request.state = MigrationState::Restoring;
1634        let restore_start = Instant::now();
1635
1636        // Load checkpoint to verify it's valid
1637        let _restored = self.storage.load(&checkpoint_name).map_err(|e| {
1638            self.stats.failed_migrations.fetch_add(1, Ordering::Relaxed);
1639            request.state = MigrationState::Failed;
1640            RingKernelError::MigrationFailed(format!("Checkpoint restore failed: {}", e))
1641        })?;
1642
1643        let restore_duration = restore_start.elapsed();
1644        self.stats
1645            .restore_time_us
1646            .fetch_add(restore_duration.as_micros() as u64, Ordering::Relaxed);
1647
1648        // Step 5: Update routing
1649        request.state = MigrationState::Completed;
1650        self.coordinator.complete_migration(request)?;
1651
1652        // Clean up checkpoint
1653        let _ = self.storage.delete(&checkpoint_name);
1654
1655        self.stats
1656            .successful_migrations
1657            .fetch_add(1, Ordering::Relaxed);
1658
1659        Ok(MigrationResult {
1660            kernel_id: request.kernel_id.clone(),
1661            source_device: request.source_device,
1662            target_device: request.target_device,
1663            checkpoint_size,
1664            checkpoint_duration,
1665            transfer_duration,
1666            restore_duration,
1667            total_duration: start_time.elapsed(),
1668        })
1669    }
1670
1671    /// Get a reference to the coordinator.
1672    pub fn coordinator(&self) -> &Arc<MultiGpuCoordinator> {
1673        &self.coordinator
1674    }
1675
1676    /// Get migration statistics snapshot.
1677    pub fn stats(&self) -> MigrationStatsSnapshot {
1678        let successful = self.stats.successful_migrations.load(Ordering::Relaxed);
1679        let failed = self.stats.failed_migrations.load(Ordering::Relaxed);
1680        let total = successful + failed;
1681        let checkpoint_us = self.stats.checkpoint_time_us.load(Ordering::Relaxed);
1682        let restore_us = self.stats.restore_time_us.load(Ordering::Relaxed);
1683
1684        MigrationStatsSnapshot {
1685            successful_migrations: successful,
1686            failed_migrations: failed,
1687            bytes_transferred: self.stats.bytes_transferred.load(Ordering::Relaxed),
1688            avg_checkpoint_time: if total > 0 {
1689                Duration::from_micros(checkpoint_us / total)
1690            } else {
1691                Duration::ZERO
1692            },
1693            avg_restore_time: if total > 0 {
1694                Duration::from_micros(restore_us / total)
1695            } else {
1696                Duration::ZERO
1697            },
1698        }
1699    }
1700}
1701
1702/// Snapshot of migration statistics.
1703#[derive(Debug, Clone)]
1704pub struct MigrationStatsSnapshot {
1705    /// Total successful migrations.
1706    pub successful_migrations: u64,
1707    /// Total failed migrations.
1708    pub failed_migrations: u64,
1709    /// Total bytes transferred.
1710    pub bytes_transferred: u64,
1711    /// Average checkpoint creation time.
1712    pub avg_checkpoint_time: Duration,
1713    /// Average restore time.
1714    pub avg_restore_time: Duration,
1715}
1716
1717/// Trait for kernels that support live migration.
1718pub trait MigratableKernel: CheckpointableKernel {
1719    /// Prepare kernel for migration (quiesce, drain messages).
1720    fn prepare_for_migration(&mut self) -> Result<()>;
1721
1722    /// Resume kernel after migration is cancelled.
1723    fn cancel_migration(&mut self) -> Result<()>;
1724
1725    /// Check if kernel is ready to be checkpointed.
1726    fn is_quiescent(&self) -> bool;
1727
1728    /// Get estimated state size for migration planning.
1729    fn estimated_state_size(&self) -> usize;
1730}
1731
1732// ============================================================================
1733// Hot Reload Support
1734// ============================================================================
1735
1736/// Configuration for kernel hot reload operations.
1737#[derive(Debug, Clone)]
1738pub struct HotReloadConfig {
1739    /// Enable hot reload functionality.
1740    pub enabled: bool,
1741    /// Timeout for reload operations.
1742    pub reload_timeout: Duration,
1743    /// Whether to preserve kernel state during reload.
1744    pub preserve_state: bool,
1745    /// Maximum retries for failed reloads.
1746    pub max_retries: u32,
1747    /// Backoff duration between retries.
1748    pub retry_backoff: Duration,
1749    /// Whether to validate new code before swapping.
1750    pub validate_before_swap: bool,
1751    /// Keep old code as fallback in case of failure.
1752    pub keep_fallback: bool,
1753}
1754
1755impl Default for HotReloadConfig {
1756    fn default() -> Self {
1757        Self {
1758            enabled: true,
1759            reload_timeout: Duration::from_secs(30),
1760            preserve_state: true,
1761            max_retries: 3,
1762            retry_backoff: Duration::from_millis(500),
1763            validate_before_swap: true,
1764            keep_fallback: true,
1765        }
1766    }
1767}
1768
1769impl HotReloadConfig {
1770    /// Create a new hot reload configuration.
1771    pub fn new() -> Self {
1772        Self::default()
1773    }
1774
1775    /// Enable or disable hot reload.
1776    pub fn with_enabled(mut self, enabled: bool) -> Self {
1777        self.enabled = enabled;
1778        self
1779    }
1780
1781    /// Set reload timeout.
1782    pub fn with_timeout(mut self, timeout: Duration) -> Self {
1783        self.reload_timeout = timeout;
1784        self
1785    }
1786
1787    /// Enable or disable state preservation.
1788    pub fn with_preserve_state(mut self, preserve: bool) -> Self {
1789        self.preserve_state = preserve;
1790        self
1791    }
1792
1793    /// Set maximum retries.
1794    pub fn with_max_retries(mut self, retries: u32) -> Self {
1795        self.max_retries = retries;
1796        self
1797    }
1798}
1799
1800/// State of a hot reload operation.
1801#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1802pub enum HotReloadState {
1803    /// Reload not started.
1804    Idle,
1805    /// Draining pending messages from kernel.
1806    Draining,
1807    /// Creating checkpoint of kernel state.
1808    Checkpointing,
1809    /// Compiling new kernel code.
1810    Compiling,
1811    /// Validating new kernel code.
1812    Validating,
1813    /// Swapping old kernel with new.
1814    Swapping,
1815    /// Restoring state to new kernel.
1816    Restoring,
1817    /// Hot reload completed successfully.
1818    Completed,
1819    /// Hot reload failed.
1820    Failed,
1821    /// Rolling back to previous version.
1822    RollingBack,
1823}
1824
1825/// Kernel code format.
1826#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1827pub enum KernelCodeFormat {
1828    /// NVIDIA PTX assembly.
1829    Ptx,
1830    /// NVIDIA CUBIN binary.
1831    Cubin,
1832    /// SPIR-V for Vulkan/WebGPU.
1833    SpirV,
1834    /// WGSL shader text.
1835    Wgsl,
1836    /// Metal Shading Language.
1837    Msl,
1838    /// Metal compiled library.
1839    MetalLib,
1840    /// Source code (requires compilation).
1841    Source,
1842}
1843
1844/// Kernel code source for hot reload.
1845#[derive(Debug, Clone)]
1846pub struct KernelCodeSource {
1847    /// Unique identifier for this code version.
1848    pub version_id: u64,
1849    /// Code format.
1850    pub format: KernelCodeFormat,
1851    /// Raw code bytes.
1852    pub code: Vec<u8>,
1853    /// Entry point function name.
1854    pub entry_point: String,
1855    /// Optional metadata (compile flags, etc.).
1856    pub metadata: HashMap<String, String>,
1857    /// Timestamp when code was created.
1858    pub created_at: Instant,
1859    /// SHA-256 hash of the code.
1860    pub hash: [u8; 32],
1861}
1862
1863impl KernelCodeSource {
1864    /// Create a new kernel code source.
1865    pub fn new(format: KernelCodeFormat, code: Vec<u8>, entry_point: impl Into<String>) -> Self {
1866        let hash = Self::compute_hash(&code);
1867        Self {
1868            version_id: 0,
1869            format,
1870            code,
1871            entry_point: entry_point.into(),
1872            metadata: HashMap::new(),
1873            created_at: Instant::now(),
1874            hash,
1875        }
1876    }
1877
1878    /// Create from PTX code.
1879    pub fn from_ptx(ptx: &str, entry_point: impl Into<String>) -> Self {
1880        Self::new(KernelCodeFormat::Ptx, ptx.as_bytes().to_vec(), entry_point)
1881    }
1882
1883    /// Create from WGSL code.
1884    pub fn from_wgsl(wgsl: &str, entry_point: impl Into<String>) -> Self {
1885        Self::new(
1886            KernelCodeFormat::Wgsl,
1887            wgsl.as_bytes().to_vec(),
1888            entry_point,
1889        )
1890    }
1891
1892    /// Create from MSL code.
1893    pub fn from_msl(msl: &str, entry_point: impl Into<String>) -> Self {
1894        Self::new(KernelCodeFormat::Msl, msl.as_bytes().to_vec(), entry_point)
1895    }
1896
1897    /// Set version ID.
1898    pub fn with_version(mut self, version: u64) -> Self {
1899        self.version_id = version;
1900        self
1901    }
1902
1903    /// Add metadata.
1904    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
1905        self.metadata.insert(key.into(), value.into());
1906        self
1907    }
1908
1909    fn compute_hash(data: &[u8]) -> [u8; 32] {
1910        use std::hash::{Hash, Hasher};
1911        let mut hasher = std::collections::hash_map::DefaultHasher::new();
1912        data.hash(&mut hasher);
1913        let h1 = hasher.finish();
1914        h1.hash(&mut hasher);
1915        let h2 = hasher.finish();
1916        h1.hash(&mut hasher);
1917        let h3 = hasher.finish();
1918        h1.hash(&mut hasher);
1919        let h4 = hasher.finish();
1920
1921        let mut hash = [0u8; 32];
1922        hash[0..8].copy_from_slice(&h1.to_le_bytes());
1923        hash[8..16].copy_from_slice(&h2.to_le_bytes());
1924        hash[16..24].copy_from_slice(&h3.to_le_bytes());
1925        hash[24..32].copy_from_slice(&h4.to_le_bytes());
1926        hash
1927    }
1928
1929    /// Get code as string (if text format).
1930    pub fn as_str(&self) -> Option<&str> {
1931        match self.format {
1932            KernelCodeFormat::Ptx
1933            | KernelCodeFormat::Wgsl
1934            | KernelCodeFormat::Msl
1935            | KernelCodeFormat::Source => std::str::from_utf8(&self.code).ok(),
1936            _ => None,
1937        }
1938    }
1939
1940    /// Get code size in bytes.
1941    pub fn size(&self) -> usize {
1942        self.code.len()
1943    }
1944}
1945
1946/// Request to hot reload a kernel.
1947#[derive(Debug)]
1948pub struct HotReloadRequest {
1949    /// Target kernel ID.
1950    pub kernel_id: KernelId,
1951    /// New kernel code.
1952    pub new_code: KernelCodeSource,
1953    /// Current state of the reload operation.
1954    pub state: HotReloadState,
1955    /// When the request was created.
1956    pub created_at: Instant,
1957    /// When the reload started.
1958    pub started_at: Option<Instant>,
1959    /// Retry count.
1960    pub retry_count: u32,
1961    /// Error message if failed.
1962    pub error: Option<String>,
1963    /// Checkpoint data (if preserving state).
1964    checkpoint_data: Option<Vec<u8>>,
1965}
1966
1967impl HotReloadRequest {
1968    /// Create a new hot reload request.
1969    pub fn new(kernel_id: KernelId, new_code: KernelCodeSource) -> Self {
1970        Self {
1971            kernel_id,
1972            new_code,
1973            state: HotReloadState::Idle,
1974            created_at: Instant::now(),
1975            started_at: None,
1976            retry_count: 0,
1977            error: None,
1978            checkpoint_data: None,
1979        }
1980    }
1981
1982    /// Check if reload is in progress.
1983    pub fn is_in_progress(&self) -> bool {
1984        !matches!(
1985            self.state,
1986            HotReloadState::Idle | HotReloadState::Completed | HotReloadState::Failed
1987        )
1988    }
1989
1990    /// Check if reload completed successfully.
1991    pub fn is_completed(&self) -> bool {
1992        self.state == HotReloadState::Completed
1993    }
1994
1995    /// Check if reload failed.
1996    pub fn is_failed(&self) -> bool {
1997        self.state == HotReloadState::Failed
1998    }
1999
2000    /// Get elapsed time since request creation.
2001    pub fn elapsed(&self) -> Duration {
2002        self.created_at.elapsed()
2003    }
2004
2005    /// Get elapsed time since reload started.
2006    pub fn reload_elapsed(&self) -> Option<Duration> {
2007        self.started_at.map(|s| s.elapsed())
2008    }
2009}
2010
2011/// Result of a completed hot reload.
2012#[derive(Debug, Clone)]
2013pub struct HotReloadResult {
2014    /// Target kernel ID.
2015    pub kernel_id: KernelId,
2016    /// Previous code version.
2017    pub old_version: u64,
2018    /// New code version.
2019    pub new_version: u64,
2020    /// Whether state was preserved.
2021    pub state_preserved: bool,
2022    /// Size of checkpoint data (if any).
2023    pub checkpoint_size: usize,
2024    /// Time to drain messages.
2025    pub drain_duration: Duration,
2026    /// Time to create checkpoint.
2027    pub checkpoint_duration: Duration,
2028    /// Time to compile new code.
2029    pub compile_duration: Duration,
2030    /// Time to swap kernels.
2031    pub swap_duration: Duration,
2032    /// Time to restore state.
2033    pub restore_duration: Duration,
2034    /// Total reload duration.
2035    pub total_duration: Duration,
2036}
2037
2038/// Statistics for hot reload operations.
2039#[derive(Debug, Default)]
2040struct HotReloadStats {
2041    successful_reloads: AtomicU64,
2042    failed_reloads: AtomicU64,
2043    rollbacks: AtomicU64,
2044    total_drain_time_us: AtomicU64,
2045    total_compile_time_us: AtomicU64,
2046    total_swap_time_us: AtomicU64,
2047    state_preserved_count: AtomicU64,
2048}
2049
2050/// Snapshot of hot reload statistics.
2051#[derive(Debug, Clone)]
2052pub struct HotReloadStatsSnapshot {
2053    /// Total successful reloads.
2054    pub successful_reloads: u64,
2055    /// Total failed reloads.
2056    pub failed_reloads: u64,
2057    /// Total rollbacks performed.
2058    pub rollbacks: u64,
2059    /// Average drain time.
2060    pub avg_drain_time: Duration,
2061    /// Average compile time.
2062    pub avg_compile_time: Duration,
2063    /// Average swap time.
2064    pub avg_swap_time: Duration,
2065    /// Number of reloads with preserved state.
2066    pub state_preserved_count: u64,
2067}
2068
2069/// Manager for kernel hot reload operations.
2070///
2071/// Provides seamless kernel code updates without stopping the system:
2072///
2073/// 1. Drain pending messages from kernel input queue
2074/// 2. Checkpoint kernel state (if preserving state)
2075/// 3. Compile/validate new kernel code
2076/// 4. Swap old kernel with new kernel
2077/// 5. Restore state to new kernel
2078/// 6. Resume processing
2079///
2080/// # Example
2081///
2082/// ```ignore
2083/// use ringkernel_core::multi_gpu::{HotReloadManager, HotReloadConfig, KernelCodeSource};
2084///
2085/// let manager = HotReloadManager::new(HotReloadConfig::default());
2086///
2087/// // Register a reloadable kernel
2088/// manager.register_kernel(&kernel_id, current_code);
2089///
2090/// // Request hot reload with new PTX
2091/// let new_code = KernelCodeSource::from_ptx(new_ptx, "my_kernel");
2092/// let request = manager.request_reload(&kernel_id, new_code).await?;
2093///
2094/// // Execute the reload
2095/// let result = manager.execute_reload(request, &mut kernel).await?;
2096/// println!("Reload completed in {:?}", result.total_duration);
2097/// ```
2098pub struct HotReloadManager {
2099    /// Configuration.
2100    config: HotReloadConfig,
2101    /// Registered kernels and their current code.
2102    kernels: RwLock<HashMap<KernelId, KernelCodeSource>>,
2103    /// Fallback code for registered kernels.
2104    fallbacks: RwLock<HashMap<KernelId, KernelCodeSource>>,
2105    /// Active reload requests.
2106    active_requests: RwLock<HashMap<KernelId, HotReloadRequest>>,
2107    /// Version counter for code versions.
2108    version_counter: AtomicU64,
2109    /// Statistics.
2110    stats: HotReloadStats,
2111}
2112
2113impl HotReloadManager {
2114    /// Create a new hot reload manager.
2115    pub fn new(config: HotReloadConfig) -> Arc<Self> {
2116        Arc::new(Self {
2117            config,
2118            kernels: RwLock::new(HashMap::new()),
2119            fallbacks: RwLock::new(HashMap::new()),
2120            active_requests: RwLock::new(HashMap::new()),
2121            version_counter: AtomicU64::new(1),
2122            stats: HotReloadStats::default(),
2123        })
2124    }
2125
2126    /// Create with default configuration.
2127    pub fn with_defaults() -> Arc<Self> {
2128        Self::new(HotReloadConfig::default())
2129    }
2130
2131    /// Check if hot reload is enabled.
2132    pub fn is_enabled(&self) -> bool {
2133        self.config.enabled
2134    }
2135
2136    /// Register a kernel for hot reload.
2137    pub fn register_kernel(&self, kernel_id: &KernelId, code: KernelCodeSource) {
2138        let version = self.version_counter.fetch_add(1, Ordering::Relaxed);
2139        let code = code.with_version(version);
2140        self.kernels.write().insert(kernel_id.clone(), code);
2141    }
2142
2143    /// Unregister a kernel from hot reload.
2144    pub fn unregister_kernel(&self, kernel_id: &KernelId) {
2145        self.kernels.write().remove(kernel_id);
2146        self.fallbacks.write().remove(kernel_id);
2147        self.active_requests.write().remove(kernel_id);
2148    }
2149
2150    /// Get current code version for a kernel.
2151    pub fn get_current_version(&self, kernel_id: &KernelId) -> Option<u64> {
2152        self.kernels.read().get(kernel_id).map(|c| c.version_id)
2153    }
2154
2155    /// Get current code for a kernel.
2156    pub fn get_current_code(&self, kernel_id: &KernelId) -> Option<KernelCodeSource> {
2157        self.kernels.read().get(kernel_id).cloned()
2158    }
2159
2160    /// Request a hot reload for a kernel.
2161    pub fn request_reload(
2162        &self,
2163        kernel_id: &KernelId,
2164        new_code: KernelCodeSource,
2165    ) -> Result<HotReloadRequest> {
2166        if !self.config.enabled {
2167            return Err(RingKernelError::ValidationError(
2168                "Hot reload is disabled".to_string(),
2169            ));
2170        }
2171
2172        // Check kernel is registered
2173        if !self.kernels.read().contains_key(kernel_id) {
2174            return Err(RingKernelError::KernelNotFound(
2175                kernel_id.as_str().to_string(),
2176            ));
2177        }
2178
2179        // Check no reload already in progress
2180        {
2181            let active = self.active_requests.read();
2182            if let Some(existing) = active.get(kernel_id) {
2183                if existing.is_in_progress() {
2184                    return Err(RingKernelError::ValidationError(
2185                        "Hot reload already in progress for this kernel".to_string(),
2186                    ));
2187                }
2188            }
2189        }
2190
2191        // Assign version to new code
2192        let version = self.version_counter.fetch_add(1, Ordering::Relaxed);
2193        let new_code = new_code.with_version(version);
2194
2195        let request = HotReloadRequest::new(kernel_id.clone(), new_code);
2196        self.active_requests.write().insert(
2197            kernel_id.clone(),
2198            HotReloadRequest::new(kernel_id.clone(), request.new_code.clone()),
2199        );
2200
2201        Ok(request)
2202    }
2203
2204    /// Execute a hot reload operation.
2205    ///
2206    /// This performs the full reload sequence:
2207    /// 1. Drain pending messages
2208    /// 2. Checkpoint state (if enabled)
2209    /// 3. Validate new code
2210    /// 4. Swap kernels
2211    /// 5. Restore state (if enabled)
2212    pub fn execute_reload<K: CheckpointableKernel>(
2213        &self,
2214        request: &mut HotReloadRequest,
2215        kernel: &K,
2216    ) -> Result<HotReloadResult> {
2217        let start_time = Instant::now();
2218        request.started_at = Some(start_time);
2219
2220        // Get old version
2221        let old_version = self
2222            .kernels
2223            .read()
2224            .get(&request.kernel_id)
2225            .map(|c| c.version_id)
2226            .unwrap_or(0);
2227
2228        // Phase 1: Drain (simulated - actual drain would wait for queue empty)
2229        request.state = HotReloadState::Draining;
2230        let drain_start = Instant::now();
2231        // In a real implementation, wait for input queue to drain
2232        std::thread::sleep(Duration::from_micros(10));
2233        let drain_duration = drain_start.elapsed();
2234        self.stats
2235            .total_drain_time_us
2236            .fetch_add(drain_duration.as_micros() as u64, Ordering::Relaxed);
2237
2238        // Phase 2: Checkpoint (if preserving state)
2239        request.state = HotReloadState::Checkpointing;
2240        let checkpoint_start = Instant::now();
2241        let checkpoint_size = if self.config.preserve_state {
2242            let checkpoint = kernel.create_checkpoint()?;
2243            let data = checkpoint.to_bytes();
2244            request.checkpoint_data = Some(data.clone());
2245            data.len()
2246        } else {
2247            0
2248        };
2249        let checkpoint_duration = checkpoint_start.elapsed();
2250
2251        // Phase 3: Validate new code
2252        request.state = HotReloadState::Validating;
2253        if self.config.validate_before_swap {
2254            self.validate_code(&request.new_code)?;
2255        }
2256
2257        // Phase 4: Compile (simulated)
2258        request.state = HotReloadState::Compiling;
2259        let compile_start = Instant::now();
2260        // In real implementation, compile PTX/WGSL to native code
2261        std::thread::sleep(Duration::from_micros(10));
2262        let compile_duration = compile_start.elapsed();
2263        self.stats
2264            .total_compile_time_us
2265            .fetch_add(compile_duration.as_micros() as u64, Ordering::Relaxed);
2266
2267        // Phase 5: Swap
2268        request.state = HotReloadState::Swapping;
2269        let swap_start = Instant::now();
2270
2271        // Save fallback
2272        if self.config.keep_fallback {
2273            if let Some(old_code) = self.kernels.read().get(&request.kernel_id).cloned() {
2274                self.fallbacks
2275                    .write()
2276                    .insert(request.kernel_id.clone(), old_code);
2277            }
2278        }
2279
2280        // Install new code
2281        self.kernels
2282            .write()
2283            .insert(request.kernel_id.clone(), request.new_code.clone());
2284        let swap_duration = swap_start.elapsed();
2285        self.stats
2286            .total_swap_time_us
2287            .fetch_add(swap_duration.as_micros() as u64, Ordering::Relaxed);
2288
2289        // Phase 6: Restore state
2290        request.state = HotReloadState::Restoring;
2291        let restore_start = Instant::now();
2292        // In real implementation, restore checkpoint to new kernel
2293        let restore_duration = restore_start.elapsed();
2294
2295        // Mark completed
2296        request.state = HotReloadState::Completed;
2297        self.stats
2298            .successful_reloads
2299            .fetch_add(1, Ordering::Relaxed);
2300        if self.config.preserve_state && checkpoint_size > 0 {
2301            self.stats
2302                .state_preserved_count
2303                .fetch_add(1, Ordering::Relaxed);
2304        }
2305
2306        // Clean up active request
2307        self.active_requests.write().remove(&request.kernel_id);
2308
2309        Ok(HotReloadResult {
2310            kernel_id: request.kernel_id.clone(),
2311            old_version,
2312            new_version: request.new_code.version_id,
2313            state_preserved: self.config.preserve_state && checkpoint_size > 0,
2314            checkpoint_size,
2315            drain_duration,
2316            checkpoint_duration,
2317            compile_duration,
2318            swap_duration,
2319            restore_duration,
2320            total_duration: start_time.elapsed(),
2321        })
2322    }
2323
2324    /// Rollback to previous kernel version.
2325    pub fn rollback(&self, kernel_id: &KernelId) -> Result<()> {
2326        let fallback =
2327            self.fallbacks.write().remove(kernel_id).ok_or_else(|| {
2328                RingKernelError::ValidationError("No fallback available".to_string())
2329            })?;
2330
2331        self.kernels.write().insert(kernel_id.clone(), fallback);
2332        self.stats.rollbacks.fetch_add(1, Ordering::Relaxed);
2333
2334        // Update any active request
2335        if let Some(request) = self.active_requests.write().get_mut(kernel_id) {
2336            request.state = HotReloadState::RollingBack;
2337        }
2338
2339        Ok(())
2340    }
2341
2342    /// Validate kernel code before swap.
2343    fn validate_code(&self, code: &KernelCodeSource) -> Result<()> {
2344        // Basic validation
2345        if code.code.is_empty() {
2346            return Err(RingKernelError::ValidationError(
2347                "Kernel code is empty".to_string(),
2348            ));
2349        }
2350
2351        if code.entry_point.is_empty() {
2352            return Err(RingKernelError::ValidationError(
2353                "Entry point is empty".to_string(),
2354            ));
2355        }
2356
2357        // Format-specific validation
2358        match code.format {
2359            KernelCodeFormat::Ptx => {
2360                // Check for valid PTX header
2361                if let Some(text) = code.as_str() {
2362                    if !text.contains(".version") && !text.contains(".target") {
2363                        return Err(RingKernelError::ValidationError(
2364                            "PTX code missing version/target directive".to_string(),
2365                        ));
2366                    }
2367                }
2368            }
2369            KernelCodeFormat::Wgsl => {
2370                // Check for basic WGSL structure
2371                if let Some(text) = code.as_str() {
2372                    if !text.contains("@compute") && !text.contains("fn ") {
2373                        return Err(RingKernelError::ValidationError(
2374                            "WGSL code missing compute shader or function".to_string(),
2375                        ));
2376                    }
2377                }
2378            }
2379            KernelCodeFormat::Msl => {
2380                // Check for Metal kernel
2381                if let Some(text) = code.as_str() {
2382                    if !text.contains("kernel ") {
2383                        return Err(RingKernelError::ValidationError(
2384                            "MSL code missing kernel function".to_string(),
2385                        ));
2386                    }
2387                }
2388            }
2389            _ => {}
2390        }
2391
2392        Ok(())
2393    }
2394
2395    /// Get statistics snapshot.
2396    pub fn stats(&self) -> HotReloadStatsSnapshot {
2397        let successful = self.stats.successful_reloads.load(Ordering::Relaxed);
2398        let failed = self.stats.failed_reloads.load(Ordering::Relaxed);
2399        let total = successful.max(1);
2400
2401        HotReloadStatsSnapshot {
2402            successful_reloads: successful,
2403            failed_reloads: failed,
2404            rollbacks: self.stats.rollbacks.load(Ordering::Relaxed),
2405            avg_drain_time: Duration::from_micros(
2406                self.stats.total_drain_time_us.load(Ordering::Relaxed) / total,
2407            ),
2408            avg_compile_time: Duration::from_micros(
2409                self.stats.total_compile_time_us.load(Ordering::Relaxed) / total,
2410            ),
2411            avg_swap_time: Duration::from_micros(
2412                self.stats.total_swap_time_us.load(Ordering::Relaxed) / total,
2413            ),
2414            state_preserved_count: self.stats.state_preserved_count.load(Ordering::Relaxed),
2415        }
2416    }
2417
2418    /// List all registered kernels.
2419    pub fn list_kernels(&self) -> Vec<KernelId> {
2420        self.kernels.read().keys().cloned().collect()
2421    }
2422
2423    /// Check if a kernel is registered.
2424    pub fn is_registered(&self, kernel_id: &KernelId) -> bool {
2425        self.kernels.read().contains_key(kernel_id)
2426    }
2427
2428    /// Check if a reload is in progress for a kernel.
2429    pub fn is_reload_in_progress(&self, kernel_id: &KernelId) -> bool {
2430        self.active_requests
2431            .read()
2432            .get(kernel_id)
2433            .map(|r| r.is_in_progress())
2434            .unwrap_or(false)
2435    }
2436
2437    /// Get the configuration.
2438    pub fn config(&self) -> &HotReloadConfig {
2439        &self.config
2440    }
2441}
2442
2443/// Trait for kernels that support hot reload.
2444pub trait HotReloadableKernel: CheckpointableKernel {
2445    /// Prepare kernel for code swap (drain messages, pause processing).
2446    fn prepare_for_reload(&mut self) -> Result<()>;
2447
2448    /// Apply new code to the kernel.
2449    fn apply_code(&mut self, code: &KernelCodeSource) -> Result<()>;
2450
2451    /// Resume processing after reload.
2452    fn resume_after_reload(&mut self) -> Result<()>;
2453
2454    /// Check if kernel is ready for reload.
2455    fn is_ready_for_reload(&self) -> bool;
2456}
2457
2458#[cfg(test)]
2459mod tests {
2460    use super::*;
2461
2462    #[test]
2463    fn test_device_info() {
2464        let info = DeviceInfo::new(0, "Test GPU".to_string(), Backend::Cuda);
2465        assert_eq!(info.index, 0);
2466        assert_eq!(info.name, "Test GPU");
2467        assert_eq!(info.memory_utilization(), 0.0);
2468    }
2469
2470    #[test]
2471    fn test_coordinator_registration() {
2472        let coord = MultiGpuBuilder::new().build();
2473
2474        let device = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2475        coord.register_device(device);
2476
2477        assert_eq!(coord.device_count(), 1);
2478        assert!(coord.device(0).is_some());
2479    }
2480
2481    #[test]
2482    fn test_kernel_assignment() {
2483        let coord = MultiGpuBuilder::new().build();
2484
2485        let device = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2486        coord.register_device(device);
2487
2488        let kernel_id = KernelId::new("test_kernel");
2489        coord.assign_kernel(kernel_id.clone(), 0);
2490
2491        assert_eq!(coord.get_kernel_device(&kernel_id), Some(0));
2492        assert_eq!(coord.kernels_on_device(0).len(), 1);
2493    }
2494
2495    #[test]
2496    fn test_load_balancing_least_loaded() {
2497        let coord = MultiGpuBuilder::new()
2498            .load_balancing(LoadBalancingStrategy::LeastLoaded)
2499            .build();
2500
2501        // Register two devices
2502        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2503        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2504
2505        // Assign a kernel to device 0
2506        coord.assign_kernel(KernelId::new("k1"), 0);
2507
2508        // Next kernel should go to device 1 (least loaded)
2509        let selected = coord.select_device(&LaunchOptions::default()).unwrap();
2510        assert_eq!(selected, 1);
2511    }
2512
2513    #[test]
2514    fn test_round_robin() {
2515        let coord = MultiGpuBuilder::new()
2516            .load_balancing(LoadBalancingStrategy::RoundRobin)
2517            .build();
2518
2519        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2520        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2521
2522        let d1 = coord.select_device(&LaunchOptions::default()).unwrap();
2523        let d2 = coord.select_device(&LaunchOptions::default()).unwrap();
2524        let d3 = coord.select_device(&LaunchOptions::default()).unwrap();
2525
2526        // Should cycle through devices
2527        assert_ne!(d1, d2);
2528        assert_eq!(d1, d3);
2529    }
2530
2531    // ========================================================================
2532    // Topology Tests
2533    // ========================================================================
2534
2535    #[test]
2536    fn test_interconnect_bandwidth() {
2537        assert!(
2538            InterconnectType::NvLink.estimated_bandwidth_gbps()
2539                > InterconnectType::Pcie.estimated_bandwidth_gbps()
2540        );
2541        assert!(
2542            InterconnectType::Pcie.estimated_bandwidth_gbps()
2543                > InterconnectType::None.estimated_bandwidth_gbps()
2544        );
2545        assert!(
2546            InterconnectType::SameDevice.estimated_bandwidth_gbps()
2547                > InterconnectType::NvLink.estimated_bandwidth_gbps()
2548        );
2549    }
2550
2551    #[test]
2552    fn test_interconnect_p2p_support() {
2553        assert!(!InterconnectType::None.supports_p2p());
2554        assert!(InterconnectType::Pcie.supports_p2p());
2555        assert!(InterconnectType::NvLink.supports_p2p());
2556        assert!(InterconnectType::NvSwitch.supports_p2p());
2557    }
2558
2559    #[test]
2560    fn test_gpu_topology_creation() {
2561        let topo = GpuTopology::new(4);
2562        assert_eq!(topo.device_count, 4);
2563
2564        // Self-connections should exist
2565        for i in 0..4 {
2566            let conn = topo.get_connection(i, i);
2567            assert!(conn.is_some());
2568            assert_eq!(conn.unwrap().interconnect, InterconnectType::SameDevice);
2569        }
2570    }
2571
2572    #[test]
2573    fn test_gpu_topology_set_connection() {
2574        let mut topo = GpuTopology::new(4);
2575
2576        // Set NVLink between GPU 0 and 1
2577        topo.set_connection(GpuConnection::new(0, 1, InterconnectType::NvLink));
2578
2579        let conn_01 = topo.get_connection(0, 1);
2580        assert!(conn_01.is_some());
2581        assert_eq!(conn_01.unwrap().interconnect, InterconnectType::NvLink);
2582
2583        // Bidirectional by default
2584        let conn_10 = topo.get_connection(1, 0);
2585        assert!(conn_10.is_some());
2586        assert_eq!(conn_10.unwrap().interconnect, InterconnectType::NvLink);
2587    }
2588
2589    #[test]
2590    fn test_gpu_topology_neighbors() {
2591        let mut topo = GpuTopology::new(4);
2592
2593        // Ring topology: 0-1-2-3-0
2594        topo.set_connection(GpuConnection::new(0, 1, InterconnectType::NvLink));
2595        topo.set_connection(GpuConnection::new(1, 2, InterconnectType::NvLink));
2596        topo.set_connection(GpuConnection::new(2, 3, InterconnectType::NvLink));
2597        topo.set_connection(GpuConnection::new(3, 0, InterconnectType::NvLink));
2598
2599        let neighbors_0 = topo.neighbors(0);
2600        assert_eq!(neighbors_0.len(), 2);
2601        assert!(neighbors_0.contains(&1));
2602        assert!(neighbors_0.contains(&3));
2603    }
2604
2605    #[test]
2606    fn test_gpu_topology_best_path() {
2607        let mut topo = GpuTopology::new(4);
2608
2609        // Create connections: 0-1, 1-2, 2-3 (no direct 0-3)
2610        topo.set_connection(GpuConnection::new(0, 1, InterconnectType::NvLink));
2611        topo.set_connection(GpuConnection::new(1, 2, InterconnectType::NvLink));
2612        topo.set_connection(GpuConnection::new(2, 3, InterconnectType::NvLink));
2613        topo.set_connection(GpuConnection::new(0, 3, InterconnectType::None)); // No direct P2P
2614
2615        // Direct path should work for adjacent nodes
2616        let path_01 = topo.best_path(0, 1);
2617        assert_eq!(path_01, vec![0, 1]);
2618
2619        // Same device
2620        let path_00 = topo.best_path(0, 0);
2621        assert_eq!(path_00, vec![0]);
2622    }
2623
2624    #[test]
2625    fn test_gpu_topology_fully_connected() {
2626        let mut topo = GpuTopology::new(3);
2627
2628        // Not fully connected initially
2629        assert!(!topo.is_fully_connected());
2630
2631        // Make fully connected mesh
2632        topo.set_connection(GpuConnection::new(0, 1, InterconnectType::NvLink));
2633        topo.set_connection(GpuConnection::new(0, 2, InterconnectType::NvLink));
2634        topo.set_connection(GpuConnection::new(1, 2, InterconnectType::NvLink));
2635
2636        assert!(topo.is_fully_connected());
2637    }
2638
2639    #[test]
2640    fn test_gpu_topology_numa() {
2641        let mut topo = GpuTopology::new(4);
2642
2643        // GPUs 0,1 on NUMA 0; GPUs 2,3 on NUMA 1
2644        topo.set_numa_node(0, 0);
2645        topo.set_numa_node(1, 0);
2646        topo.set_numa_node(2, 1);
2647        topo.set_numa_node(3, 1);
2648
2649        let numa_neighbors_0 = topo.numa_neighbors(0);
2650        assert_eq!(numa_neighbors_0, vec![1]);
2651
2652        let numa_neighbors_2 = topo.numa_neighbors(2);
2653        assert_eq!(numa_neighbors_2, vec![3]);
2654    }
2655
2656    // ========================================================================
2657    // Topology Discovery Tests
2658    // ========================================================================
2659
2660    #[test]
2661    fn test_coordinator_topology_discovery() {
2662        let coord = MultiGpuBuilder::new().enable_p2p(true).build();
2663
2664        // Register P2P capable devices
2665        let mut dev0 = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2666        dev0.p2p_capable = true;
2667        dev0.compute_capability = Some((8, 0)); // Ampere
2668
2669        let mut dev1 = DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda);
2670        dev1.p2p_capable = true;
2671        dev1.compute_capability = Some((8, 6)); // Ampere
2672
2673        coord.register_device(dev0);
2674        coord.register_device(dev1);
2675
2676        let topo = coord.discover_topology();
2677
2678        assert_eq!(topo.device_count, 2);
2679
2680        // Should detect NVLink for Ampere GPUs
2681        let conn = topo.get_connection(0, 1);
2682        assert!(conn.is_some());
2683        assert_eq!(conn.unwrap().interconnect, InterconnectType::NvLink);
2684    }
2685
2686    // ========================================================================
2687    // Migration Tests
2688    // ========================================================================
2689
2690    #[test]
2691    fn test_migration_request() {
2692        let coord = MultiGpuBuilder::new().build();
2693
2694        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2695        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2696
2697        let kernel_id = KernelId::new("migrating_kernel");
2698        coord.assign_kernel(kernel_id.clone(), 0);
2699
2700        let request = coord.request_migration(&kernel_id, 1).unwrap();
2701
2702        assert_eq!(request.source_device, 0);
2703        assert_eq!(request.target_device, 1);
2704        assert_eq!(request.state, MigrationState::Pending);
2705    }
2706
2707    #[test]
2708    fn test_migration_same_device_error() {
2709        let coord = MultiGpuBuilder::new().build();
2710
2711        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2712
2713        let kernel_id = KernelId::new("kernel");
2714        coord.assign_kernel(kernel_id.clone(), 0);
2715
2716        let result = coord.request_migration(&kernel_id, 0);
2717        assert!(result.is_err());
2718    }
2719
2720    #[test]
2721    fn test_migration_complete() {
2722        let coord = MultiGpuBuilder::new().build();
2723
2724        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2725        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2726
2727        let kernel_id = KernelId::new("migrating_kernel");
2728        coord.assign_kernel(kernel_id.clone(), 0);
2729
2730        assert_eq!(coord.get_kernel_device(&kernel_id), Some(0));
2731
2732        let request = coord.request_migration(&kernel_id, 1).unwrap();
2733        coord.complete_migration(&request).unwrap();
2734
2735        assert_eq!(coord.get_kernel_device(&kernel_id), Some(1));
2736    }
2737
2738    #[test]
2739    fn test_migration_transfer_time_estimate() {
2740        let request = MigrationRequest {
2741            kernel_id: KernelId::new("test"),
2742            source_device: 0,
2743            target_device: 1,
2744            path: vec![0, 1],
2745            estimated_bandwidth_gbps: 300.0, // NVLink
2746            estimated_latency_us: 1.0,
2747            state: MigrationState::Pending,
2748            started_at: None,
2749        };
2750
2751        // 1GB transfer at 300GB/s = ~3.3ms + 1us latency
2752        let time = request.estimate_transfer_time(1_000_000_000);
2753        assert!(time.as_micros() > 3000);
2754        assert!(time.as_micros() < 4000);
2755    }
2756
2757    // ========================================================================
2758    // Cross-GPU K2K Router Tests
2759    // ========================================================================
2760
2761    use crate::hlc::HlcTimestamp;
2762    use crate::message::MessageEnvelope;
2763
2764    fn make_test_k2k_message(source: &KernelId, dest: &KernelId) -> K2KMessage {
2765        let timestamp = HlcTimestamp::now(42);
2766        let envelope = MessageEnvelope::empty(1, 2, timestamp);
2767        K2KMessage::new(source.clone(), dest.clone(), envelope, timestamp)
2768    }
2769
2770    #[test]
2771    fn test_router_same_device() {
2772        let coord = MultiGpuBuilder::new().build();
2773        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2774
2775        let k1 = KernelId::new("k1");
2776        let k2 = KernelId::new("k2");
2777        coord.assign_kernel(k1.clone(), 0);
2778        coord.assign_kernel(k2.clone(), 0);
2779
2780        let router = CrossGpuK2KRouter::new(coord);
2781
2782        let msg = make_test_k2k_message(&k1, &k2);
2783        let decision = router.route_message(&k1, &k2, msg).unwrap();
2784
2785        matches!(decision, RoutingDecision::SameDevice);
2786    }
2787
2788    #[test]
2789    fn test_router_cross_device() {
2790        let coord = MultiGpuBuilder::new().enable_p2p(true).build();
2791
2792        let mut dev0 = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2793        dev0.p2p_capable = true;
2794        let mut dev1 = DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda);
2795        dev1.p2p_capable = true;
2796
2797        coord.register_device(dev0);
2798        coord.register_device(dev1);
2799
2800        let k1 = KernelId::new("k1");
2801        let k2 = KernelId::new("k2");
2802        coord.assign_kernel(k1.clone(), 0);
2803        coord.assign_kernel(k2.clone(), 1);
2804
2805        let router = CrossGpuK2KRouter::new(coord);
2806
2807        let msg = make_test_k2k_message(&k1, &k2);
2808        let decision = router.route_message(&k1, &k2, msg).unwrap();
2809
2810        match decision {
2811            RoutingDecision::DirectP2P {
2812                source_device,
2813                dest_device,
2814                ..
2815            } => {
2816                assert_eq!(source_device, 0);
2817                assert_eq!(dest_device, 1);
2818            }
2819            _ => panic!("Expected DirectP2P routing"),
2820        }
2821    }
2822
2823    #[test]
2824    fn test_router_pending_messages() {
2825        let coord = MultiGpuBuilder::new().enable_p2p(true).build();
2826
2827        let mut dev0 = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2828        dev0.p2p_capable = true;
2829        let mut dev1 = DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda);
2830        dev1.p2p_capable = true;
2831
2832        coord.register_device(dev0);
2833        coord.register_device(dev1);
2834
2835        let k1 = KernelId::new("k1");
2836        let k2 = KernelId::new("k2");
2837        coord.assign_kernel(k1.clone(), 0);
2838        coord.assign_kernel(k2.clone(), 1);
2839
2840        let router = CrossGpuK2KRouter::new(coord);
2841
2842        // Route 3 messages
2843        for _ in 0..3 {
2844            let msg = make_test_k2k_message(&k1, &k2);
2845            router.route_message(&k1, &k2, msg).unwrap();
2846        }
2847
2848        assert_eq!(router.stats().messages_pending, 3);
2849
2850        // Drain pending
2851        let pending = router.drain_pending(0, 1);
2852        assert_eq!(pending.len(), 3);
2853        assert_eq!(router.stats().messages_pending, 0);
2854    }
2855
2856    #[test]
2857    fn test_router_stats() {
2858        let coord = MultiGpuBuilder::new().build();
2859        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2860
2861        let k1 = KernelId::new("k1");
2862        let k2 = KernelId::new("k2");
2863        coord.assign_kernel(k1.clone(), 0);
2864        coord.assign_kernel(k2.clone(), 0);
2865
2866        let router = CrossGpuK2KRouter::new(coord);
2867
2868        let stats = router.stats();
2869        assert_eq!(stats.messages_routed, 0);
2870        assert_eq!(stats.bytes_transferred, 0);
2871        assert_eq!(stats.routing_failures, 0);
2872    }
2873
2874    // ========================================================================
2875    // Kernel Migrator Tests
2876    // ========================================================================
2877
2878    use crate::checkpoint::{Checkpoint, CheckpointBuilder};
2879
2880    /// Mock checkpointable kernel for testing.
2881    struct MockCheckpointableKernel {
2882        kernel_id: String,
2883        kernel_type: String,
2884        state_data: Vec<u8>,
2885        step: u64,
2886    }
2887
2888    impl MockCheckpointableKernel {
2889        fn new(kernel_id: &str, state_size: usize) -> Self {
2890            Self {
2891                kernel_id: kernel_id.to_string(),
2892                kernel_type: "mock_kernel".to_string(),
2893                state_data: vec![0xAB; state_size],
2894                step: 1000,
2895            }
2896        }
2897    }
2898
2899    impl CheckpointableKernel for MockCheckpointableKernel {
2900        fn create_checkpoint(&self) -> Result<Checkpoint> {
2901            let checkpoint = CheckpointBuilder::new(&self.kernel_id, &self.kernel_type)
2902                .step(self.step)
2903                .grid_size(64, 64, 64)
2904                .control_block(vec![1, 2, 3, 4])
2905                .device_memory("state", self.state_data.clone())
2906                .build();
2907            Ok(checkpoint)
2908        }
2909
2910        fn restore_from_checkpoint(&mut self, checkpoint: &Checkpoint) -> Result<()> {
2911            self.step = checkpoint.metadata.current_step;
2912            Ok(())
2913        }
2914
2915        fn checkpoint_kernel_id(&self) -> &str {
2916            &self.kernel_id
2917        }
2918
2919        fn checkpoint_kernel_type(&self) -> &str {
2920            &self.kernel_type
2921        }
2922    }
2923
2924    #[test]
2925    fn test_migrator_creation() {
2926        let coord = MultiGpuBuilder::new().build();
2927        let migrator = KernelMigrator::new(coord);
2928
2929        let stats = migrator.stats();
2930        assert_eq!(stats.successful_migrations, 0);
2931        assert_eq!(stats.failed_migrations, 0);
2932        assert_eq!(stats.bytes_transferred, 0);
2933    }
2934
2935    #[test]
2936    fn test_migrator_with_custom_storage() {
2937        let coord = MultiGpuBuilder::new().build();
2938        let storage = Arc::new(MemoryStorage::new());
2939        let migrator = KernelMigrator::with_storage(coord.clone(), storage);
2940
2941        // Verify we can access the coordinator
2942        assert!(Arc::ptr_eq(migrator.coordinator(), &coord));
2943    }
2944
2945    #[test]
2946    fn test_successful_migration() {
2947        let coord = MultiGpuBuilder::new().build();
2948
2949        // Register devices
2950        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2951        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2952
2953        // Assign kernel to device 0
2954        let kernel_id = KernelId::new("migratable_kernel");
2955        coord.assign_kernel(kernel_id.clone(), 0);
2956
2957        let migrator = KernelMigrator::new(coord.clone());
2958
2959        // Create mock kernel
2960        let kernel = MockCheckpointableKernel::new("migratable_kernel", 1024);
2961
2962        // Request migration
2963        let mut request = coord.request_migration(&kernel_id, 1).unwrap();
2964        assert_eq!(request.state, MigrationState::Pending);
2965
2966        // Perform migration
2967        let result = migrator
2968            .migrate_with_checkpoint(&kernel, &mut request)
2969            .unwrap();
2970
2971        // Verify result
2972        assert_eq!(result.kernel_id.as_str(), "migratable_kernel");
2973        assert_eq!(result.source_device, 0);
2974        assert_eq!(result.target_device, 1);
2975        assert!(result.checkpoint_size > 0);
2976        assert!(result.total_duration > Duration::ZERO);
2977
2978        // Verify kernel was moved
2979        assert_eq!(coord.get_kernel_device(&kernel_id), Some(1));
2980
2981        // Verify stats
2982        let stats = migrator.stats();
2983        assert_eq!(stats.successful_migrations, 1);
2984        assert_eq!(stats.failed_migrations, 0);
2985        assert!(stats.bytes_transferred > 0);
2986    }
2987
2988    #[test]
2989    fn test_migration_result_fields() {
2990        let coord = MultiGpuBuilder::new().build();
2991
2992        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2993        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2994
2995        let kernel_id = KernelId::new("test_kernel");
2996        coord.assign_kernel(kernel_id.clone(), 0);
2997
2998        let migrator = KernelMigrator::new(coord.clone());
2999        let kernel = MockCheckpointableKernel::new("test_kernel", 4096);
3000        let mut request = coord.request_migration(&kernel_id, 1).unwrap();
3001
3002        let result = migrator
3003            .migrate_with_checkpoint(&kernel, &mut request)
3004            .unwrap();
3005
3006        // All durations should be non-negative
3007        assert!(result.checkpoint_duration >= Duration::ZERO);
3008        assert!(result.transfer_duration >= Duration::ZERO);
3009        assert!(result.restore_duration >= Duration::ZERO);
3010
3011        // Total should be >= sum of parts
3012        assert!(result.total_duration >= result.checkpoint_duration);
3013    }
3014
3015    #[test]
3016    fn test_migration_stats_accumulate() {
3017        let coord = MultiGpuBuilder::new().build();
3018
3019        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3020        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
3021
3022        let migrator = KernelMigrator::new(coord.clone());
3023
3024        // Migrate kernel 1: 0 -> 1
3025        let k1 = KernelId::new("k1");
3026        coord.assign_kernel(k1.clone(), 0);
3027        let kernel1 = MockCheckpointableKernel::new("k1", 1000);
3028        let mut req1 = coord.request_migration(&k1, 1).unwrap();
3029        migrator
3030            .migrate_with_checkpoint(&kernel1, &mut req1)
3031            .unwrap();
3032
3033        // Migrate kernel 2: 0 -> 1
3034        let k2 = KernelId::new("k2");
3035        coord.assign_kernel(k2.clone(), 0);
3036        let kernel2 = MockCheckpointableKernel::new("k2", 2000);
3037        let mut req2 = coord.request_migration(&k2, 1).unwrap();
3038        migrator
3039            .migrate_with_checkpoint(&kernel2, &mut req2)
3040            .unwrap();
3041
3042        let stats = migrator.stats();
3043        assert_eq!(stats.successful_migrations, 2);
3044        assert_eq!(stats.failed_migrations, 0);
3045        // Both checkpoints should have been transferred
3046        assert!(stats.bytes_transferred > 0);
3047    }
3048
3049    // ========================================================================
3050    // Device Unregister Tests
3051    // ========================================================================
3052
3053    #[test]
3054    fn test_unregister_device_no_kernels() {
3055        let coord = MultiGpuBuilder::new().build();
3056
3057        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3058        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
3059
3060        let result = coord.unregister_device(0);
3061
3062        assert!(result.success);
3063        assert_eq!(result.device_index, 0);
3064        assert!(result.kernels_to_migrate.is_empty());
3065        assert!(result.orphaned_kernels.is_empty());
3066    }
3067
3068    #[test]
3069    fn test_unregister_device_with_kernels() {
3070        let coord = MultiGpuBuilder::new().build();
3071
3072        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3073        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
3074
3075        // Assign kernels to device 0
3076        let k1 = KernelId::new("k1");
3077        let k2 = KernelId::new("k2");
3078        coord.assign_kernel(k1.clone(), 0);
3079        coord.assign_kernel(k2.clone(), 0);
3080
3081        let result = coord.unregister_device(0);
3082
3083        assert!(result.success);
3084        assert_eq!(result.kernels_to_migrate.len(), 2);
3085        assert!(result.orphaned_kernels.is_empty());
3086
3087        // All kernels should migrate to device 1
3088        for plan in &result.kernels_to_migrate {
3089            assert_eq!(plan.source_device, 0);
3090            assert_eq!(plan.target_device, 1);
3091        }
3092
3093        // Verify kernel mappings were updated
3094        assert_eq!(coord.get_kernel_device(&k1), Some(1));
3095        assert_eq!(coord.get_kernel_device(&k2), Some(1));
3096    }
3097
3098    #[test]
3099    fn test_unregister_single_device_orphans_kernels() {
3100        let coord = MultiGpuBuilder::new().build();
3101
3102        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3103
3104        // Assign kernels to device 0
3105        let k1 = KernelId::new("k1");
3106        coord.assign_kernel(k1.clone(), 0);
3107
3108        let result = coord.unregister_device(0);
3109
3110        assert!(result.success);
3111        assert!(result.kernels_to_migrate.is_empty());
3112        assert_eq!(result.orphaned_kernels.len(), 1);
3113        assert_eq!(result.orphaned_kernels[0], k1);
3114
3115        // Kernel should no longer have a device
3116        assert!(coord.get_kernel_device(&k1).is_none());
3117    }
3118
3119    #[test]
3120    fn test_unregister_nonexistent_device() {
3121        let coord = MultiGpuBuilder::new().build();
3122
3123        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3124
3125        let result = coord.unregister_device(99);
3126
3127        assert!(!result.success);
3128        assert_eq!(result.device_index, 99);
3129    }
3130
3131    #[test]
3132    fn test_unregister_distributes_to_least_loaded() {
3133        let coord = MultiGpuBuilder::new().build();
3134
3135        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3136        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
3137        coord.register_device(DeviceInfo::new(2, "GPU 2".to_string(), Backend::Cuda));
3138
3139        // Preload device 1 with kernels
3140        coord.assign_kernel(KernelId::new("pre1"), 1);
3141        coord.assign_kernel(KernelId::new("pre2"), 1);
3142        coord.assign_kernel(KernelId::new("pre3"), 1);
3143
3144        // Assign kernel to device 0
3145        let k1 = KernelId::new("migrate_me");
3146        coord.assign_kernel(k1.clone(), 0);
3147
3148        let result = coord.unregister_device(0);
3149
3150        assert!(result.success);
3151        assert_eq!(result.kernels_to_migrate.len(), 1);
3152
3153        // Should migrate to device 2 (least loaded)
3154        let plan = &result.kernels_to_migrate[0];
3155        assert_eq!(plan.target_device, 2);
3156    }
3157
3158    #[test]
3159    fn test_migration_priority_enum() {
3160        let low = MigrationPriority::Low;
3161        let normal = MigrationPriority::Normal;
3162        let high = MigrationPriority::High;
3163        let critical = MigrationPriority::Critical;
3164
3165        assert_ne!(low, normal);
3166        assert_ne!(normal, high);
3167        assert_ne!(high, critical);
3168        assert_eq!(low, MigrationPriority::Low);
3169    }
3170
3171    // Hot Reload Tests
3172
3173    #[test]
3174    fn test_hot_reload_config_default() {
3175        let config = HotReloadConfig::default();
3176        assert!(config.enabled);
3177        assert!(config.preserve_state);
3178        assert!(config.validate_before_swap);
3179        assert!(config.keep_fallback);
3180        assert_eq!(config.max_retries, 3);
3181    }
3182
3183    #[test]
3184    fn test_hot_reload_config_builder() {
3185        let config = HotReloadConfig::new()
3186            .with_enabled(false)
3187            .with_preserve_state(false)
3188            .with_max_retries(5)
3189            .with_timeout(Duration::from_secs(60));
3190
3191        assert!(!config.enabled);
3192        assert!(!config.preserve_state);
3193        assert_eq!(config.max_retries, 5);
3194        assert_eq!(config.reload_timeout, Duration::from_secs(60));
3195    }
3196
3197    #[test]
3198    fn test_kernel_code_source_ptx() {
3199        let ptx = ".version 7.0\n.target sm_80\nkernel: ret;";
3200        let code = KernelCodeSource::from_ptx(ptx, "kernel");
3201
3202        assert_eq!(code.format, KernelCodeFormat::Ptx);
3203        assert_eq!(code.entry_point, "kernel");
3204        assert_eq!(code.as_str(), Some(ptx));
3205        assert_eq!(code.size(), ptx.len());
3206    }
3207
3208    #[test]
3209    fn test_kernel_code_source_wgsl() {
3210        let wgsl = "@compute fn main() {}";
3211        let code = KernelCodeSource::from_wgsl(wgsl, "main");
3212
3213        assert_eq!(code.format, KernelCodeFormat::Wgsl);
3214        assert_eq!(code.entry_point, "main");
3215        assert_eq!(code.as_str(), Some(wgsl));
3216    }
3217
3218    #[test]
3219    fn test_kernel_code_source_msl() {
3220        let msl = "kernel void my_kernel() {}";
3221        let code = KernelCodeSource::from_msl(msl, "my_kernel");
3222
3223        assert_eq!(code.format, KernelCodeFormat::Msl);
3224        assert_eq!(code.entry_point, "my_kernel");
3225        assert_eq!(code.as_str(), Some(msl));
3226    }
3227
3228    #[test]
3229    fn test_hot_reload_manager_creation() {
3230        let manager = HotReloadManager::with_defaults();
3231        assert!(manager.is_enabled());
3232        assert!(manager.list_kernels().is_empty());
3233    }
3234
3235    #[test]
3236    fn test_hot_reload_manager_register_kernel() {
3237        let manager = HotReloadManager::with_defaults();
3238        let kernel_id = KernelId::new("test_kernel");
3239        let code = KernelCodeSource::from_ptx(".version 7.0", "kernel");
3240
3241        manager.register_kernel(&kernel_id, code);
3242
3243        assert!(manager.is_registered(&kernel_id));
3244        assert!(!manager.is_reload_in_progress(&kernel_id));
3245        assert!(manager.get_current_version(&kernel_id).is_some());
3246    }
3247
3248    #[test]
3249    fn test_hot_reload_request_states() {
3250        let kernel_id = KernelId::new("test");
3251        let code = KernelCodeSource::from_ptx(".version 7.0", "kernel");
3252        let request = HotReloadRequest::new(kernel_id, code);
3253
3254        assert_eq!(request.state, HotReloadState::Idle);
3255        assert!(!request.is_in_progress());
3256        assert!(!request.is_completed());
3257        assert!(!request.is_failed());
3258    }
3259
3260    #[test]
3261    fn test_hot_reload_disabled() {
3262        let config = HotReloadConfig::new().with_enabled(false);
3263        let manager = HotReloadManager::new(config);
3264        let kernel_id = KernelId::new("test");
3265        let code = KernelCodeSource::from_ptx(".version 7.0", "kernel");
3266
3267        manager.register_kernel(&kernel_id, code.clone());
3268        let result = manager.request_reload(&kernel_id, code);
3269        assert!(result.is_err());
3270    }
3271
3272    #[test]
3273    fn test_hot_reload_stats() {
3274        let manager = HotReloadManager::with_defaults();
3275        let stats = manager.stats();
3276
3277        assert_eq!(stats.successful_reloads, 0);
3278        assert_eq!(stats.failed_reloads, 0);
3279        assert_eq!(stats.rollbacks, 0);
3280    }
3281
3282    #[test]
3283    fn test_hot_reload_code_formats() {
3284        let formats = [
3285            KernelCodeFormat::Ptx,
3286            KernelCodeFormat::Cubin,
3287            KernelCodeFormat::SpirV,
3288            KernelCodeFormat::Wgsl,
3289            KernelCodeFormat::Msl,
3290            KernelCodeFormat::MetalLib,
3291            KernelCodeFormat::Source,
3292        ];
3293
3294        // Verify all formats are distinct
3295        for (i, f1) in formats.iter().enumerate() {
3296            for (j, f2) in formats.iter().enumerate() {
3297                if i != j {
3298                    assert_ne!(f1, f2);
3299                }
3300            }
3301        }
3302    }
3303
3304    #[test]
3305    fn test_hot_reload_state_transitions() {
3306        let states = [
3307            HotReloadState::Idle,
3308            HotReloadState::Draining,
3309            HotReloadState::Checkpointing,
3310            HotReloadState::Compiling,
3311            HotReloadState::Validating,
3312            HotReloadState::Swapping,
3313            HotReloadState::Restoring,
3314            HotReloadState::Completed,
3315            HotReloadState::Failed,
3316            HotReloadState::RollingBack,
3317        ];
3318
3319        // Verify all states are distinct
3320        for (i, s1) in states.iter().enumerate() {
3321            for (j, s2) in states.iter().enumerate() {
3322                if i != j {
3323                    assert_ne!(s1, s2);
3324                }
3325            }
3326        }
3327    }
3328
3329    #[test]
3330    fn test_hot_reload_execute() {
3331        let manager = HotReloadManager::with_defaults();
3332        let kernel_id = KernelId::new("test_kernel");
3333
3334        let initial_code = KernelCodeSource::from_ptx(".version 7.0\n.target sm_80", "kernel");
3335        manager.register_kernel(&kernel_id, initial_code);
3336
3337        let new_code = KernelCodeSource::from_ptx(".version 8.0\n.target sm_90", "kernel");
3338        let mut request = manager.request_reload(&kernel_id, new_code).unwrap();
3339
3340        // Create mock kernel for checkpoint
3341        let mock_kernel = MockCheckpointableKernel::new("test_kernel", 512);
3342
3343        let result = manager.execute_reload(&mut request, &mock_kernel).unwrap();
3344
3345        assert!(request.is_completed());
3346        assert_eq!(result.kernel_id.as_str(), "test_kernel");
3347        assert!(result.state_preserved);
3348        assert!(result.checkpoint_size > 0);
3349        assert!(result.total_duration > Duration::ZERO);
3350
3351        // Stats should be updated
3352        let stats = manager.stats();
3353        assert_eq!(stats.successful_reloads, 1);
3354    }
3355
3356    #[test]
3357    fn test_hot_reload_list_kernels() {
3358        let manager = HotReloadManager::with_defaults();
3359
3360        let k1 = KernelId::new("kernel1");
3361        let k2 = KernelId::new("kernel2");
3362        let k3 = KernelId::new("kernel3");
3363
3364        manager.register_kernel(&k1, KernelCodeSource::from_ptx(".version 7.0", "k1"));
3365        manager.register_kernel(&k2, KernelCodeSource::from_ptx(".version 7.0", "k2"));
3366        manager.register_kernel(&k3, KernelCodeSource::from_ptx(".version 7.0", "k3"));
3367
3368        let kernels = manager.list_kernels();
3369        assert_eq!(kernels.len(), 3);
3370        assert!(kernels.contains(&k1));
3371        assert!(kernels.contains(&k2));
3372        assert!(kernels.contains(&k3));
3373    }
3374}