Skip to main content

ringkernel_core/
multi_gpu.rs

1//! Multi-GPU coordination, topology discovery, and cross-GPU messaging.
2//!
3//! This module provides infrastructure for coordinating work across
4//! multiple GPUs, including:
5//!
6//! - **Device Selection** - Load balancing strategies for kernel placement
7//! - **Topology Discovery** - NVLink/PCIe detection and bandwidth estimation
8//! - **Cross-GPU K2K Router** - Kernel-to-kernel messaging across GPUs
9//! - **Kernel Migration** - Move kernels between GPUs with state transfer
10//!
11//! ## Example
12//!
13//! ```ignore
14//! use ringkernel_core::multi_gpu::{MultiGpuBuilder, GpuTopology, CrossGpuK2KRouter};
15//!
16//! let coordinator = MultiGpuBuilder::new()
17//!     .load_balancing(LoadBalancingStrategy::LeastLoaded)
18//!     .enable_p2p(true)
19//!     .build();
20//!
21//! // Discover topology
22//! let topology = coordinator.discover_topology();
23//!
24//! // Create cross-GPU router
25//! let router = CrossGpuK2KRouter::new(coordinator.clone());
26//! router.route_message(source_kernel, dest_kernel, envelope).await?;
27//! ```
28
29use parking_lot::RwLock;
30use std::collections::HashMap;
31use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
32use std::sync::Arc;
33use std::time::{Duration, Instant};
34
35use crate::error::{Result, RingKernelError};
36use crate::k2k::K2KMessage;
37use crate::runtime::{Backend, KernelId, LaunchOptions};
38
39/// Configuration for multi-GPU coordination.
40#[derive(Debug, Clone)]
41pub struct MultiGpuConfig {
42    /// Load balancing strategy.
43    pub load_balancing: LoadBalancingStrategy,
44    /// Enable automatic device selection.
45    pub auto_select_device: bool,
46    /// Maximum kernels per device.
47    pub max_kernels_per_device: usize,
48    /// Enable peer-to-peer transfers when available.
49    pub enable_p2p: bool,
50    /// Preferred devices (by index).
51    pub preferred_devices: Vec<usize>,
52}
53
54impl Default for MultiGpuConfig {
55    fn default() -> Self {
56        Self {
57            load_balancing: LoadBalancingStrategy::LeastLoaded,
58            auto_select_device: true,
59            max_kernels_per_device: 64,
60            enable_p2p: true,
61            preferred_devices: vec![],
62        }
63    }
64}
65
66/// Strategy for balancing load across devices.
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub enum LoadBalancingStrategy {
69    /// Always use the first available device.
70    FirstAvailable,
71    /// Use the device with fewest kernels.
72    LeastLoaded,
73    /// Round-robin across devices.
74    RoundRobin,
75    /// Select based on memory availability.
76    MemoryBased,
77    /// Select based on compute capability.
78    ComputeCapability,
79    /// Custom selection function.
80    Custom,
81}
82
83/// Information about a GPU device.
84#[derive(Debug, Clone)]
85pub struct DeviceInfo {
86    /// Device index.
87    pub index: usize,
88    /// Device name.
89    pub name: String,
90    /// Backend type.
91    pub backend: Backend,
92    /// Total memory in bytes.
93    pub total_memory: u64,
94    /// Available memory in bytes.
95    pub available_memory: u64,
96    /// Compute capability (for CUDA).
97    pub compute_capability: Option<(u32, u32)>,
98    /// Maximum threads per block.
99    pub max_threads_per_block: u32,
100    /// Number of multiprocessors.
101    pub multiprocessor_count: u32,
102    /// Whether device supports P2P with other devices.
103    pub p2p_capable: bool,
104}
105
106impl DeviceInfo {
107    /// Create a new device info.
108    pub fn new(index: usize, name: String, backend: Backend) -> Self {
109        Self {
110            index,
111            name,
112            backend,
113            total_memory: 0,
114            available_memory: 0,
115            compute_capability: None,
116            max_threads_per_block: 1024,
117            multiprocessor_count: 1,
118            p2p_capable: false,
119        }
120    }
121
122    /// Get memory utilization (0.0-1.0).
123    pub fn memory_utilization(&self) -> f64 {
124        if self.total_memory == 0 {
125            0.0
126        } else {
127            1.0 - (self.available_memory as f64 / self.total_memory as f64)
128        }
129    }
130}
131
132// ============================================================================
133// GPU Topology Discovery
134// ============================================================================
135
136/// Type of interconnect between GPUs.
137#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
138pub enum InterconnectType {
139    /// No direct connection (must go through host).
140    None,
141    /// PCIe peer-to-peer.
142    Pcie,
143    /// NVIDIA NVLink.
144    NvLink,
145    /// NVIDIA NVSwitch (datacenter).
146    NvSwitch,
147    /// AMD Infinity Fabric.
148    InfinityFabric,
149    /// Intel Xe Link.
150    XeLink,
151    /// Same GPU (for self-connections).
152    SameDevice,
153}
154
155impl InterconnectType {
156    /// Estimated bandwidth in GB/s for this interconnect type.
157    pub fn estimated_bandwidth_gbps(&self) -> f64 {
158        match self {
159            InterconnectType::None => 16.0,      // PCIe 3.0 x16 through host
160            InterconnectType::Pcie => 32.0,      // PCIe 4.0 x16 P2P
161            InterconnectType::NvLink => 300.0,   // NVLink 3.0 (A100)
162            InterconnectType::NvSwitch => 600.0, // NVSwitch full bisection
163            InterconnectType::InfinityFabric => 200.0, // MI250X
164            InterconnectType::XeLink => 100.0,   // Intel Data Center GPUs
165            InterconnectType::SameDevice => 2000.0, // Internal bandwidth
166        }
167    }
168
169    /// Estimated latency in microseconds.
170    pub fn estimated_latency_us(&self) -> f64 {
171        match self {
172            InterconnectType::None => 10.0,    // Through host memory
173            InterconnectType::Pcie => 5.0,     // P2P PCIe
174            InterconnectType::NvLink => 1.0,   // Direct NVLink
175            InterconnectType::NvSwitch => 2.0, // Through switch
176            InterconnectType::InfinityFabric => 1.5,
177            InterconnectType::XeLink => 2.0,
178            InterconnectType::SameDevice => 0.0,
179        }
180    }
181
182    /// Whether this interconnect supports direct P2P memory access.
183    pub fn supports_p2p(&self) -> bool {
184        !matches!(self, InterconnectType::None)
185    }
186}
187
188/// Connection between two GPUs.
189#[derive(Debug, Clone)]
190pub struct GpuConnection {
191    /// Source device index.
192    pub source: usize,
193    /// Destination device index.
194    pub destination: usize,
195    /// Type of interconnect.
196    pub interconnect: InterconnectType,
197    /// Measured or estimated bandwidth in GB/s.
198    pub bandwidth_gbps: f64,
199    /// Measured or estimated latency in microseconds.
200    pub latency_us: f64,
201    /// Whether connection is bidirectional with same characteristics.
202    pub bidirectional: bool,
203    /// Number of hops (for multi-hop topologies).
204    pub hops: u32,
205}
206
207impl GpuConnection {
208    /// Create a new GPU connection.
209    pub fn new(source: usize, destination: usize, interconnect: InterconnectType) -> Self {
210        Self {
211            source,
212            destination,
213            interconnect,
214            bandwidth_gbps: interconnect.estimated_bandwidth_gbps(),
215            latency_us: interconnect.estimated_latency_us(),
216            bidirectional: true,
217            hops: if source == destination { 0 } else { 1 },
218        }
219    }
220
221    /// Set measured bandwidth.
222    pub fn with_bandwidth(mut self, gbps: f64) -> Self {
223        self.bandwidth_gbps = gbps;
224        self
225    }
226
227    /// Set measured latency.
228    pub fn with_latency(mut self, us: f64) -> Self {
229        self.latency_us = us;
230        self
231    }
232
233    /// Set hop count.
234    pub fn with_hops(mut self, hops: u32) -> Self {
235        self.hops = hops;
236        self
237    }
238}
239
240/// GPU topology graph describing all device interconnections.
241#[derive(Debug, Clone)]
242pub struct GpuTopology {
243    /// Number of devices in topology.
244    pub device_count: usize,
245    /// Connection matrix (device_count x device_count).
246    connections: Vec<Vec<Option<GpuConnection>>>,
247    /// NUMA node assignments for each device.
248    pub numa_nodes: Vec<Option<u32>>,
249    /// Whether topology has been probed (vs estimated).
250    pub probed: bool,
251    /// Timestamp of last topology update.
252    pub last_updated: Instant,
253}
254
255impl GpuTopology {
256    /// Create a new topology for N devices.
257    pub fn new(device_count: usize) -> Self {
258        let mut connections = vec![vec![None; device_count]; device_count];
259
260        // Initialize self-connections
261        for (i, row) in connections.iter_mut().enumerate().take(device_count) {
262            row[i] = Some(GpuConnection::new(i, i, InterconnectType::SameDevice));
263        }
264
265        Self {
266            device_count,
267            connections,
268            numa_nodes: vec![None; device_count],
269            probed: false,
270            last_updated: Instant::now(),
271        }
272    }
273
274    /// Set connection between two devices.
275    pub fn set_connection(&mut self, connection: GpuConnection) {
276        let src = connection.source;
277        let dst = connection.destination;
278        if src < self.device_count && dst < self.device_count {
279            self.connections[src][dst] = Some(connection.clone());
280            if connection.bidirectional && src != dst {
281                let reverse = GpuConnection {
282                    source: dst,
283                    destination: src,
284                    ..connection
285                };
286                self.connections[dst][src] = Some(reverse);
287            }
288        }
289    }
290
291    /// Get connection between two devices.
292    pub fn get_connection(&self, source: usize, destination: usize) -> Option<&GpuConnection> {
293        self.connections
294            .get(source)
295            .and_then(|row| row.get(destination))
296            .and_then(|c| c.as_ref())
297    }
298
299    /// Get best path between two devices (returns intermediate hops).
300    pub fn best_path(&self, source: usize, destination: usize) -> Vec<usize> {
301        if source == destination {
302            return vec![source];
303        }
304
305        // Direct connection available?
306        if let Some(conn) = self.get_connection(source, destination) {
307            if conn.interconnect != InterconnectType::None {
308                return vec![source, destination];
309            }
310        }
311
312        // Find best path via Dijkstra (simplified)
313        let mut best_path = vec![source, destination]; // Default to direct
314        let mut best_bandwidth = 0.0;
315
316        // Check all intermediate nodes
317        for intermediate in 0..self.device_count {
318            if intermediate == source || intermediate == destination {
319                continue;
320            }
321
322            if let (Some(c1), Some(c2)) = (
323                self.get_connection(source, intermediate),
324                self.get_connection(intermediate, destination),
325            ) {
326                // Bandwidth limited by slowest link
327                let path_bandwidth = c1.bandwidth_gbps.min(c2.bandwidth_gbps);
328                if path_bandwidth > best_bandwidth {
329                    best_bandwidth = path_bandwidth;
330                    best_path = vec![source, intermediate, destination];
331                }
332            }
333        }
334
335        best_path
336    }
337
338    /// Get all devices directly connected to a device.
339    pub fn neighbors(&self, device: usize) -> Vec<usize> {
340        if device >= self.device_count {
341            return vec![];
342        }
343
344        self.connections[device]
345            .iter()
346            .enumerate()
347            .filter_map(|(i, conn)| {
348                if i != device
349                    && conn
350                        .as_ref()
351                        .map(|c| c.interconnect.supports_p2p())
352                        .unwrap_or(false)
353                {
354                    Some(i)
355                } else {
356                    None
357                }
358            })
359            .collect()
360    }
361
362    /// Calculate total bisection bandwidth of the topology.
363    pub fn bisection_bandwidth_gbps(&self) -> f64 {
364        let half = self.device_count / 2;
365        if half == 0 {
366            return 0.0;
367        }
368
369        let mut total = 0.0;
370        for src in 0..half {
371            for dst in half..self.device_count {
372                if let Some(conn) = self.get_connection(src, dst) {
373                    total += conn.bandwidth_gbps;
374                }
375            }
376        }
377        total
378    }
379
380    /// Check if all devices have P2P connectivity.
381    pub fn is_fully_connected(&self) -> bool {
382        for src in 0..self.device_count {
383            for dst in 0..self.device_count {
384                if src != dst {
385                    if let Some(conn) = self.get_connection(src, dst) {
386                        if !conn.interconnect.supports_p2p() {
387                            return false;
388                        }
389                    } else {
390                        return false;
391                    }
392                }
393            }
394        }
395        true
396    }
397
398    /// Get devices in the same NUMA domain.
399    pub fn numa_neighbors(&self, device: usize) -> Vec<usize> {
400        let target_numa = self.numa_nodes.get(device).copied().flatten();
401        if target_numa.is_none() {
402            return vec![];
403        }
404
405        self.numa_nodes
406            .iter()
407            .enumerate()
408            .filter_map(|(i, numa)| {
409                if i != device && *numa == target_numa {
410                    Some(i)
411                } else {
412                    None
413                }
414            })
415            .collect()
416    }
417
418    /// Set NUMA node for a device.
419    pub fn set_numa_node(&mut self, device: usize, numa_node: u32) {
420        if device < self.numa_nodes.len() {
421            self.numa_nodes[device] = Some(numa_node);
422        }
423    }
424
425    /// Mark topology as probed (not estimated).
426    pub fn mark_probed(&mut self) {
427        self.probed = true;
428        self.last_updated = Instant::now();
429    }
430}
431
432/// Status of a device in the multi-GPU coordinator.
433#[derive(Debug, Clone)]
434pub struct DeviceStatus {
435    /// Device info.
436    pub info: DeviceInfo,
437    /// Number of kernels running on this device.
438    pub kernel_count: usize,
439    /// Kernels running on this device.
440    pub kernels: Vec<KernelId>,
441    /// Whether device is available for new kernels.
442    pub available: bool,
443    /// Current load estimate (0.0-1.0).
444    pub load: f64,
445}
446
447/// Result of unregistering a device from the coordinator.
448#[derive(Debug, Clone)]
449pub struct DeviceUnregisterResult {
450    /// Index of the unregistered device.
451    pub device_index: usize,
452    /// Kernels that were on this device and need migration.
453    pub kernels_to_migrate: Vec<KernelMigrationPlan>,
454    /// Kernels that could not be migrated (no available target).
455    pub orphaned_kernels: Vec<KernelId>,
456    /// Whether the device was successfully unregistered.
457    pub success: bool,
458}
459
460/// Plan for migrating a single kernel during device unregister.
461#[derive(Debug, Clone)]
462pub struct KernelMigrationPlan {
463    /// Kernel to migrate.
464    pub kernel_id: KernelId,
465    /// Source device (the unregistered device).
466    pub source_device: usize,
467    /// Target device selected for migration.
468    pub target_device: usize,
469    /// Estimated migration priority (based on kernel load).
470    pub priority: MigrationPriority,
471}
472
473/// Priority for kernel migration.
474#[derive(Debug, Clone, Copy, PartialEq, Eq)]
475pub enum MigrationPriority {
476    /// Low priority - can be migrated lazily.
477    Low,
478    /// Normal priority - migrate in reasonable time.
479    Normal,
480    /// High priority - migrate as soon as possible.
481    High,
482    /// Critical - must migrate immediately.
483    Critical,
484}
485
486/// Multi-GPU coordinator for managing kernels across devices.
487pub struct MultiGpuCoordinator {
488    /// Configuration.
489    config: MultiGpuConfig,
490    /// Available devices.
491    devices: RwLock<Vec<DeviceInfo>>,
492    /// Kernel-to-device mapping.
493    kernel_device_map: RwLock<HashMap<KernelId, usize>>,
494    /// Device kernel counts.
495    device_kernel_counts: RwLock<Vec<AtomicUsize>>,
496    /// Round-robin counter.
497    round_robin_counter: AtomicUsize,
498    /// Total kernels launched.
499    total_kernels: AtomicU64,
500    /// Device selection callbacks (for custom strategy).
501    #[allow(clippy::type_complexity)]
502    custom_selector:
503        RwLock<Option<Arc<dyn Fn(&[DeviceStatus], &LaunchOptions) -> usize + Send + Sync>>>,
504    /// GPU topology graph.
505    topology: RwLock<Option<GpuTopology>>,
506}
507
508impl MultiGpuCoordinator {
509    /// Create a new multi-GPU coordinator.
510    pub fn new(config: MultiGpuConfig) -> Arc<Self> {
511        Arc::new(Self {
512            config,
513            devices: RwLock::new(Vec::new()),
514            kernel_device_map: RwLock::new(HashMap::new()),
515            device_kernel_counts: RwLock::new(Vec::new()),
516            round_robin_counter: AtomicUsize::new(0),
517            total_kernels: AtomicU64::new(0),
518            custom_selector: RwLock::new(None),
519            topology: RwLock::new(None),
520        })
521    }
522
523    /// Register a device with the coordinator.
524    pub fn register_device(&self, device: DeviceInfo) {
525        let index = device.index;
526        let mut devices = self.devices.write();
527        let mut counts = self.device_kernel_counts.write();
528
529        // Ensure we have enough slots
530        let mut current_len = devices.len();
531        while current_len <= index {
532            devices.push(DeviceInfo::new(
533                current_len,
534                "Unknown".to_string(),
535                Backend::Cpu,
536            ));
537            counts.push(AtomicUsize::new(0));
538            current_len += 1;
539        }
540
541        devices[index] = device;
542    }
543
544    /// Unregister a device and plan kernel migrations.
545    ///
546    /// This method:
547    /// 1. Identifies all kernels on the device being removed
548    /// 2. Finds target devices for each kernel using load balancing
549    /// 3. Creates migration plans for kernels that can be moved
550    /// 4. Marks orphaned kernels that have no migration target
551    /// 5. Updates internal routing tables
552    ///
553    /// The caller is responsible for executing the actual migrations using
554    /// [`KernelMigrator`] with the returned [`KernelMigrationPlan`] entries.
555    pub fn unregister_device(&self, index: usize) -> DeviceUnregisterResult {
556        let devices = self.devices.read();
557
558        // Check if device exists
559        if index >= devices.len() {
560            return DeviceUnregisterResult {
561                device_index: index,
562                kernels_to_migrate: Vec::new(),
563                orphaned_kernels: Vec::new(),
564                success: false,
565            };
566        }
567
568        // Get all kernels on this device
569        let kernels_on_device = self.kernels_on_device(index);
570
571        // Find available target devices (excluding the one being unregistered)
572        let available_targets: Vec<usize> = devices
573            .iter()
574            .enumerate()
575            .filter(|(i, _)| *i != index)
576            .map(|(i, _)| i)
577            .collect();
578
579        drop(devices); // Release read lock before acquiring write lock
580
581        let mut kernels_to_migrate = Vec::new();
582        let mut orphaned_kernels = Vec::new();
583
584        if available_targets.is_empty() {
585            // No other devices available - all kernels are orphaned
586            orphaned_kernels = kernels_on_device;
587        } else {
588            // Plan migrations for each kernel
589            for kernel_id in kernels_on_device {
590                // Select target based on current load
591                if let Some(target) = self.select_migration_target(&available_targets) {
592                    let priority = self.calculate_migration_priority(&kernel_id);
593                    kernels_to_migrate.push(KernelMigrationPlan {
594                        kernel_id,
595                        source_device: index,
596                        target_device: target,
597                        priority,
598                    });
599                } else {
600                    orphaned_kernels.push(kernel_id);
601                }
602            }
603        }
604
605        // Update kernel-device mappings for planned migrations
606        {
607            let mut kernel_map = self.kernel_device_map.write();
608            let counts = self.device_kernel_counts.read();
609
610            for plan in &kernels_to_migrate {
611                // Update mapping to target device
612                kernel_map.insert(plan.kernel_id.clone(), plan.target_device);
613
614                // Update kernel counts
615                if index < counts.len() {
616                    counts[index].fetch_sub(1, Ordering::Relaxed);
617                }
618                if plan.target_device < counts.len() {
619                    counts[plan.target_device].fetch_add(1, Ordering::Relaxed);
620                }
621            }
622
623            // Remove orphaned kernels from mapping
624            for kernel_id in &orphaned_kernels {
625                kernel_map.remove(kernel_id);
626                if index < counts.len() {
627                    counts[index].fetch_sub(1, Ordering::Relaxed);
628                }
629            }
630        }
631
632        // Mark device as unavailable (but don't remove it to preserve indices)
633        {
634            let mut devices = self.devices.write();
635            if index < devices.len() {
636                devices[index].available_memory = 0;
637                devices[index].name = format!("{} (unregistered)", devices[index].name);
638            }
639        }
640
641        DeviceUnregisterResult {
642            device_index: index,
643            kernels_to_migrate,
644            orphaned_kernels,
645            success: true,
646        }
647    }
648
649    /// Select the best target device for migration.
650    fn select_migration_target(&self, candidates: &[usize]) -> Option<usize> {
651        if candidates.is_empty() {
652            return None;
653        }
654
655        let counts = self.device_kernel_counts.read();
656
657        // Find device with lowest kernel count
658        candidates
659            .iter()
660            .filter_map(|&idx| {
661                if idx < counts.len() {
662                    Some((idx, counts[idx].load(Ordering::Relaxed)))
663                } else {
664                    None
665                }
666            })
667            .min_by_key(|(_, count)| *count)
668            .map(|(idx, _)| idx)
669    }
670
671    /// Calculate migration priority for a kernel.
672    fn calculate_migration_priority(&self, _kernel_id: &KernelId) -> MigrationPriority {
673        // In a real implementation, this would check:
674        // - Message queue depth
675        // - Time since last activity
676        // - Kernel type/importance
677        // For now, use normal priority
678        MigrationPriority::Normal
679    }
680
681    /// Get all registered devices.
682    pub fn devices(&self) -> Vec<DeviceInfo> {
683        self.devices.read().clone()
684    }
685
686    /// Get device info by index.
687    pub fn device(&self, index: usize) -> Option<DeviceInfo> {
688        self.devices.read().get(index).cloned()
689    }
690
691    /// Get number of devices.
692    pub fn device_count(&self) -> usize {
693        self.devices.read().len()
694    }
695
696    /// Select a device for launching a kernel.
697    pub fn select_device(&self, options: &LaunchOptions) -> Result<usize> {
698        let devices = self.devices.read();
699        if devices.is_empty() {
700            return Err(RingKernelError::BackendUnavailable(
701                "No GPU devices available".to_string(),
702            ));
703        }
704
705        // Get current status
706        let status = self.get_all_status();
707
708        // Check for custom selector
709        if self.config.load_balancing == LoadBalancingStrategy::Custom {
710            if let Some(selector) = &*self.custom_selector.read() {
711                return Ok(selector(&status, options));
712            }
713        }
714
715        // Apply preferred devices filter if specified
716        let candidates: Vec<_> = if !self.config.preferred_devices.is_empty() {
717            status
718                .into_iter()
719                .filter(|s| self.config.preferred_devices.contains(&s.info.index))
720                .collect()
721        } else {
722            status
723        };
724
725        if candidates.is_empty() {
726            return Err(RingKernelError::BackendUnavailable(
727                "No suitable GPU device available".to_string(),
728            ));
729        }
730
731        let selected = match self.config.load_balancing {
732            LoadBalancingStrategy::FirstAvailable => {
733                candidates.first().map(|s| s.info.index).unwrap_or(0)
734            }
735            LoadBalancingStrategy::LeastLoaded => candidates
736                .iter()
737                .filter(|s| s.available && s.kernel_count < self.config.max_kernels_per_device)
738                .min_by(|a, b| a.kernel_count.cmp(&b.kernel_count))
739                .map(|s| s.info.index)
740                .unwrap_or(0),
741            LoadBalancingStrategy::RoundRobin => {
742                let available: Vec<_> = candidates.iter().filter(|s| s.available).collect();
743
744                if available.is_empty() {
745                    candidates.first().map(|s| s.info.index).unwrap_or(0)
746                } else {
747                    let idx =
748                        self.round_robin_counter.fetch_add(1, Ordering::Relaxed) % available.len();
749                    available[idx].info.index
750                }
751            }
752            LoadBalancingStrategy::MemoryBased => candidates
753                .iter()
754                .filter(|s| s.available)
755                .max_by(|a, b| a.info.available_memory.cmp(&b.info.available_memory))
756                .map(|s| s.info.index)
757                .unwrap_or(0),
758            LoadBalancingStrategy::ComputeCapability => candidates
759                .iter()
760                .filter(|s| s.available)
761                .max_by(|a, b| {
762                    let a_cap = a.info.compute_capability.unwrap_or((0, 0));
763                    let b_cap = b.info.compute_capability.unwrap_or((0, 0));
764                    a_cap.cmp(&b_cap)
765                })
766                .map(|s| s.info.index)
767                .unwrap_or(0),
768            LoadBalancingStrategy::Custom => {
769                // Should have been handled above
770                0
771            }
772        };
773
774        Ok(selected)
775    }
776
777    /// Assign a kernel to a device.
778    pub fn assign_kernel(&self, kernel_id: KernelId, device_index: usize) {
779        self.kernel_device_map
780            .write()
781            .insert(kernel_id, device_index);
782
783        let counts = self.device_kernel_counts.read();
784        if device_index < counts.len() {
785            counts[device_index].fetch_add(1, Ordering::Relaxed);
786        }
787
788        self.total_kernels.fetch_add(1, Ordering::Relaxed);
789    }
790
791    /// Remove a kernel assignment.
792    pub fn remove_kernel(&self, kernel_id: &KernelId) {
793        if let Some(device_index) = self.kernel_device_map.write().remove(kernel_id) {
794            let counts = self.device_kernel_counts.read();
795            if device_index < counts.len() {
796                counts[device_index].fetch_sub(1, Ordering::Relaxed);
797            }
798        }
799    }
800
801    /// Get device for a kernel.
802    pub fn get_kernel_device(&self, kernel_id: &KernelId) -> Option<usize> {
803        self.kernel_device_map.read().get(kernel_id).copied()
804    }
805
806    /// Get all kernels on a device.
807    pub fn kernels_on_device(&self, device_index: usize) -> Vec<KernelId> {
808        self.kernel_device_map
809            .read()
810            .iter()
811            .filter(|(_, &idx)| idx == device_index)
812            .map(|(k, _)| k.clone())
813            .collect()
814    }
815
816    /// Get status of all devices.
817    pub fn get_all_status(&self) -> Vec<DeviceStatus> {
818        let devices = self.devices.read();
819        let kernel_map = self.kernel_device_map.read();
820        let counts = self.device_kernel_counts.read();
821
822        devices
823            .iter()
824            .enumerate()
825            .map(|(idx, info)| {
826                let kernel_count = if idx < counts.len() {
827                    counts[idx].load(Ordering::Relaxed)
828                } else {
829                    0
830                };
831
832                let kernels: Vec<_> = kernel_map
833                    .iter()
834                    .filter(|(_, &dev_idx)| dev_idx == idx)
835                    .map(|(k, _)| k.clone())
836                    .collect();
837
838                let load = kernel_count as f64 / self.config.max_kernels_per_device as f64;
839                let available = kernel_count < self.config.max_kernels_per_device;
840
841                DeviceStatus {
842                    info: info.clone(),
843                    kernel_count,
844                    kernels,
845                    available,
846                    load,
847                }
848            })
849            .collect()
850    }
851
852    /// Get status of a specific device.
853    pub fn get_device_status(&self, device_index: usize) -> Option<DeviceStatus> {
854        self.get_all_status().into_iter().nth(device_index)
855    }
856
857    /// Set custom device selector.
858    pub fn set_custom_selector<F>(&self, selector: F)
859    where
860        F: Fn(&[DeviceStatus], &LaunchOptions) -> usize + Send + Sync + 'static,
861    {
862        *self.custom_selector.write() = Some(Arc::new(selector));
863    }
864
865    /// Get coordinator statistics.
866    pub fn stats(&self) -> MultiGpuStats {
867        let status = self.get_all_status();
868        let total_kernels: usize = status.iter().map(|s| s.kernel_count).sum();
869        let total_memory: u64 = status.iter().map(|s| s.info.total_memory).sum();
870        let available_memory: u64 = status.iter().map(|s| s.info.available_memory).sum();
871
872        MultiGpuStats {
873            device_count: status.len(),
874            total_kernels,
875            total_memory,
876            available_memory,
877            kernels_launched: self.total_kernels.load(Ordering::Relaxed),
878        }
879    }
880
881    /// Check if P2P is available between two devices.
882    pub fn can_p2p(&self, device_a: usize, device_b: usize) -> bool {
883        if !self.config.enable_p2p {
884            return false;
885        }
886
887        let devices = self.devices.read();
888        if let (Some(a), Some(b)) = (devices.get(device_a), devices.get(device_b)) {
889            a.p2p_capable && b.p2p_capable
890        } else {
891            false
892        }
893    }
894
895    /// Update device memory info.
896    pub fn update_device_memory(&self, device_index: usize, available_memory: u64) {
897        let mut devices = self.devices.write();
898        if let Some(device) = devices.get_mut(device_index) {
899            device.available_memory = available_memory;
900        }
901    }
902
903    // ========================================================================
904    // Topology Discovery
905    // ========================================================================
906
907    /// Discover GPU topology (estimates if probing not available).
908    pub fn discover_topology(&self) -> GpuTopology {
909        let devices = self.devices.read();
910        let device_count = devices.len();
911
912        if device_count == 0 {
913            return GpuTopology::new(0);
914        }
915
916        let mut topo = GpuTopology::new(device_count);
917
918        // Set up connections based on device info
919        for (i, dev_i) in devices.iter().enumerate() {
920            for (j, dev_j) in devices.iter().enumerate() {
921                if i == j {
922                    continue;
923                }
924
925                // Determine interconnect type based on device capabilities
926                let interconnect = if dev_i.p2p_capable && dev_j.p2p_capable {
927                    // Check if same backend (can do P2P)
928                    if dev_i.backend == dev_j.backend {
929                        match dev_i.backend {
930                            Backend::Cuda => {
931                                // For CUDA, check compute capability for NVLink
932                                let cc_i = dev_i.compute_capability.unwrap_or((0, 0));
933                                let cc_j = dev_j.compute_capability.unwrap_or((0, 0));
934
935                                // Ampere+ (SM 80+) likely has NVLink
936                                if cc_i.0 >= 8 && cc_j.0 >= 8 {
937                                    InterconnectType::NvLink
938                                } else {
939                                    InterconnectType::Pcie
940                                }
941                            }
942                            _ => InterconnectType::Pcie,
943                        }
944                    } else {
945                        InterconnectType::None
946                    }
947                } else {
948                    InterconnectType::None
949                };
950
951                topo.set_connection(GpuConnection::new(i, j, interconnect));
952            }
953        }
954
955        // Store topology
956        *self.topology.write() = Some(topo.clone());
957
958        topo
959    }
960
961    /// Get current topology (discovers if not cached).
962    pub fn topology(&self) -> GpuTopology {
963        {
964            let topo = self.topology.read();
965            if let Some(ref t) = *topo {
966                return t.clone();
967            }
968        }
969        self.discover_topology()
970    }
971
972    /// Set custom topology (for testing or manual configuration).
973    pub fn set_topology(&self, topology: GpuTopology) {
974        *self.topology.write() = Some(topology);
975    }
976
977    /// Get best device for communicating with a source kernel.
978    pub fn select_device_for_k2k(&self, source_kernel: &KernelId) -> Result<usize> {
979        let source_idx = match self.get_kernel_device(source_kernel) {
980            Some(idx) => idx,
981            None => return self.select_device(&LaunchOptions::default()),
982        };
983        let topo = self.topology();
984        let status = self.get_all_status();
985
986        // Find best device based on connectivity and load
987        let neighbors = topo.neighbors(source_idx);
988
989        if neighbors.is_empty() {
990            // No P2P neighbors, fall back to normal selection
991            return self.select_device(&LaunchOptions::default());
992        }
993
994        // Score devices by: connectivity bandwidth / (load + 1)
995        let best = neighbors
996            .iter()
997            .filter_map(|&dev_idx| {
998                status.iter().find(|s| s.info.index == dev_idx).map(|s| {
999                    let conn = topo.get_connection(source_idx, dev_idx);
1000                    let bandwidth = conn.map(|c| c.bandwidth_gbps).unwrap_or(1.0);
1001                    let score = bandwidth / (s.load + 0.1);
1002                    (dev_idx, score)
1003                })
1004            })
1005            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
1006            .map(|(idx, _)| idx);
1007
1008        best.ok_or_else(|| {
1009            RingKernelError::BackendUnavailable("No suitable K2K device found".to_string())
1010        })
1011    }
1012
1013    // ========================================================================
1014    // Kernel Migration
1015    // ========================================================================
1016
1017    /// Request to migrate a kernel to another device.
1018    pub fn request_migration(
1019        &self,
1020        kernel_id: &KernelId,
1021        target_device: usize,
1022    ) -> Result<MigrationRequest> {
1023        let source_device = self
1024            .get_kernel_device(kernel_id)
1025            .ok_or_else(|| RingKernelError::KernelNotFound(kernel_id.as_str().to_string()))?;
1026
1027        if source_device == target_device {
1028            return Err(RingKernelError::InvalidConfig(
1029                "Cannot migrate to same device".to_string(),
1030            ));
1031        }
1032
1033        let devices = self.devices.read();
1034        if target_device >= devices.len() {
1035            return Err(RingKernelError::DeviceNotAvailable(format!(
1036                "Device {} not available",
1037                target_device
1038            )));
1039        }
1040
1041        let topo = self.topology();
1042        let path = topo.best_path(source_device, target_device);
1043        let connection = topo.get_connection(source_device, target_device);
1044
1045        Ok(MigrationRequest {
1046            kernel_id: kernel_id.clone(),
1047            source_device,
1048            target_device,
1049            path,
1050            estimated_bandwidth_gbps: connection.map(|c| c.bandwidth_gbps).unwrap_or(16.0),
1051            estimated_latency_us: connection.map(|c| c.latency_us).unwrap_or(10.0),
1052            state: MigrationState::Pending,
1053            started_at: None,
1054        })
1055    }
1056
1057    /// Complete a migration (updates internal mappings).
1058    pub fn complete_migration(&self, request: &MigrationRequest) -> Result<()> {
1059        // Update kernel-device mapping
1060        {
1061            let mut map = self.kernel_device_map.write();
1062            if let Some(dev) = map.get_mut(&request.kernel_id) {
1063                *dev = request.target_device;
1064            }
1065        }
1066
1067        // Update kernel counts
1068        {
1069            let counts = self.device_kernel_counts.read();
1070            if request.source_device < counts.len() {
1071                counts[request.source_device].fetch_sub(1, Ordering::Relaxed);
1072            }
1073            if request.target_device < counts.len() {
1074                counts[request.target_device].fetch_add(1, Ordering::Relaxed);
1075            }
1076        }
1077
1078        Ok(())
1079    }
1080}
1081
1082// ============================================================================
1083// Kernel Migration Types
1084// ============================================================================
1085
1086/// State of a kernel migration.
1087#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1088pub enum MigrationState {
1089    /// Migration is pending, not yet started.
1090    Pending,
1091    /// Kernel is being quiesced (draining messages).
1092    Quiescing,
1093    /// Checkpoint is being created.
1094    Checkpointing,
1095    /// State is being transferred.
1096    Transferring,
1097    /// Kernel is being restored on target.
1098    Restoring,
1099    /// Migration completed successfully.
1100    Completed,
1101    /// Migration failed.
1102    Failed,
1103    /// Migration was cancelled.
1104    Cancelled,
1105}
1106
1107/// Request to migrate a kernel between devices.
1108#[derive(Debug, Clone)]
1109pub struct MigrationRequest {
1110    /// Kernel to migrate.
1111    pub kernel_id: KernelId,
1112    /// Source device index.
1113    pub source_device: usize,
1114    /// Target device index.
1115    pub target_device: usize,
1116    /// Path of devices for multi-hop migration.
1117    pub path: Vec<usize>,
1118    /// Estimated bandwidth for transfer.
1119    pub estimated_bandwidth_gbps: f64,
1120    /// Estimated latency.
1121    pub estimated_latency_us: f64,
1122    /// Current state.
1123    pub state: MigrationState,
1124    /// When migration started.
1125    pub started_at: Option<Instant>,
1126}
1127
1128impl MigrationRequest {
1129    /// Estimate transfer time for given state size.
1130    pub fn estimate_transfer_time(&self, state_size_bytes: usize) -> Duration {
1131        // time = size / bandwidth + latency
1132        let size_gb = state_size_bytes as f64 / 1_000_000_000.0;
1133        let transfer_time_s = size_gb / self.estimated_bandwidth_gbps;
1134        let total_us = (transfer_time_s * 1_000_000.0) + self.estimated_latency_us;
1135        Duration::from_micros(total_us as u64)
1136    }
1137}
1138
1139// ============================================================================
1140// Cross-GPU K2K Router
1141// ============================================================================
1142
1143/// Routes K2K messages across GPU boundaries.
1144pub struct CrossGpuK2KRouter {
1145    /// Multi-GPU coordinator.
1146    coordinator: Arc<MultiGpuCoordinator>,
1147    /// Message queues for pending cross-device messages.
1148    pending_queues: RwLock<HashMap<(usize, usize), Vec<PendingK2KMessage>>>,
1149    /// Statistics.
1150    stats: CrossGpuRouterStats,
1151}
1152
1153/// A pending cross-GPU K2K message.
1154#[derive(Debug, Clone)]
1155pub struct PendingK2KMessage {
1156    /// Source kernel ID.
1157    pub source_kernel: KernelId,
1158    /// Destination kernel ID.
1159    pub dest_kernel: KernelId,
1160    /// Message payload.
1161    pub message: K2KMessage,
1162    /// Timestamp when queued.
1163    pub queued_at: Instant,
1164    /// Number of routing hops.
1165    pub hops: u32,
1166}
1167
1168/// Statistics for cross-GPU K2K routing.
1169#[derive(Debug, Default)]
1170pub struct CrossGpuRouterStats {
1171    /// Total messages routed.
1172    messages_routed: AtomicU64,
1173    /// Total bytes transferred.
1174    bytes_transferred: AtomicU64,
1175    /// Messages currently pending.
1176    messages_pending: AtomicUsize,
1177    /// Total routing latency (microseconds).
1178    total_latency_us: AtomicU64,
1179    /// Failed routing attempts.
1180    routing_failures: AtomicU64,
1181}
1182
1183impl CrossGpuK2KRouter {
1184    /// Create a new cross-GPU K2K router.
1185    pub fn new(coordinator: Arc<MultiGpuCoordinator>) -> Arc<Self> {
1186        Arc::new(Self {
1187            coordinator,
1188            pending_queues: RwLock::new(HashMap::new()),
1189            stats: CrossGpuRouterStats::default(),
1190        })
1191    }
1192
1193    /// Route a message from source kernel to destination kernel.
1194    pub fn route_message(
1195        &self,
1196        source_kernel: &KernelId,
1197        dest_kernel: &KernelId,
1198        message: K2KMessage,
1199    ) -> Result<RoutingDecision> {
1200        let source_device = self
1201            .coordinator
1202            .get_kernel_device(source_kernel)
1203            .ok_or_else(|| {
1204                RingKernelError::K2KDestinationNotFound(source_kernel.as_str().to_string())
1205            })?;
1206
1207        let dest_device = self
1208            .coordinator
1209            .get_kernel_device(dest_kernel)
1210            .ok_or_else(|| {
1211                RingKernelError::K2KDestinationNotFound(dest_kernel.as_str().to_string())
1212            })?;
1213
1214        // Same device - use regular K2K
1215        if source_device == dest_device {
1216            return Ok(RoutingDecision::SameDevice);
1217        }
1218
1219        // Get topology for routing
1220        let topo = self.coordinator.topology();
1221        let path = topo.best_path(source_device, dest_device);
1222
1223        // Check if direct P2P is available
1224        if let Some(conn) = topo.get_connection(source_device, dest_device) {
1225            if conn.interconnect.supports_p2p() {
1226                // Queue for direct P2P transfer
1227                let pending = PendingK2KMessage {
1228                    source_kernel: source_kernel.clone(),
1229                    dest_kernel: dest_kernel.clone(),
1230                    message,
1231                    queued_at: Instant::now(),
1232                    hops: 1,
1233                };
1234
1235                self.enqueue_pending(source_device, dest_device, pending);
1236                self.stats.messages_pending.fetch_add(1, Ordering::Relaxed);
1237
1238                return Ok(RoutingDecision::DirectP2P {
1239                    source_device,
1240                    dest_device,
1241                    bandwidth_gbps: conn.bandwidth_gbps,
1242                });
1243            }
1244        }
1245
1246        // Multi-hop routing required
1247        if path.len() > 2 {
1248            let pending = PendingK2KMessage {
1249                source_kernel: source_kernel.clone(),
1250                dest_kernel: dest_kernel.clone(),
1251                message,
1252                queued_at: Instant::now(),
1253                hops: (path.len() - 1) as u32,
1254            };
1255
1256            // Queue for first hop
1257            self.enqueue_pending(source_device, path[1], pending);
1258            self.stats.messages_pending.fetch_add(1, Ordering::Relaxed);
1259
1260            return Ok(RoutingDecision::MultiHop {
1261                path: path.clone(),
1262                total_hops: (path.len() - 1) as u32,
1263            });
1264        }
1265
1266        // Fall back to host-mediated transfer
1267        let pending = PendingK2KMessage {
1268            source_kernel: source_kernel.clone(),
1269            dest_kernel: dest_kernel.clone(),
1270            message,
1271            queued_at: Instant::now(),
1272            hops: 2, // device->host->device
1273        };
1274
1275        self.enqueue_pending(source_device, dest_device, pending);
1276        self.stats.messages_pending.fetch_add(1, Ordering::Relaxed);
1277
1278        Ok(RoutingDecision::HostMediated {
1279            source_device,
1280            dest_device,
1281        })
1282    }
1283
1284    /// Get pending messages for a device pair.
1285    pub fn drain_pending(&self, source: usize, dest: usize) -> Vec<PendingK2KMessage> {
1286        let mut queues = self.pending_queues.write();
1287        let messages = queues.remove(&(source, dest)).unwrap_or_default();
1288        self.stats
1289            .messages_pending
1290            .fetch_sub(messages.len(), Ordering::Relaxed);
1291        messages
1292    }
1293
1294    /// Record successful message delivery.
1295    pub fn record_delivery(&self, message: &PendingK2KMessage, payload_size: usize) {
1296        self.stats.messages_routed.fetch_add(1, Ordering::Relaxed);
1297        self.stats
1298            .bytes_transferred
1299            .fetch_add(payload_size as u64, Ordering::Relaxed);
1300
1301        let latency = message.queued_at.elapsed().as_micros() as u64;
1302        self.stats
1303            .total_latency_us
1304            .fetch_add(latency, Ordering::Relaxed);
1305    }
1306
1307    /// Record routing failure.
1308    pub fn record_failure(&self) {
1309        self.stats.routing_failures.fetch_add(1, Ordering::Relaxed);
1310    }
1311
1312    /// Get router statistics.
1313    pub fn stats(&self) -> CrossGpuRouterStatsSnapshot {
1314        let messages_routed = self.stats.messages_routed.load(Ordering::Relaxed);
1315        let total_latency = self.stats.total_latency_us.load(Ordering::Relaxed);
1316
1317        CrossGpuRouterStatsSnapshot {
1318            messages_routed,
1319            bytes_transferred: self.stats.bytes_transferred.load(Ordering::Relaxed),
1320            messages_pending: self.stats.messages_pending.load(Ordering::Relaxed),
1321            avg_latency_us: if messages_routed > 0 {
1322                total_latency as f64 / messages_routed as f64
1323            } else {
1324                0.0
1325            },
1326            routing_failures: self.stats.routing_failures.load(Ordering::Relaxed),
1327        }
1328    }
1329
1330    fn enqueue_pending(&self, source: usize, dest: usize, message: PendingK2KMessage) {
1331        let mut queues = self.pending_queues.write();
1332        queues.entry((source, dest)).or_default().push(message);
1333    }
1334}
1335
1336/// Snapshot of router statistics.
1337#[derive(Debug, Clone)]
1338pub struct CrossGpuRouterStatsSnapshot {
1339    /// Total messages successfully routed.
1340    pub messages_routed: u64,
1341    /// Total bytes transferred.
1342    pub bytes_transferred: u64,
1343    /// Messages currently pending.
1344    pub messages_pending: usize,
1345    /// Average routing latency in microseconds.
1346    pub avg_latency_us: f64,
1347    /// Total routing failures.
1348    pub routing_failures: u64,
1349}
1350
1351/// Decision for how to route a K2K message.
1352#[derive(Debug, Clone)]
1353pub enum RoutingDecision {
1354    /// Source and destination on same device.
1355    SameDevice,
1356    /// Direct peer-to-peer transfer.
1357    DirectP2P {
1358        /// Source device index.
1359        source_device: usize,
1360        /// Destination device index.
1361        dest_device: usize,
1362        /// Available bandwidth.
1363        bandwidth_gbps: f64,
1364    },
1365    /// Multi-hop routing through intermediate devices.
1366    MultiHop {
1367        /// Device path.
1368        path: Vec<usize>,
1369        /// Total number of hops.
1370        total_hops: u32,
1371    },
1372    /// Route through host memory (slowest).
1373    HostMediated {
1374        /// Source device index.
1375        source_device: usize,
1376        /// Destination device index.
1377        dest_device: usize,
1378    },
1379}
1380
1381/// Multi-GPU coordinator statistics.
1382#[derive(Debug, Clone, Default)]
1383pub struct MultiGpuStats {
1384    /// Number of registered devices.
1385    pub device_count: usize,
1386    /// Total kernels across all devices.
1387    pub total_kernels: usize,
1388    /// Total memory across all devices.
1389    pub total_memory: u64,
1390    /// Available memory across all devices.
1391    pub available_memory: u64,
1392    /// Total kernels launched since start.
1393    pub kernels_launched: u64,
1394}
1395
1396/// Builder for multi-GPU coordinator.
1397pub struct MultiGpuBuilder {
1398    config: MultiGpuConfig,
1399}
1400
1401impl MultiGpuBuilder {
1402    /// Create a new builder.
1403    pub fn new() -> Self {
1404        Self {
1405            config: MultiGpuConfig::default(),
1406        }
1407    }
1408
1409    /// Set load balancing strategy.
1410    pub fn load_balancing(mut self, strategy: LoadBalancingStrategy) -> Self {
1411        self.config.load_balancing = strategy;
1412        self
1413    }
1414
1415    /// Set auto device selection.
1416    pub fn auto_select_device(mut self, enable: bool) -> Self {
1417        self.config.auto_select_device = enable;
1418        self
1419    }
1420
1421    /// Set max kernels per device.
1422    pub fn max_kernels_per_device(mut self, max: usize) -> Self {
1423        self.config.max_kernels_per_device = max;
1424        self
1425    }
1426
1427    /// Enable P2P transfers.
1428    pub fn enable_p2p(mut self, enable: bool) -> Self {
1429        self.config.enable_p2p = enable;
1430        self
1431    }
1432
1433    /// Set preferred devices.
1434    pub fn preferred_devices(mut self, devices: Vec<usize>) -> Self {
1435        self.config.preferred_devices = devices;
1436        self
1437    }
1438
1439    /// Build the coordinator.
1440    pub fn build(self) -> Arc<MultiGpuCoordinator> {
1441        MultiGpuCoordinator::new(self.config)
1442    }
1443}
1444
1445impl Default for MultiGpuBuilder {
1446    fn default() -> Self {
1447        Self::new()
1448    }
1449}
1450
1451/// Helper for cross-device data transfer.
1452pub struct CrossDeviceTransfer {
1453    /// Source device index.
1454    pub source_device: usize,
1455    /// Destination device index.
1456    pub dest_device: usize,
1457    /// Data size in bytes.
1458    pub size: usize,
1459    /// Use P2P if available.
1460    pub use_p2p: bool,
1461}
1462
1463impl CrossDeviceTransfer {
1464    /// Create a new transfer specification.
1465    pub fn new(source: usize, dest: usize, size: usize) -> Self {
1466        Self {
1467            source_device: source,
1468            dest_device: dest,
1469            size,
1470            use_p2p: true,
1471        }
1472    }
1473
1474    /// Disable P2P for this transfer.
1475    pub fn without_p2p(mut self) -> Self {
1476        self.use_p2p = false;
1477        self
1478    }
1479}
1480
1481// ============================================================================
1482// Kernel Migrator with Checkpoint Integration
1483// ============================================================================
1484
1485use crate::checkpoint::{CheckpointStorage, CheckpointableKernel, MemoryStorage};
1486
1487/// Migrator that uses checkpoints for kernel state transfer between GPUs.
1488///
1489/// This integrates the checkpoint infrastructure with the multi-GPU migration
1490/// system to enable live migration of persistent kernels.
1491///
1492/// # Example
1493///
1494/// ```ignore
1495/// use ringkernel_core::multi_gpu::{KernelMigrator, MultiGpuBuilder};
1496///
1497/// let coordinator = MultiGpuBuilder::new().build();
1498/// let migrator = KernelMigrator::new(coordinator);
1499///
1500/// // Migrate kernel from GPU 0 to GPU 1
1501/// migrator.migrate_with_checkpoint(&kernel, &mut request).await?;
1502/// ```
1503pub struct KernelMigrator {
1504    /// Multi-GPU coordinator.
1505    coordinator: Arc<MultiGpuCoordinator>,
1506    /// Checkpoint storage for migration state.
1507    storage: Arc<dyn CheckpointStorage>,
1508    /// Statistics.
1509    stats: MigrationStats,
1510}
1511
1512/// Statistics for kernel migrations.
1513#[derive(Debug, Default)]
1514pub struct MigrationStats {
1515    /// Total successful migrations.
1516    pub successful_migrations: AtomicU64,
1517    /// Total failed migrations.
1518    pub failed_migrations: AtomicU64,
1519    /// Total bytes transferred during migrations.
1520    pub bytes_transferred: AtomicU64,
1521    /// Total checkpoint time (microseconds).
1522    pub checkpoint_time_us: AtomicU64,
1523    /// Total restore time (microseconds).
1524    pub restore_time_us: AtomicU64,
1525}
1526
1527/// Result of a completed migration.
1528#[derive(Debug, Clone)]
1529pub struct MigrationResult {
1530    /// Kernel that was migrated.
1531    pub kernel_id: KernelId,
1532    /// Source device.
1533    pub source_device: usize,
1534    /// Target device.
1535    pub target_device: usize,
1536    /// Checkpoint size in bytes.
1537    pub checkpoint_size: usize,
1538    /// Time spent creating checkpoint.
1539    pub checkpoint_duration: Duration,
1540    /// Time spent transferring state.
1541    pub transfer_duration: Duration,
1542    /// Time spent restoring kernel.
1543    pub restore_duration: Duration,
1544    /// Total migration time.
1545    pub total_duration: Duration,
1546}
1547
1548impl KernelMigrator {
1549    /// Create a new kernel migrator with default in-memory storage.
1550    pub fn new(coordinator: Arc<MultiGpuCoordinator>) -> Self {
1551        Self {
1552            coordinator,
1553            storage: Arc::new(MemoryStorage::new()),
1554            stats: MigrationStats::default(),
1555        }
1556    }
1557
1558    /// Create a migrator with custom checkpoint storage.
1559    pub fn with_storage(
1560        coordinator: Arc<MultiGpuCoordinator>,
1561        storage: Arc<dyn CheckpointStorage>,
1562    ) -> Self {
1563        Self {
1564            coordinator,
1565            storage,
1566            stats: MigrationStats::default(),
1567        }
1568    }
1569
1570    /// Perform a complete migration using checkpoint-based state transfer.
1571    ///
1572    /// Steps:
1573    /// 1. Quiesce the source kernel (drain pending messages)
1574    /// 2. Create checkpoint of kernel state
1575    /// 3. Transfer checkpoint to target device
1576    /// 4. Restore kernel on target device
1577    /// 5. Update coordinator routing tables
1578    pub fn migrate_with_checkpoint<K: CheckpointableKernel>(
1579        &self,
1580        kernel: &K,
1581        request: &mut MigrationRequest,
1582    ) -> Result<MigrationResult> {
1583        let start_time = Instant::now();
1584        request.started_at = Some(start_time);
1585
1586        // Step 1: Quiesce
1587        request.state = MigrationState::Quiescing;
1588        // In a real implementation, this would drain message queues
1589        // For now, we assume the kernel is ready for checkpointing
1590
1591        // Step 2: Create checkpoint
1592        request.state = MigrationState::Checkpointing;
1593        let checkpoint_start = Instant::now();
1594        let checkpoint = kernel.create_checkpoint().map_err(|e| {
1595            self.stats.failed_migrations.fetch_add(1, Ordering::Relaxed);
1596            request.state = MigrationState::Failed;
1597            RingKernelError::MigrationFailed(format!("Checkpoint creation failed: {}", e))
1598        })?;
1599        let checkpoint_duration = checkpoint_start.elapsed();
1600        let checkpoint_size = checkpoint.total_size();
1601
1602        self.stats
1603            .checkpoint_time_us
1604            .fetch_add(checkpoint_duration.as_micros() as u64, Ordering::Relaxed);
1605
1606        // Step 3: Transfer
1607        request.state = MigrationState::Transferring;
1608        let transfer_start = Instant::now();
1609
1610        // Store checkpoint (simulates transfer)
1611        let checkpoint_name = format!(
1612            "migration_{}_{}_{}",
1613            request.kernel_id.as_str(),
1614            request.source_device,
1615            request.target_device
1616        );
1617        self.storage
1618            .save(&checkpoint, &checkpoint_name)
1619            .map_err(|e| {
1620                self.stats.failed_migrations.fetch_add(1, Ordering::Relaxed);
1621                request.state = MigrationState::Failed;
1622                RingKernelError::MigrationFailed(format!("Checkpoint transfer failed: {}", e))
1623            })?;
1624
1625        let transfer_duration = transfer_start.elapsed();
1626        self.stats
1627            .bytes_transferred
1628            .fetch_add(checkpoint_size as u64, Ordering::Relaxed);
1629
1630        // Step 4: Restore (would be done on target kernel)
1631        request.state = MigrationState::Restoring;
1632        let restore_start = Instant::now();
1633
1634        // Load checkpoint to verify it's valid
1635        let _restored = self.storage.load(&checkpoint_name).map_err(|e| {
1636            self.stats.failed_migrations.fetch_add(1, Ordering::Relaxed);
1637            request.state = MigrationState::Failed;
1638            RingKernelError::MigrationFailed(format!("Checkpoint restore failed: {}", e))
1639        })?;
1640
1641        let restore_duration = restore_start.elapsed();
1642        self.stats
1643            .restore_time_us
1644            .fetch_add(restore_duration.as_micros() as u64, Ordering::Relaxed);
1645
1646        // Step 5: Update routing
1647        request.state = MigrationState::Completed;
1648        self.coordinator.complete_migration(request)?;
1649
1650        // Clean up checkpoint
1651        let _ = self.storage.delete(&checkpoint_name);
1652
1653        self.stats
1654            .successful_migrations
1655            .fetch_add(1, Ordering::Relaxed);
1656
1657        Ok(MigrationResult {
1658            kernel_id: request.kernel_id.clone(),
1659            source_device: request.source_device,
1660            target_device: request.target_device,
1661            checkpoint_size,
1662            checkpoint_duration,
1663            transfer_duration,
1664            restore_duration,
1665            total_duration: start_time.elapsed(),
1666        })
1667    }
1668
1669    /// Get a reference to the coordinator.
1670    pub fn coordinator(&self) -> &Arc<MultiGpuCoordinator> {
1671        &self.coordinator
1672    }
1673
1674    /// Get migration statistics snapshot.
1675    pub fn stats(&self) -> MigrationStatsSnapshot {
1676        let successful = self.stats.successful_migrations.load(Ordering::Relaxed);
1677        let failed = self.stats.failed_migrations.load(Ordering::Relaxed);
1678        let total = successful + failed;
1679        let checkpoint_us = self.stats.checkpoint_time_us.load(Ordering::Relaxed);
1680        let restore_us = self.stats.restore_time_us.load(Ordering::Relaxed);
1681
1682        MigrationStatsSnapshot {
1683            successful_migrations: successful,
1684            failed_migrations: failed,
1685            bytes_transferred: self.stats.bytes_transferred.load(Ordering::Relaxed),
1686            avg_checkpoint_time: checkpoint_us
1687                .checked_div(total)
1688                .map(Duration::from_micros)
1689                .unwrap_or(Duration::ZERO),
1690            avg_restore_time: restore_us
1691                .checked_div(total)
1692                .map(Duration::from_micros)
1693                .unwrap_or(Duration::ZERO),
1694        }
1695    }
1696}
1697
1698/// Snapshot of migration statistics.
1699#[derive(Debug, Clone)]
1700pub struct MigrationStatsSnapshot {
1701    /// Total successful migrations.
1702    pub successful_migrations: u64,
1703    /// Total failed migrations.
1704    pub failed_migrations: u64,
1705    /// Total bytes transferred.
1706    pub bytes_transferred: u64,
1707    /// Average checkpoint creation time.
1708    pub avg_checkpoint_time: Duration,
1709    /// Average restore time.
1710    pub avg_restore_time: Duration,
1711}
1712
1713/// Trait for kernels that support live migration.
1714pub trait MigratableKernel: CheckpointableKernel {
1715    /// Prepare kernel for migration (quiesce, drain messages).
1716    fn prepare_for_migration(&mut self) -> Result<()>;
1717
1718    /// Resume kernel after migration is cancelled.
1719    fn cancel_migration(&mut self) -> Result<()>;
1720
1721    /// Check if kernel is ready to be checkpointed.
1722    fn is_quiescent(&self) -> bool;
1723
1724    /// Get estimated state size for migration planning.
1725    fn estimated_state_size(&self) -> usize;
1726}
1727
1728// ============================================================================
1729// Hot Reload Support
1730// ============================================================================
1731
1732/// Configuration for kernel hot reload operations.
1733#[derive(Debug, Clone)]
1734pub struct HotReloadConfig {
1735    /// Enable hot reload functionality.
1736    pub enabled: bool,
1737    /// Timeout for reload operations.
1738    pub reload_timeout: Duration,
1739    /// Whether to preserve kernel state during reload.
1740    pub preserve_state: bool,
1741    /// Maximum retries for failed reloads.
1742    pub max_retries: u32,
1743    /// Backoff duration between retries.
1744    pub retry_backoff: Duration,
1745    /// Whether to validate new code before swapping.
1746    pub validate_before_swap: bool,
1747    /// Keep old code as fallback in case of failure.
1748    pub keep_fallback: bool,
1749    /// Number of compiled-rule versions retained per rule for rollback
1750    /// (FIFO eviction, see `rules::RuleRegistry`). Default: 5.
1751    pub max_rule_history: usize,
1752}
1753
1754impl Default for HotReloadConfig {
1755    fn default() -> Self {
1756        Self {
1757            enabled: true,
1758            reload_timeout: Duration::from_secs(30),
1759            preserve_state: true,
1760            max_retries: 3,
1761            retry_backoff: Duration::from_millis(500),
1762            validate_before_swap: true,
1763            keep_fallback: true,
1764            max_rule_history: 5,
1765        }
1766    }
1767}
1768
1769impl HotReloadConfig {
1770    /// Create a new hot reload configuration.
1771    pub fn new() -> Self {
1772        Self::default()
1773    }
1774
1775    /// Enable or disable hot reload.
1776    pub fn with_enabled(mut self, enabled: bool) -> Self {
1777        self.enabled = enabled;
1778        self
1779    }
1780
1781    /// Set reload timeout.
1782    pub fn with_timeout(mut self, timeout: Duration) -> Self {
1783        self.reload_timeout = timeout;
1784        self
1785    }
1786
1787    /// Enable or disable state preservation.
1788    pub fn with_preserve_state(mut self, preserve: bool) -> Self {
1789        self.preserve_state = preserve;
1790        self
1791    }
1792
1793    /// Set maximum retries.
1794    pub fn with_max_retries(mut self, retries: u32) -> Self {
1795        self.max_retries = retries;
1796        self
1797    }
1798}
1799
1800/// State of a hot reload operation.
1801#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1802pub enum HotReloadState {
1803    /// Reload not started.
1804    Idle,
1805    /// Draining pending messages from kernel.
1806    Draining,
1807    /// Creating checkpoint of kernel state.
1808    Checkpointing,
1809    /// Compiling new kernel code.
1810    Compiling,
1811    /// Validating new kernel code.
1812    Validating,
1813    /// Swapping old kernel with new.
1814    Swapping,
1815    /// Restoring state to new kernel.
1816    Restoring,
1817    /// Hot reload completed successfully.
1818    Completed,
1819    /// Hot reload failed.
1820    Failed,
1821    /// Rolling back to previous version.
1822    RollingBack,
1823}
1824
1825/// Kernel code format.
1826#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1827pub enum KernelCodeFormat {
1828    /// NVIDIA PTX assembly.
1829    Ptx,
1830    /// NVIDIA CUBIN binary.
1831    Cubin,
1832    /// SPIR-V for Vulkan/WebGPU.
1833    SpirV,
1834    /// WGSL shader text.
1835    Wgsl,
1836    /// Metal Shading Language.
1837    Msl,
1838    /// Metal compiled library.
1839    MetalLib,
1840    /// Source code (requires compilation).
1841    Source,
1842}
1843
1844/// Kernel code source for hot reload.
1845#[derive(Debug, Clone)]
1846pub struct KernelCodeSource {
1847    /// Unique identifier for this code version.
1848    pub version_id: u64,
1849    /// Code format.
1850    pub format: KernelCodeFormat,
1851    /// Raw code bytes.
1852    pub code: Vec<u8>,
1853    /// Entry point function name.
1854    pub entry_point: String,
1855    /// Optional metadata (compile flags, etc.).
1856    pub metadata: HashMap<String, String>,
1857    /// Timestamp when code was created.
1858    pub created_at: Instant,
1859    /// SHA-256 hash of the code.
1860    pub hash: [u8; 32],
1861}
1862
1863impl KernelCodeSource {
1864    /// Create a new kernel code source.
1865    pub fn new(format: KernelCodeFormat, code: Vec<u8>, entry_point: impl Into<String>) -> Self {
1866        let hash = Self::compute_hash(&code);
1867        Self {
1868            version_id: 0,
1869            format,
1870            code,
1871            entry_point: entry_point.into(),
1872            metadata: HashMap::new(),
1873            created_at: Instant::now(),
1874            hash,
1875        }
1876    }
1877
1878    /// Create from PTX code.
1879    pub fn from_ptx(ptx: &str, entry_point: impl Into<String>) -> Self {
1880        Self::new(KernelCodeFormat::Ptx, ptx.as_bytes().to_vec(), entry_point)
1881    }
1882
1883    /// Create from WGSL code.
1884    pub fn from_wgsl(wgsl: &str, entry_point: impl Into<String>) -> Self {
1885        Self::new(
1886            KernelCodeFormat::Wgsl,
1887            wgsl.as_bytes().to_vec(),
1888            entry_point,
1889        )
1890    }
1891
1892    /// Create from MSL code.
1893    pub fn from_msl(msl: &str, entry_point: impl Into<String>) -> Self {
1894        Self::new(KernelCodeFormat::Msl, msl.as_bytes().to_vec(), entry_point)
1895    }
1896
1897    /// Set version ID.
1898    pub fn with_version(mut self, version: u64) -> Self {
1899        self.version_id = version;
1900        self
1901    }
1902
1903    /// Add metadata.
1904    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
1905        self.metadata.insert(key.into(), value.into());
1906        self
1907    }
1908
1909    fn compute_hash(data: &[u8]) -> [u8; 32] {
1910        use std::hash::{Hash, Hasher};
1911        let mut hasher = std::collections::hash_map::DefaultHasher::new();
1912        data.hash(&mut hasher);
1913        let h1 = hasher.finish();
1914        h1.hash(&mut hasher);
1915        let h2 = hasher.finish();
1916        h1.hash(&mut hasher);
1917        let h3 = hasher.finish();
1918        h1.hash(&mut hasher);
1919        let h4 = hasher.finish();
1920
1921        let mut hash = [0u8; 32];
1922        hash[0..8].copy_from_slice(&h1.to_le_bytes());
1923        hash[8..16].copy_from_slice(&h2.to_le_bytes());
1924        hash[16..24].copy_from_slice(&h3.to_le_bytes());
1925        hash[24..32].copy_from_slice(&h4.to_le_bytes());
1926        hash
1927    }
1928
1929    /// Get code as string (if text format).
1930    pub fn as_str(&self) -> Option<&str> {
1931        match self.format {
1932            KernelCodeFormat::Ptx
1933            | KernelCodeFormat::Wgsl
1934            | KernelCodeFormat::Msl
1935            | KernelCodeFormat::Source => std::str::from_utf8(&self.code).ok(),
1936            _ => None,
1937        }
1938    }
1939
1940    /// Get code size in bytes.
1941    pub fn size(&self) -> usize {
1942        self.code.len()
1943    }
1944}
1945
1946/// Request to hot reload a kernel.
1947#[derive(Debug)]
1948pub struct HotReloadRequest {
1949    /// Target kernel ID.
1950    pub kernel_id: KernelId,
1951    /// New kernel code.
1952    pub new_code: KernelCodeSource,
1953    /// Current state of the reload operation.
1954    pub state: HotReloadState,
1955    /// When the request was created.
1956    pub created_at: Instant,
1957    /// When the reload started.
1958    pub started_at: Option<Instant>,
1959    /// Retry count.
1960    pub retry_count: u32,
1961    /// Error message if failed.
1962    pub error: Option<String>,
1963    /// Checkpoint data (if preserving state).
1964    checkpoint_data: Option<Vec<u8>>,
1965}
1966
1967impl HotReloadRequest {
1968    /// Create a new hot reload request.
1969    pub fn new(kernel_id: KernelId, new_code: KernelCodeSource) -> Self {
1970        Self {
1971            kernel_id,
1972            new_code,
1973            state: HotReloadState::Idle,
1974            created_at: Instant::now(),
1975            started_at: None,
1976            retry_count: 0,
1977            error: None,
1978            checkpoint_data: None,
1979        }
1980    }
1981
1982    /// Check if reload is in progress.
1983    pub fn is_in_progress(&self) -> bool {
1984        !matches!(
1985            self.state,
1986            HotReloadState::Idle | HotReloadState::Completed | HotReloadState::Failed
1987        )
1988    }
1989
1990    /// Check if reload completed successfully.
1991    pub fn is_completed(&self) -> bool {
1992        self.state == HotReloadState::Completed
1993    }
1994
1995    /// Check if reload failed.
1996    pub fn is_failed(&self) -> bool {
1997        self.state == HotReloadState::Failed
1998    }
1999
2000    /// Get elapsed time since request creation.
2001    pub fn elapsed(&self) -> Duration {
2002        self.created_at.elapsed()
2003    }
2004
2005    /// Get elapsed time since reload started.
2006    pub fn reload_elapsed(&self) -> Option<Duration> {
2007        self.started_at.map(|s| s.elapsed())
2008    }
2009}
2010
2011/// Result of a completed hot reload.
2012#[derive(Debug, Clone)]
2013pub struct HotReloadResult {
2014    /// Target kernel ID.
2015    pub kernel_id: KernelId,
2016    /// Previous code version.
2017    pub old_version: u64,
2018    /// New code version.
2019    pub new_version: u64,
2020    /// Whether state was preserved.
2021    pub state_preserved: bool,
2022    /// Size of checkpoint data (if any).
2023    pub checkpoint_size: usize,
2024    /// Time to drain messages.
2025    pub drain_duration: Duration,
2026    /// Time to create checkpoint.
2027    pub checkpoint_duration: Duration,
2028    /// Time to compile new code.
2029    pub compile_duration: Duration,
2030    /// Time to swap kernels.
2031    pub swap_duration: Duration,
2032    /// Time to restore state.
2033    pub restore_duration: Duration,
2034    /// Total reload duration.
2035    pub total_duration: Duration,
2036}
2037
2038/// Statistics for hot reload operations.
2039#[derive(Debug, Default)]
2040struct HotReloadStats {
2041    successful_reloads: AtomicU64,
2042    failed_reloads: AtomicU64,
2043    rollbacks: AtomicU64,
2044    total_drain_time_us: AtomicU64,
2045    total_compile_time_us: AtomicU64,
2046    total_swap_time_us: AtomicU64,
2047    state_preserved_count: AtomicU64,
2048}
2049
2050/// Snapshot of hot reload statistics.
2051#[derive(Debug, Clone)]
2052pub struct HotReloadStatsSnapshot {
2053    /// Total successful reloads.
2054    pub successful_reloads: u64,
2055    /// Total failed reloads.
2056    pub failed_reloads: u64,
2057    /// Total rollbacks performed.
2058    pub rollbacks: u64,
2059    /// Average drain time.
2060    pub avg_drain_time: Duration,
2061    /// Average compile time.
2062    pub avg_compile_time: Duration,
2063    /// Average swap time.
2064    pub avg_swap_time: Duration,
2065    /// Number of reloads with preserved state.
2066    pub state_preserved_count: u64,
2067}
2068
2069/// Manager for kernel hot reload operations.
2070///
2071/// Provides seamless kernel code updates without stopping the system:
2072///
2073/// 1. Drain pending messages from kernel input queue
2074/// 2. Checkpoint kernel state (if preserving state)
2075/// 3. Compile/validate new kernel code
2076/// 4. Swap old kernel with new kernel
2077/// 5. Restore state to new kernel
2078/// 6. Resume processing
2079///
2080/// # Example
2081///
2082/// ```ignore
2083/// use ringkernel_core::multi_gpu::{HotReloadManager, HotReloadConfig, KernelCodeSource};
2084///
2085/// let manager = HotReloadManager::new(HotReloadConfig::default());
2086///
2087/// // Register a reloadable kernel
2088/// manager.register_kernel(&kernel_id, current_code);
2089///
2090/// // Request hot reload with new PTX
2091/// let new_code = KernelCodeSource::from_ptx(new_ptx, "my_kernel");
2092/// let request = manager.request_reload(&kernel_id, new_code).await?;
2093///
2094/// // Execute the reload
2095/// let result = manager.execute_reload(request, &mut kernel).await?;
2096/// println!("Reload completed in {:?}", result.total_duration);
2097/// ```
2098pub struct HotReloadManager {
2099    /// Configuration.
2100    config: HotReloadConfig,
2101    /// Registered kernels and their current code.
2102    kernels: RwLock<HashMap<KernelId, KernelCodeSource>>,
2103    /// Fallback code for registered kernels.
2104    fallbacks: RwLock<HashMap<KernelId, KernelCodeSource>>,
2105    /// Active reload requests.
2106    active_requests: RwLock<HashMap<KernelId, HotReloadRequest>>,
2107    /// Version counter for code versions.
2108    version_counter: AtomicU64,
2109    /// Statistics.
2110    stats: HotReloadStats,
2111    /// Compiled-rule registry (v1.1, per spec 3.3). Hot-swaps whole
2112    /// inference-rule actors identified by opaque `CompiledRule`
2113    /// artifacts produced by upstream compilers (e.g. VynGraph).
2114    rule_registry: Arc<crate::rules::RuleRegistry>,
2115}
2116
2117impl HotReloadManager {
2118    /// Create a new hot reload manager.
2119    pub fn new(config: HotReloadConfig) -> Arc<Self> {
2120        Self::with_rule_backend(config, Arc::new(crate::rules::NoopSwapBackend))
2121    }
2122
2123    /// Create a new hot reload manager with a custom rule-swap backend.
2124    ///
2125    /// Production code (e.g. `ringkernel-cuda`) should use this to inject
2126    /// a backend that actually performs the GPU-side atomic actor swap.
2127    pub fn with_rule_backend(
2128        config: HotReloadConfig,
2129        rule_backend: Arc<dyn crate::rules::RuleSwapBackend>,
2130    ) -> Arc<Self> {
2131        let rule_registry = Arc::new(crate::rules::RuleRegistry::new(
2132            config.max_rule_history,
2133            rule_backend,
2134        ));
2135        Arc::new(Self {
2136            config,
2137            kernels: RwLock::new(HashMap::new()),
2138            fallbacks: RwLock::new(HashMap::new()),
2139            active_requests: RwLock::new(HashMap::new()),
2140            version_counter: AtomicU64::new(1),
2141            stats: HotReloadStats::default(),
2142            rule_registry,
2143        })
2144    }
2145
2146    /// Access the compiled-rule registry for hot-swap of inference rules.
2147    pub fn rule_registry(&self) -> &Arc<crate::rules::RuleRegistry> {
2148        &self.rule_registry
2149    }
2150
2151    /// Create with default configuration.
2152    pub fn with_defaults() -> Arc<Self> {
2153        Self::new(HotReloadConfig::default())
2154    }
2155
2156    /// Check if hot reload is enabled.
2157    pub fn is_enabled(&self) -> bool {
2158        self.config.enabled
2159    }
2160
2161    /// Register a kernel for hot reload.
2162    pub fn register_kernel(&self, kernel_id: &KernelId, code: KernelCodeSource) {
2163        let version = self.version_counter.fetch_add(1, Ordering::Relaxed);
2164        let code = code.with_version(version);
2165        self.kernels.write().insert(kernel_id.clone(), code);
2166    }
2167
2168    /// Unregister a kernel from hot reload.
2169    pub fn unregister_kernel(&self, kernel_id: &KernelId) {
2170        self.kernels.write().remove(kernel_id);
2171        self.fallbacks.write().remove(kernel_id);
2172        self.active_requests.write().remove(kernel_id);
2173    }
2174
2175    /// Get current code version for a kernel.
2176    pub fn get_current_version(&self, kernel_id: &KernelId) -> Option<u64> {
2177        self.kernels.read().get(kernel_id).map(|c| c.version_id)
2178    }
2179
2180    /// Get current code for a kernel.
2181    pub fn get_current_code(&self, kernel_id: &KernelId) -> Option<KernelCodeSource> {
2182        self.kernels.read().get(kernel_id).cloned()
2183    }
2184
2185    /// Request a hot reload for a kernel.
2186    pub fn request_reload(
2187        &self,
2188        kernel_id: &KernelId,
2189        new_code: KernelCodeSource,
2190    ) -> Result<HotReloadRequest> {
2191        if !self.config.enabled {
2192            return Err(RingKernelError::ValidationError(
2193                "Hot reload is disabled".to_string(),
2194            ));
2195        }
2196
2197        // Check kernel is registered
2198        if !self.kernels.read().contains_key(kernel_id) {
2199            return Err(RingKernelError::KernelNotFound(
2200                kernel_id.as_str().to_string(),
2201            ));
2202        }
2203
2204        // Check no reload already in progress
2205        {
2206            let active = self.active_requests.read();
2207            if let Some(existing) = active.get(kernel_id) {
2208                if existing.is_in_progress() {
2209                    return Err(RingKernelError::ValidationError(
2210                        "Hot reload already in progress for this kernel".to_string(),
2211                    ));
2212                }
2213            }
2214        }
2215
2216        // Assign version to new code
2217        let version = self.version_counter.fetch_add(1, Ordering::Relaxed);
2218        let new_code = new_code.with_version(version);
2219
2220        let request = HotReloadRequest::new(kernel_id.clone(), new_code);
2221        self.active_requests.write().insert(
2222            kernel_id.clone(),
2223            HotReloadRequest::new(kernel_id.clone(), request.new_code.clone()),
2224        );
2225
2226        Ok(request)
2227    }
2228
2229    /// Execute a hot reload operation.
2230    ///
2231    /// This performs the full reload sequence:
2232    /// 1. Drain pending messages
2233    /// 2. Checkpoint state (if enabled)
2234    /// 3. Validate new code
2235    /// 4. Swap kernels
2236    /// 5. Restore state (if enabled)
2237    pub fn execute_reload<K: CheckpointableKernel>(
2238        &self,
2239        request: &mut HotReloadRequest,
2240        kernel: &K,
2241    ) -> Result<HotReloadResult> {
2242        let start_time = Instant::now();
2243        request.started_at = Some(start_time);
2244
2245        // Get old version
2246        let old_version = self
2247            .kernels
2248            .read()
2249            .get(&request.kernel_id)
2250            .map(|c| c.version_id)
2251            .unwrap_or(0);
2252
2253        // Phase 1: Drain (simulated - actual drain would wait for queue empty)
2254        request.state = HotReloadState::Draining;
2255        let drain_start = Instant::now();
2256        // In a real implementation, wait for input queue to drain
2257        std::thread::sleep(Duration::from_micros(10));
2258        let drain_duration = drain_start.elapsed();
2259        self.stats
2260            .total_drain_time_us
2261            .fetch_add(drain_duration.as_micros() as u64, Ordering::Relaxed);
2262
2263        // Phase 2: Checkpoint (if preserving state)
2264        request.state = HotReloadState::Checkpointing;
2265        let checkpoint_start = Instant::now();
2266        let checkpoint_size = if self.config.preserve_state {
2267            let checkpoint = kernel.create_checkpoint()?;
2268            let data = checkpoint.to_bytes();
2269            request.checkpoint_data = Some(data.clone());
2270            data.len()
2271        } else {
2272            0
2273        };
2274        let checkpoint_duration = checkpoint_start.elapsed();
2275
2276        // Phase 3: Validate new code
2277        request.state = HotReloadState::Validating;
2278        if self.config.validate_before_swap {
2279            self.validate_code(&request.new_code)?;
2280        }
2281
2282        // Phase 4: Compile (simulated)
2283        request.state = HotReloadState::Compiling;
2284        let compile_start = Instant::now();
2285        // In real implementation, compile PTX/WGSL to native code
2286        std::thread::sleep(Duration::from_micros(10));
2287        let compile_duration = compile_start.elapsed();
2288        self.stats
2289            .total_compile_time_us
2290            .fetch_add(compile_duration.as_micros() as u64, Ordering::Relaxed);
2291
2292        // Phase 5: Swap
2293        request.state = HotReloadState::Swapping;
2294        let swap_start = Instant::now();
2295
2296        // Save fallback
2297        if self.config.keep_fallback {
2298            if let Some(old_code) = self.kernels.read().get(&request.kernel_id).cloned() {
2299                self.fallbacks
2300                    .write()
2301                    .insert(request.kernel_id.clone(), old_code);
2302            }
2303        }
2304
2305        // Install new code
2306        self.kernels
2307            .write()
2308            .insert(request.kernel_id.clone(), request.new_code.clone());
2309        let swap_duration = swap_start.elapsed();
2310        self.stats
2311            .total_swap_time_us
2312            .fetch_add(swap_duration.as_micros() as u64, Ordering::Relaxed);
2313
2314        // Phase 6: Restore state
2315        request.state = HotReloadState::Restoring;
2316        let restore_start = Instant::now();
2317        // In real implementation, restore checkpoint to new kernel
2318        let restore_duration = restore_start.elapsed();
2319
2320        // Mark completed
2321        request.state = HotReloadState::Completed;
2322        self.stats
2323            .successful_reloads
2324            .fetch_add(1, Ordering::Relaxed);
2325        if self.config.preserve_state && checkpoint_size > 0 {
2326            self.stats
2327                .state_preserved_count
2328                .fetch_add(1, Ordering::Relaxed);
2329        }
2330
2331        // Clean up active request
2332        self.active_requests.write().remove(&request.kernel_id);
2333
2334        Ok(HotReloadResult {
2335            kernel_id: request.kernel_id.clone(),
2336            old_version,
2337            new_version: request.new_code.version_id,
2338            state_preserved: self.config.preserve_state && checkpoint_size > 0,
2339            checkpoint_size,
2340            drain_duration,
2341            checkpoint_duration,
2342            compile_duration,
2343            swap_duration,
2344            restore_duration,
2345            total_duration: start_time.elapsed(),
2346        })
2347    }
2348
2349    /// Rollback to previous kernel version.
2350    pub fn rollback(&self, kernel_id: &KernelId) -> Result<()> {
2351        let fallback =
2352            self.fallbacks.write().remove(kernel_id).ok_or_else(|| {
2353                RingKernelError::ValidationError("No fallback available".to_string())
2354            })?;
2355
2356        self.kernels.write().insert(kernel_id.clone(), fallback);
2357        self.stats.rollbacks.fetch_add(1, Ordering::Relaxed);
2358
2359        // Update any active request
2360        if let Some(request) = self.active_requests.write().get_mut(kernel_id) {
2361            request.state = HotReloadState::RollingBack;
2362        }
2363
2364        Ok(())
2365    }
2366
2367    /// Validate kernel code before swap.
2368    fn validate_code(&self, code: &KernelCodeSource) -> Result<()> {
2369        // Basic validation
2370        if code.code.is_empty() {
2371            return Err(RingKernelError::ValidationError(
2372                "Kernel code is empty".to_string(),
2373            ));
2374        }
2375
2376        if code.entry_point.is_empty() {
2377            return Err(RingKernelError::ValidationError(
2378                "Entry point is empty".to_string(),
2379            ));
2380        }
2381
2382        // Format-specific validation
2383        match code.format {
2384            KernelCodeFormat::Ptx => {
2385                // Check for valid PTX header
2386                if let Some(text) = code.as_str() {
2387                    if !text.contains(".version") && !text.contains(".target") {
2388                        return Err(RingKernelError::ValidationError(
2389                            "PTX code missing version/target directive".to_string(),
2390                        ));
2391                    }
2392                }
2393            }
2394            KernelCodeFormat::Wgsl => {
2395                // Check for basic WGSL structure
2396                if let Some(text) = code.as_str() {
2397                    if !text.contains("@compute") && !text.contains("fn ") {
2398                        return Err(RingKernelError::ValidationError(
2399                            "WGSL code missing compute shader or function".to_string(),
2400                        ));
2401                    }
2402                }
2403            }
2404            KernelCodeFormat::Msl => {
2405                // Check for Metal kernel
2406                if let Some(text) = code.as_str() {
2407                    if !text.contains("kernel ") {
2408                        return Err(RingKernelError::ValidationError(
2409                            "MSL code missing kernel function".to_string(),
2410                        ));
2411                    }
2412                }
2413            }
2414            _ => {}
2415        }
2416
2417        Ok(())
2418    }
2419
2420    /// Get statistics snapshot.
2421    pub fn stats(&self) -> HotReloadStatsSnapshot {
2422        let successful = self.stats.successful_reloads.load(Ordering::Relaxed);
2423        let failed = self.stats.failed_reloads.load(Ordering::Relaxed);
2424        let total = successful.max(1);
2425
2426        HotReloadStatsSnapshot {
2427            successful_reloads: successful,
2428            failed_reloads: failed,
2429            rollbacks: self.stats.rollbacks.load(Ordering::Relaxed),
2430            avg_drain_time: Duration::from_micros(
2431                self.stats.total_drain_time_us.load(Ordering::Relaxed) / total,
2432            ),
2433            avg_compile_time: Duration::from_micros(
2434                self.stats.total_compile_time_us.load(Ordering::Relaxed) / total,
2435            ),
2436            avg_swap_time: Duration::from_micros(
2437                self.stats.total_swap_time_us.load(Ordering::Relaxed) / total,
2438            ),
2439            state_preserved_count: self.stats.state_preserved_count.load(Ordering::Relaxed),
2440        }
2441    }
2442
2443    /// List all registered kernels.
2444    pub fn list_kernels(&self) -> Vec<KernelId> {
2445        self.kernels.read().keys().cloned().collect()
2446    }
2447
2448    /// Check if a kernel is registered.
2449    pub fn is_registered(&self, kernel_id: &KernelId) -> bool {
2450        self.kernels.read().contains_key(kernel_id)
2451    }
2452
2453    /// Check if a reload is in progress for a kernel.
2454    pub fn is_reload_in_progress(&self, kernel_id: &KernelId) -> bool {
2455        self.active_requests
2456            .read()
2457            .get(kernel_id)
2458            .map(|r| r.is_in_progress())
2459            .unwrap_or(false)
2460    }
2461
2462    /// Get the configuration.
2463    pub fn config(&self) -> &HotReloadConfig {
2464        &self.config
2465    }
2466}
2467
2468/// Trait for kernels that support hot reload.
2469pub trait HotReloadableKernel: CheckpointableKernel {
2470    /// Prepare kernel for code swap (drain messages, pause processing).
2471    fn prepare_for_reload(&mut self) -> Result<()>;
2472
2473    /// Apply new code to the kernel.
2474    fn apply_code(&mut self, code: &KernelCodeSource) -> Result<()>;
2475
2476    /// Resume processing after reload.
2477    fn resume_after_reload(&mut self) -> Result<()>;
2478
2479    /// Check if kernel is ready for reload.
2480    fn is_ready_for_reload(&self) -> bool;
2481}
2482
2483#[cfg(test)]
2484mod tests {
2485    use super::*;
2486
2487    #[test]
2488    fn test_device_info() {
2489        let info = DeviceInfo::new(0, "Test GPU".to_string(), Backend::Cuda);
2490        assert_eq!(info.index, 0);
2491        assert_eq!(info.name, "Test GPU");
2492        assert_eq!(info.memory_utilization(), 0.0);
2493    }
2494
2495    #[test]
2496    fn test_coordinator_registration() {
2497        let coord = MultiGpuBuilder::new().build();
2498
2499        let device = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2500        coord.register_device(device);
2501
2502        assert_eq!(coord.device_count(), 1);
2503        assert!(coord.device(0).is_some());
2504    }
2505
2506    #[test]
2507    fn test_kernel_assignment() {
2508        let coord = MultiGpuBuilder::new().build();
2509
2510        let device = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2511        coord.register_device(device);
2512
2513        let kernel_id = KernelId::new("test_kernel");
2514        coord.assign_kernel(kernel_id.clone(), 0);
2515
2516        assert_eq!(coord.get_kernel_device(&kernel_id), Some(0));
2517        assert_eq!(coord.kernels_on_device(0).len(), 1);
2518    }
2519
2520    #[test]
2521    fn test_load_balancing_least_loaded() {
2522        let coord = MultiGpuBuilder::new()
2523            .load_balancing(LoadBalancingStrategy::LeastLoaded)
2524            .build();
2525
2526        // Register two devices
2527        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2528        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2529
2530        // Assign a kernel to device 0
2531        coord.assign_kernel(KernelId::new("k1"), 0);
2532
2533        // Next kernel should go to device 1 (least loaded)
2534        let selected = coord.select_device(&LaunchOptions::default()).unwrap();
2535        assert_eq!(selected, 1);
2536    }
2537
2538    #[test]
2539    fn test_round_robin() {
2540        let coord = MultiGpuBuilder::new()
2541            .load_balancing(LoadBalancingStrategy::RoundRobin)
2542            .build();
2543
2544        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2545        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2546
2547        let d1 = coord.select_device(&LaunchOptions::default()).unwrap();
2548        let d2 = coord.select_device(&LaunchOptions::default()).unwrap();
2549        let d3 = coord.select_device(&LaunchOptions::default()).unwrap();
2550
2551        // Should cycle through devices
2552        assert_ne!(d1, d2);
2553        assert_eq!(d1, d3);
2554    }
2555
2556    // ========================================================================
2557    // Topology Tests
2558    // ========================================================================
2559
2560    #[test]
2561    fn test_interconnect_bandwidth() {
2562        assert!(
2563            InterconnectType::NvLink.estimated_bandwidth_gbps()
2564                > InterconnectType::Pcie.estimated_bandwidth_gbps()
2565        );
2566        assert!(
2567            InterconnectType::Pcie.estimated_bandwidth_gbps()
2568                > InterconnectType::None.estimated_bandwidth_gbps()
2569        );
2570        assert!(
2571            InterconnectType::SameDevice.estimated_bandwidth_gbps()
2572                > InterconnectType::NvLink.estimated_bandwidth_gbps()
2573        );
2574    }
2575
2576    #[test]
2577    fn test_interconnect_p2p_support() {
2578        assert!(!InterconnectType::None.supports_p2p());
2579        assert!(InterconnectType::Pcie.supports_p2p());
2580        assert!(InterconnectType::NvLink.supports_p2p());
2581        assert!(InterconnectType::NvSwitch.supports_p2p());
2582    }
2583
2584    #[test]
2585    fn test_gpu_topology_creation() {
2586        let topo = GpuTopology::new(4);
2587        assert_eq!(topo.device_count, 4);
2588
2589        // Self-connections should exist
2590        for i in 0..4 {
2591            let conn = topo.get_connection(i, i);
2592            assert!(conn.is_some());
2593            assert_eq!(conn.unwrap().interconnect, InterconnectType::SameDevice);
2594        }
2595    }
2596
2597    #[test]
2598    fn test_gpu_topology_set_connection() {
2599        let mut topo = GpuTopology::new(4);
2600
2601        // Set NVLink between GPU 0 and 1
2602        topo.set_connection(GpuConnection::new(0, 1, InterconnectType::NvLink));
2603
2604        let conn_01 = topo.get_connection(0, 1);
2605        assert!(conn_01.is_some());
2606        assert_eq!(conn_01.unwrap().interconnect, InterconnectType::NvLink);
2607
2608        // Bidirectional by default
2609        let conn_10 = topo.get_connection(1, 0);
2610        assert!(conn_10.is_some());
2611        assert_eq!(conn_10.unwrap().interconnect, InterconnectType::NvLink);
2612    }
2613
2614    #[test]
2615    fn test_gpu_topology_neighbors() {
2616        let mut topo = GpuTopology::new(4);
2617
2618        // Ring topology: 0-1-2-3-0
2619        topo.set_connection(GpuConnection::new(0, 1, InterconnectType::NvLink));
2620        topo.set_connection(GpuConnection::new(1, 2, InterconnectType::NvLink));
2621        topo.set_connection(GpuConnection::new(2, 3, InterconnectType::NvLink));
2622        topo.set_connection(GpuConnection::new(3, 0, InterconnectType::NvLink));
2623
2624        let neighbors_0 = topo.neighbors(0);
2625        assert_eq!(neighbors_0.len(), 2);
2626        assert!(neighbors_0.contains(&1));
2627        assert!(neighbors_0.contains(&3));
2628    }
2629
2630    #[test]
2631    fn test_gpu_topology_best_path() {
2632        let mut topo = GpuTopology::new(4);
2633
2634        // Create connections: 0-1, 1-2, 2-3 (no direct 0-3)
2635        topo.set_connection(GpuConnection::new(0, 1, InterconnectType::NvLink));
2636        topo.set_connection(GpuConnection::new(1, 2, InterconnectType::NvLink));
2637        topo.set_connection(GpuConnection::new(2, 3, InterconnectType::NvLink));
2638        topo.set_connection(GpuConnection::new(0, 3, InterconnectType::None)); // No direct P2P
2639
2640        // Direct path should work for adjacent nodes
2641        let path_01 = topo.best_path(0, 1);
2642        assert_eq!(path_01, vec![0, 1]);
2643
2644        // Same device
2645        let path_00 = topo.best_path(0, 0);
2646        assert_eq!(path_00, vec![0]);
2647    }
2648
2649    #[test]
2650    fn test_gpu_topology_fully_connected() {
2651        let mut topo = GpuTopology::new(3);
2652
2653        // Not fully connected initially
2654        assert!(!topo.is_fully_connected());
2655
2656        // Make fully connected mesh
2657        topo.set_connection(GpuConnection::new(0, 1, InterconnectType::NvLink));
2658        topo.set_connection(GpuConnection::new(0, 2, InterconnectType::NvLink));
2659        topo.set_connection(GpuConnection::new(1, 2, InterconnectType::NvLink));
2660
2661        assert!(topo.is_fully_connected());
2662    }
2663
2664    #[test]
2665    fn test_gpu_topology_numa() {
2666        let mut topo = GpuTopology::new(4);
2667
2668        // GPUs 0,1 on NUMA 0; GPUs 2,3 on NUMA 1
2669        topo.set_numa_node(0, 0);
2670        topo.set_numa_node(1, 0);
2671        topo.set_numa_node(2, 1);
2672        topo.set_numa_node(3, 1);
2673
2674        let numa_neighbors_0 = topo.numa_neighbors(0);
2675        assert_eq!(numa_neighbors_0, vec![1]);
2676
2677        let numa_neighbors_2 = topo.numa_neighbors(2);
2678        assert_eq!(numa_neighbors_2, vec![3]);
2679    }
2680
2681    // ========================================================================
2682    // Topology Discovery Tests
2683    // ========================================================================
2684
2685    #[test]
2686    fn test_coordinator_topology_discovery() {
2687        let coord = MultiGpuBuilder::new().enable_p2p(true).build();
2688
2689        // Register P2P capable devices
2690        let mut dev0 = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2691        dev0.p2p_capable = true;
2692        dev0.compute_capability = Some((8, 0)); // Ampere
2693
2694        let mut dev1 = DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda);
2695        dev1.p2p_capable = true;
2696        dev1.compute_capability = Some((8, 6)); // Ampere
2697
2698        coord.register_device(dev0);
2699        coord.register_device(dev1);
2700
2701        let topo = coord.discover_topology();
2702
2703        assert_eq!(topo.device_count, 2);
2704
2705        // Should detect NVLink for Ampere GPUs
2706        let conn = topo.get_connection(0, 1);
2707        assert!(conn.is_some());
2708        assert_eq!(conn.unwrap().interconnect, InterconnectType::NvLink);
2709    }
2710
2711    // ========================================================================
2712    // Migration Tests
2713    // ========================================================================
2714
2715    #[test]
2716    fn test_migration_request() {
2717        let coord = MultiGpuBuilder::new().build();
2718
2719        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2720        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2721
2722        let kernel_id = KernelId::new("migrating_kernel");
2723        coord.assign_kernel(kernel_id.clone(), 0);
2724
2725        let request = coord.request_migration(&kernel_id, 1).unwrap();
2726
2727        assert_eq!(request.source_device, 0);
2728        assert_eq!(request.target_device, 1);
2729        assert_eq!(request.state, MigrationState::Pending);
2730    }
2731
2732    #[test]
2733    fn test_migration_same_device_error() {
2734        let coord = MultiGpuBuilder::new().build();
2735
2736        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2737
2738        let kernel_id = KernelId::new("kernel");
2739        coord.assign_kernel(kernel_id.clone(), 0);
2740
2741        let result = coord.request_migration(&kernel_id, 0);
2742        assert!(result.is_err());
2743    }
2744
2745    #[test]
2746    fn test_migration_complete() {
2747        let coord = MultiGpuBuilder::new().build();
2748
2749        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2750        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2751
2752        let kernel_id = KernelId::new("migrating_kernel");
2753        coord.assign_kernel(kernel_id.clone(), 0);
2754
2755        assert_eq!(coord.get_kernel_device(&kernel_id), Some(0));
2756
2757        let request = coord.request_migration(&kernel_id, 1).unwrap();
2758        coord.complete_migration(&request).unwrap();
2759
2760        assert_eq!(coord.get_kernel_device(&kernel_id), Some(1));
2761    }
2762
2763    #[test]
2764    fn test_migration_transfer_time_estimate() {
2765        let request = MigrationRequest {
2766            kernel_id: KernelId::new("test"),
2767            source_device: 0,
2768            target_device: 1,
2769            path: vec![0, 1],
2770            estimated_bandwidth_gbps: 300.0, // NVLink
2771            estimated_latency_us: 1.0,
2772            state: MigrationState::Pending,
2773            started_at: None,
2774        };
2775
2776        // 1GB transfer at 300GB/s = ~3.3ms + 1us latency
2777        let time = request.estimate_transfer_time(1_000_000_000);
2778        assert!(time.as_micros() > 3000);
2779        assert!(time.as_micros() < 4000);
2780    }
2781
2782    // ========================================================================
2783    // Cross-GPU K2K Router Tests
2784    // ========================================================================
2785
2786    use crate::hlc::HlcTimestamp;
2787    use crate::message::MessageEnvelope;
2788
2789    fn make_test_k2k_message(source: &KernelId, dest: &KernelId) -> K2KMessage {
2790        let timestamp = HlcTimestamp::now(42);
2791        let envelope = MessageEnvelope::empty(1, 2, timestamp);
2792        K2KMessage::new(source.clone(), dest.clone(), envelope, timestamp)
2793    }
2794
2795    #[test]
2796    fn test_router_same_device() {
2797        let coord = MultiGpuBuilder::new().build();
2798        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2799
2800        let k1 = KernelId::new("k1");
2801        let k2 = KernelId::new("k2");
2802        coord.assign_kernel(k1.clone(), 0);
2803        coord.assign_kernel(k2.clone(), 0);
2804
2805        let router = CrossGpuK2KRouter::new(coord);
2806
2807        let msg = make_test_k2k_message(&k1, &k2);
2808        let decision = router.route_message(&k1, &k2, msg).unwrap();
2809
2810        matches!(decision, RoutingDecision::SameDevice);
2811    }
2812
2813    #[test]
2814    fn test_router_cross_device() {
2815        let coord = MultiGpuBuilder::new().enable_p2p(true).build();
2816
2817        let mut dev0 = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2818        dev0.p2p_capable = true;
2819        let mut dev1 = DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda);
2820        dev1.p2p_capable = true;
2821
2822        coord.register_device(dev0);
2823        coord.register_device(dev1);
2824
2825        let k1 = KernelId::new("k1");
2826        let k2 = KernelId::new("k2");
2827        coord.assign_kernel(k1.clone(), 0);
2828        coord.assign_kernel(k2.clone(), 1);
2829
2830        let router = CrossGpuK2KRouter::new(coord);
2831
2832        let msg = make_test_k2k_message(&k1, &k2);
2833        let decision = router.route_message(&k1, &k2, msg).unwrap();
2834
2835        match decision {
2836            RoutingDecision::DirectP2P {
2837                source_device,
2838                dest_device,
2839                ..
2840            } => {
2841                assert_eq!(source_device, 0);
2842                assert_eq!(dest_device, 1);
2843            }
2844            _ => panic!("Expected DirectP2P routing"),
2845        }
2846    }
2847
2848    #[test]
2849    fn test_router_pending_messages() {
2850        let coord = MultiGpuBuilder::new().enable_p2p(true).build();
2851
2852        let mut dev0 = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2853        dev0.p2p_capable = true;
2854        let mut dev1 = DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda);
2855        dev1.p2p_capable = true;
2856
2857        coord.register_device(dev0);
2858        coord.register_device(dev1);
2859
2860        let k1 = KernelId::new("k1");
2861        let k2 = KernelId::new("k2");
2862        coord.assign_kernel(k1.clone(), 0);
2863        coord.assign_kernel(k2.clone(), 1);
2864
2865        let router = CrossGpuK2KRouter::new(coord);
2866
2867        // Route 3 messages
2868        for _ in 0..3 {
2869            let msg = make_test_k2k_message(&k1, &k2);
2870            router.route_message(&k1, &k2, msg).unwrap();
2871        }
2872
2873        assert_eq!(router.stats().messages_pending, 3);
2874
2875        // Drain pending
2876        let pending = router.drain_pending(0, 1);
2877        assert_eq!(pending.len(), 3);
2878        assert_eq!(router.stats().messages_pending, 0);
2879    }
2880
2881    #[test]
2882    fn test_router_stats() {
2883        let coord = MultiGpuBuilder::new().build();
2884        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2885
2886        let k1 = KernelId::new("k1");
2887        let k2 = KernelId::new("k2");
2888        coord.assign_kernel(k1.clone(), 0);
2889        coord.assign_kernel(k2.clone(), 0);
2890
2891        let router = CrossGpuK2KRouter::new(coord);
2892
2893        let stats = router.stats();
2894        assert_eq!(stats.messages_routed, 0);
2895        assert_eq!(stats.bytes_transferred, 0);
2896        assert_eq!(stats.routing_failures, 0);
2897    }
2898
2899    // ========================================================================
2900    // Kernel Migrator Tests
2901    // ========================================================================
2902
2903    use crate::checkpoint::{Checkpoint, CheckpointBuilder};
2904
2905    /// Mock checkpointable kernel for testing.
2906    struct MockCheckpointableKernel {
2907        kernel_id: String,
2908        kernel_type: String,
2909        state_data: Vec<u8>,
2910        step: u64,
2911    }
2912
2913    impl MockCheckpointableKernel {
2914        fn new(kernel_id: &str, state_size: usize) -> Self {
2915            Self {
2916                kernel_id: kernel_id.to_string(),
2917                kernel_type: "mock_kernel".to_string(),
2918                state_data: vec![0xAB; state_size],
2919                step: 1000,
2920            }
2921        }
2922    }
2923
2924    impl CheckpointableKernel for MockCheckpointableKernel {
2925        fn create_checkpoint(&self) -> Result<Checkpoint> {
2926            let checkpoint = CheckpointBuilder::new(&self.kernel_id, &self.kernel_type)
2927                .step(self.step)
2928                .grid_size(64, 64, 64)
2929                .control_block(vec![1, 2, 3, 4])
2930                .device_memory("state", self.state_data.clone())
2931                .build();
2932            Ok(checkpoint)
2933        }
2934
2935        fn restore_from_checkpoint(&mut self, checkpoint: &Checkpoint) -> Result<()> {
2936            self.step = checkpoint.metadata.current_step;
2937            Ok(())
2938        }
2939
2940        fn checkpoint_kernel_id(&self) -> &str {
2941            &self.kernel_id
2942        }
2943
2944        fn checkpoint_kernel_type(&self) -> &str {
2945            &self.kernel_type
2946        }
2947    }
2948
2949    #[test]
2950    fn test_migrator_creation() {
2951        let coord = MultiGpuBuilder::new().build();
2952        let migrator = KernelMigrator::new(coord);
2953
2954        let stats = migrator.stats();
2955        assert_eq!(stats.successful_migrations, 0);
2956        assert_eq!(stats.failed_migrations, 0);
2957        assert_eq!(stats.bytes_transferred, 0);
2958    }
2959
2960    #[test]
2961    fn test_migrator_with_custom_storage() {
2962        let coord = MultiGpuBuilder::new().build();
2963        let storage = Arc::new(MemoryStorage::new());
2964        let migrator = KernelMigrator::with_storage(coord.clone(), storage);
2965
2966        // Verify we can access the coordinator
2967        assert!(Arc::ptr_eq(migrator.coordinator(), &coord));
2968    }
2969
2970    #[test]
2971    fn test_successful_migration() {
2972        let coord = MultiGpuBuilder::new().build();
2973
2974        // Register devices
2975        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2976        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2977
2978        // Assign kernel to device 0
2979        let kernel_id = KernelId::new("migratable_kernel");
2980        coord.assign_kernel(kernel_id.clone(), 0);
2981
2982        let migrator = KernelMigrator::new(coord.clone());
2983
2984        // Create mock kernel
2985        let kernel = MockCheckpointableKernel::new("migratable_kernel", 1024);
2986
2987        // Request migration
2988        let mut request = coord.request_migration(&kernel_id, 1).unwrap();
2989        assert_eq!(request.state, MigrationState::Pending);
2990
2991        // Perform migration
2992        let result = migrator
2993            .migrate_with_checkpoint(&kernel, &mut request)
2994            .unwrap();
2995
2996        // Verify result
2997        assert_eq!(result.kernel_id.as_str(), "migratable_kernel");
2998        assert_eq!(result.source_device, 0);
2999        assert_eq!(result.target_device, 1);
3000        assert!(result.checkpoint_size > 0);
3001        assert!(result.total_duration > Duration::ZERO);
3002
3003        // Verify kernel was moved
3004        assert_eq!(coord.get_kernel_device(&kernel_id), Some(1));
3005
3006        // Verify stats
3007        let stats = migrator.stats();
3008        assert_eq!(stats.successful_migrations, 1);
3009        assert_eq!(stats.failed_migrations, 0);
3010        assert!(stats.bytes_transferred > 0);
3011    }
3012
3013    #[test]
3014    fn test_migration_result_fields() {
3015        let coord = MultiGpuBuilder::new().build();
3016
3017        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3018        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
3019
3020        let kernel_id = KernelId::new("test_kernel");
3021        coord.assign_kernel(kernel_id.clone(), 0);
3022
3023        let migrator = KernelMigrator::new(coord.clone());
3024        let kernel = MockCheckpointableKernel::new("test_kernel", 4096);
3025        let mut request = coord.request_migration(&kernel_id, 1).unwrap();
3026
3027        let result = migrator
3028            .migrate_with_checkpoint(&kernel, &mut request)
3029            .unwrap();
3030
3031        // All durations should be non-negative
3032        assert!(result.checkpoint_duration >= Duration::ZERO);
3033        assert!(result.transfer_duration >= Duration::ZERO);
3034        assert!(result.restore_duration >= Duration::ZERO);
3035
3036        // Total should be >= sum of parts
3037        assert!(result.total_duration >= result.checkpoint_duration);
3038    }
3039
3040    #[test]
3041    fn test_migration_stats_accumulate() {
3042        let coord = MultiGpuBuilder::new().build();
3043
3044        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3045        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
3046
3047        let migrator = KernelMigrator::new(coord.clone());
3048
3049        // Migrate kernel 1: 0 -> 1
3050        let k1 = KernelId::new("k1");
3051        coord.assign_kernel(k1.clone(), 0);
3052        let kernel1 = MockCheckpointableKernel::new("k1", 1000);
3053        let mut req1 = coord.request_migration(&k1, 1).unwrap();
3054        migrator
3055            .migrate_with_checkpoint(&kernel1, &mut req1)
3056            .unwrap();
3057
3058        // Migrate kernel 2: 0 -> 1
3059        let k2 = KernelId::new("k2");
3060        coord.assign_kernel(k2.clone(), 0);
3061        let kernel2 = MockCheckpointableKernel::new("k2", 2000);
3062        let mut req2 = coord.request_migration(&k2, 1).unwrap();
3063        migrator
3064            .migrate_with_checkpoint(&kernel2, &mut req2)
3065            .unwrap();
3066
3067        let stats = migrator.stats();
3068        assert_eq!(stats.successful_migrations, 2);
3069        assert_eq!(stats.failed_migrations, 0);
3070        // Both checkpoints should have been transferred
3071        assert!(stats.bytes_transferred > 0);
3072    }
3073
3074    // ========================================================================
3075    // Device Unregister Tests
3076    // ========================================================================
3077
3078    #[test]
3079    fn test_unregister_device_no_kernels() {
3080        let coord = MultiGpuBuilder::new().build();
3081
3082        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3083        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
3084
3085        let result = coord.unregister_device(0);
3086
3087        assert!(result.success);
3088        assert_eq!(result.device_index, 0);
3089        assert!(result.kernels_to_migrate.is_empty());
3090        assert!(result.orphaned_kernels.is_empty());
3091    }
3092
3093    #[test]
3094    fn test_unregister_device_with_kernels() {
3095        let coord = MultiGpuBuilder::new().build();
3096
3097        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3098        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
3099
3100        // Assign kernels to device 0
3101        let k1 = KernelId::new("k1");
3102        let k2 = KernelId::new("k2");
3103        coord.assign_kernel(k1.clone(), 0);
3104        coord.assign_kernel(k2.clone(), 0);
3105
3106        let result = coord.unregister_device(0);
3107
3108        assert!(result.success);
3109        assert_eq!(result.kernels_to_migrate.len(), 2);
3110        assert!(result.orphaned_kernels.is_empty());
3111
3112        // All kernels should migrate to device 1
3113        for plan in &result.kernels_to_migrate {
3114            assert_eq!(plan.source_device, 0);
3115            assert_eq!(plan.target_device, 1);
3116        }
3117
3118        // Verify kernel mappings were updated
3119        assert_eq!(coord.get_kernel_device(&k1), Some(1));
3120        assert_eq!(coord.get_kernel_device(&k2), Some(1));
3121    }
3122
3123    #[test]
3124    fn test_unregister_single_device_orphans_kernels() {
3125        let coord = MultiGpuBuilder::new().build();
3126
3127        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3128
3129        // Assign kernels to device 0
3130        let k1 = KernelId::new("k1");
3131        coord.assign_kernel(k1.clone(), 0);
3132
3133        let result = coord.unregister_device(0);
3134
3135        assert!(result.success);
3136        assert!(result.kernels_to_migrate.is_empty());
3137        assert_eq!(result.orphaned_kernels.len(), 1);
3138        assert_eq!(result.orphaned_kernels[0], k1);
3139
3140        // Kernel should no longer have a device
3141        assert!(coord.get_kernel_device(&k1).is_none());
3142    }
3143
3144    #[test]
3145    fn test_unregister_nonexistent_device() {
3146        let coord = MultiGpuBuilder::new().build();
3147
3148        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3149
3150        let result = coord.unregister_device(99);
3151
3152        assert!(!result.success);
3153        assert_eq!(result.device_index, 99);
3154    }
3155
3156    #[test]
3157    fn test_unregister_distributes_to_least_loaded() {
3158        let coord = MultiGpuBuilder::new().build();
3159
3160        coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3161        coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
3162        coord.register_device(DeviceInfo::new(2, "GPU 2".to_string(), Backend::Cuda));
3163
3164        // Preload device 1 with kernels
3165        coord.assign_kernel(KernelId::new("pre1"), 1);
3166        coord.assign_kernel(KernelId::new("pre2"), 1);
3167        coord.assign_kernel(KernelId::new("pre3"), 1);
3168
3169        // Assign kernel to device 0
3170        let k1 = KernelId::new("migrate_me");
3171        coord.assign_kernel(k1.clone(), 0);
3172
3173        let result = coord.unregister_device(0);
3174
3175        assert!(result.success);
3176        assert_eq!(result.kernels_to_migrate.len(), 1);
3177
3178        // Should migrate to device 2 (least loaded)
3179        let plan = &result.kernels_to_migrate[0];
3180        assert_eq!(plan.target_device, 2);
3181    }
3182
3183    #[test]
3184    fn test_migration_priority_enum() {
3185        let low = MigrationPriority::Low;
3186        let normal = MigrationPriority::Normal;
3187        let high = MigrationPriority::High;
3188        let critical = MigrationPriority::Critical;
3189
3190        assert_ne!(low, normal);
3191        assert_ne!(normal, high);
3192        assert_ne!(high, critical);
3193        assert_eq!(low, MigrationPriority::Low);
3194    }
3195
3196    // Hot Reload Tests
3197
3198    #[test]
3199    fn test_hot_reload_config_default() {
3200        let config = HotReloadConfig::default();
3201        assert!(config.enabled);
3202        assert!(config.preserve_state);
3203        assert!(config.validate_before_swap);
3204        assert!(config.keep_fallback);
3205        assert_eq!(config.max_retries, 3);
3206    }
3207
3208    #[test]
3209    fn test_hot_reload_config_builder() {
3210        let config = HotReloadConfig::new()
3211            .with_enabled(false)
3212            .with_preserve_state(false)
3213            .with_max_retries(5)
3214            .with_timeout(Duration::from_secs(60));
3215
3216        assert!(!config.enabled);
3217        assert!(!config.preserve_state);
3218        assert_eq!(config.max_retries, 5);
3219        assert_eq!(config.reload_timeout, Duration::from_secs(60));
3220    }
3221
3222    #[test]
3223    fn test_kernel_code_source_ptx() {
3224        let ptx = ".version 7.0\n.target sm_80\nkernel: ret;";
3225        let code = KernelCodeSource::from_ptx(ptx, "kernel");
3226
3227        assert_eq!(code.format, KernelCodeFormat::Ptx);
3228        assert_eq!(code.entry_point, "kernel");
3229        assert_eq!(code.as_str(), Some(ptx));
3230        assert_eq!(code.size(), ptx.len());
3231    }
3232
3233    #[test]
3234    fn test_kernel_code_source_wgsl() {
3235        let wgsl = "@compute fn main() {}";
3236        let code = KernelCodeSource::from_wgsl(wgsl, "main");
3237
3238        assert_eq!(code.format, KernelCodeFormat::Wgsl);
3239        assert_eq!(code.entry_point, "main");
3240        assert_eq!(code.as_str(), Some(wgsl));
3241    }
3242
3243    #[test]
3244    fn test_kernel_code_source_msl() {
3245        let msl = "kernel void my_kernel() {}";
3246        let code = KernelCodeSource::from_msl(msl, "my_kernel");
3247
3248        assert_eq!(code.format, KernelCodeFormat::Msl);
3249        assert_eq!(code.entry_point, "my_kernel");
3250        assert_eq!(code.as_str(), Some(msl));
3251    }
3252
3253    #[test]
3254    fn test_hot_reload_manager_creation() {
3255        let manager = HotReloadManager::with_defaults();
3256        assert!(manager.is_enabled());
3257        assert!(manager.list_kernels().is_empty());
3258    }
3259
3260    #[test]
3261    fn test_hot_reload_manager_register_kernel() {
3262        let manager = HotReloadManager::with_defaults();
3263        let kernel_id = KernelId::new("test_kernel");
3264        let code = KernelCodeSource::from_ptx(".version 7.0", "kernel");
3265
3266        manager.register_kernel(&kernel_id, code);
3267
3268        assert!(manager.is_registered(&kernel_id));
3269        assert!(!manager.is_reload_in_progress(&kernel_id));
3270        assert!(manager.get_current_version(&kernel_id).is_some());
3271    }
3272
3273    #[test]
3274    fn test_hot_reload_request_states() {
3275        let kernel_id = KernelId::new("test");
3276        let code = KernelCodeSource::from_ptx(".version 7.0", "kernel");
3277        let request = HotReloadRequest::new(kernel_id, code);
3278
3279        assert_eq!(request.state, HotReloadState::Idle);
3280        assert!(!request.is_in_progress());
3281        assert!(!request.is_completed());
3282        assert!(!request.is_failed());
3283    }
3284
3285    #[test]
3286    fn test_hot_reload_disabled() {
3287        let config = HotReloadConfig::new().with_enabled(false);
3288        let manager = HotReloadManager::new(config);
3289        let kernel_id = KernelId::new("test");
3290        let code = KernelCodeSource::from_ptx(".version 7.0", "kernel");
3291
3292        manager.register_kernel(&kernel_id, code.clone());
3293        let result = manager.request_reload(&kernel_id, code);
3294        assert!(result.is_err());
3295    }
3296
3297    #[test]
3298    fn test_hot_reload_stats() {
3299        let manager = HotReloadManager::with_defaults();
3300        let stats = manager.stats();
3301
3302        assert_eq!(stats.successful_reloads, 0);
3303        assert_eq!(stats.failed_reloads, 0);
3304        assert_eq!(stats.rollbacks, 0);
3305    }
3306
3307    #[test]
3308    fn test_hot_reload_code_formats() {
3309        let formats = [
3310            KernelCodeFormat::Ptx,
3311            KernelCodeFormat::Cubin,
3312            KernelCodeFormat::SpirV,
3313            KernelCodeFormat::Wgsl,
3314            KernelCodeFormat::Msl,
3315            KernelCodeFormat::MetalLib,
3316            KernelCodeFormat::Source,
3317        ];
3318
3319        // Verify all formats are distinct
3320        for (i, f1) in formats.iter().enumerate() {
3321            for (j, f2) in formats.iter().enumerate() {
3322                if i != j {
3323                    assert_ne!(f1, f2);
3324                }
3325            }
3326        }
3327    }
3328
3329    #[test]
3330    fn test_hot_reload_state_transitions() {
3331        let states = [
3332            HotReloadState::Idle,
3333            HotReloadState::Draining,
3334            HotReloadState::Checkpointing,
3335            HotReloadState::Compiling,
3336            HotReloadState::Validating,
3337            HotReloadState::Swapping,
3338            HotReloadState::Restoring,
3339            HotReloadState::Completed,
3340            HotReloadState::Failed,
3341            HotReloadState::RollingBack,
3342        ];
3343
3344        // Verify all states are distinct
3345        for (i, s1) in states.iter().enumerate() {
3346            for (j, s2) in states.iter().enumerate() {
3347                if i != j {
3348                    assert_ne!(s1, s2);
3349                }
3350            }
3351        }
3352    }
3353
3354    #[test]
3355    fn test_hot_reload_execute() {
3356        let manager = HotReloadManager::with_defaults();
3357        let kernel_id = KernelId::new("test_kernel");
3358
3359        let initial_code = KernelCodeSource::from_ptx(".version 7.0\n.target sm_80", "kernel");
3360        manager.register_kernel(&kernel_id, initial_code);
3361
3362        let new_code = KernelCodeSource::from_ptx(".version 8.0\n.target sm_90", "kernel");
3363        let mut request = manager.request_reload(&kernel_id, new_code).unwrap();
3364
3365        // Create mock kernel for checkpoint
3366        let mock_kernel = MockCheckpointableKernel::new("test_kernel", 512);
3367
3368        let result = manager.execute_reload(&mut request, &mock_kernel).unwrap();
3369
3370        assert!(request.is_completed());
3371        assert_eq!(result.kernel_id.as_str(), "test_kernel");
3372        assert!(result.state_preserved);
3373        assert!(result.checkpoint_size > 0);
3374        assert!(result.total_duration > Duration::ZERO);
3375
3376        // Stats should be updated
3377        let stats = manager.stats();
3378        assert_eq!(stats.successful_reloads, 1);
3379    }
3380
3381    #[test]
3382    fn test_hot_reload_list_kernels() {
3383        let manager = HotReloadManager::with_defaults();
3384
3385        let k1 = KernelId::new("kernel1");
3386        let k2 = KernelId::new("kernel2");
3387        let k3 = KernelId::new("kernel3");
3388
3389        manager.register_kernel(&k1, KernelCodeSource::from_ptx(".version 7.0", "k1"));
3390        manager.register_kernel(&k2, KernelCodeSource::from_ptx(".version 7.0", "k2"));
3391        manager.register_kernel(&k3, KernelCodeSource::from_ptx(".version 7.0", "k3"));
3392
3393        let kernels = manager.list_kernels();
3394        assert_eq!(kernels.len(), 3);
3395        assert!(kernels.contains(&k1));
3396        assert!(kernels.contains(&k2));
3397        assert!(kernels.contains(&k3));
3398    }
3399}