1use parking_lot::RwLock;
30use std::collections::HashMap;
31use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
32use std::sync::Arc;
33use std::time::{Duration, Instant};
34
35use crate::error::{Result, RingKernelError};
36use crate::k2k::K2KMessage;
37use crate::runtime::{Backend, KernelId, LaunchOptions};
38
39#[derive(Debug, Clone)]
41pub struct MultiGpuConfig {
42 pub load_balancing: LoadBalancingStrategy,
44 pub auto_select_device: bool,
46 pub max_kernels_per_device: usize,
48 pub enable_p2p: bool,
50 pub preferred_devices: Vec<usize>,
52}
53
54impl Default for MultiGpuConfig {
55 fn default() -> Self {
56 Self {
57 load_balancing: LoadBalancingStrategy::LeastLoaded,
58 auto_select_device: true,
59 max_kernels_per_device: 64,
60 enable_p2p: true,
61 preferred_devices: vec![],
62 }
63 }
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub enum LoadBalancingStrategy {
69 FirstAvailable,
71 LeastLoaded,
73 RoundRobin,
75 MemoryBased,
77 ComputeCapability,
79 Custom,
81}
82
83#[derive(Debug, Clone)]
85pub struct DeviceInfo {
86 pub index: usize,
88 pub name: String,
90 pub backend: Backend,
92 pub total_memory: u64,
94 pub available_memory: u64,
96 pub compute_capability: Option<(u32, u32)>,
98 pub max_threads_per_block: u32,
100 pub multiprocessor_count: u32,
102 pub p2p_capable: bool,
104}
105
106impl DeviceInfo {
107 pub fn new(index: usize, name: String, backend: Backend) -> Self {
109 Self {
110 index,
111 name,
112 backend,
113 total_memory: 0,
114 available_memory: 0,
115 compute_capability: None,
116 max_threads_per_block: 1024,
117 multiprocessor_count: 1,
118 p2p_capable: false,
119 }
120 }
121
122 pub fn memory_utilization(&self) -> f64 {
124 if self.total_memory == 0 {
125 0.0
126 } else {
127 1.0 - (self.available_memory as f64 / self.total_memory as f64)
128 }
129 }
130}
131
132#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
138pub enum InterconnectType {
139 None,
141 Pcie,
143 NvLink,
145 NvSwitch,
147 InfinityFabric,
149 XeLink,
151 SameDevice,
153}
154
155impl InterconnectType {
156 pub fn estimated_bandwidth_gbps(&self) -> f64 {
158 match self {
159 InterconnectType::None => 16.0, InterconnectType::Pcie => 32.0, InterconnectType::NvLink => 300.0, InterconnectType::NvSwitch => 600.0, InterconnectType::InfinityFabric => 200.0, InterconnectType::XeLink => 100.0, InterconnectType::SameDevice => 2000.0, }
167 }
168
169 pub fn estimated_latency_us(&self) -> f64 {
171 match self {
172 InterconnectType::None => 10.0, InterconnectType::Pcie => 5.0, InterconnectType::NvLink => 1.0, InterconnectType::NvSwitch => 2.0, InterconnectType::InfinityFabric => 1.5,
177 InterconnectType::XeLink => 2.0,
178 InterconnectType::SameDevice => 0.0,
179 }
180 }
181
182 pub fn supports_p2p(&self) -> bool {
184 !matches!(self, InterconnectType::None)
185 }
186}
187
188#[derive(Debug, Clone)]
190pub struct GpuConnection {
191 pub source: usize,
193 pub destination: usize,
195 pub interconnect: InterconnectType,
197 pub bandwidth_gbps: f64,
199 pub latency_us: f64,
201 pub bidirectional: bool,
203 pub hops: u32,
205}
206
207impl GpuConnection {
208 pub fn new(source: usize, destination: usize, interconnect: InterconnectType) -> Self {
210 Self {
211 source,
212 destination,
213 interconnect,
214 bandwidth_gbps: interconnect.estimated_bandwidth_gbps(),
215 latency_us: interconnect.estimated_latency_us(),
216 bidirectional: true,
217 hops: if source == destination { 0 } else { 1 },
218 }
219 }
220
221 pub fn with_bandwidth(mut self, gbps: f64) -> Self {
223 self.bandwidth_gbps = gbps;
224 self
225 }
226
227 pub fn with_latency(mut self, us: f64) -> Self {
229 self.latency_us = us;
230 self
231 }
232
233 pub fn with_hops(mut self, hops: u32) -> Self {
235 self.hops = hops;
236 self
237 }
238}
239
240#[derive(Debug, Clone)]
242pub struct GpuTopology {
243 pub device_count: usize,
245 connections: Vec<Vec<Option<GpuConnection>>>,
247 pub numa_nodes: Vec<Option<u32>>,
249 pub probed: bool,
251 pub last_updated: Instant,
253}
254
255impl GpuTopology {
256 pub fn new(device_count: usize) -> Self {
258 let mut connections = vec![vec![None; device_count]; device_count];
259
260 for (i, row) in connections.iter_mut().enumerate().take(device_count) {
262 row[i] = Some(GpuConnection::new(i, i, InterconnectType::SameDevice));
263 }
264
265 Self {
266 device_count,
267 connections,
268 numa_nodes: vec![None; device_count],
269 probed: false,
270 last_updated: Instant::now(),
271 }
272 }
273
274 pub fn set_connection(&mut self, connection: GpuConnection) {
276 let src = connection.source;
277 let dst = connection.destination;
278 if src < self.device_count && dst < self.device_count {
279 self.connections[src][dst] = Some(connection.clone());
280 if connection.bidirectional && src != dst {
281 let reverse = GpuConnection {
282 source: dst,
283 destination: src,
284 ..connection
285 };
286 self.connections[dst][src] = Some(reverse);
287 }
288 }
289 }
290
291 pub fn get_connection(&self, source: usize, destination: usize) -> Option<&GpuConnection> {
293 self.connections
294 .get(source)
295 .and_then(|row| row.get(destination))
296 .and_then(|c| c.as_ref())
297 }
298
299 pub fn best_path(&self, source: usize, destination: usize) -> Vec<usize> {
301 if source == destination {
302 return vec![source];
303 }
304
305 if let Some(conn) = self.get_connection(source, destination) {
307 if conn.interconnect != InterconnectType::None {
308 return vec![source, destination];
309 }
310 }
311
312 let mut best_path = vec![source, destination]; let mut best_bandwidth = 0.0;
315
316 for intermediate in 0..self.device_count {
318 if intermediate == source || intermediate == destination {
319 continue;
320 }
321
322 if let (Some(c1), Some(c2)) = (
323 self.get_connection(source, intermediate),
324 self.get_connection(intermediate, destination),
325 ) {
326 let path_bandwidth = c1.bandwidth_gbps.min(c2.bandwidth_gbps);
328 if path_bandwidth > best_bandwidth {
329 best_bandwidth = path_bandwidth;
330 best_path = vec![source, intermediate, destination];
331 }
332 }
333 }
334
335 best_path
336 }
337
338 pub fn neighbors(&self, device: usize) -> Vec<usize> {
340 if device >= self.device_count {
341 return vec![];
342 }
343
344 self.connections[device]
345 .iter()
346 .enumerate()
347 .filter_map(|(i, conn)| {
348 if i != device
349 && conn
350 .as_ref()
351 .map(|c| c.interconnect.supports_p2p())
352 .unwrap_or(false)
353 {
354 Some(i)
355 } else {
356 None
357 }
358 })
359 .collect()
360 }
361
362 pub fn bisection_bandwidth_gbps(&self) -> f64 {
364 let half = self.device_count / 2;
365 if half == 0 {
366 return 0.0;
367 }
368
369 let mut total = 0.0;
370 for src in 0..half {
371 for dst in half..self.device_count {
372 if let Some(conn) = self.get_connection(src, dst) {
373 total += conn.bandwidth_gbps;
374 }
375 }
376 }
377 total
378 }
379
380 pub fn is_fully_connected(&self) -> bool {
382 for src in 0..self.device_count {
383 for dst in 0..self.device_count {
384 if src != dst {
385 if let Some(conn) = self.get_connection(src, dst) {
386 if !conn.interconnect.supports_p2p() {
387 return false;
388 }
389 } else {
390 return false;
391 }
392 }
393 }
394 }
395 true
396 }
397
398 pub fn numa_neighbors(&self, device: usize) -> Vec<usize> {
400 let target_numa = self.numa_nodes.get(device).copied().flatten();
401 if target_numa.is_none() {
402 return vec![];
403 }
404
405 self.numa_nodes
406 .iter()
407 .enumerate()
408 .filter_map(|(i, numa)| {
409 if i != device && *numa == target_numa {
410 Some(i)
411 } else {
412 None
413 }
414 })
415 .collect()
416 }
417
418 pub fn set_numa_node(&mut self, device: usize, numa_node: u32) {
420 if device < self.numa_nodes.len() {
421 self.numa_nodes[device] = Some(numa_node);
422 }
423 }
424
425 pub fn mark_probed(&mut self) {
427 self.probed = true;
428 self.last_updated = Instant::now();
429 }
430}
431
432#[derive(Debug, Clone)]
434pub struct DeviceStatus {
435 pub info: DeviceInfo,
437 pub kernel_count: usize,
439 pub kernels: Vec<KernelId>,
441 pub available: bool,
443 pub load: f64,
445}
446
447#[derive(Debug, Clone)]
449pub struct DeviceUnregisterResult {
450 pub device_index: usize,
452 pub kernels_to_migrate: Vec<KernelMigrationPlan>,
454 pub orphaned_kernels: Vec<KernelId>,
456 pub success: bool,
458}
459
460#[derive(Debug, Clone)]
462pub struct KernelMigrationPlan {
463 pub kernel_id: KernelId,
465 pub source_device: usize,
467 pub target_device: usize,
469 pub priority: MigrationPriority,
471}
472
473#[derive(Debug, Clone, Copy, PartialEq, Eq)]
475pub enum MigrationPriority {
476 Low,
478 Normal,
480 High,
482 Critical,
484}
485
486pub struct MultiGpuCoordinator {
488 config: MultiGpuConfig,
490 devices: RwLock<Vec<DeviceInfo>>,
492 kernel_device_map: RwLock<HashMap<KernelId, usize>>,
494 device_kernel_counts: RwLock<Vec<AtomicUsize>>,
496 round_robin_counter: AtomicUsize,
498 total_kernels: AtomicU64,
500 #[allow(clippy::type_complexity)]
502 custom_selector:
503 RwLock<Option<Arc<dyn Fn(&[DeviceStatus], &LaunchOptions) -> usize + Send + Sync>>>,
504 topology: RwLock<Option<GpuTopology>>,
506}
507
508impl MultiGpuCoordinator {
509 pub fn new(config: MultiGpuConfig) -> Arc<Self> {
511 Arc::new(Self {
512 config,
513 devices: RwLock::new(Vec::new()),
514 kernel_device_map: RwLock::new(HashMap::new()),
515 device_kernel_counts: RwLock::new(Vec::new()),
516 round_robin_counter: AtomicUsize::new(0),
517 total_kernels: AtomicU64::new(0),
518 custom_selector: RwLock::new(None),
519 topology: RwLock::new(None),
520 })
521 }
522
523 pub fn register_device(&self, device: DeviceInfo) {
525 let index = device.index;
526 let mut devices = self.devices.write();
527 let mut counts = self.device_kernel_counts.write();
528
529 let mut current_len = devices.len();
531 while current_len <= index {
532 devices.push(DeviceInfo::new(
533 current_len,
534 "Unknown".to_string(),
535 Backend::Cpu,
536 ));
537 counts.push(AtomicUsize::new(0));
538 current_len += 1;
539 }
540
541 devices[index] = device;
542 }
543
544 pub fn unregister_device(&self, index: usize) -> DeviceUnregisterResult {
556 let devices = self.devices.read();
557
558 if index >= devices.len() {
560 return DeviceUnregisterResult {
561 device_index: index,
562 kernels_to_migrate: Vec::new(),
563 orphaned_kernels: Vec::new(),
564 success: false,
565 };
566 }
567
568 let kernels_on_device = self.kernels_on_device(index);
570
571 let available_targets: Vec<usize> = devices
573 .iter()
574 .enumerate()
575 .filter(|(i, _)| *i != index)
576 .map(|(i, _)| i)
577 .collect();
578
579 drop(devices); let mut kernels_to_migrate = Vec::new();
582 let mut orphaned_kernels = Vec::new();
583
584 if available_targets.is_empty() {
585 orphaned_kernels = kernels_on_device;
587 } else {
588 for kernel_id in kernels_on_device {
590 if let Some(target) = self.select_migration_target(&available_targets) {
592 let priority = self.calculate_migration_priority(&kernel_id);
593 kernels_to_migrate.push(KernelMigrationPlan {
594 kernel_id,
595 source_device: index,
596 target_device: target,
597 priority,
598 });
599 } else {
600 orphaned_kernels.push(kernel_id);
601 }
602 }
603 }
604
605 {
607 let mut kernel_map = self.kernel_device_map.write();
608 let counts = self.device_kernel_counts.read();
609
610 for plan in &kernels_to_migrate {
611 kernel_map.insert(plan.kernel_id.clone(), plan.target_device);
613
614 if index < counts.len() {
616 counts[index].fetch_sub(1, Ordering::Relaxed);
617 }
618 if plan.target_device < counts.len() {
619 counts[plan.target_device].fetch_add(1, Ordering::Relaxed);
620 }
621 }
622
623 for kernel_id in &orphaned_kernels {
625 kernel_map.remove(kernel_id);
626 if index < counts.len() {
627 counts[index].fetch_sub(1, Ordering::Relaxed);
628 }
629 }
630 }
631
632 {
634 let mut devices = self.devices.write();
635 if index < devices.len() {
636 devices[index].available_memory = 0;
637 devices[index].name = format!("{} (unregistered)", devices[index].name);
638 }
639 }
640
641 DeviceUnregisterResult {
642 device_index: index,
643 kernels_to_migrate,
644 orphaned_kernels,
645 success: true,
646 }
647 }
648
649 fn select_migration_target(&self, candidates: &[usize]) -> Option<usize> {
651 if candidates.is_empty() {
652 return None;
653 }
654
655 let counts = self.device_kernel_counts.read();
656
657 candidates
659 .iter()
660 .filter_map(|&idx| {
661 if idx < counts.len() {
662 Some((idx, counts[idx].load(Ordering::Relaxed)))
663 } else {
664 None
665 }
666 })
667 .min_by_key(|(_, count)| *count)
668 .map(|(idx, _)| idx)
669 }
670
671 fn calculate_migration_priority(&self, _kernel_id: &KernelId) -> MigrationPriority {
673 MigrationPriority::Normal
679 }
680
681 pub fn devices(&self) -> Vec<DeviceInfo> {
683 self.devices.read().clone()
684 }
685
686 pub fn device(&self, index: usize) -> Option<DeviceInfo> {
688 self.devices.read().get(index).cloned()
689 }
690
691 pub fn device_count(&self) -> usize {
693 self.devices.read().len()
694 }
695
696 pub fn select_device(&self, options: &LaunchOptions) -> Result<usize> {
698 let devices = self.devices.read();
699 if devices.is_empty() {
700 return Err(RingKernelError::BackendUnavailable(
701 "No GPU devices available".to_string(),
702 ));
703 }
704
705 let status = self.get_all_status();
707
708 if self.config.load_balancing == LoadBalancingStrategy::Custom {
710 if let Some(selector) = &*self.custom_selector.read() {
711 return Ok(selector(&status, options));
712 }
713 }
714
715 let candidates: Vec<_> = if !self.config.preferred_devices.is_empty() {
717 status
718 .into_iter()
719 .filter(|s| self.config.preferred_devices.contains(&s.info.index))
720 .collect()
721 } else {
722 status
723 };
724
725 if candidates.is_empty() {
726 return Err(RingKernelError::BackendUnavailable(
727 "No suitable GPU device available".to_string(),
728 ));
729 }
730
731 let selected = match self.config.load_balancing {
732 LoadBalancingStrategy::FirstAvailable => {
733 candidates.first().map(|s| s.info.index).unwrap_or(0)
734 }
735 LoadBalancingStrategy::LeastLoaded => candidates
736 .iter()
737 .filter(|s| s.available && s.kernel_count < self.config.max_kernels_per_device)
738 .min_by(|a, b| a.kernel_count.cmp(&b.kernel_count))
739 .map(|s| s.info.index)
740 .unwrap_or(0),
741 LoadBalancingStrategy::RoundRobin => {
742 let available: Vec<_> = candidates.iter().filter(|s| s.available).collect();
743
744 if available.is_empty() {
745 candidates.first().map(|s| s.info.index).unwrap_or(0)
746 } else {
747 let idx =
748 self.round_robin_counter.fetch_add(1, Ordering::Relaxed) % available.len();
749 available[idx].info.index
750 }
751 }
752 LoadBalancingStrategy::MemoryBased => candidates
753 .iter()
754 .filter(|s| s.available)
755 .max_by(|a, b| a.info.available_memory.cmp(&b.info.available_memory))
756 .map(|s| s.info.index)
757 .unwrap_or(0),
758 LoadBalancingStrategy::ComputeCapability => candidates
759 .iter()
760 .filter(|s| s.available)
761 .max_by(|a, b| {
762 let a_cap = a.info.compute_capability.unwrap_or((0, 0));
763 let b_cap = b.info.compute_capability.unwrap_or((0, 0));
764 a_cap.cmp(&b_cap)
765 })
766 .map(|s| s.info.index)
767 .unwrap_or(0),
768 LoadBalancingStrategy::Custom => {
769 0
771 }
772 };
773
774 Ok(selected)
775 }
776
777 pub fn assign_kernel(&self, kernel_id: KernelId, device_index: usize) {
779 self.kernel_device_map
780 .write()
781 .insert(kernel_id, device_index);
782
783 let counts = self.device_kernel_counts.read();
784 if device_index < counts.len() {
785 counts[device_index].fetch_add(1, Ordering::Relaxed);
786 }
787
788 self.total_kernels.fetch_add(1, Ordering::Relaxed);
789 }
790
791 pub fn remove_kernel(&self, kernel_id: &KernelId) {
793 if let Some(device_index) = self.kernel_device_map.write().remove(kernel_id) {
794 let counts = self.device_kernel_counts.read();
795 if device_index < counts.len() {
796 counts[device_index].fetch_sub(1, Ordering::Relaxed);
797 }
798 }
799 }
800
801 pub fn get_kernel_device(&self, kernel_id: &KernelId) -> Option<usize> {
803 self.kernel_device_map.read().get(kernel_id).copied()
804 }
805
806 pub fn kernels_on_device(&self, device_index: usize) -> Vec<KernelId> {
808 self.kernel_device_map
809 .read()
810 .iter()
811 .filter(|(_, &idx)| idx == device_index)
812 .map(|(k, _)| k.clone())
813 .collect()
814 }
815
816 pub fn get_all_status(&self) -> Vec<DeviceStatus> {
818 let devices = self.devices.read();
819 let kernel_map = self.kernel_device_map.read();
820 let counts = self.device_kernel_counts.read();
821
822 devices
823 .iter()
824 .enumerate()
825 .map(|(idx, info)| {
826 let kernel_count = if idx < counts.len() {
827 counts[idx].load(Ordering::Relaxed)
828 } else {
829 0
830 };
831
832 let kernels: Vec<_> = kernel_map
833 .iter()
834 .filter(|(_, &dev_idx)| dev_idx == idx)
835 .map(|(k, _)| k.clone())
836 .collect();
837
838 let load = kernel_count as f64 / self.config.max_kernels_per_device as f64;
839 let available = kernel_count < self.config.max_kernels_per_device;
840
841 DeviceStatus {
842 info: info.clone(),
843 kernel_count,
844 kernels,
845 available,
846 load,
847 }
848 })
849 .collect()
850 }
851
852 pub fn get_device_status(&self, device_index: usize) -> Option<DeviceStatus> {
854 self.get_all_status().into_iter().nth(device_index)
855 }
856
857 pub fn set_custom_selector<F>(&self, selector: F)
859 where
860 F: Fn(&[DeviceStatus], &LaunchOptions) -> usize + Send + Sync + 'static,
861 {
862 *self.custom_selector.write() = Some(Arc::new(selector));
863 }
864
865 pub fn stats(&self) -> MultiGpuStats {
867 let status = self.get_all_status();
868 let total_kernels: usize = status.iter().map(|s| s.kernel_count).sum();
869 let total_memory: u64 = status.iter().map(|s| s.info.total_memory).sum();
870 let available_memory: u64 = status.iter().map(|s| s.info.available_memory).sum();
871
872 MultiGpuStats {
873 device_count: status.len(),
874 total_kernels,
875 total_memory,
876 available_memory,
877 kernels_launched: self.total_kernels.load(Ordering::Relaxed),
878 }
879 }
880
881 pub fn can_p2p(&self, device_a: usize, device_b: usize) -> bool {
883 if !self.config.enable_p2p {
884 return false;
885 }
886
887 let devices = self.devices.read();
888 if let (Some(a), Some(b)) = (devices.get(device_a), devices.get(device_b)) {
889 a.p2p_capable && b.p2p_capable
890 } else {
891 false
892 }
893 }
894
895 pub fn update_device_memory(&self, device_index: usize, available_memory: u64) {
897 let mut devices = self.devices.write();
898 if let Some(device) = devices.get_mut(device_index) {
899 device.available_memory = available_memory;
900 }
901 }
902
903 pub fn discover_topology(&self) -> GpuTopology {
909 let devices = self.devices.read();
910 let device_count = devices.len();
911
912 if device_count == 0 {
913 return GpuTopology::new(0);
914 }
915
916 let mut topo = GpuTopology::new(device_count);
917
918 for (i, dev_i) in devices.iter().enumerate() {
920 for (j, dev_j) in devices.iter().enumerate() {
921 if i == j {
922 continue;
923 }
924
925 let interconnect = if dev_i.p2p_capable && dev_j.p2p_capable {
927 if dev_i.backend == dev_j.backend {
929 match dev_i.backend {
930 Backend::Cuda => {
931 let cc_i = dev_i.compute_capability.unwrap_or((0, 0));
933 let cc_j = dev_j.compute_capability.unwrap_or((0, 0));
934
935 if cc_i.0 >= 8 && cc_j.0 >= 8 {
937 InterconnectType::NvLink
938 } else {
939 InterconnectType::Pcie
940 }
941 }
942 _ => InterconnectType::Pcie,
943 }
944 } else {
945 InterconnectType::None
946 }
947 } else {
948 InterconnectType::None
949 };
950
951 topo.set_connection(GpuConnection::new(i, j, interconnect));
952 }
953 }
954
955 *self.topology.write() = Some(topo.clone());
957
958 topo
959 }
960
961 pub fn topology(&self) -> GpuTopology {
963 {
964 let topo = self.topology.read();
965 if let Some(ref t) = *topo {
966 return t.clone();
967 }
968 }
969 self.discover_topology()
970 }
971
972 pub fn set_topology(&self, topology: GpuTopology) {
974 *self.topology.write() = Some(topology);
975 }
976
977 pub fn select_device_for_k2k(&self, source_kernel: &KernelId) -> Result<usize> {
979 let source_device = self.get_kernel_device(source_kernel);
980 if source_device.is_none() {
981 return self.select_device(&LaunchOptions::default());
982 }
983
984 let source_idx = source_device.unwrap();
985 let topo = self.topology();
986 let status = self.get_all_status();
987
988 let neighbors = topo.neighbors(source_idx);
990
991 if neighbors.is_empty() {
992 return self.select_device(&LaunchOptions::default());
994 }
995
996 let best = neighbors
998 .iter()
999 .filter_map(|&dev_idx| {
1000 status.iter().find(|s| s.info.index == dev_idx).map(|s| {
1001 let conn = topo.get_connection(source_idx, dev_idx);
1002 let bandwidth = conn.map(|c| c.bandwidth_gbps).unwrap_or(1.0);
1003 let score = bandwidth / (s.load + 0.1);
1004 (dev_idx, score)
1005 })
1006 })
1007 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
1008 .map(|(idx, _)| idx);
1009
1010 best.ok_or_else(|| {
1011 RingKernelError::BackendUnavailable("No suitable K2K device found".to_string())
1012 })
1013 }
1014
1015 pub fn request_migration(
1021 &self,
1022 kernel_id: &KernelId,
1023 target_device: usize,
1024 ) -> Result<MigrationRequest> {
1025 let source_device = self
1026 .get_kernel_device(kernel_id)
1027 .ok_or_else(|| RingKernelError::KernelNotFound(kernel_id.as_str().to_string()))?;
1028
1029 if source_device == target_device {
1030 return Err(RingKernelError::InvalidConfig(
1031 "Cannot migrate to same device".to_string(),
1032 ));
1033 }
1034
1035 let devices = self.devices.read();
1036 if target_device >= devices.len() {
1037 return Err(RingKernelError::DeviceNotAvailable(format!(
1038 "Device {} not available",
1039 target_device
1040 )));
1041 }
1042
1043 let topo = self.topology();
1044 let path = topo.best_path(source_device, target_device);
1045 let connection = topo.get_connection(source_device, target_device);
1046
1047 Ok(MigrationRequest {
1048 kernel_id: kernel_id.clone(),
1049 source_device,
1050 target_device,
1051 path,
1052 estimated_bandwidth_gbps: connection.map(|c| c.bandwidth_gbps).unwrap_or(16.0),
1053 estimated_latency_us: connection.map(|c| c.latency_us).unwrap_or(10.0),
1054 state: MigrationState::Pending,
1055 started_at: None,
1056 })
1057 }
1058
1059 pub fn complete_migration(&self, request: &MigrationRequest) -> Result<()> {
1061 {
1063 let mut map = self.kernel_device_map.write();
1064 if let Some(dev) = map.get_mut(&request.kernel_id) {
1065 *dev = request.target_device;
1066 }
1067 }
1068
1069 {
1071 let counts = self.device_kernel_counts.read();
1072 if request.source_device < counts.len() {
1073 counts[request.source_device].fetch_sub(1, Ordering::Relaxed);
1074 }
1075 if request.target_device < counts.len() {
1076 counts[request.target_device].fetch_add(1, Ordering::Relaxed);
1077 }
1078 }
1079
1080 Ok(())
1081 }
1082}
1083
1084#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1090pub enum MigrationState {
1091 Pending,
1093 Quiescing,
1095 Checkpointing,
1097 Transferring,
1099 Restoring,
1101 Completed,
1103 Failed,
1105 Cancelled,
1107}
1108
1109#[derive(Debug, Clone)]
1111pub struct MigrationRequest {
1112 pub kernel_id: KernelId,
1114 pub source_device: usize,
1116 pub target_device: usize,
1118 pub path: Vec<usize>,
1120 pub estimated_bandwidth_gbps: f64,
1122 pub estimated_latency_us: f64,
1124 pub state: MigrationState,
1126 pub started_at: Option<Instant>,
1128}
1129
1130impl MigrationRequest {
1131 pub fn estimate_transfer_time(&self, state_size_bytes: usize) -> Duration {
1133 let size_gb = state_size_bytes as f64 / 1_000_000_000.0;
1135 let transfer_time_s = size_gb / self.estimated_bandwidth_gbps;
1136 let total_us = (transfer_time_s * 1_000_000.0) + self.estimated_latency_us;
1137 Duration::from_micros(total_us as u64)
1138 }
1139}
1140
1141pub struct CrossGpuK2KRouter {
1147 coordinator: Arc<MultiGpuCoordinator>,
1149 pending_queues: RwLock<HashMap<(usize, usize), Vec<PendingK2KMessage>>>,
1151 stats: CrossGpuRouterStats,
1153}
1154
1155#[derive(Debug, Clone)]
1157pub struct PendingK2KMessage {
1158 pub source_kernel: KernelId,
1160 pub dest_kernel: KernelId,
1162 pub message: K2KMessage,
1164 pub queued_at: Instant,
1166 pub hops: u32,
1168}
1169
1170#[derive(Debug, Default)]
1172pub struct CrossGpuRouterStats {
1173 messages_routed: AtomicU64,
1175 bytes_transferred: AtomicU64,
1177 messages_pending: AtomicUsize,
1179 total_latency_us: AtomicU64,
1181 routing_failures: AtomicU64,
1183}
1184
1185impl CrossGpuK2KRouter {
1186 pub fn new(coordinator: Arc<MultiGpuCoordinator>) -> Arc<Self> {
1188 Arc::new(Self {
1189 coordinator,
1190 pending_queues: RwLock::new(HashMap::new()),
1191 stats: CrossGpuRouterStats::default(),
1192 })
1193 }
1194
1195 pub fn route_message(
1197 &self,
1198 source_kernel: &KernelId,
1199 dest_kernel: &KernelId,
1200 message: K2KMessage,
1201 ) -> Result<RoutingDecision> {
1202 let source_device = self
1203 .coordinator
1204 .get_kernel_device(source_kernel)
1205 .ok_or_else(|| {
1206 RingKernelError::K2KDestinationNotFound(source_kernel.as_str().to_string())
1207 })?;
1208
1209 let dest_device = self
1210 .coordinator
1211 .get_kernel_device(dest_kernel)
1212 .ok_or_else(|| {
1213 RingKernelError::K2KDestinationNotFound(dest_kernel.as_str().to_string())
1214 })?;
1215
1216 if source_device == dest_device {
1218 return Ok(RoutingDecision::SameDevice);
1219 }
1220
1221 let topo = self.coordinator.topology();
1223 let path = topo.best_path(source_device, dest_device);
1224
1225 if let Some(conn) = topo.get_connection(source_device, dest_device) {
1227 if conn.interconnect.supports_p2p() {
1228 let pending = PendingK2KMessage {
1230 source_kernel: source_kernel.clone(),
1231 dest_kernel: dest_kernel.clone(),
1232 message,
1233 queued_at: Instant::now(),
1234 hops: 1,
1235 };
1236
1237 self.enqueue_pending(source_device, dest_device, pending);
1238 self.stats.messages_pending.fetch_add(1, Ordering::Relaxed);
1239
1240 return Ok(RoutingDecision::DirectP2P {
1241 source_device,
1242 dest_device,
1243 bandwidth_gbps: conn.bandwidth_gbps,
1244 });
1245 }
1246 }
1247
1248 if path.len() > 2 {
1250 let pending = PendingK2KMessage {
1251 source_kernel: source_kernel.clone(),
1252 dest_kernel: dest_kernel.clone(),
1253 message,
1254 queued_at: Instant::now(),
1255 hops: (path.len() - 1) as u32,
1256 };
1257
1258 self.enqueue_pending(source_device, path[1], pending);
1260 self.stats.messages_pending.fetch_add(1, Ordering::Relaxed);
1261
1262 return Ok(RoutingDecision::MultiHop {
1263 path: path.clone(),
1264 total_hops: (path.len() - 1) as u32,
1265 });
1266 }
1267
1268 let pending = PendingK2KMessage {
1270 source_kernel: source_kernel.clone(),
1271 dest_kernel: dest_kernel.clone(),
1272 message,
1273 queued_at: Instant::now(),
1274 hops: 2, };
1276
1277 self.enqueue_pending(source_device, dest_device, pending);
1278 self.stats.messages_pending.fetch_add(1, Ordering::Relaxed);
1279
1280 Ok(RoutingDecision::HostMediated {
1281 source_device,
1282 dest_device,
1283 })
1284 }
1285
1286 pub fn drain_pending(&self, source: usize, dest: usize) -> Vec<PendingK2KMessage> {
1288 let mut queues = self.pending_queues.write();
1289 let messages = queues.remove(&(source, dest)).unwrap_or_default();
1290 self.stats
1291 .messages_pending
1292 .fetch_sub(messages.len(), Ordering::Relaxed);
1293 messages
1294 }
1295
1296 pub fn record_delivery(&self, message: &PendingK2KMessage, payload_size: usize) {
1298 self.stats.messages_routed.fetch_add(1, Ordering::Relaxed);
1299 self.stats
1300 .bytes_transferred
1301 .fetch_add(payload_size as u64, Ordering::Relaxed);
1302
1303 let latency = message.queued_at.elapsed().as_micros() as u64;
1304 self.stats
1305 .total_latency_us
1306 .fetch_add(latency, Ordering::Relaxed);
1307 }
1308
1309 pub fn record_failure(&self) {
1311 self.stats.routing_failures.fetch_add(1, Ordering::Relaxed);
1312 }
1313
1314 pub fn stats(&self) -> CrossGpuRouterStatsSnapshot {
1316 let messages_routed = self.stats.messages_routed.load(Ordering::Relaxed);
1317 let total_latency = self.stats.total_latency_us.load(Ordering::Relaxed);
1318
1319 CrossGpuRouterStatsSnapshot {
1320 messages_routed,
1321 bytes_transferred: self.stats.bytes_transferred.load(Ordering::Relaxed),
1322 messages_pending: self.stats.messages_pending.load(Ordering::Relaxed),
1323 avg_latency_us: if messages_routed > 0 {
1324 total_latency as f64 / messages_routed as f64
1325 } else {
1326 0.0
1327 },
1328 routing_failures: self.stats.routing_failures.load(Ordering::Relaxed),
1329 }
1330 }
1331
1332 fn enqueue_pending(&self, source: usize, dest: usize, message: PendingK2KMessage) {
1333 let mut queues = self.pending_queues.write();
1334 queues.entry((source, dest)).or_default().push(message);
1335 }
1336}
1337
1338#[derive(Debug, Clone)]
1340pub struct CrossGpuRouterStatsSnapshot {
1341 pub messages_routed: u64,
1343 pub bytes_transferred: u64,
1345 pub messages_pending: usize,
1347 pub avg_latency_us: f64,
1349 pub routing_failures: u64,
1351}
1352
1353#[derive(Debug, Clone)]
1355pub enum RoutingDecision {
1356 SameDevice,
1358 DirectP2P {
1360 source_device: usize,
1362 dest_device: usize,
1364 bandwidth_gbps: f64,
1366 },
1367 MultiHop {
1369 path: Vec<usize>,
1371 total_hops: u32,
1373 },
1374 HostMediated {
1376 source_device: usize,
1378 dest_device: usize,
1380 },
1381}
1382
1383#[derive(Debug, Clone, Default)]
1385pub struct MultiGpuStats {
1386 pub device_count: usize,
1388 pub total_kernels: usize,
1390 pub total_memory: u64,
1392 pub available_memory: u64,
1394 pub kernels_launched: u64,
1396}
1397
1398pub struct MultiGpuBuilder {
1400 config: MultiGpuConfig,
1401}
1402
1403impl MultiGpuBuilder {
1404 pub fn new() -> Self {
1406 Self {
1407 config: MultiGpuConfig::default(),
1408 }
1409 }
1410
1411 pub fn load_balancing(mut self, strategy: LoadBalancingStrategy) -> Self {
1413 self.config.load_balancing = strategy;
1414 self
1415 }
1416
1417 pub fn auto_select_device(mut self, enable: bool) -> Self {
1419 self.config.auto_select_device = enable;
1420 self
1421 }
1422
1423 pub fn max_kernels_per_device(mut self, max: usize) -> Self {
1425 self.config.max_kernels_per_device = max;
1426 self
1427 }
1428
1429 pub fn enable_p2p(mut self, enable: bool) -> Self {
1431 self.config.enable_p2p = enable;
1432 self
1433 }
1434
1435 pub fn preferred_devices(mut self, devices: Vec<usize>) -> Self {
1437 self.config.preferred_devices = devices;
1438 self
1439 }
1440
1441 pub fn build(self) -> Arc<MultiGpuCoordinator> {
1443 MultiGpuCoordinator::new(self.config)
1444 }
1445}
1446
1447impl Default for MultiGpuBuilder {
1448 fn default() -> Self {
1449 Self::new()
1450 }
1451}
1452
1453pub struct CrossDeviceTransfer {
1455 pub source_device: usize,
1457 pub dest_device: usize,
1459 pub size: usize,
1461 pub use_p2p: bool,
1463}
1464
1465impl CrossDeviceTransfer {
1466 pub fn new(source: usize, dest: usize, size: usize) -> Self {
1468 Self {
1469 source_device: source,
1470 dest_device: dest,
1471 size,
1472 use_p2p: true,
1473 }
1474 }
1475
1476 pub fn without_p2p(mut self) -> Self {
1478 self.use_p2p = false;
1479 self
1480 }
1481}
1482
1483use crate::checkpoint::{CheckpointStorage, CheckpointableKernel, MemoryStorage};
1488
1489pub struct KernelMigrator {
1506 coordinator: Arc<MultiGpuCoordinator>,
1508 storage: Arc<dyn CheckpointStorage>,
1510 stats: MigrationStats,
1512}
1513
1514#[derive(Debug, Default)]
1516pub struct MigrationStats {
1517 pub successful_migrations: AtomicU64,
1519 pub failed_migrations: AtomicU64,
1521 pub bytes_transferred: AtomicU64,
1523 pub checkpoint_time_us: AtomicU64,
1525 pub restore_time_us: AtomicU64,
1527}
1528
1529#[derive(Debug, Clone)]
1531pub struct MigrationResult {
1532 pub kernel_id: KernelId,
1534 pub source_device: usize,
1536 pub target_device: usize,
1538 pub checkpoint_size: usize,
1540 pub checkpoint_duration: Duration,
1542 pub transfer_duration: Duration,
1544 pub restore_duration: Duration,
1546 pub total_duration: Duration,
1548}
1549
1550impl KernelMigrator {
1551 pub fn new(coordinator: Arc<MultiGpuCoordinator>) -> Self {
1553 Self {
1554 coordinator,
1555 storage: Arc::new(MemoryStorage::new()),
1556 stats: MigrationStats::default(),
1557 }
1558 }
1559
1560 pub fn with_storage(
1562 coordinator: Arc<MultiGpuCoordinator>,
1563 storage: Arc<dyn CheckpointStorage>,
1564 ) -> Self {
1565 Self {
1566 coordinator,
1567 storage,
1568 stats: MigrationStats::default(),
1569 }
1570 }
1571
1572 pub fn migrate_with_checkpoint<K: CheckpointableKernel>(
1581 &self,
1582 kernel: &K,
1583 request: &mut MigrationRequest,
1584 ) -> Result<MigrationResult> {
1585 let start_time = Instant::now();
1586 request.started_at = Some(start_time);
1587
1588 request.state = MigrationState::Quiescing;
1590 request.state = MigrationState::Checkpointing;
1595 let checkpoint_start = Instant::now();
1596 let checkpoint = kernel.create_checkpoint().map_err(|e| {
1597 self.stats.failed_migrations.fetch_add(1, Ordering::Relaxed);
1598 request.state = MigrationState::Failed;
1599 RingKernelError::MigrationFailed(format!("Checkpoint creation failed: {}", e))
1600 })?;
1601 let checkpoint_duration = checkpoint_start.elapsed();
1602 let checkpoint_size = checkpoint.total_size();
1603
1604 self.stats
1605 .checkpoint_time_us
1606 .fetch_add(checkpoint_duration.as_micros() as u64, Ordering::Relaxed);
1607
1608 request.state = MigrationState::Transferring;
1610 let transfer_start = Instant::now();
1611
1612 let checkpoint_name = format!(
1614 "migration_{}_{}_{}",
1615 request.kernel_id.as_str(),
1616 request.source_device,
1617 request.target_device
1618 );
1619 self.storage
1620 .save(&checkpoint, &checkpoint_name)
1621 .map_err(|e| {
1622 self.stats.failed_migrations.fetch_add(1, Ordering::Relaxed);
1623 request.state = MigrationState::Failed;
1624 RingKernelError::MigrationFailed(format!("Checkpoint transfer failed: {}", e))
1625 })?;
1626
1627 let transfer_duration = transfer_start.elapsed();
1628 self.stats
1629 .bytes_transferred
1630 .fetch_add(checkpoint_size as u64, Ordering::Relaxed);
1631
1632 request.state = MigrationState::Restoring;
1634 let restore_start = Instant::now();
1635
1636 let _restored = self.storage.load(&checkpoint_name).map_err(|e| {
1638 self.stats.failed_migrations.fetch_add(1, Ordering::Relaxed);
1639 request.state = MigrationState::Failed;
1640 RingKernelError::MigrationFailed(format!("Checkpoint restore failed: {}", e))
1641 })?;
1642
1643 let restore_duration = restore_start.elapsed();
1644 self.stats
1645 .restore_time_us
1646 .fetch_add(restore_duration.as_micros() as u64, Ordering::Relaxed);
1647
1648 request.state = MigrationState::Completed;
1650 self.coordinator.complete_migration(request)?;
1651
1652 let _ = self.storage.delete(&checkpoint_name);
1654
1655 self.stats
1656 .successful_migrations
1657 .fetch_add(1, Ordering::Relaxed);
1658
1659 Ok(MigrationResult {
1660 kernel_id: request.kernel_id.clone(),
1661 source_device: request.source_device,
1662 target_device: request.target_device,
1663 checkpoint_size,
1664 checkpoint_duration,
1665 transfer_duration,
1666 restore_duration,
1667 total_duration: start_time.elapsed(),
1668 })
1669 }
1670
1671 pub fn coordinator(&self) -> &Arc<MultiGpuCoordinator> {
1673 &self.coordinator
1674 }
1675
1676 pub fn stats(&self) -> MigrationStatsSnapshot {
1678 let successful = self.stats.successful_migrations.load(Ordering::Relaxed);
1679 let failed = self.stats.failed_migrations.load(Ordering::Relaxed);
1680 let total = successful + failed;
1681 let checkpoint_us = self.stats.checkpoint_time_us.load(Ordering::Relaxed);
1682 let restore_us = self.stats.restore_time_us.load(Ordering::Relaxed);
1683
1684 MigrationStatsSnapshot {
1685 successful_migrations: successful,
1686 failed_migrations: failed,
1687 bytes_transferred: self.stats.bytes_transferred.load(Ordering::Relaxed),
1688 avg_checkpoint_time: if total > 0 {
1689 Duration::from_micros(checkpoint_us / total)
1690 } else {
1691 Duration::ZERO
1692 },
1693 avg_restore_time: if total > 0 {
1694 Duration::from_micros(restore_us / total)
1695 } else {
1696 Duration::ZERO
1697 },
1698 }
1699 }
1700}
1701
1702#[derive(Debug, Clone)]
1704pub struct MigrationStatsSnapshot {
1705 pub successful_migrations: u64,
1707 pub failed_migrations: u64,
1709 pub bytes_transferred: u64,
1711 pub avg_checkpoint_time: Duration,
1713 pub avg_restore_time: Duration,
1715}
1716
1717pub trait MigratableKernel: CheckpointableKernel {
1719 fn prepare_for_migration(&mut self) -> Result<()>;
1721
1722 fn cancel_migration(&mut self) -> Result<()>;
1724
1725 fn is_quiescent(&self) -> bool;
1727
1728 fn estimated_state_size(&self) -> usize;
1730}
1731
1732#[derive(Debug, Clone)]
1738pub struct HotReloadConfig {
1739 pub enabled: bool,
1741 pub reload_timeout: Duration,
1743 pub preserve_state: bool,
1745 pub max_retries: u32,
1747 pub retry_backoff: Duration,
1749 pub validate_before_swap: bool,
1751 pub keep_fallback: bool,
1753}
1754
1755impl Default for HotReloadConfig {
1756 fn default() -> Self {
1757 Self {
1758 enabled: true,
1759 reload_timeout: Duration::from_secs(30),
1760 preserve_state: true,
1761 max_retries: 3,
1762 retry_backoff: Duration::from_millis(500),
1763 validate_before_swap: true,
1764 keep_fallback: true,
1765 }
1766 }
1767}
1768
1769impl HotReloadConfig {
1770 pub fn new() -> Self {
1772 Self::default()
1773 }
1774
1775 pub fn with_enabled(mut self, enabled: bool) -> Self {
1777 self.enabled = enabled;
1778 self
1779 }
1780
1781 pub fn with_timeout(mut self, timeout: Duration) -> Self {
1783 self.reload_timeout = timeout;
1784 self
1785 }
1786
1787 pub fn with_preserve_state(mut self, preserve: bool) -> Self {
1789 self.preserve_state = preserve;
1790 self
1791 }
1792
1793 pub fn with_max_retries(mut self, retries: u32) -> Self {
1795 self.max_retries = retries;
1796 self
1797 }
1798}
1799
1800#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1802pub enum HotReloadState {
1803 Idle,
1805 Draining,
1807 Checkpointing,
1809 Compiling,
1811 Validating,
1813 Swapping,
1815 Restoring,
1817 Completed,
1819 Failed,
1821 RollingBack,
1823}
1824
1825#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1827pub enum KernelCodeFormat {
1828 Ptx,
1830 Cubin,
1832 SpirV,
1834 Wgsl,
1836 Msl,
1838 MetalLib,
1840 Source,
1842}
1843
1844#[derive(Debug, Clone)]
1846pub struct KernelCodeSource {
1847 pub version_id: u64,
1849 pub format: KernelCodeFormat,
1851 pub code: Vec<u8>,
1853 pub entry_point: String,
1855 pub metadata: HashMap<String, String>,
1857 pub created_at: Instant,
1859 pub hash: [u8; 32],
1861}
1862
1863impl KernelCodeSource {
1864 pub fn new(format: KernelCodeFormat, code: Vec<u8>, entry_point: impl Into<String>) -> Self {
1866 let hash = Self::compute_hash(&code);
1867 Self {
1868 version_id: 0,
1869 format,
1870 code,
1871 entry_point: entry_point.into(),
1872 metadata: HashMap::new(),
1873 created_at: Instant::now(),
1874 hash,
1875 }
1876 }
1877
1878 pub fn from_ptx(ptx: &str, entry_point: impl Into<String>) -> Self {
1880 Self::new(KernelCodeFormat::Ptx, ptx.as_bytes().to_vec(), entry_point)
1881 }
1882
1883 pub fn from_wgsl(wgsl: &str, entry_point: impl Into<String>) -> Self {
1885 Self::new(
1886 KernelCodeFormat::Wgsl,
1887 wgsl.as_bytes().to_vec(),
1888 entry_point,
1889 )
1890 }
1891
1892 pub fn from_msl(msl: &str, entry_point: impl Into<String>) -> Self {
1894 Self::new(KernelCodeFormat::Msl, msl.as_bytes().to_vec(), entry_point)
1895 }
1896
1897 pub fn with_version(mut self, version: u64) -> Self {
1899 self.version_id = version;
1900 self
1901 }
1902
1903 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
1905 self.metadata.insert(key.into(), value.into());
1906 self
1907 }
1908
1909 fn compute_hash(data: &[u8]) -> [u8; 32] {
1910 use std::hash::{Hash, Hasher};
1911 let mut hasher = std::collections::hash_map::DefaultHasher::new();
1912 data.hash(&mut hasher);
1913 let h1 = hasher.finish();
1914 h1.hash(&mut hasher);
1915 let h2 = hasher.finish();
1916 h1.hash(&mut hasher);
1917 let h3 = hasher.finish();
1918 h1.hash(&mut hasher);
1919 let h4 = hasher.finish();
1920
1921 let mut hash = [0u8; 32];
1922 hash[0..8].copy_from_slice(&h1.to_le_bytes());
1923 hash[8..16].copy_from_slice(&h2.to_le_bytes());
1924 hash[16..24].copy_from_slice(&h3.to_le_bytes());
1925 hash[24..32].copy_from_slice(&h4.to_le_bytes());
1926 hash
1927 }
1928
1929 pub fn as_str(&self) -> Option<&str> {
1931 match self.format {
1932 KernelCodeFormat::Ptx
1933 | KernelCodeFormat::Wgsl
1934 | KernelCodeFormat::Msl
1935 | KernelCodeFormat::Source => std::str::from_utf8(&self.code).ok(),
1936 _ => None,
1937 }
1938 }
1939
1940 pub fn size(&self) -> usize {
1942 self.code.len()
1943 }
1944}
1945
1946#[derive(Debug)]
1948pub struct HotReloadRequest {
1949 pub kernel_id: KernelId,
1951 pub new_code: KernelCodeSource,
1953 pub state: HotReloadState,
1955 pub created_at: Instant,
1957 pub started_at: Option<Instant>,
1959 pub retry_count: u32,
1961 pub error: Option<String>,
1963 checkpoint_data: Option<Vec<u8>>,
1965}
1966
1967impl HotReloadRequest {
1968 pub fn new(kernel_id: KernelId, new_code: KernelCodeSource) -> Self {
1970 Self {
1971 kernel_id,
1972 new_code,
1973 state: HotReloadState::Idle,
1974 created_at: Instant::now(),
1975 started_at: None,
1976 retry_count: 0,
1977 error: None,
1978 checkpoint_data: None,
1979 }
1980 }
1981
1982 pub fn is_in_progress(&self) -> bool {
1984 !matches!(
1985 self.state,
1986 HotReloadState::Idle | HotReloadState::Completed | HotReloadState::Failed
1987 )
1988 }
1989
1990 pub fn is_completed(&self) -> bool {
1992 self.state == HotReloadState::Completed
1993 }
1994
1995 pub fn is_failed(&self) -> bool {
1997 self.state == HotReloadState::Failed
1998 }
1999
2000 pub fn elapsed(&self) -> Duration {
2002 self.created_at.elapsed()
2003 }
2004
2005 pub fn reload_elapsed(&self) -> Option<Duration> {
2007 self.started_at.map(|s| s.elapsed())
2008 }
2009}
2010
2011#[derive(Debug, Clone)]
2013pub struct HotReloadResult {
2014 pub kernel_id: KernelId,
2016 pub old_version: u64,
2018 pub new_version: u64,
2020 pub state_preserved: bool,
2022 pub checkpoint_size: usize,
2024 pub drain_duration: Duration,
2026 pub checkpoint_duration: Duration,
2028 pub compile_duration: Duration,
2030 pub swap_duration: Duration,
2032 pub restore_duration: Duration,
2034 pub total_duration: Duration,
2036}
2037
2038#[derive(Debug, Default)]
2040struct HotReloadStats {
2041 successful_reloads: AtomicU64,
2042 failed_reloads: AtomicU64,
2043 rollbacks: AtomicU64,
2044 total_drain_time_us: AtomicU64,
2045 total_compile_time_us: AtomicU64,
2046 total_swap_time_us: AtomicU64,
2047 state_preserved_count: AtomicU64,
2048}
2049
2050#[derive(Debug, Clone)]
2052pub struct HotReloadStatsSnapshot {
2053 pub successful_reloads: u64,
2055 pub failed_reloads: u64,
2057 pub rollbacks: u64,
2059 pub avg_drain_time: Duration,
2061 pub avg_compile_time: Duration,
2063 pub avg_swap_time: Duration,
2065 pub state_preserved_count: u64,
2067}
2068
2069pub struct HotReloadManager {
2099 config: HotReloadConfig,
2101 kernels: RwLock<HashMap<KernelId, KernelCodeSource>>,
2103 fallbacks: RwLock<HashMap<KernelId, KernelCodeSource>>,
2105 active_requests: RwLock<HashMap<KernelId, HotReloadRequest>>,
2107 version_counter: AtomicU64,
2109 stats: HotReloadStats,
2111}
2112
2113impl HotReloadManager {
2114 pub fn new(config: HotReloadConfig) -> Arc<Self> {
2116 Arc::new(Self {
2117 config,
2118 kernels: RwLock::new(HashMap::new()),
2119 fallbacks: RwLock::new(HashMap::new()),
2120 active_requests: RwLock::new(HashMap::new()),
2121 version_counter: AtomicU64::new(1),
2122 stats: HotReloadStats::default(),
2123 })
2124 }
2125
2126 pub fn with_defaults() -> Arc<Self> {
2128 Self::new(HotReloadConfig::default())
2129 }
2130
2131 pub fn is_enabled(&self) -> bool {
2133 self.config.enabled
2134 }
2135
2136 pub fn register_kernel(&self, kernel_id: &KernelId, code: KernelCodeSource) {
2138 let version = self.version_counter.fetch_add(1, Ordering::Relaxed);
2139 let code = code.with_version(version);
2140 self.kernels.write().insert(kernel_id.clone(), code);
2141 }
2142
2143 pub fn unregister_kernel(&self, kernel_id: &KernelId) {
2145 self.kernels.write().remove(kernel_id);
2146 self.fallbacks.write().remove(kernel_id);
2147 self.active_requests.write().remove(kernel_id);
2148 }
2149
2150 pub fn get_current_version(&self, kernel_id: &KernelId) -> Option<u64> {
2152 self.kernels.read().get(kernel_id).map(|c| c.version_id)
2153 }
2154
2155 pub fn get_current_code(&self, kernel_id: &KernelId) -> Option<KernelCodeSource> {
2157 self.kernels.read().get(kernel_id).cloned()
2158 }
2159
2160 pub fn request_reload(
2162 &self,
2163 kernel_id: &KernelId,
2164 new_code: KernelCodeSource,
2165 ) -> Result<HotReloadRequest> {
2166 if !self.config.enabled {
2167 return Err(RingKernelError::ValidationError(
2168 "Hot reload is disabled".to_string(),
2169 ));
2170 }
2171
2172 if !self.kernels.read().contains_key(kernel_id) {
2174 return Err(RingKernelError::KernelNotFound(
2175 kernel_id.as_str().to_string(),
2176 ));
2177 }
2178
2179 {
2181 let active = self.active_requests.read();
2182 if let Some(existing) = active.get(kernel_id) {
2183 if existing.is_in_progress() {
2184 return Err(RingKernelError::ValidationError(
2185 "Hot reload already in progress for this kernel".to_string(),
2186 ));
2187 }
2188 }
2189 }
2190
2191 let version = self.version_counter.fetch_add(1, Ordering::Relaxed);
2193 let new_code = new_code.with_version(version);
2194
2195 let request = HotReloadRequest::new(kernel_id.clone(), new_code);
2196 self.active_requests.write().insert(
2197 kernel_id.clone(),
2198 HotReloadRequest::new(kernel_id.clone(), request.new_code.clone()),
2199 );
2200
2201 Ok(request)
2202 }
2203
2204 pub fn execute_reload<K: CheckpointableKernel>(
2213 &self,
2214 request: &mut HotReloadRequest,
2215 kernel: &K,
2216 ) -> Result<HotReloadResult> {
2217 let start_time = Instant::now();
2218 request.started_at = Some(start_time);
2219
2220 let old_version = self
2222 .kernels
2223 .read()
2224 .get(&request.kernel_id)
2225 .map(|c| c.version_id)
2226 .unwrap_or(0);
2227
2228 request.state = HotReloadState::Draining;
2230 let drain_start = Instant::now();
2231 std::thread::sleep(Duration::from_micros(10));
2233 let drain_duration = drain_start.elapsed();
2234 self.stats
2235 .total_drain_time_us
2236 .fetch_add(drain_duration.as_micros() as u64, Ordering::Relaxed);
2237
2238 request.state = HotReloadState::Checkpointing;
2240 let checkpoint_start = Instant::now();
2241 let checkpoint_size = if self.config.preserve_state {
2242 let checkpoint = kernel.create_checkpoint()?;
2243 let data = checkpoint.to_bytes();
2244 request.checkpoint_data = Some(data.clone());
2245 data.len()
2246 } else {
2247 0
2248 };
2249 let checkpoint_duration = checkpoint_start.elapsed();
2250
2251 request.state = HotReloadState::Validating;
2253 if self.config.validate_before_swap {
2254 self.validate_code(&request.new_code)?;
2255 }
2256
2257 request.state = HotReloadState::Compiling;
2259 let compile_start = Instant::now();
2260 std::thread::sleep(Duration::from_micros(10));
2262 let compile_duration = compile_start.elapsed();
2263 self.stats
2264 .total_compile_time_us
2265 .fetch_add(compile_duration.as_micros() as u64, Ordering::Relaxed);
2266
2267 request.state = HotReloadState::Swapping;
2269 let swap_start = Instant::now();
2270
2271 if self.config.keep_fallback {
2273 if let Some(old_code) = self.kernels.read().get(&request.kernel_id).cloned() {
2274 self.fallbacks
2275 .write()
2276 .insert(request.kernel_id.clone(), old_code);
2277 }
2278 }
2279
2280 self.kernels
2282 .write()
2283 .insert(request.kernel_id.clone(), request.new_code.clone());
2284 let swap_duration = swap_start.elapsed();
2285 self.stats
2286 .total_swap_time_us
2287 .fetch_add(swap_duration.as_micros() as u64, Ordering::Relaxed);
2288
2289 request.state = HotReloadState::Restoring;
2291 let restore_start = Instant::now();
2292 let restore_duration = restore_start.elapsed();
2294
2295 request.state = HotReloadState::Completed;
2297 self.stats
2298 .successful_reloads
2299 .fetch_add(1, Ordering::Relaxed);
2300 if self.config.preserve_state && checkpoint_size > 0 {
2301 self.stats
2302 .state_preserved_count
2303 .fetch_add(1, Ordering::Relaxed);
2304 }
2305
2306 self.active_requests.write().remove(&request.kernel_id);
2308
2309 Ok(HotReloadResult {
2310 kernel_id: request.kernel_id.clone(),
2311 old_version,
2312 new_version: request.new_code.version_id,
2313 state_preserved: self.config.preserve_state && checkpoint_size > 0,
2314 checkpoint_size,
2315 drain_duration,
2316 checkpoint_duration,
2317 compile_duration,
2318 swap_duration,
2319 restore_duration,
2320 total_duration: start_time.elapsed(),
2321 })
2322 }
2323
2324 pub fn rollback(&self, kernel_id: &KernelId) -> Result<()> {
2326 let fallback =
2327 self.fallbacks.write().remove(kernel_id).ok_or_else(|| {
2328 RingKernelError::ValidationError("No fallback available".to_string())
2329 })?;
2330
2331 self.kernels.write().insert(kernel_id.clone(), fallback);
2332 self.stats.rollbacks.fetch_add(1, Ordering::Relaxed);
2333
2334 if let Some(request) = self.active_requests.write().get_mut(kernel_id) {
2336 request.state = HotReloadState::RollingBack;
2337 }
2338
2339 Ok(())
2340 }
2341
2342 fn validate_code(&self, code: &KernelCodeSource) -> Result<()> {
2344 if code.code.is_empty() {
2346 return Err(RingKernelError::ValidationError(
2347 "Kernel code is empty".to_string(),
2348 ));
2349 }
2350
2351 if code.entry_point.is_empty() {
2352 return Err(RingKernelError::ValidationError(
2353 "Entry point is empty".to_string(),
2354 ));
2355 }
2356
2357 match code.format {
2359 KernelCodeFormat::Ptx => {
2360 if let Some(text) = code.as_str() {
2362 if !text.contains(".version") && !text.contains(".target") {
2363 return Err(RingKernelError::ValidationError(
2364 "PTX code missing version/target directive".to_string(),
2365 ));
2366 }
2367 }
2368 }
2369 KernelCodeFormat::Wgsl => {
2370 if let Some(text) = code.as_str() {
2372 if !text.contains("@compute") && !text.contains("fn ") {
2373 return Err(RingKernelError::ValidationError(
2374 "WGSL code missing compute shader or function".to_string(),
2375 ));
2376 }
2377 }
2378 }
2379 KernelCodeFormat::Msl => {
2380 if let Some(text) = code.as_str() {
2382 if !text.contains("kernel ") {
2383 return Err(RingKernelError::ValidationError(
2384 "MSL code missing kernel function".to_string(),
2385 ));
2386 }
2387 }
2388 }
2389 _ => {}
2390 }
2391
2392 Ok(())
2393 }
2394
2395 pub fn stats(&self) -> HotReloadStatsSnapshot {
2397 let successful = self.stats.successful_reloads.load(Ordering::Relaxed);
2398 let failed = self.stats.failed_reloads.load(Ordering::Relaxed);
2399 let total = successful.max(1);
2400
2401 HotReloadStatsSnapshot {
2402 successful_reloads: successful,
2403 failed_reloads: failed,
2404 rollbacks: self.stats.rollbacks.load(Ordering::Relaxed),
2405 avg_drain_time: Duration::from_micros(
2406 self.stats.total_drain_time_us.load(Ordering::Relaxed) / total,
2407 ),
2408 avg_compile_time: Duration::from_micros(
2409 self.stats.total_compile_time_us.load(Ordering::Relaxed) / total,
2410 ),
2411 avg_swap_time: Duration::from_micros(
2412 self.stats.total_swap_time_us.load(Ordering::Relaxed) / total,
2413 ),
2414 state_preserved_count: self.stats.state_preserved_count.load(Ordering::Relaxed),
2415 }
2416 }
2417
2418 pub fn list_kernels(&self) -> Vec<KernelId> {
2420 self.kernels.read().keys().cloned().collect()
2421 }
2422
2423 pub fn is_registered(&self, kernel_id: &KernelId) -> bool {
2425 self.kernels.read().contains_key(kernel_id)
2426 }
2427
2428 pub fn is_reload_in_progress(&self, kernel_id: &KernelId) -> bool {
2430 self.active_requests
2431 .read()
2432 .get(kernel_id)
2433 .map(|r| r.is_in_progress())
2434 .unwrap_or(false)
2435 }
2436
2437 pub fn config(&self) -> &HotReloadConfig {
2439 &self.config
2440 }
2441}
2442
2443pub trait HotReloadableKernel: CheckpointableKernel {
2445 fn prepare_for_reload(&mut self) -> Result<()>;
2447
2448 fn apply_code(&mut self, code: &KernelCodeSource) -> Result<()>;
2450
2451 fn resume_after_reload(&mut self) -> Result<()>;
2453
2454 fn is_ready_for_reload(&self) -> bool;
2456}
2457
2458#[cfg(test)]
2459mod tests {
2460 use super::*;
2461
2462 #[test]
2463 fn test_device_info() {
2464 let info = DeviceInfo::new(0, "Test GPU".to_string(), Backend::Cuda);
2465 assert_eq!(info.index, 0);
2466 assert_eq!(info.name, "Test GPU");
2467 assert_eq!(info.memory_utilization(), 0.0);
2468 }
2469
2470 #[test]
2471 fn test_coordinator_registration() {
2472 let coord = MultiGpuBuilder::new().build();
2473
2474 let device = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2475 coord.register_device(device);
2476
2477 assert_eq!(coord.device_count(), 1);
2478 assert!(coord.device(0).is_some());
2479 }
2480
2481 #[test]
2482 fn test_kernel_assignment() {
2483 let coord = MultiGpuBuilder::new().build();
2484
2485 let device = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2486 coord.register_device(device);
2487
2488 let kernel_id = KernelId::new("test_kernel");
2489 coord.assign_kernel(kernel_id.clone(), 0);
2490
2491 assert_eq!(coord.get_kernel_device(&kernel_id), Some(0));
2492 assert_eq!(coord.kernels_on_device(0).len(), 1);
2493 }
2494
2495 #[test]
2496 fn test_load_balancing_least_loaded() {
2497 let coord = MultiGpuBuilder::new()
2498 .load_balancing(LoadBalancingStrategy::LeastLoaded)
2499 .build();
2500
2501 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2503 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2504
2505 coord.assign_kernel(KernelId::new("k1"), 0);
2507
2508 let selected = coord.select_device(&LaunchOptions::default()).unwrap();
2510 assert_eq!(selected, 1);
2511 }
2512
2513 #[test]
2514 fn test_round_robin() {
2515 let coord = MultiGpuBuilder::new()
2516 .load_balancing(LoadBalancingStrategy::RoundRobin)
2517 .build();
2518
2519 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2520 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2521
2522 let d1 = coord.select_device(&LaunchOptions::default()).unwrap();
2523 let d2 = coord.select_device(&LaunchOptions::default()).unwrap();
2524 let d3 = coord.select_device(&LaunchOptions::default()).unwrap();
2525
2526 assert_ne!(d1, d2);
2528 assert_eq!(d1, d3);
2529 }
2530
2531 #[test]
2536 fn test_interconnect_bandwidth() {
2537 assert!(
2538 InterconnectType::NvLink.estimated_bandwidth_gbps()
2539 > InterconnectType::Pcie.estimated_bandwidth_gbps()
2540 );
2541 assert!(
2542 InterconnectType::Pcie.estimated_bandwidth_gbps()
2543 > InterconnectType::None.estimated_bandwidth_gbps()
2544 );
2545 assert!(
2546 InterconnectType::SameDevice.estimated_bandwidth_gbps()
2547 > InterconnectType::NvLink.estimated_bandwidth_gbps()
2548 );
2549 }
2550
2551 #[test]
2552 fn test_interconnect_p2p_support() {
2553 assert!(!InterconnectType::None.supports_p2p());
2554 assert!(InterconnectType::Pcie.supports_p2p());
2555 assert!(InterconnectType::NvLink.supports_p2p());
2556 assert!(InterconnectType::NvSwitch.supports_p2p());
2557 }
2558
2559 #[test]
2560 fn test_gpu_topology_creation() {
2561 let topo = GpuTopology::new(4);
2562 assert_eq!(topo.device_count, 4);
2563
2564 for i in 0..4 {
2566 let conn = topo.get_connection(i, i);
2567 assert!(conn.is_some());
2568 assert_eq!(conn.unwrap().interconnect, InterconnectType::SameDevice);
2569 }
2570 }
2571
2572 #[test]
2573 fn test_gpu_topology_set_connection() {
2574 let mut topo = GpuTopology::new(4);
2575
2576 topo.set_connection(GpuConnection::new(0, 1, InterconnectType::NvLink));
2578
2579 let conn_01 = topo.get_connection(0, 1);
2580 assert!(conn_01.is_some());
2581 assert_eq!(conn_01.unwrap().interconnect, InterconnectType::NvLink);
2582
2583 let conn_10 = topo.get_connection(1, 0);
2585 assert!(conn_10.is_some());
2586 assert_eq!(conn_10.unwrap().interconnect, InterconnectType::NvLink);
2587 }
2588
2589 #[test]
2590 fn test_gpu_topology_neighbors() {
2591 let mut topo = GpuTopology::new(4);
2592
2593 topo.set_connection(GpuConnection::new(0, 1, InterconnectType::NvLink));
2595 topo.set_connection(GpuConnection::new(1, 2, InterconnectType::NvLink));
2596 topo.set_connection(GpuConnection::new(2, 3, InterconnectType::NvLink));
2597 topo.set_connection(GpuConnection::new(3, 0, InterconnectType::NvLink));
2598
2599 let neighbors_0 = topo.neighbors(0);
2600 assert_eq!(neighbors_0.len(), 2);
2601 assert!(neighbors_0.contains(&1));
2602 assert!(neighbors_0.contains(&3));
2603 }
2604
2605 #[test]
2606 fn test_gpu_topology_best_path() {
2607 let mut topo = GpuTopology::new(4);
2608
2609 topo.set_connection(GpuConnection::new(0, 1, InterconnectType::NvLink));
2611 topo.set_connection(GpuConnection::new(1, 2, InterconnectType::NvLink));
2612 topo.set_connection(GpuConnection::new(2, 3, InterconnectType::NvLink));
2613 topo.set_connection(GpuConnection::new(0, 3, InterconnectType::None)); let path_01 = topo.best_path(0, 1);
2617 assert_eq!(path_01, vec![0, 1]);
2618
2619 let path_00 = topo.best_path(0, 0);
2621 assert_eq!(path_00, vec![0]);
2622 }
2623
2624 #[test]
2625 fn test_gpu_topology_fully_connected() {
2626 let mut topo = GpuTopology::new(3);
2627
2628 assert!(!topo.is_fully_connected());
2630
2631 topo.set_connection(GpuConnection::new(0, 1, InterconnectType::NvLink));
2633 topo.set_connection(GpuConnection::new(0, 2, InterconnectType::NvLink));
2634 topo.set_connection(GpuConnection::new(1, 2, InterconnectType::NvLink));
2635
2636 assert!(topo.is_fully_connected());
2637 }
2638
2639 #[test]
2640 fn test_gpu_topology_numa() {
2641 let mut topo = GpuTopology::new(4);
2642
2643 topo.set_numa_node(0, 0);
2645 topo.set_numa_node(1, 0);
2646 topo.set_numa_node(2, 1);
2647 topo.set_numa_node(3, 1);
2648
2649 let numa_neighbors_0 = topo.numa_neighbors(0);
2650 assert_eq!(numa_neighbors_0, vec![1]);
2651
2652 let numa_neighbors_2 = topo.numa_neighbors(2);
2653 assert_eq!(numa_neighbors_2, vec![3]);
2654 }
2655
2656 #[test]
2661 fn test_coordinator_topology_discovery() {
2662 let coord = MultiGpuBuilder::new().enable_p2p(true).build();
2663
2664 let mut dev0 = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2666 dev0.p2p_capable = true;
2667 dev0.compute_capability = Some((8, 0)); let mut dev1 = DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda);
2670 dev1.p2p_capable = true;
2671 dev1.compute_capability = Some((8, 6)); coord.register_device(dev0);
2674 coord.register_device(dev1);
2675
2676 let topo = coord.discover_topology();
2677
2678 assert_eq!(topo.device_count, 2);
2679
2680 let conn = topo.get_connection(0, 1);
2682 assert!(conn.is_some());
2683 assert_eq!(conn.unwrap().interconnect, InterconnectType::NvLink);
2684 }
2685
2686 #[test]
2691 fn test_migration_request() {
2692 let coord = MultiGpuBuilder::new().build();
2693
2694 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2695 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2696
2697 let kernel_id = KernelId::new("migrating_kernel");
2698 coord.assign_kernel(kernel_id.clone(), 0);
2699
2700 let request = coord.request_migration(&kernel_id, 1).unwrap();
2701
2702 assert_eq!(request.source_device, 0);
2703 assert_eq!(request.target_device, 1);
2704 assert_eq!(request.state, MigrationState::Pending);
2705 }
2706
2707 #[test]
2708 fn test_migration_same_device_error() {
2709 let coord = MultiGpuBuilder::new().build();
2710
2711 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2712
2713 let kernel_id = KernelId::new("kernel");
2714 coord.assign_kernel(kernel_id.clone(), 0);
2715
2716 let result = coord.request_migration(&kernel_id, 0);
2717 assert!(result.is_err());
2718 }
2719
2720 #[test]
2721 fn test_migration_complete() {
2722 let coord = MultiGpuBuilder::new().build();
2723
2724 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2725 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2726
2727 let kernel_id = KernelId::new("migrating_kernel");
2728 coord.assign_kernel(kernel_id.clone(), 0);
2729
2730 assert_eq!(coord.get_kernel_device(&kernel_id), Some(0));
2731
2732 let request = coord.request_migration(&kernel_id, 1).unwrap();
2733 coord.complete_migration(&request).unwrap();
2734
2735 assert_eq!(coord.get_kernel_device(&kernel_id), Some(1));
2736 }
2737
2738 #[test]
2739 fn test_migration_transfer_time_estimate() {
2740 let request = MigrationRequest {
2741 kernel_id: KernelId::new("test"),
2742 source_device: 0,
2743 target_device: 1,
2744 path: vec![0, 1],
2745 estimated_bandwidth_gbps: 300.0, estimated_latency_us: 1.0,
2747 state: MigrationState::Pending,
2748 started_at: None,
2749 };
2750
2751 let time = request.estimate_transfer_time(1_000_000_000);
2753 assert!(time.as_micros() > 3000);
2754 assert!(time.as_micros() < 4000);
2755 }
2756
2757 use crate::hlc::HlcTimestamp;
2762 use crate::message::MessageEnvelope;
2763
2764 fn make_test_k2k_message(source: &KernelId, dest: &KernelId) -> K2KMessage {
2765 let timestamp = HlcTimestamp::now(42);
2766 let envelope = MessageEnvelope::empty(1, 2, timestamp);
2767 K2KMessage::new(source.clone(), dest.clone(), envelope, timestamp)
2768 }
2769
2770 #[test]
2771 fn test_router_same_device() {
2772 let coord = MultiGpuBuilder::new().build();
2773 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2774
2775 let k1 = KernelId::new("k1");
2776 let k2 = KernelId::new("k2");
2777 coord.assign_kernel(k1.clone(), 0);
2778 coord.assign_kernel(k2.clone(), 0);
2779
2780 let router = CrossGpuK2KRouter::new(coord);
2781
2782 let msg = make_test_k2k_message(&k1, &k2);
2783 let decision = router.route_message(&k1, &k2, msg).unwrap();
2784
2785 matches!(decision, RoutingDecision::SameDevice);
2786 }
2787
2788 #[test]
2789 fn test_router_cross_device() {
2790 let coord = MultiGpuBuilder::new().enable_p2p(true).build();
2791
2792 let mut dev0 = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2793 dev0.p2p_capable = true;
2794 let mut dev1 = DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda);
2795 dev1.p2p_capable = true;
2796
2797 coord.register_device(dev0);
2798 coord.register_device(dev1);
2799
2800 let k1 = KernelId::new("k1");
2801 let k2 = KernelId::new("k2");
2802 coord.assign_kernel(k1.clone(), 0);
2803 coord.assign_kernel(k2.clone(), 1);
2804
2805 let router = CrossGpuK2KRouter::new(coord);
2806
2807 let msg = make_test_k2k_message(&k1, &k2);
2808 let decision = router.route_message(&k1, &k2, msg).unwrap();
2809
2810 match decision {
2811 RoutingDecision::DirectP2P {
2812 source_device,
2813 dest_device,
2814 ..
2815 } => {
2816 assert_eq!(source_device, 0);
2817 assert_eq!(dest_device, 1);
2818 }
2819 _ => panic!("Expected DirectP2P routing"),
2820 }
2821 }
2822
2823 #[test]
2824 fn test_router_pending_messages() {
2825 let coord = MultiGpuBuilder::new().enable_p2p(true).build();
2826
2827 let mut dev0 = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2828 dev0.p2p_capable = true;
2829 let mut dev1 = DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda);
2830 dev1.p2p_capable = true;
2831
2832 coord.register_device(dev0);
2833 coord.register_device(dev1);
2834
2835 let k1 = KernelId::new("k1");
2836 let k2 = KernelId::new("k2");
2837 coord.assign_kernel(k1.clone(), 0);
2838 coord.assign_kernel(k2.clone(), 1);
2839
2840 let router = CrossGpuK2KRouter::new(coord);
2841
2842 for _ in 0..3 {
2844 let msg = make_test_k2k_message(&k1, &k2);
2845 router.route_message(&k1, &k2, msg).unwrap();
2846 }
2847
2848 assert_eq!(router.stats().messages_pending, 3);
2849
2850 let pending = router.drain_pending(0, 1);
2852 assert_eq!(pending.len(), 3);
2853 assert_eq!(router.stats().messages_pending, 0);
2854 }
2855
2856 #[test]
2857 fn test_router_stats() {
2858 let coord = MultiGpuBuilder::new().build();
2859 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2860
2861 let k1 = KernelId::new("k1");
2862 let k2 = KernelId::new("k2");
2863 coord.assign_kernel(k1.clone(), 0);
2864 coord.assign_kernel(k2.clone(), 0);
2865
2866 let router = CrossGpuK2KRouter::new(coord);
2867
2868 let stats = router.stats();
2869 assert_eq!(stats.messages_routed, 0);
2870 assert_eq!(stats.bytes_transferred, 0);
2871 assert_eq!(stats.routing_failures, 0);
2872 }
2873
2874 use crate::checkpoint::{Checkpoint, CheckpointBuilder};
2879
2880 struct MockCheckpointableKernel {
2882 kernel_id: String,
2883 kernel_type: String,
2884 state_data: Vec<u8>,
2885 step: u64,
2886 }
2887
2888 impl MockCheckpointableKernel {
2889 fn new(kernel_id: &str, state_size: usize) -> Self {
2890 Self {
2891 kernel_id: kernel_id.to_string(),
2892 kernel_type: "mock_kernel".to_string(),
2893 state_data: vec![0xAB; state_size],
2894 step: 1000,
2895 }
2896 }
2897 }
2898
2899 impl CheckpointableKernel for MockCheckpointableKernel {
2900 fn create_checkpoint(&self) -> Result<Checkpoint> {
2901 let checkpoint = CheckpointBuilder::new(&self.kernel_id, &self.kernel_type)
2902 .step(self.step)
2903 .grid_size(64, 64, 64)
2904 .control_block(vec![1, 2, 3, 4])
2905 .device_memory("state", self.state_data.clone())
2906 .build();
2907 Ok(checkpoint)
2908 }
2909
2910 fn restore_from_checkpoint(&mut self, checkpoint: &Checkpoint) -> Result<()> {
2911 self.step = checkpoint.metadata.current_step;
2912 Ok(())
2913 }
2914
2915 fn checkpoint_kernel_id(&self) -> &str {
2916 &self.kernel_id
2917 }
2918
2919 fn checkpoint_kernel_type(&self) -> &str {
2920 &self.kernel_type
2921 }
2922 }
2923
2924 #[test]
2925 fn test_migrator_creation() {
2926 let coord = MultiGpuBuilder::new().build();
2927 let migrator = KernelMigrator::new(coord);
2928
2929 let stats = migrator.stats();
2930 assert_eq!(stats.successful_migrations, 0);
2931 assert_eq!(stats.failed_migrations, 0);
2932 assert_eq!(stats.bytes_transferred, 0);
2933 }
2934
2935 #[test]
2936 fn test_migrator_with_custom_storage() {
2937 let coord = MultiGpuBuilder::new().build();
2938 let storage = Arc::new(MemoryStorage::new());
2939 let migrator = KernelMigrator::with_storage(coord.clone(), storage);
2940
2941 assert!(Arc::ptr_eq(migrator.coordinator(), &coord));
2943 }
2944
2945 #[test]
2946 fn test_successful_migration() {
2947 let coord = MultiGpuBuilder::new().build();
2948
2949 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2951 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2952
2953 let kernel_id = KernelId::new("migratable_kernel");
2955 coord.assign_kernel(kernel_id.clone(), 0);
2956
2957 let migrator = KernelMigrator::new(coord.clone());
2958
2959 let kernel = MockCheckpointableKernel::new("migratable_kernel", 1024);
2961
2962 let mut request = coord.request_migration(&kernel_id, 1).unwrap();
2964 assert_eq!(request.state, MigrationState::Pending);
2965
2966 let result = migrator
2968 .migrate_with_checkpoint(&kernel, &mut request)
2969 .unwrap();
2970
2971 assert_eq!(result.kernel_id.as_str(), "migratable_kernel");
2973 assert_eq!(result.source_device, 0);
2974 assert_eq!(result.target_device, 1);
2975 assert!(result.checkpoint_size > 0);
2976 assert!(result.total_duration > Duration::ZERO);
2977
2978 assert_eq!(coord.get_kernel_device(&kernel_id), Some(1));
2980
2981 let stats = migrator.stats();
2983 assert_eq!(stats.successful_migrations, 1);
2984 assert_eq!(stats.failed_migrations, 0);
2985 assert!(stats.bytes_transferred > 0);
2986 }
2987
2988 #[test]
2989 fn test_migration_result_fields() {
2990 let coord = MultiGpuBuilder::new().build();
2991
2992 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2993 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2994
2995 let kernel_id = KernelId::new("test_kernel");
2996 coord.assign_kernel(kernel_id.clone(), 0);
2997
2998 let migrator = KernelMigrator::new(coord.clone());
2999 let kernel = MockCheckpointableKernel::new("test_kernel", 4096);
3000 let mut request = coord.request_migration(&kernel_id, 1).unwrap();
3001
3002 let result = migrator
3003 .migrate_with_checkpoint(&kernel, &mut request)
3004 .unwrap();
3005
3006 assert!(result.checkpoint_duration >= Duration::ZERO);
3008 assert!(result.transfer_duration >= Duration::ZERO);
3009 assert!(result.restore_duration >= Duration::ZERO);
3010
3011 assert!(result.total_duration >= result.checkpoint_duration);
3013 }
3014
3015 #[test]
3016 fn test_migration_stats_accumulate() {
3017 let coord = MultiGpuBuilder::new().build();
3018
3019 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3020 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
3021
3022 let migrator = KernelMigrator::new(coord.clone());
3023
3024 let k1 = KernelId::new("k1");
3026 coord.assign_kernel(k1.clone(), 0);
3027 let kernel1 = MockCheckpointableKernel::new("k1", 1000);
3028 let mut req1 = coord.request_migration(&k1, 1).unwrap();
3029 migrator
3030 .migrate_with_checkpoint(&kernel1, &mut req1)
3031 .unwrap();
3032
3033 let k2 = KernelId::new("k2");
3035 coord.assign_kernel(k2.clone(), 0);
3036 let kernel2 = MockCheckpointableKernel::new("k2", 2000);
3037 let mut req2 = coord.request_migration(&k2, 1).unwrap();
3038 migrator
3039 .migrate_with_checkpoint(&kernel2, &mut req2)
3040 .unwrap();
3041
3042 let stats = migrator.stats();
3043 assert_eq!(stats.successful_migrations, 2);
3044 assert_eq!(stats.failed_migrations, 0);
3045 assert!(stats.bytes_transferred > 0);
3047 }
3048
3049 #[test]
3054 fn test_unregister_device_no_kernels() {
3055 let coord = MultiGpuBuilder::new().build();
3056
3057 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3058 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
3059
3060 let result = coord.unregister_device(0);
3061
3062 assert!(result.success);
3063 assert_eq!(result.device_index, 0);
3064 assert!(result.kernels_to_migrate.is_empty());
3065 assert!(result.orphaned_kernels.is_empty());
3066 }
3067
3068 #[test]
3069 fn test_unregister_device_with_kernels() {
3070 let coord = MultiGpuBuilder::new().build();
3071
3072 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3073 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
3074
3075 let k1 = KernelId::new("k1");
3077 let k2 = KernelId::new("k2");
3078 coord.assign_kernel(k1.clone(), 0);
3079 coord.assign_kernel(k2.clone(), 0);
3080
3081 let result = coord.unregister_device(0);
3082
3083 assert!(result.success);
3084 assert_eq!(result.kernels_to_migrate.len(), 2);
3085 assert!(result.orphaned_kernels.is_empty());
3086
3087 for plan in &result.kernels_to_migrate {
3089 assert_eq!(plan.source_device, 0);
3090 assert_eq!(plan.target_device, 1);
3091 }
3092
3093 assert_eq!(coord.get_kernel_device(&k1), Some(1));
3095 assert_eq!(coord.get_kernel_device(&k2), Some(1));
3096 }
3097
3098 #[test]
3099 fn test_unregister_single_device_orphans_kernels() {
3100 let coord = MultiGpuBuilder::new().build();
3101
3102 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3103
3104 let k1 = KernelId::new("k1");
3106 coord.assign_kernel(k1.clone(), 0);
3107
3108 let result = coord.unregister_device(0);
3109
3110 assert!(result.success);
3111 assert!(result.kernels_to_migrate.is_empty());
3112 assert_eq!(result.orphaned_kernels.len(), 1);
3113 assert_eq!(result.orphaned_kernels[0], k1);
3114
3115 assert!(coord.get_kernel_device(&k1).is_none());
3117 }
3118
3119 #[test]
3120 fn test_unregister_nonexistent_device() {
3121 let coord = MultiGpuBuilder::new().build();
3122
3123 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3124
3125 let result = coord.unregister_device(99);
3126
3127 assert!(!result.success);
3128 assert_eq!(result.device_index, 99);
3129 }
3130
3131 #[test]
3132 fn test_unregister_distributes_to_least_loaded() {
3133 let coord = MultiGpuBuilder::new().build();
3134
3135 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3136 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
3137 coord.register_device(DeviceInfo::new(2, "GPU 2".to_string(), Backend::Cuda));
3138
3139 coord.assign_kernel(KernelId::new("pre1"), 1);
3141 coord.assign_kernel(KernelId::new("pre2"), 1);
3142 coord.assign_kernel(KernelId::new("pre3"), 1);
3143
3144 let k1 = KernelId::new("migrate_me");
3146 coord.assign_kernel(k1.clone(), 0);
3147
3148 let result = coord.unregister_device(0);
3149
3150 assert!(result.success);
3151 assert_eq!(result.kernels_to_migrate.len(), 1);
3152
3153 let plan = &result.kernels_to_migrate[0];
3155 assert_eq!(plan.target_device, 2);
3156 }
3157
3158 #[test]
3159 fn test_migration_priority_enum() {
3160 let low = MigrationPriority::Low;
3161 let normal = MigrationPriority::Normal;
3162 let high = MigrationPriority::High;
3163 let critical = MigrationPriority::Critical;
3164
3165 assert_ne!(low, normal);
3166 assert_ne!(normal, high);
3167 assert_ne!(high, critical);
3168 assert_eq!(low, MigrationPriority::Low);
3169 }
3170
3171 #[test]
3174 fn test_hot_reload_config_default() {
3175 let config = HotReloadConfig::default();
3176 assert!(config.enabled);
3177 assert!(config.preserve_state);
3178 assert!(config.validate_before_swap);
3179 assert!(config.keep_fallback);
3180 assert_eq!(config.max_retries, 3);
3181 }
3182
3183 #[test]
3184 fn test_hot_reload_config_builder() {
3185 let config = HotReloadConfig::new()
3186 .with_enabled(false)
3187 .with_preserve_state(false)
3188 .with_max_retries(5)
3189 .with_timeout(Duration::from_secs(60));
3190
3191 assert!(!config.enabled);
3192 assert!(!config.preserve_state);
3193 assert_eq!(config.max_retries, 5);
3194 assert_eq!(config.reload_timeout, Duration::from_secs(60));
3195 }
3196
3197 #[test]
3198 fn test_kernel_code_source_ptx() {
3199 let ptx = ".version 7.0\n.target sm_80\nkernel: ret;";
3200 let code = KernelCodeSource::from_ptx(ptx, "kernel");
3201
3202 assert_eq!(code.format, KernelCodeFormat::Ptx);
3203 assert_eq!(code.entry_point, "kernel");
3204 assert_eq!(code.as_str(), Some(ptx));
3205 assert_eq!(code.size(), ptx.len());
3206 }
3207
3208 #[test]
3209 fn test_kernel_code_source_wgsl() {
3210 let wgsl = "@compute fn main() {}";
3211 let code = KernelCodeSource::from_wgsl(wgsl, "main");
3212
3213 assert_eq!(code.format, KernelCodeFormat::Wgsl);
3214 assert_eq!(code.entry_point, "main");
3215 assert_eq!(code.as_str(), Some(wgsl));
3216 }
3217
3218 #[test]
3219 fn test_kernel_code_source_msl() {
3220 let msl = "kernel void my_kernel() {}";
3221 let code = KernelCodeSource::from_msl(msl, "my_kernel");
3222
3223 assert_eq!(code.format, KernelCodeFormat::Msl);
3224 assert_eq!(code.entry_point, "my_kernel");
3225 assert_eq!(code.as_str(), Some(msl));
3226 }
3227
3228 #[test]
3229 fn test_hot_reload_manager_creation() {
3230 let manager = HotReloadManager::with_defaults();
3231 assert!(manager.is_enabled());
3232 assert!(manager.list_kernels().is_empty());
3233 }
3234
3235 #[test]
3236 fn test_hot_reload_manager_register_kernel() {
3237 let manager = HotReloadManager::with_defaults();
3238 let kernel_id = KernelId::new("test_kernel");
3239 let code = KernelCodeSource::from_ptx(".version 7.0", "kernel");
3240
3241 manager.register_kernel(&kernel_id, code);
3242
3243 assert!(manager.is_registered(&kernel_id));
3244 assert!(!manager.is_reload_in_progress(&kernel_id));
3245 assert!(manager.get_current_version(&kernel_id).is_some());
3246 }
3247
3248 #[test]
3249 fn test_hot_reload_request_states() {
3250 let kernel_id = KernelId::new("test");
3251 let code = KernelCodeSource::from_ptx(".version 7.0", "kernel");
3252 let request = HotReloadRequest::new(kernel_id, code);
3253
3254 assert_eq!(request.state, HotReloadState::Idle);
3255 assert!(!request.is_in_progress());
3256 assert!(!request.is_completed());
3257 assert!(!request.is_failed());
3258 }
3259
3260 #[test]
3261 fn test_hot_reload_disabled() {
3262 let config = HotReloadConfig::new().with_enabled(false);
3263 let manager = HotReloadManager::new(config);
3264 let kernel_id = KernelId::new("test");
3265 let code = KernelCodeSource::from_ptx(".version 7.0", "kernel");
3266
3267 manager.register_kernel(&kernel_id, code.clone());
3268 let result = manager.request_reload(&kernel_id, code);
3269 assert!(result.is_err());
3270 }
3271
3272 #[test]
3273 fn test_hot_reload_stats() {
3274 let manager = HotReloadManager::with_defaults();
3275 let stats = manager.stats();
3276
3277 assert_eq!(stats.successful_reloads, 0);
3278 assert_eq!(stats.failed_reloads, 0);
3279 assert_eq!(stats.rollbacks, 0);
3280 }
3281
3282 #[test]
3283 fn test_hot_reload_code_formats() {
3284 let formats = [
3285 KernelCodeFormat::Ptx,
3286 KernelCodeFormat::Cubin,
3287 KernelCodeFormat::SpirV,
3288 KernelCodeFormat::Wgsl,
3289 KernelCodeFormat::Msl,
3290 KernelCodeFormat::MetalLib,
3291 KernelCodeFormat::Source,
3292 ];
3293
3294 for (i, f1) in formats.iter().enumerate() {
3296 for (j, f2) in formats.iter().enumerate() {
3297 if i != j {
3298 assert_ne!(f1, f2);
3299 }
3300 }
3301 }
3302 }
3303
3304 #[test]
3305 fn test_hot_reload_state_transitions() {
3306 let states = [
3307 HotReloadState::Idle,
3308 HotReloadState::Draining,
3309 HotReloadState::Checkpointing,
3310 HotReloadState::Compiling,
3311 HotReloadState::Validating,
3312 HotReloadState::Swapping,
3313 HotReloadState::Restoring,
3314 HotReloadState::Completed,
3315 HotReloadState::Failed,
3316 HotReloadState::RollingBack,
3317 ];
3318
3319 for (i, s1) in states.iter().enumerate() {
3321 for (j, s2) in states.iter().enumerate() {
3322 if i != j {
3323 assert_ne!(s1, s2);
3324 }
3325 }
3326 }
3327 }
3328
3329 #[test]
3330 fn test_hot_reload_execute() {
3331 let manager = HotReloadManager::with_defaults();
3332 let kernel_id = KernelId::new("test_kernel");
3333
3334 let initial_code = KernelCodeSource::from_ptx(".version 7.0\n.target sm_80", "kernel");
3335 manager.register_kernel(&kernel_id, initial_code);
3336
3337 let new_code = KernelCodeSource::from_ptx(".version 8.0\n.target sm_90", "kernel");
3338 let mut request = manager.request_reload(&kernel_id, new_code).unwrap();
3339
3340 let mock_kernel = MockCheckpointableKernel::new("test_kernel", 512);
3342
3343 let result = manager.execute_reload(&mut request, &mock_kernel).unwrap();
3344
3345 assert!(request.is_completed());
3346 assert_eq!(result.kernel_id.as_str(), "test_kernel");
3347 assert!(result.state_preserved);
3348 assert!(result.checkpoint_size > 0);
3349 assert!(result.total_duration > Duration::ZERO);
3350
3351 let stats = manager.stats();
3353 assert_eq!(stats.successful_reloads, 1);
3354 }
3355
3356 #[test]
3357 fn test_hot_reload_list_kernels() {
3358 let manager = HotReloadManager::with_defaults();
3359
3360 let k1 = KernelId::new("kernel1");
3361 let k2 = KernelId::new("kernel2");
3362 let k3 = KernelId::new("kernel3");
3363
3364 manager.register_kernel(&k1, KernelCodeSource::from_ptx(".version 7.0", "k1"));
3365 manager.register_kernel(&k2, KernelCodeSource::from_ptx(".version 7.0", "k2"));
3366 manager.register_kernel(&k3, KernelCodeSource::from_ptx(".version 7.0", "k3"));
3367
3368 let kernels = manager.list_kernels();
3369 assert_eq!(kernels.len(), 3);
3370 assert!(kernels.contains(&k1));
3371 assert!(kernels.contains(&k2));
3372 assert!(kernels.contains(&k3));
3373 }
3374}