1use parking_lot::RwLock;
30use std::collections::HashMap;
31use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
32use std::sync::Arc;
33use std::time::{Duration, Instant};
34
35use crate::error::{Result, RingKernelError};
36use crate::k2k::K2KMessage;
37use crate::runtime::{Backend, KernelId, LaunchOptions};
38
39#[derive(Debug, Clone)]
41pub struct MultiGpuConfig {
42 pub load_balancing: LoadBalancingStrategy,
44 pub auto_select_device: bool,
46 pub max_kernels_per_device: usize,
48 pub enable_p2p: bool,
50 pub preferred_devices: Vec<usize>,
52}
53
54impl Default for MultiGpuConfig {
55 fn default() -> Self {
56 Self {
57 load_balancing: LoadBalancingStrategy::LeastLoaded,
58 auto_select_device: true,
59 max_kernels_per_device: 64,
60 enable_p2p: true,
61 preferred_devices: vec![],
62 }
63 }
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub enum LoadBalancingStrategy {
69 FirstAvailable,
71 LeastLoaded,
73 RoundRobin,
75 MemoryBased,
77 ComputeCapability,
79 Custom,
81}
82
83#[derive(Debug, Clone)]
85pub struct DeviceInfo {
86 pub index: usize,
88 pub name: String,
90 pub backend: Backend,
92 pub total_memory: u64,
94 pub available_memory: u64,
96 pub compute_capability: Option<(u32, u32)>,
98 pub max_threads_per_block: u32,
100 pub multiprocessor_count: u32,
102 pub p2p_capable: bool,
104}
105
106impl DeviceInfo {
107 pub fn new(index: usize, name: String, backend: Backend) -> Self {
109 Self {
110 index,
111 name,
112 backend,
113 total_memory: 0,
114 available_memory: 0,
115 compute_capability: None,
116 max_threads_per_block: 1024,
117 multiprocessor_count: 1,
118 p2p_capable: false,
119 }
120 }
121
122 pub fn memory_utilization(&self) -> f64 {
124 if self.total_memory == 0 {
125 0.0
126 } else {
127 1.0 - (self.available_memory as f64 / self.total_memory as f64)
128 }
129 }
130}
131
132#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
138pub enum InterconnectType {
139 None,
141 Pcie,
143 NvLink,
145 NvSwitch,
147 InfinityFabric,
149 XeLink,
151 SameDevice,
153}
154
155impl InterconnectType {
156 pub fn estimated_bandwidth_gbps(&self) -> f64 {
158 match self {
159 InterconnectType::None => 16.0, InterconnectType::Pcie => 32.0, InterconnectType::NvLink => 300.0, InterconnectType::NvSwitch => 600.0, InterconnectType::InfinityFabric => 200.0, InterconnectType::XeLink => 100.0, InterconnectType::SameDevice => 2000.0, }
167 }
168
169 pub fn estimated_latency_us(&self) -> f64 {
171 match self {
172 InterconnectType::None => 10.0, InterconnectType::Pcie => 5.0, InterconnectType::NvLink => 1.0, InterconnectType::NvSwitch => 2.0, InterconnectType::InfinityFabric => 1.5,
177 InterconnectType::XeLink => 2.0,
178 InterconnectType::SameDevice => 0.0,
179 }
180 }
181
182 pub fn supports_p2p(&self) -> bool {
184 !matches!(self, InterconnectType::None)
185 }
186}
187
188#[derive(Debug, Clone)]
190pub struct GpuConnection {
191 pub source: usize,
193 pub destination: usize,
195 pub interconnect: InterconnectType,
197 pub bandwidth_gbps: f64,
199 pub latency_us: f64,
201 pub bidirectional: bool,
203 pub hops: u32,
205}
206
207impl GpuConnection {
208 pub fn new(source: usize, destination: usize, interconnect: InterconnectType) -> Self {
210 Self {
211 source,
212 destination,
213 interconnect,
214 bandwidth_gbps: interconnect.estimated_bandwidth_gbps(),
215 latency_us: interconnect.estimated_latency_us(),
216 bidirectional: true,
217 hops: if source == destination { 0 } else { 1 },
218 }
219 }
220
221 pub fn with_bandwidth(mut self, gbps: f64) -> Self {
223 self.bandwidth_gbps = gbps;
224 self
225 }
226
227 pub fn with_latency(mut self, us: f64) -> Self {
229 self.latency_us = us;
230 self
231 }
232
233 pub fn with_hops(mut self, hops: u32) -> Self {
235 self.hops = hops;
236 self
237 }
238}
239
240#[derive(Debug, Clone)]
242pub struct GpuTopology {
243 pub device_count: usize,
245 connections: Vec<Vec<Option<GpuConnection>>>,
247 pub numa_nodes: Vec<Option<u32>>,
249 pub probed: bool,
251 pub last_updated: Instant,
253}
254
255impl GpuTopology {
256 pub fn new(device_count: usize) -> Self {
258 let mut connections = vec![vec![None; device_count]; device_count];
259
260 for (i, row) in connections.iter_mut().enumerate().take(device_count) {
262 row[i] = Some(GpuConnection::new(i, i, InterconnectType::SameDevice));
263 }
264
265 Self {
266 device_count,
267 connections,
268 numa_nodes: vec![None; device_count],
269 probed: false,
270 last_updated: Instant::now(),
271 }
272 }
273
274 pub fn set_connection(&mut self, connection: GpuConnection) {
276 let src = connection.source;
277 let dst = connection.destination;
278 if src < self.device_count && dst < self.device_count {
279 self.connections[src][dst] = Some(connection.clone());
280 if connection.bidirectional && src != dst {
281 let reverse = GpuConnection {
282 source: dst,
283 destination: src,
284 ..connection
285 };
286 self.connections[dst][src] = Some(reverse);
287 }
288 }
289 }
290
291 pub fn get_connection(&self, source: usize, destination: usize) -> Option<&GpuConnection> {
293 self.connections
294 .get(source)
295 .and_then(|row| row.get(destination))
296 .and_then(|c| c.as_ref())
297 }
298
299 pub fn best_path(&self, source: usize, destination: usize) -> Vec<usize> {
301 if source == destination {
302 return vec![source];
303 }
304
305 if let Some(conn) = self.get_connection(source, destination) {
307 if conn.interconnect != InterconnectType::None {
308 return vec![source, destination];
309 }
310 }
311
312 let mut best_path = vec![source, destination]; let mut best_bandwidth = 0.0;
315
316 for intermediate in 0..self.device_count {
318 if intermediate == source || intermediate == destination {
319 continue;
320 }
321
322 if let (Some(c1), Some(c2)) = (
323 self.get_connection(source, intermediate),
324 self.get_connection(intermediate, destination),
325 ) {
326 let path_bandwidth = c1.bandwidth_gbps.min(c2.bandwidth_gbps);
328 if path_bandwidth > best_bandwidth {
329 best_bandwidth = path_bandwidth;
330 best_path = vec![source, intermediate, destination];
331 }
332 }
333 }
334
335 best_path
336 }
337
338 pub fn neighbors(&self, device: usize) -> Vec<usize> {
340 if device >= self.device_count {
341 return vec![];
342 }
343
344 self.connections[device]
345 .iter()
346 .enumerate()
347 .filter_map(|(i, conn)| {
348 if i != device
349 && conn
350 .as_ref()
351 .map(|c| c.interconnect.supports_p2p())
352 .unwrap_or(false)
353 {
354 Some(i)
355 } else {
356 None
357 }
358 })
359 .collect()
360 }
361
362 pub fn bisection_bandwidth_gbps(&self) -> f64 {
364 let half = self.device_count / 2;
365 if half == 0 {
366 return 0.0;
367 }
368
369 let mut total = 0.0;
370 for src in 0..half {
371 for dst in half..self.device_count {
372 if let Some(conn) = self.get_connection(src, dst) {
373 total += conn.bandwidth_gbps;
374 }
375 }
376 }
377 total
378 }
379
380 pub fn is_fully_connected(&self) -> bool {
382 for src in 0..self.device_count {
383 for dst in 0..self.device_count {
384 if src != dst {
385 if let Some(conn) = self.get_connection(src, dst) {
386 if !conn.interconnect.supports_p2p() {
387 return false;
388 }
389 } else {
390 return false;
391 }
392 }
393 }
394 }
395 true
396 }
397
398 pub fn numa_neighbors(&self, device: usize) -> Vec<usize> {
400 let target_numa = self.numa_nodes.get(device).copied().flatten();
401 if target_numa.is_none() {
402 return vec![];
403 }
404
405 self.numa_nodes
406 .iter()
407 .enumerate()
408 .filter_map(|(i, numa)| {
409 if i != device && *numa == target_numa {
410 Some(i)
411 } else {
412 None
413 }
414 })
415 .collect()
416 }
417
418 pub fn set_numa_node(&mut self, device: usize, numa_node: u32) {
420 if device < self.numa_nodes.len() {
421 self.numa_nodes[device] = Some(numa_node);
422 }
423 }
424
425 pub fn mark_probed(&mut self) {
427 self.probed = true;
428 self.last_updated = Instant::now();
429 }
430}
431
432#[derive(Debug, Clone)]
434pub struct DeviceStatus {
435 pub info: DeviceInfo,
437 pub kernel_count: usize,
439 pub kernels: Vec<KernelId>,
441 pub available: bool,
443 pub load: f64,
445}
446
447#[derive(Debug, Clone)]
449pub struct DeviceUnregisterResult {
450 pub device_index: usize,
452 pub kernels_to_migrate: Vec<KernelMigrationPlan>,
454 pub orphaned_kernels: Vec<KernelId>,
456 pub success: bool,
458}
459
460#[derive(Debug, Clone)]
462pub struct KernelMigrationPlan {
463 pub kernel_id: KernelId,
465 pub source_device: usize,
467 pub target_device: usize,
469 pub priority: MigrationPriority,
471}
472
473#[derive(Debug, Clone, Copy, PartialEq, Eq)]
475pub enum MigrationPriority {
476 Low,
478 Normal,
480 High,
482 Critical,
484}
485
486pub struct MultiGpuCoordinator {
488 config: MultiGpuConfig,
490 devices: RwLock<Vec<DeviceInfo>>,
492 kernel_device_map: RwLock<HashMap<KernelId, usize>>,
494 device_kernel_counts: RwLock<Vec<AtomicUsize>>,
496 round_robin_counter: AtomicUsize,
498 total_kernels: AtomicU64,
500 #[allow(clippy::type_complexity)]
502 custom_selector:
503 RwLock<Option<Arc<dyn Fn(&[DeviceStatus], &LaunchOptions) -> usize + Send + Sync>>>,
504 topology: RwLock<Option<GpuTopology>>,
506}
507
508impl MultiGpuCoordinator {
509 pub fn new(config: MultiGpuConfig) -> Arc<Self> {
511 Arc::new(Self {
512 config,
513 devices: RwLock::new(Vec::new()),
514 kernel_device_map: RwLock::new(HashMap::new()),
515 device_kernel_counts: RwLock::new(Vec::new()),
516 round_robin_counter: AtomicUsize::new(0),
517 total_kernels: AtomicU64::new(0),
518 custom_selector: RwLock::new(None),
519 topology: RwLock::new(None),
520 })
521 }
522
523 pub fn register_device(&self, device: DeviceInfo) {
525 let index = device.index;
526 let mut devices = self.devices.write();
527 let mut counts = self.device_kernel_counts.write();
528
529 let mut current_len = devices.len();
531 while current_len <= index {
532 devices.push(DeviceInfo::new(
533 current_len,
534 "Unknown".to_string(),
535 Backend::Cpu,
536 ));
537 counts.push(AtomicUsize::new(0));
538 current_len += 1;
539 }
540
541 devices[index] = device;
542 }
543
544 pub fn unregister_device(&self, index: usize) -> DeviceUnregisterResult {
556 let devices = self.devices.read();
557
558 if index >= devices.len() {
560 return DeviceUnregisterResult {
561 device_index: index,
562 kernels_to_migrate: Vec::new(),
563 orphaned_kernels: Vec::new(),
564 success: false,
565 };
566 }
567
568 let kernels_on_device = self.kernels_on_device(index);
570
571 let available_targets: Vec<usize> = devices
573 .iter()
574 .enumerate()
575 .filter(|(i, _)| *i != index)
576 .map(|(i, _)| i)
577 .collect();
578
579 drop(devices); let mut kernels_to_migrate = Vec::new();
582 let mut orphaned_kernels = Vec::new();
583
584 if available_targets.is_empty() {
585 orphaned_kernels = kernels_on_device;
587 } else {
588 for kernel_id in kernels_on_device {
590 if let Some(target) = self.select_migration_target(&available_targets) {
592 let priority = self.calculate_migration_priority(&kernel_id);
593 kernels_to_migrate.push(KernelMigrationPlan {
594 kernel_id,
595 source_device: index,
596 target_device: target,
597 priority,
598 });
599 } else {
600 orphaned_kernels.push(kernel_id);
601 }
602 }
603 }
604
605 {
607 let mut kernel_map = self.kernel_device_map.write();
608 let counts = self.device_kernel_counts.read();
609
610 for plan in &kernels_to_migrate {
611 kernel_map.insert(plan.kernel_id.clone(), plan.target_device);
613
614 if index < counts.len() {
616 counts[index].fetch_sub(1, Ordering::Relaxed);
617 }
618 if plan.target_device < counts.len() {
619 counts[plan.target_device].fetch_add(1, Ordering::Relaxed);
620 }
621 }
622
623 for kernel_id in &orphaned_kernels {
625 kernel_map.remove(kernel_id);
626 if index < counts.len() {
627 counts[index].fetch_sub(1, Ordering::Relaxed);
628 }
629 }
630 }
631
632 {
634 let mut devices = self.devices.write();
635 if index < devices.len() {
636 devices[index].available_memory = 0;
637 devices[index].name = format!("{} (unregistered)", devices[index].name);
638 }
639 }
640
641 DeviceUnregisterResult {
642 device_index: index,
643 kernels_to_migrate,
644 orphaned_kernels,
645 success: true,
646 }
647 }
648
649 fn select_migration_target(&self, candidates: &[usize]) -> Option<usize> {
651 if candidates.is_empty() {
652 return None;
653 }
654
655 let counts = self.device_kernel_counts.read();
656
657 candidates
659 .iter()
660 .filter_map(|&idx| {
661 if idx < counts.len() {
662 Some((idx, counts[idx].load(Ordering::Relaxed)))
663 } else {
664 None
665 }
666 })
667 .min_by_key(|(_, count)| *count)
668 .map(|(idx, _)| idx)
669 }
670
671 fn calculate_migration_priority(&self, _kernel_id: &KernelId) -> MigrationPriority {
673 MigrationPriority::Normal
679 }
680
681 pub fn devices(&self) -> Vec<DeviceInfo> {
683 self.devices.read().clone()
684 }
685
686 pub fn device(&self, index: usize) -> Option<DeviceInfo> {
688 self.devices.read().get(index).cloned()
689 }
690
691 pub fn device_count(&self) -> usize {
693 self.devices.read().len()
694 }
695
696 pub fn select_device(&self, options: &LaunchOptions) -> Result<usize> {
698 let devices = self.devices.read();
699 if devices.is_empty() {
700 return Err(RingKernelError::BackendUnavailable(
701 "No GPU devices available".to_string(),
702 ));
703 }
704
705 let status = self.get_all_status();
707
708 if self.config.load_balancing == LoadBalancingStrategy::Custom {
710 if let Some(selector) = &*self.custom_selector.read() {
711 return Ok(selector(&status, options));
712 }
713 }
714
715 let candidates: Vec<_> = if !self.config.preferred_devices.is_empty() {
717 status
718 .into_iter()
719 .filter(|s| self.config.preferred_devices.contains(&s.info.index))
720 .collect()
721 } else {
722 status
723 };
724
725 if candidates.is_empty() {
726 return Err(RingKernelError::BackendUnavailable(
727 "No suitable GPU device available".to_string(),
728 ));
729 }
730
731 let selected = match self.config.load_balancing {
732 LoadBalancingStrategy::FirstAvailable => {
733 candidates.first().map(|s| s.info.index).unwrap_or(0)
734 }
735 LoadBalancingStrategy::LeastLoaded => candidates
736 .iter()
737 .filter(|s| s.available && s.kernel_count < self.config.max_kernels_per_device)
738 .min_by(|a, b| a.kernel_count.cmp(&b.kernel_count))
739 .map(|s| s.info.index)
740 .unwrap_or(0),
741 LoadBalancingStrategy::RoundRobin => {
742 let available: Vec<_> = candidates.iter().filter(|s| s.available).collect();
743
744 if available.is_empty() {
745 candidates.first().map(|s| s.info.index).unwrap_or(0)
746 } else {
747 let idx =
748 self.round_robin_counter.fetch_add(1, Ordering::Relaxed) % available.len();
749 available[idx].info.index
750 }
751 }
752 LoadBalancingStrategy::MemoryBased => candidates
753 .iter()
754 .filter(|s| s.available)
755 .max_by(|a, b| a.info.available_memory.cmp(&b.info.available_memory))
756 .map(|s| s.info.index)
757 .unwrap_or(0),
758 LoadBalancingStrategy::ComputeCapability => candidates
759 .iter()
760 .filter(|s| s.available)
761 .max_by(|a, b| {
762 let a_cap = a.info.compute_capability.unwrap_or((0, 0));
763 let b_cap = b.info.compute_capability.unwrap_or((0, 0));
764 a_cap.cmp(&b_cap)
765 })
766 .map(|s| s.info.index)
767 .unwrap_or(0),
768 LoadBalancingStrategy::Custom => {
769 0
771 }
772 };
773
774 Ok(selected)
775 }
776
777 pub fn assign_kernel(&self, kernel_id: KernelId, device_index: usize) {
779 self.kernel_device_map
780 .write()
781 .insert(kernel_id, device_index);
782
783 let counts = self.device_kernel_counts.read();
784 if device_index < counts.len() {
785 counts[device_index].fetch_add(1, Ordering::Relaxed);
786 }
787
788 self.total_kernels.fetch_add(1, Ordering::Relaxed);
789 }
790
791 pub fn remove_kernel(&self, kernel_id: &KernelId) {
793 if let Some(device_index) = self.kernel_device_map.write().remove(kernel_id) {
794 let counts = self.device_kernel_counts.read();
795 if device_index < counts.len() {
796 counts[device_index].fetch_sub(1, Ordering::Relaxed);
797 }
798 }
799 }
800
801 pub fn get_kernel_device(&self, kernel_id: &KernelId) -> Option<usize> {
803 self.kernel_device_map.read().get(kernel_id).copied()
804 }
805
806 pub fn kernels_on_device(&self, device_index: usize) -> Vec<KernelId> {
808 self.kernel_device_map
809 .read()
810 .iter()
811 .filter(|(_, &idx)| idx == device_index)
812 .map(|(k, _)| k.clone())
813 .collect()
814 }
815
816 pub fn get_all_status(&self) -> Vec<DeviceStatus> {
818 let devices = self.devices.read();
819 let kernel_map = self.kernel_device_map.read();
820 let counts = self.device_kernel_counts.read();
821
822 devices
823 .iter()
824 .enumerate()
825 .map(|(idx, info)| {
826 let kernel_count = if idx < counts.len() {
827 counts[idx].load(Ordering::Relaxed)
828 } else {
829 0
830 };
831
832 let kernels: Vec<_> = kernel_map
833 .iter()
834 .filter(|(_, &dev_idx)| dev_idx == idx)
835 .map(|(k, _)| k.clone())
836 .collect();
837
838 let load = kernel_count as f64 / self.config.max_kernels_per_device as f64;
839 let available = kernel_count < self.config.max_kernels_per_device;
840
841 DeviceStatus {
842 info: info.clone(),
843 kernel_count,
844 kernels,
845 available,
846 load,
847 }
848 })
849 .collect()
850 }
851
852 pub fn get_device_status(&self, device_index: usize) -> Option<DeviceStatus> {
854 self.get_all_status().into_iter().nth(device_index)
855 }
856
857 pub fn set_custom_selector<F>(&self, selector: F)
859 where
860 F: Fn(&[DeviceStatus], &LaunchOptions) -> usize + Send + Sync + 'static,
861 {
862 *self.custom_selector.write() = Some(Arc::new(selector));
863 }
864
865 pub fn stats(&self) -> MultiGpuStats {
867 let status = self.get_all_status();
868 let total_kernels: usize = status.iter().map(|s| s.kernel_count).sum();
869 let total_memory: u64 = status.iter().map(|s| s.info.total_memory).sum();
870 let available_memory: u64 = status.iter().map(|s| s.info.available_memory).sum();
871
872 MultiGpuStats {
873 device_count: status.len(),
874 total_kernels,
875 total_memory,
876 available_memory,
877 kernels_launched: self.total_kernels.load(Ordering::Relaxed),
878 }
879 }
880
881 pub fn can_p2p(&self, device_a: usize, device_b: usize) -> bool {
883 if !self.config.enable_p2p {
884 return false;
885 }
886
887 let devices = self.devices.read();
888 if let (Some(a), Some(b)) = (devices.get(device_a), devices.get(device_b)) {
889 a.p2p_capable && b.p2p_capable
890 } else {
891 false
892 }
893 }
894
895 pub fn update_device_memory(&self, device_index: usize, available_memory: u64) {
897 let mut devices = self.devices.write();
898 if let Some(device) = devices.get_mut(device_index) {
899 device.available_memory = available_memory;
900 }
901 }
902
903 pub fn discover_topology(&self) -> GpuTopology {
909 let devices = self.devices.read();
910 let device_count = devices.len();
911
912 if device_count == 0 {
913 return GpuTopology::new(0);
914 }
915
916 let mut topo = GpuTopology::new(device_count);
917
918 for (i, dev_i) in devices.iter().enumerate() {
920 for (j, dev_j) in devices.iter().enumerate() {
921 if i == j {
922 continue;
923 }
924
925 let interconnect = if dev_i.p2p_capable && dev_j.p2p_capable {
927 if dev_i.backend == dev_j.backend {
929 match dev_i.backend {
930 Backend::Cuda => {
931 let cc_i = dev_i.compute_capability.unwrap_or((0, 0));
933 let cc_j = dev_j.compute_capability.unwrap_or((0, 0));
934
935 if cc_i.0 >= 8 && cc_j.0 >= 8 {
937 InterconnectType::NvLink
938 } else {
939 InterconnectType::Pcie
940 }
941 }
942 _ => InterconnectType::Pcie,
943 }
944 } else {
945 InterconnectType::None
946 }
947 } else {
948 InterconnectType::None
949 };
950
951 topo.set_connection(GpuConnection::new(i, j, interconnect));
952 }
953 }
954
955 *self.topology.write() = Some(topo.clone());
957
958 topo
959 }
960
961 pub fn topology(&self) -> GpuTopology {
963 {
964 let topo = self.topology.read();
965 if let Some(ref t) = *topo {
966 return t.clone();
967 }
968 }
969 self.discover_topology()
970 }
971
972 pub fn set_topology(&self, topology: GpuTopology) {
974 *self.topology.write() = Some(topology);
975 }
976
977 pub fn select_device_for_k2k(&self, source_kernel: &KernelId) -> Result<usize> {
979 let source_device = self.get_kernel_device(source_kernel);
980 if source_device.is_none() {
981 return self.select_device(&LaunchOptions::default());
982 }
983
984 let source_idx = source_device.unwrap();
985 let topo = self.topology();
986 let status = self.get_all_status();
987
988 let neighbors = topo.neighbors(source_idx);
990
991 if neighbors.is_empty() {
992 return self.select_device(&LaunchOptions::default());
994 }
995
996 let best = neighbors
998 .iter()
999 .filter_map(|&dev_idx| {
1000 status.iter().find(|s| s.info.index == dev_idx).map(|s| {
1001 let conn = topo.get_connection(source_idx, dev_idx);
1002 let bandwidth = conn.map(|c| c.bandwidth_gbps).unwrap_or(1.0);
1003 let score = bandwidth / (s.load + 0.1);
1004 (dev_idx, score)
1005 })
1006 })
1007 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
1008 .map(|(idx, _)| idx);
1009
1010 best.ok_or_else(|| {
1011 RingKernelError::BackendUnavailable("No suitable K2K device found".to_string())
1012 })
1013 }
1014
1015 pub fn request_migration(
1021 &self,
1022 kernel_id: &KernelId,
1023 target_device: usize,
1024 ) -> Result<MigrationRequest> {
1025 let source_device = self
1026 .get_kernel_device(kernel_id)
1027 .ok_or_else(|| RingKernelError::KernelNotFound(kernel_id.as_str().to_string()))?;
1028
1029 if source_device == target_device {
1030 return Err(RingKernelError::InvalidConfig(
1031 "Cannot migrate to same device".to_string(),
1032 ));
1033 }
1034
1035 let devices = self.devices.read();
1036 if target_device >= devices.len() {
1037 return Err(RingKernelError::DeviceNotAvailable(format!(
1038 "Device {} not available",
1039 target_device
1040 )));
1041 }
1042
1043 let topo = self.topology();
1044 let path = topo.best_path(source_device, target_device);
1045 let connection = topo.get_connection(source_device, target_device);
1046
1047 Ok(MigrationRequest {
1048 kernel_id: kernel_id.clone(),
1049 source_device,
1050 target_device,
1051 path,
1052 estimated_bandwidth_gbps: connection.map(|c| c.bandwidth_gbps).unwrap_or(16.0),
1053 estimated_latency_us: connection.map(|c| c.latency_us).unwrap_or(10.0),
1054 state: MigrationState::Pending,
1055 started_at: None,
1056 })
1057 }
1058
1059 pub fn complete_migration(&self, request: &MigrationRequest) -> Result<()> {
1061 {
1063 let mut map = self.kernel_device_map.write();
1064 if let Some(dev) = map.get_mut(&request.kernel_id) {
1065 *dev = request.target_device;
1066 }
1067 }
1068
1069 {
1071 let counts = self.device_kernel_counts.read();
1072 if request.source_device < counts.len() {
1073 counts[request.source_device].fetch_sub(1, Ordering::Relaxed);
1074 }
1075 if request.target_device < counts.len() {
1076 counts[request.target_device].fetch_add(1, Ordering::Relaxed);
1077 }
1078 }
1079
1080 Ok(())
1081 }
1082}
1083
1084#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1090pub enum MigrationState {
1091 Pending,
1093 Quiescing,
1095 Checkpointing,
1097 Transferring,
1099 Restoring,
1101 Completed,
1103 Failed,
1105 Cancelled,
1107}
1108
1109#[derive(Debug, Clone)]
1111pub struct MigrationRequest {
1112 pub kernel_id: KernelId,
1114 pub source_device: usize,
1116 pub target_device: usize,
1118 pub path: Vec<usize>,
1120 pub estimated_bandwidth_gbps: f64,
1122 pub estimated_latency_us: f64,
1124 pub state: MigrationState,
1126 pub started_at: Option<Instant>,
1128}
1129
1130impl MigrationRequest {
1131 pub fn estimate_transfer_time(&self, state_size_bytes: usize) -> Duration {
1133 let size_gb = state_size_bytes as f64 / 1_000_000_000.0;
1135 let transfer_time_s = size_gb / self.estimated_bandwidth_gbps;
1136 let total_us = (transfer_time_s * 1_000_000.0) + self.estimated_latency_us;
1137 Duration::from_micros(total_us as u64)
1138 }
1139}
1140
1141pub struct CrossGpuK2KRouter {
1147 coordinator: Arc<MultiGpuCoordinator>,
1149 pending_queues: RwLock<HashMap<(usize, usize), Vec<PendingK2KMessage>>>,
1151 stats: CrossGpuRouterStats,
1153}
1154
1155#[derive(Debug, Clone)]
1157pub struct PendingK2KMessage {
1158 pub source_kernel: KernelId,
1160 pub dest_kernel: KernelId,
1162 pub message: K2KMessage,
1164 pub queued_at: Instant,
1166 pub hops: u32,
1168}
1169
1170#[derive(Debug, Default)]
1172pub struct CrossGpuRouterStats {
1173 messages_routed: AtomicU64,
1175 bytes_transferred: AtomicU64,
1177 messages_pending: AtomicUsize,
1179 total_latency_us: AtomicU64,
1181 routing_failures: AtomicU64,
1183}
1184
1185impl CrossGpuK2KRouter {
1186 pub fn new(coordinator: Arc<MultiGpuCoordinator>) -> Arc<Self> {
1188 Arc::new(Self {
1189 coordinator,
1190 pending_queues: RwLock::new(HashMap::new()),
1191 stats: CrossGpuRouterStats::default(),
1192 })
1193 }
1194
1195 pub fn route_message(
1197 &self,
1198 source_kernel: &KernelId,
1199 dest_kernel: &KernelId,
1200 message: K2KMessage,
1201 ) -> Result<RoutingDecision> {
1202 let source_device = self
1203 .coordinator
1204 .get_kernel_device(source_kernel)
1205 .ok_or_else(|| {
1206 RingKernelError::K2KDestinationNotFound(source_kernel.as_str().to_string())
1207 })?;
1208
1209 let dest_device = self
1210 .coordinator
1211 .get_kernel_device(dest_kernel)
1212 .ok_or_else(|| {
1213 RingKernelError::K2KDestinationNotFound(dest_kernel.as_str().to_string())
1214 })?;
1215
1216 if source_device == dest_device {
1218 return Ok(RoutingDecision::SameDevice);
1219 }
1220
1221 let topo = self.coordinator.topology();
1223 let path = topo.best_path(source_device, dest_device);
1224
1225 if let Some(conn) = topo.get_connection(source_device, dest_device) {
1227 if conn.interconnect.supports_p2p() {
1228 let pending = PendingK2KMessage {
1230 source_kernel: source_kernel.clone(),
1231 dest_kernel: dest_kernel.clone(),
1232 message,
1233 queued_at: Instant::now(),
1234 hops: 1,
1235 };
1236
1237 self.enqueue_pending(source_device, dest_device, pending);
1238 self.stats.messages_pending.fetch_add(1, Ordering::Relaxed);
1239
1240 return Ok(RoutingDecision::DirectP2P {
1241 source_device,
1242 dest_device,
1243 bandwidth_gbps: conn.bandwidth_gbps,
1244 });
1245 }
1246 }
1247
1248 if path.len() > 2 {
1250 let pending = PendingK2KMessage {
1251 source_kernel: source_kernel.clone(),
1252 dest_kernel: dest_kernel.clone(),
1253 message,
1254 queued_at: Instant::now(),
1255 hops: (path.len() - 1) as u32,
1256 };
1257
1258 self.enqueue_pending(source_device, path[1], pending);
1260 self.stats.messages_pending.fetch_add(1, Ordering::Relaxed);
1261
1262 return Ok(RoutingDecision::MultiHop {
1263 path: path.clone(),
1264 total_hops: (path.len() - 1) as u32,
1265 });
1266 }
1267
1268 let pending = PendingK2KMessage {
1270 source_kernel: source_kernel.clone(),
1271 dest_kernel: dest_kernel.clone(),
1272 message,
1273 queued_at: Instant::now(),
1274 hops: 2, };
1276
1277 self.enqueue_pending(source_device, dest_device, pending);
1278 self.stats.messages_pending.fetch_add(1, Ordering::Relaxed);
1279
1280 Ok(RoutingDecision::HostMediated {
1281 source_device,
1282 dest_device,
1283 })
1284 }
1285
1286 pub fn drain_pending(&self, source: usize, dest: usize) -> Vec<PendingK2KMessage> {
1288 let mut queues = self.pending_queues.write();
1289 let messages = queues.remove(&(source, dest)).unwrap_or_default();
1290 self.stats
1291 .messages_pending
1292 .fetch_sub(messages.len(), Ordering::Relaxed);
1293 messages
1294 }
1295
1296 pub fn record_delivery(&self, message: &PendingK2KMessage, payload_size: usize) {
1298 self.stats.messages_routed.fetch_add(1, Ordering::Relaxed);
1299 self.stats
1300 .bytes_transferred
1301 .fetch_add(payload_size as u64, Ordering::Relaxed);
1302
1303 let latency = message.queued_at.elapsed().as_micros() as u64;
1304 self.stats
1305 .total_latency_us
1306 .fetch_add(latency, Ordering::Relaxed);
1307 }
1308
1309 pub fn record_failure(&self) {
1311 self.stats.routing_failures.fetch_add(1, Ordering::Relaxed);
1312 }
1313
1314 pub fn stats(&self) -> CrossGpuRouterStatsSnapshot {
1316 let messages_routed = self.stats.messages_routed.load(Ordering::Relaxed);
1317 let total_latency = self.stats.total_latency_us.load(Ordering::Relaxed);
1318
1319 CrossGpuRouterStatsSnapshot {
1320 messages_routed,
1321 bytes_transferred: self.stats.bytes_transferred.load(Ordering::Relaxed),
1322 messages_pending: self.stats.messages_pending.load(Ordering::Relaxed),
1323 avg_latency_us: if messages_routed > 0 {
1324 total_latency as f64 / messages_routed as f64
1325 } else {
1326 0.0
1327 },
1328 routing_failures: self.stats.routing_failures.load(Ordering::Relaxed),
1329 }
1330 }
1331
1332 fn enqueue_pending(&self, source: usize, dest: usize, message: PendingK2KMessage) {
1333 let mut queues = self.pending_queues.write();
1334 queues.entry((source, dest)).or_default().push(message);
1335 }
1336}
1337
1338#[derive(Debug, Clone)]
1340pub struct CrossGpuRouterStatsSnapshot {
1341 pub messages_routed: u64,
1343 pub bytes_transferred: u64,
1345 pub messages_pending: usize,
1347 pub avg_latency_us: f64,
1349 pub routing_failures: u64,
1351}
1352
1353#[derive(Debug, Clone)]
1355pub enum RoutingDecision {
1356 SameDevice,
1358 DirectP2P {
1360 source_device: usize,
1362 dest_device: usize,
1364 bandwidth_gbps: f64,
1366 },
1367 MultiHop {
1369 path: Vec<usize>,
1371 total_hops: u32,
1373 },
1374 HostMediated {
1376 source_device: usize,
1378 dest_device: usize,
1380 },
1381}
1382
1383#[derive(Debug, Clone, Default)]
1385pub struct MultiGpuStats {
1386 pub device_count: usize,
1388 pub total_kernels: usize,
1390 pub total_memory: u64,
1392 pub available_memory: u64,
1394 pub kernels_launched: u64,
1396}
1397
1398pub struct MultiGpuBuilder {
1400 config: MultiGpuConfig,
1401}
1402
1403impl MultiGpuBuilder {
1404 pub fn new() -> Self {
1406 Self {
1407 config: MultiGpuConfig::default(),
1408 }
1409 }
1410
1411 pub fn load_balancing(mut self, strategy: LoadBalancingStrategy) -> Self {
1413 self.config.load_balancing = strategy;
1414 self
1415 }
1416
1417 pub fn auto_select_device(mut self, enable: bool) -> Self {
1419 self.config.auto_select_device = enable;
1420 self
1421 }
1422
1423 pub fn max_kernels_per_device(mut self, max: usize) -> Self {
1425 self.config.max_kernels_per_device = max;
1426 self
1427 }
1428
1429 pub fn enable_p2p(mut self, enable: bool) -> Self {
1431 self.config.enable_p2p = enable;
1432 self
1433 }
1434
1435 pub fn preferred_devices(mut self, devices: Vec<usize>) -> Self {
1437 self.config.preferred_devices = devices;
1438 self
1439 }
1440
1441 pub fn build(self) -> Arc<MultiGpuCoordinator> {
1443 MultiGpuCoordinator::new(self.config)
1444 }
1445}
1446
1447impl Default for MultiGpuBuilder {
1448 fn default() -> Self {
1449 Self::new()
1450 }
1451}
1452
1453pub struct CrossDeviceTransfer {
1455 pub source_device: usize,
1457 pub dest_device: usize,
1459 pub size: usize,
1461 pub use_p2p: bool,
1463}
1464
1465impl CrossDeviceTransfer {
1466 pub fn new(source: usize, dest: usize, size: usize) -> Self {
1468 Self {
1469 source_device: source,
1470 dest_device: dest,
1471 size,
1472 use_p2p: true,
1473 }
1474 }
1475
1476 pub fn without_p2p(mut self) -> Self {
1478 self.use_p2p = false;
1479 self
1480 }
1481}
1482
1483use crate::checkpoint::{CheckpointStorage, CheckpointableKernel, MemoryStorage};
1488
1489pub struct KernelMigrator {
1506 coordinator: Arc<MultiGpuCoordinator>,
1508 storage: Arc<dyn CheckpointStorage>,
1510 stats: MigrationStats,
1512}
1513
1514#[derive(Debug, Default)]
1516pub struct MigrationStats {
1517 pub successful_migrations: AtomicU64,
1519 pub failed_migrations: AtomicU64,
1521 pub bytes_transferred: AtomicU64,
1523 pub checkpoint_time_us: AtomicU64,
1525 pub restore_time_us: AtomicU64,
1527}
1528
1529#[derive(Debug, Clone)]
1531pub struct MigrationResult {
1532 pub kernel_id: KernelId,
1534 pub source_device: usize,
1536 pub target_device: usize,
1538 pub checkpoint_size: usize,
1540 pub checkpoint_duration: Duration,
1542 pub transfer_duration: Duration,
1544 pub restore_duration: Duration,
1546 pub total_duration: Duration,
1548}
1549
1550impl KernelMigrator {
1551 pub fn new(coordinator: Arc<MultiGpuCoordinator>) -> Self {
1553 Self {
1554 coordinator,
1555 storage: Arc::new(MemoryStorage::new()),
1556 stats: MigrationStats::default(),
1557 }
1558 }
1559
1560 pub fn with_storage(
1562 coordinator: Arc<MultiGpuCoordinator>,
1563 storage: Arc<dyn CheckpointStorage>,
1564 ) -> Self {
1565 Self {
1566 coordinator,
1567 storage,
1568 stats: MigrationStats::default(),
1569 }
1570 }
1571
1572 pub fn migrate_with_checkpoint<K: CheckpointableKernel>(
1581 &self,
1582 kernel: &K,
1583 request: &mut MigrationRequest,
1584 ) -> Result<MigrationResult> {
1585 let start_time = Instant::now();
1586 request.started_at = Some(start_time);
1587
1588 request.state = MigrationState::Quiescing;
1590 request.state = MigrationState::Checkpointing;
1595 let checkpoint_start = Instant::now();
1596 let checkpoint = kernel.create_checkpoint().map_err(|e| {
1597 self.stats.failed_migrations.fetch_add(1, Ordering::Relaxed);
1598 request.state = MigrationState::Failed;
1599 RingKernelError::MigrationFailed(format!("Checkpoint creation failed: {}", e))
1600 })?;
1601 let checkpoint_duration = checkpoint_start.elapsed();
1602 let checkpoint_size = checkpoint.total_size();
1603
1604 self.stats
1605 .checkpoint_time_us
1606 .fetch_add(checkpoint_duration.as_micros() as u64, Ordering::Relaxed);
1607
1608 request.state = MigrationState::Transferring;
1610 let transfer_start = Instant::now();
1611
1612 let checkpoint_name = format!(
1614 "migration_{}_{}_{}",
1615 request.kernel_id.as_str(),
1616 request.source_device,
1617 request.target_device
1618 );
1619 self.storage
1620 .save(&checkpoint, &checkpoint_name)
1621 .map_err(|e| {
1622 self.stats.failed_migrations.fetch_add(1, Ordering::Relaxed);
1623 request.state = MigrationState::Failed;
1624 RingKernelError::MigrationFailed(format!("Checkpoint transfer failed: {}", e))
1625 })?;
1626
1627 let transfer_duration = transfer_start.elapsed();
1628 self.stats
1629 .bytes_transferred
1630 .fetch_add(checkpoint_size as u64, Ordering::Relaxed);
1631
1632 request.state = MigrationState::Restoring;
1634 let restore_start = Instant::now();
1635
1636 let _restored = self.storage.load(&checkpoint_name).map_err(|e| {
1638 self.stats.failed_migrations.fetch_add(1, Ordering::Relaxed);
1639 request.state = MigrationState::Failed;
1640 RingKernelError::MigrationFailed(format!("Checkpoint restore failed: {}", e))
1641 })?;
1642
1643 let restore_duration = restore_start.elapsed();
1644 self.stats
1645 .restore_time_us
1646 .fetch_add(restore_duration.as_micros() as u64, Ordering::Relaxed);
1647
1648 request.state = MigrationState::Completed;
1650 self.coordinator.complete_migration(request)?;
1651
1652 let _ = self.storage.delete(&checkpoint_name);
1654
1655 self.stats
1656 .successful_migrations
1657 .fetch_add(1, Ordering::Relaxed);
1658
1659 Ok(MigrationResult {
1660 kernel_id: request.kernel_id.clone(),
1661 source_device: request.source_device,
1662 target_device: request.target_device,
1663 checkpoint_size,
1664 checkpoint_duration,
1665 transfer_duration,
1666 restore_duration,
1667 total_duration: start_time.elapsed(),
1668 })
1669 }
1670
1671 pub fn coordinator(&self) -> &Arc<MultiGpuCoordinator> {
1673 &self.coordinator
1674 }
1675
1676 pub fn stats(&self) -> MigrationStatsSnapshot {
1678 let successful = self.stats.successful_migrations.load(Ordering::Relaxed);
1679 let failed = self.stats.failed_migrations.load(Ordering::Relaxed);
1680 let total = successful + failed;
1681 let checkpoint_us = self.stats.checkpoint_time_us.load(Ordering::Relaxed);
1682 let restore_us = self.stats.restore_time_us.load(Ordering::Relaxed);
1683
1684 MigrationStatsSnapshot {
1685 successful_migrations: successful,
1686 failed_migrations: failed,
1687 bytes_transferred: self.stats.bytes_transferred.load(Ordering::Relaxed),
1688 avg_checkpoint_time: checkpoint_us
1689 .checked_div(total)
1690 .map(Duration::from_micros)
1691 .unwrap_or(Duration::ZERO),
1692 avg_restore_time: restore_us
1693 .checked_div(total)
1694 .map(Duration::from_micros)
1695 .unwrap_or(Duration::ZERO),
1696 }
1697 }
1698}
1699
1700#[derive(Debug, Clone)]
1702pub struct MigrationStatsSnapshot {
1703 pub successful_migrations: u64,
1705 pub failed_migrations: u64,
1707 pub bytes_transferred: u64,
1709 pub avg_checkpoint_time: Duration,
1711 pub avg_restore_time: Duration,
1713}
1714
1715pub trait MigratableKernel: CheckpointableKernel {
1717 fn prepare_for_migration(&mut self) -> Result<()>;
1719
1720 fn cancel_migration(&mut self) -> Result<()>;
1722
1723 fn is_quiescent(&self) -> bool;
1725
1726 fn estimated_state_size(&self) -> usize;
1728}
1729
1730#[derive(Debug, Clone)]
1736pub struct HotReloadConfig {
1737 pub enabled: bool,
1739 pub reload_timeout: Duration,
1741 pub preserve_state: bool,
1743 pub max_retries: u32,
1745 pub retry_backoff: Duration,
1747 pub validate_before_swap: bool,
1749 pub keep_fallback: bool,
1751}
1752
1753impl Default for HotReloadConfig {
1754 fn default() -> Self {
1755 Self {
1756 enabled: true,
1757 reload_timeout: Duration::from_secs(30),
1758 preserve_state: true,
1759 max_retries: 3,
1760 retry_backoff: Duration::from_millis(500),
1761 validate_before_swap: true,
1762 keep_fallback: true,
1763 }
1764 }
1765}
1766
1767impl HotReloadConfig {
1768 pub fn new() -> Self {
1770 Self::default()
1771 }
1772
1773 pub fn with_enabled(mut self, enabled: bool) -> Self {
1775 self.enabled = enabled;
1776 self
1777 }
1778
1779 pub fn with_timeout(mut self, timeout: Duration) -> Self {
1781 self.reload_timeout = timeout;
1782 self
1783 }
1784
1785 pub fn with_preserve_state(mut self, preserve: bool) -> Self {
1787 self.preserve_state = preserve;
1788 self
1789 }
1790
1791 pub fn with_max_retries(mut self, retries: u32) -> Self {
1793 self.max_retries = retries;
1794 self
1795 }
1796}
1797
1798#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1800pub enum HotReloadState {
1801 Idle,
1803 Draining,
1805 Checkpointing,
1807 Compiling,
1809 Validating,
1811 Swapping,
1813 Restoring,
1815 Completed,
1817 Failed,
1819 RollingBack,
1821}
1822
1823#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1825pub enum KernelCodeFormat {
1826 Ptx,
1828 Cubin,
1830 SpirV,
1832 Wgsl,
1834 Msl,
1836 MetalLib,
1838 Source,
1840}
1841
1842#[derive(Debug, Clone)]
1844pub struct KernelCodeSource {
1845 pub version_id: u64,
1847 pub format: KernelCodeFormat,
1849 pub code: Vec<u8>,
1851 pub entry_point: String,
1853 pub metadata: HashMap<String, String>,
1855 pub created_at: Instant,
1857 pub hash: [u8; 32],
1859}
1860
1861impl KernelCodeSource {
1862 pub fn new(format: KernelCodeFormat, code: Vec<u8>, entry_point: impl Into<String>) -> Self {
1864 let hash = Self::compute_hash(&code);
1865 Self {
1866 version_id: 0,
1867 format,
1868 code,
1869 entry_point: entry_point.into(),
1870 metadata: HashMap::new(),
1871 created_at: Instant::now(),
1872 hash,
1873 }
1874 }
1875
1876 pub fn from_ptx(ptx: &str, entry_point: impl Into<String>) -> Self {
1878 Self::new(KernelCodeFormat::Ptx, ptx.as_bytes().to_vec(), entry_point)
1879 }
1880
1881 pub fn from_wgsl(wgsl: &str, entry_point: impl Into<String>) -> Self {
1883 Self::new(
1884 KernelCodeFormat::Wgsl,
1885 wgsl.as_bytes().to_vec(),
1886 entry_point,
1887 )
1888 }
1889
1890 pub fn from_msl(msl: &str, entry_point: impl Into<String>) -> Self {
1892 Self::new(KernelCodeFormat::Msl, msl.as_bytes().to_vec(), entry_point)
1893 }
1894
1895 pub fn with_version(mut self, version: u64) -> Self {
1897 self.version_id = version;
1898 self
1899 }
1900
1901 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
1903 self.metadata.insert(key.into(), value.into());
1904 self
1905 }
1906
1907 fn compute_hash(data: &[u8]) -> [u8; 32] {
1908 use std::hash::{Hash, Hasher};
1909 let mut hasher = std::collections::hash_map::DefaultHasher::new();
1910 data.hash(&mut hasher);
1911 let h1 = hasher.finish();
1912 h1.hash(&mut hasher);
1913 let h2 = hasher.finish();
1914 h1.hash(&mut hasher);
1915 let h3 = hasher.finish();
1916 h1.hash(&mut hasher);
1917 let h4 = hasher.finish();
1918
1919 let mut hash = [0u8; 32];
1920 hash[0..8].copy_from_slice(&h1.to_le_bytes());
1921 hash[8..16].copy_from_slice(&h2.to_le_bytes());
1922 hash[16..24].copy_from_slice(&h3.to_le_bytes());
1923 hash[24..32].copy_from_slice(&h4.to_le_bytes());
1924 hash
1925 }
1926
1927 pub fn as_str(&self) -> Option<&str> {
1929 match self.format {
1930 KernelCodeFormat::Ptx
1931 | KernelCodeFormat::Wgsl
1932 | KernelCodeFormat::Msl
1933 | KernelCodeFormat::Source => std::str::from_utf8(&self.code).ok(),
1934 _ => None,
1935 }
1936 }
1937
1938 pub fn size(&self) -> usize {
1940 self.code.len()
1941 }
1942}
1943
1944#[derive(Debug)]
1946pub struct HotReloadRequest {
1947 pub kernel_id: KernelId,
1949 pub new_code: KernelCodeSource,
1951 pub state: HotReloadState,
1953 pub created_at: Instant,
1955 pub started_at: Option<Instant>,
1957 pub retry_count: u32,
1959 pub error: Option<String>,
1961 checkpoint_data: Option<Vec<u8>>,
1963}
1964
1965impl HotReloadRequest {
1966 pub fn new(kernel_id: KernelId, new_code: KernelCodeSource) -> Self {
1968 Self {
1969 kernel_id,
1970 new_code,
1971 state: HotReloadState::Idle,
1972 created_at: Instant::now(),
1973 started_at: None,
1974 retry_count: 0,
1975 error: None,
1976 checkpoint_data: None,
1977 }
1978 }
1979
1980 pub fn is_in_progress(&self) -> bool {
1982 !matches!(
1983 self.state,
1984 HotReloadState::Idle | HotReloadState::Completed | HotReloadState::Failed
1985 )
1986 }
1987
1988 pub fn is_completed(&self) -> bool {
1990 self.state == HotReloadState::Completed
1991 }
1992
1993 pub fn is_failed(&self) -> bool {
1995 self.state == HotReloadState::Failed
1996 }
1997
1998 pub fn elapsed(&self) -> Duration {
2000 self.created_at.elapsed()
2001 }
2002
2003 pub fn reload_elapsed(&self) -> Option<Duration> {
2005 self.started_at.map(|s| s.elapsed())
2006 }
2007}
2008
2009#[derive(Debug, Clone)]
2011pub struct HotReloadResult {
2012 pub kernel_id: KernelId,
2014 pub old_version: u64,
2016 pub new_version: u64,
2018 pub state_preserved: bool,
2020 pub checkpoint_size: usize,
2022 pub drain_duration: Duration,
2024 pub checkpoint_duration: Duration,
2026 pub compile_duration: Duration,
2028 pub swap_duration: Duration,
2030 pub restore_duration: Duration,
2032 pub total_duration: Duration,
2034}
2035
2036#[derive(Debug, Default)]
2038struct HotReloadStats {
2039 successful_reloads: AtomicU64,
2040 failed_reloads: AtomicU64,
2041 rollbacks: AtomicU64,
2042 total_drain_time_us: AtomicU64,
2043 total_compile_time_us: AtomicU64,
2044 total_swap_time_us: AtomicU64,
2045 state_preserved_count: AtomicU64,
2046}
2047
2048#[derive(Debug, Clone)]
2050pub struct HotReloadStatsSnapshot {
2051 pub successful_reloads: u64,
2053 pub failed_reloads: u64,
2055 pub rollbacks: u64,
2057 pub avg_drain_time: Duration,
2059 pub avg_compile_time: Duration,
2061 pub avg_swap_time: Duration,
2063 pub state_preserved_count: u64,
2065}
2066
2067pub struct HotReloadManager {
2097 config: HotReloadConfig,
2099 kernels: RwLock<HashMap<KernelId, KernelCodeSource>>,
2101 fallbacks: RwLock<HashMap<KernelId, KernelCodeSource>>,
2103 active_requests: RwLock<HashMap<KernelId, HotReloadRequest>>,
2105 version_counter: AtomicU64,
2107 stats: HotReloadStats,
2109}
2110
2111impl HotReloadManager {
2112 pub fn new(config: HotReloadConfig) -> Arc<Self> {
2114 Arc::new(Self {
2115 config,
2116 kernels: RwLock::new(HashMap::new()),
2117 fallbacks: RwLock::new(HashMap::new()),
2118 active_requests: RwLock::new(HashMap::new()),
2119 version_counter: AtomicU64::new(1),
2120 stats: HotReloadStats::default(),
2121 })
2122 }
2123
2124 pub fn with_defaults() -> Arc<Self> {
2126 Self::new(HotReloadConfig::default())
2127 }
2128
2129 pub fn is_enabled(&self) -> bool {
2131 self.config.enabled
2132 }
2133
2134 pub fn register_kernel(&self, kernel_id: &KernelId, code: KernelCodeSource) {
2136 let version = self.version_counter.fetch_add(1, Ordering::Relaxed);
2137 let code = code.with_version(version);
2138 self.kernels.write().insert(kernel_id.clone(), code);
2139 }
2140
2141 pub fn unregister_kernel(&self, kernel_id: &KernelId) {
2143 self.kernels.write().remove(kernel_id);
2144 self.fallbacks.write().remove(kernel_id);
2145 self.active_requests.write().remove(kernel_id);
2146 }
2147
2148 pub fn get_current_version(&self, kernel_id: &KernelId) -> Option<u64> {
2150 self.kernels.read().get(kernel_id).map(|c| c.version_id)
2151 }
2152
2153 pub fn get_current_code(&self, kernel_id: &KernelId) -> Option<KernelCodeSource> {
2155 self.kernels.read().get(kernel_id).cloned()
2156 }
2157
2158 pub fn request_reload(
2160 &self,
2161 kernel_id: &KernelId,
2162 new_code: KernelCodeSource,
2163 ) -> Result<HotReloadRequest> {
2164 if !self.config.enabled {
2165 return Err(RingKernelError::ValidationError(
2166 "Hot reload is disabled".to_string(),
2167 ));
2168 }
2169
2170 if !self.kernels.read().contains_key(kernel_id) {
2172 return Err(RingKernelError::KernelNotFound(
2173 kernel_id.as_str().to_string(),
2174 ));
2175 }
2176
2177 {
2179 let active = self.active_requests.read();
2180 if let Some(existing) = active.get(kernel_id) {
2181 if existing.is_in_progress() {
2182 return Err(RingKernelError::ValidationError(
2183 "Hot reload already in progress for this kernel".to_string(),
2184 ));
2185 }
2186 }
2187 }
2188
2189 let version = self.version_counter.fetch_add(1, Ordering::Relaxed);
2191 let new_code = new_code.with_version(version);
2192
2193 let request = HotReloadRequest::new(kernel_id.clone(), new_code);
2194 self.active_requests.write().insert(
2195 kernel_id.clone(),
2196 HotReloadRequest::new(kernel_id.clone(), request.new_code.clone()),
2197 );
2198
2199 Ok(request)
2200 }
2201
2202 pub fn execute_reload<K: CheckpointableKernel>(
2211 &self,
2212 request: &mut HotReloadRequest,
2213 kernel: &K,
2214 ) -> Result<HotReloadResult> {
2215 let start_time = Instant::now();
2216 request.started_at = Some(start_time);
2217
2218 let old_version = self
2220 .kernels
2221 .read()
2222 .get(&request.kernel_id)
2223 .map(|c| c.version_id)
2224 .unwrap_or(0);
2225
2226 request.state = HotReloadState::Draining;
2228 let drain_start = Instant::now();
2229 std::thread::sleep(Duration::from_micros(10));
2231 let drain_duration = drain_start.elapsed();
2232 self.stats
2233 .total_drain_time_us
2234 .fetch_add(drain_duration.as_micros() as u64, Ordering::Relaxed);
2235
2236 request.state = HotReloadState::Checkpointing;
2238 let checkpoint_start = Instant::now();
2239 let checkpoint_size = if self.config.preserve_state {
2240 let checkpoint = kernel.create_checkpoint()?;
2241 let data = checkpoint.to_bytes();
2242 request.checkpoint_data = Some(data.clone());
2243 data.len()
2244 } else {
2245 0
2246 };
2247 let checkpoint_duration = checkpoint_start.elapsed();
2248
2249 request.state = HotReloadState::Validating;
2251 if self.config.validate_before_swap {
2252 self.validate_code(&request.new_code)?;
2253 }
2254
2255 request.state = HotReloadState::Compiling;
2257 let compile_start = Instant::now();
2258 std::thread::sleep(Duration::from_micros(10));
2260 let compile_duration = compile_start.elapsed();
2261 self.stats
2262 .total_compile_time_us
2263 .fetch_add(compile_duration.as_micros() as u64, Ordering::Relaxed);
2264
2265 request.state = HotReloadState::Swapping;
2267 let swap_start = Instant::now();
2268
2269 if self.config.keep_fallback {
2271 if let Some(old_code) = self.kernels.read().get(&request.kernel_id).cloned() {
2272 self.fallbacks
2273 .write()
2274 .insert(request.kernel_id.clone(), old_code);
2275 }
2276 }
2277
2278 self.kernels
2280 .write()
2281 .insert(request.kernel_id.clone(), request.new_code.clone());
2282 let swap_duration = swap_start.elapsed();
2283 self.stats
2284 .total_swap_time_us
2285 .fetch_add(swap_duration.as_micros() as u64, Ordering::Relaxed);
2286
2287 request.state = HotReloadState::Restoring;
2289 let restore_start = Instant::now();
2290 let restore_duration = restore_start.elapsed();
2292
2293 request.state = HotReloadState::Completed;
2295 self.stats
2296 .successful_reloads
2297 .fetch_add(1, Ordering::Relaxed);
2298 if self.config.preserve_state && checkpoint_size > 0 {
2299 self.stats
2300 .state_preserved_count
2301 .fetch_add(1, Ordering::Relaxed);
2302 }
2303
2304 self.active_requests.write().remove(&request.kernel_id);
2306
2307 Ok(HotReloadResult {
2308 kernel_id: request.kernel_id.clone(),
2309 old_version,
2310 new_version: request.new_code.version_id,
2311 state_preserved: self.config.preserve_state && checkpoint_size > 0,
2312 checkpoint_size,
2313 drain_duration,
2314 checkpoint_duration,
2315 compile_duration,
2316 swap_duration,
2317 restore_duration,
2318 total_duration: start_time.elapsed(),
2319 })
2320 }
2321
2322 pub fn rollback(&self, kernel_id: &KernelId) -> Result<()> {
2324 let fallback =
2325 self.fallbacks.write().remove(kernel_id).ok_or_else(|| {
2326 RingKernelError::ValidationError("No fallback available".to_string())
2327 })?;
2328
2329 self.kernels.write().insert(kernel_id.clone(), fallback);
2330 self.stats.rollbacks.fetch_add(1, Ordering::Relaxed);
2331
2332 if let Some(request) = self.active_requests.write().get_mut(kernel_id) {
2334 request.state = HotReloadState::RollingBack;
2335 }
2336
2337 Ok(())
2338 }
2339
2340 fn validate_code(&self, code: &KernelCodeSource) -> Result<()> {
2342 if code.code.is_empty() {
2344 return Err(RingKernelError::ValidationError(
2345 "Kernel code is empty".to_string(),
2346 ));
2347 }
2348
2349 if code.entry_point.is_empty() {
2350 return Err(RingKernelError::ValidationError(
2351 "Entry point is empty".to_string(),
2352 ));
2353 }
2354
2355 match code.format {
2357 KernelCodeFormat::Ptx => {
2358 if let Some(text) = code.as_str() {
2360 if !text.contains(".version") && !text.contains(".target") {
2361 return Err(RingKernelError::ValidationError(
2362 "PTX code missing version/target directive".to_string(),
2363 ));
2364 }
2365 }
2366 }
2367 KernelCodeFormat::Wgsl => {
2368 if let Some(text) = code.as_str() {
2370 if !text.contains("@compute") && !text.contains("fn ") {
2371 return Err(RingKernelError::ValidationError(
2372 "WGSL code missing compute shader or function".to_string(),
2373 ));
2374 }
2375 }
2376 }
2377 KernelCodeFormat::Msl => {
2378 if let Some(text) = code.as_str() {
2380 if !text.contains("kernel ") {
2381 return Err(RingKernelError::ValidationError(
2382 "MSL code missing kernel function".to_string(),
2383 ));
2384 }
2385 }
2386 }
2387 _ => {}
2388 }
2389
2390 Ok(())
2391 }
2392
2393 pub fn stats(&self) -> HotReloadStatsSnapshot {
2395 let successful = self.stats.successful_reloads.load(Ordering::Relaxed);
2396 let failed = self.stats.failed_reloads.load(Ordering::Relaxed);
2397 let total = successful.max(1);
2398
2399 HotReloadStatsSnapshot {
2400 successful_reloads: successful,
2401 failed_reloads: failed,
2402 rollbacks: self.stats.rollbacks.load(Ordering::Relaxed),
2403 avg_drain_time: Duration::from_micros(
2404 self.stats.total_drain_time_us.load(Ordering::Relaxed) / total,
2405 ),
2406 avg_compile_time: Duration::from_micros(
2407 self.stats.total_compile_time_us.load(Ordering::Relaxed) / total,
2408 ),
2409 avg_swap_time: Duration::from_micros(
2410 self.stats.total_swap_time_us.load(Ordering::Relaxed) / total,
2411 ),
2412 state_preserved_count: self.stats.state_preserved_count.load(Ordering::Relaxed),
2413 }
2414 }
2415
2416 pub fn list_kernels(&self) -> Vec<KernelId> {
2418 self.kernels.read().keys().cloned().collect()
2419 }
2420
2421 pub fn is_registered(&self, kernel_id: &KernelId) -> bool {
2423 self.kernels.read().contains_key(kernel_id)
2424 }
2425
2426 pub fn is_reload_in_progress(&self, kernel_id: &KernelId) -> bool {
2428 self.active_requests
2429 .read()
2430 .get(kernel_id)
2431 .map(|r| r.is_in_progress())
2432 .unwrap_or(false)
2433 }
2434
2435 pub fn config(&self) -> &HotReloadConfig {
2437 &self.config
2438 }
2439}
2440
2441pub trait HotReloadableKernel: CheckpointableKernel {
2443 fn prepare_for_reload(&mut self) -> Result<()>;
2445
2446 fn apply_code(&mut self, code: &KernelCodeSource) -> Result<()>;
2448
2449 fn resume_after_reload(&mut self) -> Result<()>;
2451
2452 fn is_ready_for_reload(&self) -> bool;
2454}
2455
2456#[cfg(test)]
2457mod tests {
2458 use super::*;
2459
2460 #[test]
2461 fn test_device_info() {
2462 let info = DeviceInfo::new(0, "Test GPU".to_string(), Backend::Cuda);
2463 assert_eq!(info.index, 0);
2464 assert_eq!(info.name, "Test GPU");
2465 assert_eq!(info.memory_utilization(), 0.0);
2466 }
2467
2468 #[test]
2469 fn test_coordinator_registration() {
2470 let coord = MultiGpuBuilder::new().build();
2471
2472 let device = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2473 coord.register_device(device);
2474
2475 assert_eq!(coord.device_count(), 1);
2476 assert!(coord.device(0).is_some());
2477 }
2478
2479 #[test]
2480 fn test_kernel_assignment() {
2481 let coord = MultiGpuBuilder::new().build();
2482
2483 let device = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2484 coord.register_device(device);
2485
2486 let kernel_id = KernelId::new("test_kernel");
2487 coord.assign_kernel(kernel_id.clone(), 0);
2488
2489 assert_eq!(coord.get_kernel_device(&kernel_id), Some(0));
2490 assert_eq!(coord.kernels_on_device(0).len(), 1);
2491 }
2492
2493 #[test]
2494 fn test_load_balancing_least_loaded() {
2495 let coord = MultiGpuBuilder::new()
2496 .load_balancing(LoadBalancingStrategy::LeastLoaded)
2497 .build();
2498
2499 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2501 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2502
2503 coord.assign_kernel(KernelId::new("k1"), 0);
2505
2506 let selected = coord.select_device(&LaunchOptions::default()).unwrap();
2508 assert_eq!(selected, 1);
2509 }
2510
2511 #[test]
2512 fn test_round_robin() {
2513 let coord = MultiGpuBuilder::new()
2514 .load_balancing(LoadBalancingStrategy::RoundRobin)
2515 .build();
2516
2517 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2518 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2519
2520 let d1 = coord.select_device(&LaunchOptions::default()).unwrap();
2521 let d2 = coord.select_device(&LaunchOptions::default()).unwrap();
2522 let d3 = coord.select_device(&LaunchOptions::default()).unwrap();
2523
2524 assert_ne!(d1, d2);
2526 assert_eq!(d1, d3);
2527 }
2528
2529 #[test]
2534 fn test_interconnect_bandwidth() {
2535 assert!(
2536 InterconnectType::NvLink.estimated_bandwidth_gbps()
2537 > InterconnectType::Pcie.estimated_bandwidth_gbps()
2538 );
2539 assert!(
2540 InterconnectType::Pcie.estimated_bandwidth_gbps()
2541 > InterconnectType::None.estimated_bandwidth_gbps()
2542 );
2543 assert!(
2544 InterconnectType::SameDevice.estimated_bandwidth_gbps()
2545 > InterconnectType::NvLink.estimated_bandwidth_gbps()
2546 );
2547 }
2548
2549 #[test]
2550 fn test_interconnect_p2p_support() {
2551 assert!(!InterconnectType::None.supports_p2p());
2552 assert!(InterconnectType::Pcie.supports_p2p());
2553 assert!(InterconnectType::NvLink.supports_p2p());
2554 assert!(InterconnectType::NvSwitch.supports_p2p());
2555 }
2556
2557 #[test]
2558 fn test_gpu_topology_creation() {
2559 let topo = GpuTopology::new(4);
2560 assert_eq!(topo.device_count, 4);
2561
2562 for i in 0..4 {
2564 let conn = topo.get_connection(i, i);
2565 assert!(conn.is_some());
2566 assert_eq!(conn.unwrap().interconnect, InterconnectType::SameDevice);
2567 }
2568 }
2569
2570 #[test]
2571 fn test_gpu_topology_set_connection() {
2572 let mut topo = GpuTopology::new(4);
2573
2574 topo.set_connection(GpuConnection::new(0, 1, InterconnectType::NvLink));
2576
2577 let conn_01 = topo.get_connection(0, 1);
2578 assert!(conn_01.is_some());
2579 assert_eq!(conn_01.unwrap().interconnect, InterconnectType::NvLink);
2580
2581 let conn_10 = topo.get_connection(1, 0);
2583 assert!(conn_10.is_some());
2584 assert_eq!(conn_10.unwrap().interconnect, InterconnectType::NvLink);
2585 }
2586
2587 #[test]
2588 fn test_gpu_topology_neighbors() {
2589 let mut topo = GpuTopology::new(4);
2590
2591 topo.set_connection(GpuConnection::new(0, 1, InterconnectType::NvLink));
2593 topo.set_connection(GpuConnection::new(1, 2, InterconnectType::NvLink));
2594 topo.set_connection(GpuConnection::new(2, 3, InterconnectType::NvLink));
2595 topo.set_connection(GpuConnection::new(3, 0, InterconnectType::NvLink));
2596
2597 let neighbors_0 = topo.neighbors(0);
2598 assert_eq!(neighbors_0.len(), 2);
2599 assert!(neighbors_0.contains(&1));
2600 assert!(neighbors_0.contains(&3));
2601 }
2602
2603 #[test]
2604 fn test_gpu_topology_best_path() {
2605 let mut topo = GpuTopology::new(4);
2606
2607 topo.set_connection(GpuConnection::new(0, 1, InterconnectType::NvLink));
2609 topo.set_connection(GpuConnection::new(1, 2, InterconnectType::NvLink));
2610 topo.set_connection(GpuConnection::new(2, 3, InterconnectType::NvLink));
2611 topo.set_connection(GpuConnection::new(0, 3, InterconnectType::None)); let path_01 = topo.best_path(0, 1);
2615 assert_eq!(path_01, vec![0, 1]);
2616
2617 let path_00 = topo.best_path(0, 0);
2619 assert_eq!(path_00, vec![0]);
2620 }
2621
2622 #[test]
2623 fn test_gpu_topology_fully_connected() {
2624 let mut topo = GpuTopology::new(3);
2625
2626 assert!(!topo.is_fully_connected());
2628
2629 topo.set_connection(GpuConnection::new(0, 1, InterconnectType::NvLink));
2631 topo.set_connection(GpuConnection::new(0, 2, InterconnectType::NvLink));
2632 topo.set_connection(GpuConnection::new(1, 2, InterconnectType::NvLink));
2633
2634 assert!(topo.is_fully_connected());
2635 }
2636
2637 #[test]
2638 fn test_gpu_topology_numa() {
2639 let mut topo = GpuTopology::new(4);
2640
2641 topo.set_numa_node(0, 0);
2643 topo.set_numa_node(1, 0);
2644 topo.set_numa_node(2, 1);
2645 topo.set_numa_node(3, 1);
2646
2647 let numa_neighbors_0 = topo.numa_neighbors(0);
2648 assert_eq!(numa_neighbors_0, vec![1]);
2649
2650 let numa_neighbors_2 = topo.numa_neighbors(2);
2651 assert_eq!(numa_neighbors_2, vec![3]);
2652 }
2653
2654 #[test]
2659 fn test_coordinator_topology_discovery() {
2660 let coord = MultiGpuBuilder::new().enable_p2p(true).build();
2661
2662 let mut dev0 = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2664 dev0.p2p_capable = true;
2665 dev0.compute_capability = Some((8, 0)); let mut dev1 = DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda);
2668 dev1.p2p_capable = true;
2669 dev1.compute_capability = Some((8, 6)); coord.register_device(dev0);
2672 coord.register_device(dev1);
2673
2674 let topo = coord.discover_topology();
2675
2676 assert_eq!(topo.device_count, 2);
2677
2678 let conn = topo.get_connection(0, 1);
2680 assert!(conn.is_some());
2681 assert_eq!(conn.unwrap().interconnect, InterconnectType::NvLink);
2682 }
2683
2684 #[test]
2689 fn test_migration_request() {
2690 let coord = MultiGpuBuilder::new().build();
2691
2692 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2693 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2694
2695 let kernel_id = KernelId::new("migrating_kernel");
2696 coord.assign_kernel(kernel_id.clone(), 0);
2697
2698 let request = coord.request_migration(&kernel_id, 1).unwrap();
2699
2700 assert_eq!(request.source_device, 0);
2701 assert_eq!(request.target_device, 1);
2702 assert_eq!(request.state, MigrationState::Pending);
2703 }
2704
2705 #[test]
2706 fn test_migration_same_device_error() {
2707 let coord = MultiGpuBuilder::new().build();
2708
2709 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2710
2711 let kernel_id = KernelId::new("kernel");
2712 coord.assign_kernel(kernel_id.clone(), 0);
2713
2714 let result = coord.request_migration(&kernel_id, 0);
2715 assert!(result.is_err());
2716 }
2717
2718 #[test]
2719 fn test_migration_complete() {
2720 let coord = MultiGpuBuilder::new().build();
2721
2722 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2723 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2724
2725 let kernel_id = KernelId::new("migrating_kernel");
2726 coord.assign_kernel(kernel_id.clone(), 0);
2727
2728 assert_eq!(coord.get_kernel_device(&kernel_id), Some(0));
2729
2730 let request = coord.request_migration(&kernel_id, 1).unwrap();
2731 coord.complete_migration(&request).unwrap();
2732
2733 assert_eq!(coord.get_kernel_device(&kernel_id), Some(1));
2734 }
2735
2736 #[test]
2737 fn test_migration_transfer_time_estimate() {
2738 let request = MigrationRequest {
2739 kernel_id: KernelId::new("test"),
2740 source_device: 0,
2741 target_device: 1,
2742 path: vec![0, 1],
2743 estimated_bandwidth_gbps: 300.0, estimated_latency_us: 1.0,
2745 state: MigrationState::Pending,
2746 started_at: None,
2747 };
2748
2749 let time = request.estimate_transfer_time(1_000_000_000);
2751 assert!(time.as_micros() > 3000);
2752 assert!(time.as_micros() < 4000);
2753 }
2754
2755 use crate::hlc::HlcTimestamp;
2760 use crate::message::MessageEnvelope;
2761
2762 fn make_test_k2k_message(source: &KernelId, dest: &KernelId) -> K2KMessage {
2763 let timestamp = HlcTimestamp::now(42);
2764 let envelope = MessageEnvelope::empty(1, 2, timestamp);
2765 K2KMessage::new(source.clone(), dest.clone(), envelope, timestamp)
2766 }
2767
2768 #[test]
2769 fn test_router_same_device() {
2770 let coord = MultiGpuBuilder::new().build();
2771 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2772
2773 let k1 = KernelId::new("k1");
2774 let k2 = KernelId::new("k2");
2775 coord.assign_kernel(k1.clone(), 0);
2776 coord.assign_kernel(k2.clone(), 0);
2777
2778 let router = CrossGpuK2KRouter::new(coord);
2779
2780 let msg = make_test_k2k_message(&k1, &k2);
2781 let decision = router.route_message(&k1, &k2, msg).unwrap();
2782
2783 matches!(decision, RoutingDecision::SameDevice);
2784 }
2785
2786 #[test]
2787 fn test_router_cross_device() {
2788 let coord = MultiGpuBuilder::new().enable_p2p(true).build();
2789
2790 let mut dev0 = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2791 dev0.p2p_capable = true;
2792 let mut dev1 = DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda);
2793 dev1.p2p_capable = true;
2794
2795 coord.register_device(dev0);
2796 coord.register_device(dev1);
2797
2798 let k1 = KernelId::new("k1");
2799 let k2 = KernelId::new("k2");
2800 coord.assign_kernel(k1.clone(), 0);
2801 coord.assign_kernel(k2.clone(), 1);
2802
2803 let router = CrossGpuK2KRouter::new(coord);
2804
2805 let msg = make_test_k2k_message(&k1, &k2);
2806 let decision = router.route_message(&k1, &k2, msg).unwrap();
2807
2808 match decision {
2809 RoutingDecision::DirectP2P {
2810 source_device,
2811 dest_device,
2812 ..
2813 } => {
2814 assert_eq!(source_device, 0);
2815 assert_eq!(dest_device, 1);
2816 }
2817 _ => panic!("Expected DirectP2P routing"),
2818 }
2819 }
2820
2821 #[test]
2822 fn test_router_pending_messages() {
2823 let coord = MultiGpuBuilder::new().enable_p2p(true).build();
2824
2825 let mut dev0 = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2826 dev0.p2p_capable = true;
2827 let mut dev1 = DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda);
2828 dev1.p2p_capable = true;
2829
2830 coord.register_device(dev0);
2831 coord.register_device(dev1);
2832
2833 let k1 = KernelId::new("k1");
2834 let k2 = KernelId::new("k2");
2835 coord.assign_kernel(k1.clone(), 0);
2836 coord.assign_kernel(k2.clone(), 1);
2837
2838 let router = CrossGpuK2KRouter::new(coord);
2839
2840 for _ in 0..3 {
2842 let msg = make_test_k2k_message(&k1, &k2);
2843 router.route_message(&k1, &k2, msg).unwrap();
2844 }
2845
2846 assert_eq!(router.stats().messages_pending, 3);
2847
2848 let pending = router.drain_pending(0, 1);
2850 assert_eq!(pending.len(), 3);
2851 assert_eq!(router.stats().messages_pending, 0);
2852 }
2853
2854 #[test]
2855 fn test_router_stats() {
2856 let coord = MultiGpuBuilder::new().build();
2857 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2858
2859 let k1 = KernelId::new("k1");
2860 let k2 = KernelId::new("k2");
2861 coord.assign_kernel(k1.clone(), 0);
2862 coord.assign_kernel(k2.clone(), 0);
2863
2864 let router = CrossGpuK2KRouter::new(coord);
2865
2866 let stats = router.stats();
2867 assert_eq!(stats.messages_routed, 0);
2868 assert_eq!(stats.bytes_transferred, 0);
2869 assert_eq!(stats.routing_failures, 0);
2870 }
2871
2872 use crate::checkpoint::{Checkpoint, CheckpointBuilder};
2877
2878 struct MockCheckpointableKernel {
2880 kernel_id: String,
2881 kernel_type: String,
2882 state_data: Vec<u8>,
2883 step: u64,
2884 }
2885
2886 impl MockCheckpointableKernel {
2887 fn new(kernel_id: &str, state_size: usize) -> Self {
2888 Self {
2889 kernel_id: kernel_id.to_string(),
2890 kernel_type: "mock_kernel".to_string(),
2891 state_data: vec![0xAB; state_size],
2892 step: 1000,
2893 }
2894 }
2895 }
2896
2897 impl CheckpointableKernel for MockCheckpointableKernel {
2898 fn create_checkpoint(&self) -> Result<Checkpoint> {
2899 let checkpoint = CheckpointBuilder::new(&self.kernel_id, &self.kernel_type)
2900 .step(self.step)
2901 .grid_size(64, 64, 64)
2902 .control_block(vec![1, 2, 3, 4])
2903 .device_memory("state", self.state_data.clone())
2904 .build();
2905 Ok(checkpoint)
2906 }
2907
2908 fn restore_from_checkpoint(&mut self, checkpoint: &Checkpoint) -> Result<()> {
2909 self.step = checkpoint.metadata.current_step;
2910 Ok(())
2911 }
2912
2913 fn checkpoint_kernel_id(&self) -> &str {
2914 &self.kernel_id
2915 }
2916
2917 fn checkpoint_kernel_type(&self) -> &str {
2918 &self.kernel_type
2919 }
2920 }
2921
2922 #[test]
2923 fn test_migrator_creation() {
2924 let coord = MultiGpuBuilder::new().build();
2925 let migrator = KernelMigrator::new(coord);
2926
2927 let stats = migrator.stats();
2928 assert_eq!(stats.successful_migrations, 0);
2929 assert_eq!(stats.failed_migrations, 0);
2930 assert_eq!(stats.bytes_transferred, 0);
2931 }
2932
2933 #[test]
2934 fn test_migrator_with_custom_storage() {
2935 let coord = MultiGpuBuilder::new().build();
2936 let storage = Arc::new(MemoryStorage::new());
2937 let migrator = KernelMigrator::with_storage(coord.clone(), storage);
2938
2939 assert!(Arc::ptr_eq(migrator.coordinator(), &coord));
2941 }
2942
2943 #[test]
2944 fn test_successful_migration() {
2945 let coord = MultiGpuBuilder::new().build();
2946
2947 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2949 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2950
2951 let kernel_id = KernelId::new("migratable_kernel");
2953 coord.assign_kernel(kernel_id.clone(), 0);
2954
2955 let migrator = KernelMigrator::new(coord.clone());
2956
2957 let kernel = MockCheckpointableKernel::new("migratable_kernel", 1024);
2959
2960 let mut request = coord.request_migration(&kernel_id, 1).unwrap();
2962 assert_eq!(request.state, MigrationState::Pending);
2963
2964 let result = migrator
2966 .migrate_with_checkpoint(&kernel, &mut request)
2967 .unwrap();
2968
2969 assert_eq!(result.kernel_id.as_str(), "migratable_kernel");
2971 assert_eq!(result.source_device, 0);
2972 assert_eq!(result.target_device, 1);
2973 assert!(result.checkpoint_size > 0);
2974 assert!(result.total_duration > Duration::ZERO);
2975
2976 assert_eq!(coord.get_kernel_device(&kernel_id), Some(1));
2978
2979 let stats = migrator.stats();
2981 assert_eq!(stats.successful_migrations, 1);
2982 assert_eq!(stats.failed_migrations, 0);
2983 assert!(stats.bytes_transferred > 0);
2984 }
2985
2986 #[test]
2987 fn test_migration_result_fields() {
2988 let coord = MultiGpuBuilder::new().build();
2989
2990 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2991 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2992
2993 let kernel_id = KernelId::new("test_kernel");
2994 coord.assign_kernel(kernel_id.clone(), 0);
2995
2996 let migrator = KernelMigrator::new(coord.clone());
2997 let kernel = MockCheckpointableKernel::new("test_kernel", 4096);
2998 let mut request = coord.request_migration(&kernel_id, 1).unwrap();
2999
3000 let result = migrator
3001 .migrate_with_checkpoint(&kernel, &mut request)
3002 .unwrap();
3003
3004 assert!(result.checkpoint_duration >= Duration::ZERO);
3006 assert!(result.transfer_duration >= Duration::ZERO);
3007 assert!(result.restore_duration >= Duration::ZERO);
3008
3009 assert!(result.total_duration >= result.checkpoint_duration);
3011 }
3012
3013 #[test]
3014 fn test_migration_stats_accumulate() {
3015 let coord = MultiGpuBuilder::new().build();
3016
3017 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3018 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
3019
3020 let migrator = KernelMigrator::new(coord.clone());
3021
3022 let k1 = KernelId::new("k1");
3024 coord.assign_kernel(k1.clone(), 0);
3025 let kernel1 = MockCheckpointableKernel::new("k1", 1000);
3026 let mut req1 = coord.request_migration(&k1, 1).unwrap();
3027 migrator
3028 .migrate_with_checkpoint(&kernel1, &mut req1)
3029 .unwrap();
3030
3031 let k2 = KernelId::new("k2");
3033 coord.assign_kernel(k2.clone(), 0);
3034 let kernel2 = MockCheckpointableKernel::new("k2", 2000);
3035 let mut req2 = coord.request_migration(&k2, 1).unwrap();
3036 migrator
3037 .migrate_with_checkpoint(&kernel2, &mut req2)
3038 .unwrap();
3039
3040 let stats = migrator.stats();
3041 assert_eq!(stats.successful_migrations, 2);
3042 assert_eq!(stats.failed_migrations, 0);
3043 assert!(stats.bytes_transferred > 0);
3045 }
3046
3047 #[test]
3052 fn test_unregister_device_no_kernels() {
3053 let coord = MultiGpuBuilder::new().build();
3054
3055 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3056 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
3057
3058 let result = coord.unregister_device(0);
3059
3060 assert!(result.success);
3061 assert_eq!(result.device_index, 0);
3062 assert!(result.kernels_to_migrate.is_empty());
3063 assert!(result.orphaned_kernels.is_empty());
3064 }
3065
3066 #[test]
3067 fn test_unregister_device_with_kernels() {
3068 let coord = MultiGpuBuilder::new().build();
3069
3070 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3071 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
3072
3073 let k1 = KernelId::new("k1");
3075 let k2 = KernelId::new("k2");
3076 coord.assign_kernel(k1.clone(), 0);
3077 coord.assign_kernel(k2.clone(), 0);
3078
3079 let result = coord.unregister_device(0);
3080
3081 assert!(result.success);
3082 assert_eq!(result.kernels_to_migrate.len(), 2);
3083 assert!(result.orphaned_kernels.is_empty());
3084
3085 for plan in &result.kernels_to_migrate {
3087 assert_eq!(plan.source_device, 0);
3088 assert_eq!(plan.target_device, 1);
3089 }
3090
3091 assert_eq!(coord.get_kernel_device(&k1), Some(1));
3093 assert_eq!(coord.get_kernel_device(&k2), Some(1));
3094 }
3095
3096 #[test]
3097 fn test_unregister_single_device_orphans_kernels() {
3098 let coord = MultiGpuBuilder::new().build();
3099
3100 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3101
3102 let k1 = KernelId::new("k1");
3104 coord.assign_kernel(k1.clone(), 0);
3105
3106 let result = coord.unregister_device(0);
3107
3108 assert!(result.success);
3109 assert!(result.kernels_to_migrate.is_empty());
3110 assert_eq!(result.orphaned_kernels.len(), 1);
3111 assert_eq!(result.orphaned_kernels[0], k1);
3112
3113 assert!(coord.get_kernel_device(&k1).is_none());
3115 }
3116
3117 #[test]
3118 fn test_unregister_nonexistent_device() {
3119 let coord = MultiGpuBuilder::new().build();
3120
3121 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3122
3123 let result = coord.unregister_device(99);
3124
3125 assert!(!result.success);
3126 assert_eq!(result.device_index, 99);
3127 }
3128
3129 #[test]
3130 fn test_unregister_distributes_to_least_loaded() {
3131 let coord = MultiGpuBuilder::new().build();
3132
3133 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3134 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
3135 coord.register_device(DeviceInfo::new(2, "GPU 2".to_string(), Backend::Cuda));
3136
3137 coord.assign_kernel(KernelId::new("pre1"), 1);
3139 coord.assign_kernel(KernelId::new("pre2"), 1);
3140 coord.assign_kernel(KernelId::new("pre3"), 1);
3141
3142 let k1 = KernelId::new("migrate_me");
3144 coord.assign_kernel(k1.clone(), 0);
3145
3146 let result = coord.unregister_device(0);
3147
3148 assert!(result.success);
3149 assert_eq!(result.kernels_to_migrate.len(), 1);
3150
3151 let plan = &result.kernels_to_migrate[0];
3153 assert_eq!(plan.target_device, 2);
3154 }
3155
3156 #[test]
3157 fn test_migration_priority_enum() {
3158 let low = MigrationPriority::Low;
3159 let normal = MigrationPriority::Normal;
3160 let high = MigrationPriority::High;
3161 let critical = MigrationPriority::Critical;
3162
3163 assert_ne!(low, normal);
3164 assert_ne!(normal, high);
3165 assert_ne!(high, critical);
3166 assert_eq!(low, MigrationPriority::Low);
3167 }
3168
3169 #[test]
3172 fn test_hot_reload_config_default() {
3173 let config = HotReloadConfig::default();
3174 assert!(config.enabled);
3175 assert!(config.preserve_state);
3176 assert!(config.validate_before_swap);
3177 assert!(config.keep_fallback);
3178 assert_eq!(config.max_retries, 3);
3179 }
3180
3181 #[test]
3182 fn test_hot_reload_config_builder() {
3183 let config = HotReloadConfig::new()
3184 .with_enabled(false)
3185 .with_preserve_state(false)
3186 .with_max_retries(5)
3187 .with_timeout(Duration::from_secs(60));
3188
3189 assert!(!config.enabled);
3190 assert!(!config.preserve_state);
3191 assert_eq!(config.max_retries, 5);
3192 assert_eq!(config.reload_timeout, Duration::from_secs(60));
3193 }
3194
3195 #[test]
3196 fn test_kernel_code_source_ptx() {
3197 let ptx = ".version 7.0\n.target sm_80\nkernel: ret;";
3198 let code = KernelCodeSource::from_ptx(ptx, "kernel");
3199
3200 assert_eq!(code.format, KernelCodeFormat::Ptx);
3201 assert_eq!(code.entry_point, "kernel");
3202 assert_eq!(code.as_str(), Some(ptx));
3203 assert_eq!(code.size(), ptx.len());
3204 }
3205
3206 #[test]
3207 fn test_kernel_code_source_wgsl() {
3208 let wgsl = "@compute fn main() {}";
3209 let code = KernelCodeSource::from_wgsl(wgsl, "main");
3210
3211 assert_eq!(code.format, KernelCodeFormat::Wgsl);
3212 assert_eq!(code.entry_point, "main");
3213 assert_eq!(code.as_str(), Some(wgsl));
3214 }
3215
3216 #[test]
3217 fn test_kernel_code_source_msl() {
3218 let msl = "kernel void my_kernel() {}";
3219 let code = KernelCodeSource::from_msl(msl, "my_kernel");
3220
3221 assert_eq!(code.format, KernelCodeFormat::Msl);
3222 assert_eq!(code.entry_point, "my_kernel");
3223 assert_eq!(code.as_str(), Some(msl));
3224 }
3225
3226 #[test]
3227 fn test_hot_reload_manager_creation() {
3228 let manager = HotReloadManager::with_defaults();
3229 assert!(manager.is_enabled());
3230 assert!(manager.list_kernels().is_empty());
3231 }
3232
3233 #[test]
3234 fn test_hot_reload_manager_register_kernel() {
3235 let manager = HotReloadManager::with_defaults();
3236 let kernel_id = KernelId::new("test_kernel");
3237 let code = KernelCodeSource::from_ptx(".version 7.0", "kernel");
3238
3239 manager.register_kernel(&kernel_id, code);
3240
3241 assert!(manager.is_registered(&kernel_id));
3242 assert!(!manager.is_reload_in_progress(&kernel_id));
3243 assert!(manager.get_current_version(&kernel_id).is_some());
3244 }
3245
3246 #[test]
3247 fn test_hot_reload_request_states() {
3248 let kernel_id = KernelId::new("test");
3249 let code = KernelCodeSource::from_ptx(".version 7.0", "kernel");
3250 let request = HotReloadRequest::new(kernel_id, code);
3251
3252 assert_eq!(request.state, HotReloadState::Idle);
3253 assert!(!request.is_in_progress());
3254 assert!(!request.is_completed());
3255 assert!(!request.is_failed());
3256 }
3257
3258 #[test]
3259 fn test_hot_reload_disabled() {
3260 let config = HotReloadConfig::new().with_enabled(false);
3261 let manager = HotReloadManager::new(config);
3262 let kernel_id = KernelId::new("test");
3263 let code = KernelCodeSource::from_ptx(".version 7.0", "kernel");
3264
3265 manager.register_kernel(&kernel_id, code.clone());
3266 let result = manager.request_reload(&kernel_id, code);
3267 assert!(result.is_err());
3268 }
3269
3270 #[test]
3271 fn test_hot_reload_stats() {
3272 let manager = HotReloadManager::with_defaults();
3273 let stats = manager.stats();
3274
3275 assert_eq!(stats.successful_reloads, 0);
3276 assert_eq!(stats.failed_reloads, 0);
3277 assert_eq!(stats.rollbacks, 0);
3278 }
3279
3280 #[test]
3281 fn test_hot_reload_code_formats() {
3282 let formats = [
3283 KernelCodeFormat::Ptx,
3284 KernelCodeFormat::Cubin,
3285 KernelCodeFormat::SpirV,
3286 KernelCodeFormat::Wgsl,
3287 KernelCodeFormat::Msl,
3288 KernelCodeFormat::MetalLib,
3289 KernelCodeFormat::Source,
3290 ];
3291
3292 for (i, f1) in formats.iter().enumerate() {
3294 for (j, f2) in formats.iter().enumerate() {
3295 if i != j {
3296 assert_ne!(f1, f2);
3297 }
3298 }
3299 }
3300 }
3301
3302 #[test]
3303 fn test_hot_reload_state_transitions() {
3304 let states = [
3305 HotReloadState::Idle,
3306 HotReloadState::Draining,
3307 HotReloadState::Checkpointing,
3308 HotReloadState::Compiling,
3309 HotReloadState::Validating,
3310 HotReloadState::Swapping,
3311 HotReloadState::Restoring,
3312 HotReloadState::Completed,
3313 HotReloadState::Failed,
3314 HotReloadState::RollingBack,
3315 ];
3316
3317 for (i, s1) in states.iter().enumerate() {
3319 for (j, s2) in states.iter().enumerate() {
3320 if i != j {
3321 assert_ne!(s1, s2);
3322 }
3323 }
3324 }
3325 }
3326
3327 #[test]
3328 fn test_hot_reload_execute() {
3329 let manager = HotReloadManager::with_defaults();
3330 let kernel_id = KernelId::new("test_kernel");
3331
3332 let initial_code = KernelCodeSource::from_ptx(".version 7.0\n.target sm_80", "kernel");
3333 manager.register_kernel(&kernel_id, initial_code);
3334
3335 let new_code = KernelCodeSource::from_ptx(".version 8.0\n.target sm_90", "kernel");
3336 let mut request = manager.request_reload(&kernel_id, new_code).unwrap();
3337
3338 let mock_kernel = MockCheckpointableKernel::new("test_kernel", 512);
3340
3341 let result = manager.execute_reload(&mut request, &mock_kernel).unwrap();
3342
3343 assert!(request.is_completed());
3344 assert_eq!(result.kernel_id.as_str(), "test_kernel");
3345 assert!(result.state_preserved);
3346 assert!(result.checkpoint_size > 0);
3347 assert!(result.total_duration > Duration::ZERO);
3348
3349 let stats = manager.stats();
3351 assert_eq!(stats.successful_reloads, 1);
3352 }
3353
3354 #[test]
3355 fn test_hot_reload_list_kernels() {
3356 let manager = HotReloadManager::with_defaults();
3357
3358 let k1 = KernelId::new("kernel1");
3359 let k2 = KernelId::new("kernel2");
3360 let k3 = KernelId::new("kernel3");
3361
3362 manager.register_kernel(&k1, KernelCodeSource::from_ptx(".version 7.0", "k1"));
3363 manager.register_kernel(&k2, KernelCodeSource::from_ptx(".version 7.0", "k2"));
3364 manager.register_kernel(&k3, KernelCodeSource::from_ptx(".version 7.0", "k3"));
3365
3366 let kernels = manager.list_kernels();
3367 assert_eq!(kernels.len(), 3);
3368 assert!(kernels.contains(&k1));
3369 assert!(kernels.contains(&k2));
3370 assert!(kernels.contains(&k3));
3371 }
3372}