1use parking_lot::RwLock;
30use std::collections::HashMap;
31use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
32use std::sync::Arc;
33use std::time::{Duration, Instant};
34
35use crate::error::{Result, RingKernelError};
36use crate::k2k::K2KMessage;
37use crate::runtime::{Backend, KernelId, LaunchOptions};
38
39#[derive(Debug, Clone)]
41pub struct MultiGpuConfig {
42 pub load_balancing: LoadBalancingStrategy,
44 pub auto_select_device: bool,
46 pub max_kernels_per_device: usize,
48 pub enable_p2p: bool,
50 pub preferred_devices: Vec<usize>,
52}
53
54impl Default for MultiGpuConfig {
55 fn default() -> Self {
56 Self {
57 load_balancing: LoadBalancingStrategy::LeastLoaded,
58 auto_select_device: true,
59 max_kernels_per_device: 64,
60 enable_p2p: true,
61 preferred_devices: vec![],
62 }
63 }
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub enum LoadBalancingStrategy {
69 FirstAvailable,
71 LeastLoaded,
73 RoundRobin,
75 MemoryBased,
77 ComputeCapability,
79 Custom,
81}
82
83#[derive(Debug, Clone)]
85pub struct DeviceInfo {
86 pub index: usize,
88 pub name: String,
90 pub backend: Backend,
92 pub total_memory: u64,
94 pub available_memory: u64,
96 pub compute_capability: Option<(u32, u32)>,
98 pub max_threads_per_block: u32,
100 pub multiprocessor_count: u32,
102 pub p2p_capable: bool,
104}
105
106impl DeviceInfo {
107 pub fn new(index: usize, name: String, backend: Backend) -> Self {
109 Self {
110 index,
111 name,
112 backend,
113 total_memory: 0,
114 available_memory: 0,
115 compute_capability: None,
116 max_threads_per_block: 1024,
117 multiprocessor_count: 1,
118 p2p_capable: false,
119 }
120 }
121
122 pub fn memory_utilization(&self) -> f64 {
124 if self.total_memory == 0 {
125 0.0
126 } else {
127 1.0 - (self.available_memory as f64 / self.total_memory as f64)
128 }
129 }
130}
131
132#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
138pub enum InterconnectType {
139 None,
141 Pcie,
143 NvLink,
145 NvSwitch,
147 InfinityFabric,
149 XeLink,
151 SameDevice,
153}
154
155impl InterconnectType {
156 pub fn estimated_bandwidth_gbps(&self) -> f64 {
158 match self {
159 InterconnectType::None => 16.0, InterconnectType::Pcie => 32.0, InterconnectType::NvLink => 300.0, InterconnectType::NvSwitch => 600.0, InterconnectType::InfinityFabric => 200.0, InterconnectType::XeLink => 100.0, InterconnectType::SameDevice => 2000.0, }
167 }
168
169 pub fn estimated_latency_us(&self) -> f64 {
171 match self {
172 InterconnectType::None => 10.0, InterconnectType::Pcie => 5.0, InterconnectType::NvLink => 1.0, InterconnectType::NvSwitch => 2.0, InterconnectType::InfinityFabric => 1.5,
177 InterconnectType::XeLink => 2.0,
178 InterconnectType::SameDevice => 0.0,
179 }
180 }
181
182 pub fn supports_p2p(&self) -> bool {
184 !matches!(self, InterconnectType::None)
185 }
186}
187
188#[derive(Debug, Clone)]
190pub struct GpuConnection {
191 pub source: usize,
193 pub destination: usize,
195 pub interconnect: InterconnectType,
197 pub bandwidth_gbps: f64,
199 pub latency_us: f64,
201 pub bidirectional: bool,
203 pub hops: u32,
205}
206
207impl GpuConnection {
208 pub fn new(source: usize, destination: usize, interconnect: InterconnectType) -> Self {
210 Self {
211 source,
212 destination,
213 interconnect,
214 bandwidth_gbps: interconnect.estimated_bandwidth_gbps(),
215 latency_us: interconnect.estimated_latency_us(),
216 bidirectional: true,
217 hops: if source == destination { 0 } else { 1 },
218 }
219 }
220
221 pub fn with_bandwidth(mut self, gbps: f64) -> Self {
223 self.bandwidth_gbps = gbps;
224 self
225 }
226
227 pub fn with_latency(mut self, us: f64) -> Self {
229 self.latency_us = us;
230 self
231 }
232
233 pub fn with_hops(mut self, hops: u32) -> Self {
235 self.hops = hops;
236 self
237 }
238}
239
240#[derive(Debug, Clone)]
242pub struct GpuTopology {
243 pub device_count: usize,
245 connections: Vec<Vec<Option<GpuConnection>>>,
247 pub numa_nodes: Vec<Option<u32>>,
249 pub probed: bool,
251 pub last_updated: Instant,
253}
254
255impl GpuTopology {
256 pub fn new(device_count: usize) -> Self {
258 let mut connections = vec![vec![None; device_count]; device_count];
259
260 for (i, row) in connections.iter_mut().enumerate().take(device_count) {
262 row[i] = Some(GpuConnection::new(i, i, InterconnectType::SameDevice));
263 }
264
265 Self {
266 device_count,
267 connections,
268 numa_nodes: vec![None; device_count],
269 probed: false,
270 last_updated: Instant::now(),
271 }
272 }
273
274 pub fn set_connection(&mut self, connection: GpuConnection) {
276 let src = connection.source;
277 let dst = connection.destination;
278 if src < self.device_count && dst < self.device_count {
279 self.connections[src][dst] = Some(connection.clone());
280 if connection.bidirectional && src != dst {
281 let reverse = GpuConnection {
282 source: dst,
283 destination: src,
284 ..connection
285 };
286 self.connections[dst][src] = Some(reverse);
287 }
288 }
289 }
290
291 pub fn get_connection(&self, source: usize, destination: usize) -> Option<&GpuConnection> {
293 self.connections
294 .get(source)
295 .and_then(|row| row.get(destination))
296 .and_then(|c| c.as_ref())
297 }
298
299 pub fn best_path(&self, source: usize, destination: usize) -> Vec<usize> {
301 if source == destination {
302 return vec![source];
303 }
304
305 if let Some(conn) = self.get_connection(source, destination) {
307 if conn.interconnect != InterconnectType::None {
308 return vec![source, destination];
309 }
310 }
311
312 let mut best_path = vec![source, destination]; let mut best_bandwidth = 0.0;
315
316 for intermediate in 0..self.device_count {
318 if intermediate == source || intermediate == destination {
319 continue;
320 }
321
322 if let (Some(c1), Some(c2)) = (
323 self.get_connection(source, intermediate),
324 self.get_connection(intermediate, destination),
325 ) {
326 let path_bandwidth = c1.bandwidth_gbps.min(c2.bandwidth_gbps);
328 if path_bandwidth > best_bandwidth {
329 best_bandwidth = path_bandwidth;
330 best_path = vec![source, intermediate, destination];
331 }
332 }
333 }
334
335 best_path
336 }
337
338 pub fn neighbors(&self, device: usize) -> Vec<usize> {
340 if device >= self.device_count {
341 return vec![];
342 }
343
344 self.connections[device]
345 .iter()
346 .enumerate()
347 .filter_map(|(i, conn)| {
348 if i != device
349 && conn
350 .as_ref()
351 .map(|c| c.interconnect.supports_p2p())
352 .unwrap_or(false)
353 {
354 Some(i)
355 } else {
356 None
357 }
358 })
359 .collect()
360 }
361
362 pub fn bisection_bandwidth_gbps(&self) -> f64 {
364 let half = self.device_count / 2;
365 if half == 0 {
366 return 0.0;
367 }
368
369 let mut total = 0.0;
370 for src in 0..half {
371 for dst in half..self.device_count {
372 if let Some(conn) = self.get_connection(src, dst) {
373 total += conn.bandwidth_gbps;
374 }
375 }
376 }
377 total
378 }
379
380 pub fn is_fully_connected(&self) -> bool {
382 for src in 0..self.device_count {
383 for dst in 0..self.device_count {
384 if src != dst {
385 if let Some(conn) = self.get_connection(src, dst) {
386 if !conn.interconnect.supports_p2p() {
387 return false;
388 }
389 } else {
390 return false;
391 }
392 }
393 }
394 }
395 true
396 }
397
398 pub fn numa_neighbors(&self, device: usize) -> Vec<usize> {
400 let target_numa = self.numa_nodes.get(device).copied().flatten();
401 if target_numa.is_none() {
402 return vec![];
403 }
404
405 self.numa_nodes
406 .iter()
407 .enumerate()
408 .filter_map(|(i, numa)| {
409 if i != device && *numa == target_numa {
410 Some(i)
411 } else {
412 None
413 }
414 })
415 .collect()
416 }
417
418 pub fn set_numa_node(&mut self, device: usize, numa_node: u32) {
420 if device < self.numa_nodes.len() {
421 self.numa_nodes[device] = Some(numa_node);
422 }
423 }
424
425 pub fn mark_probed(&mut self) {
427 self.probed = true;
428 self.last_updated = Instant::now();
429 }
430}
431
432#[derive(Debug, Clone)]
434pub struct DeviceStatus {
435 pub info: DeviceInfo,
437 pub kernel_count: usize,
439 pub kernels: Vec<KernelId>,
441 pub available: bool,
443 pub load: f64,
445}
446
447#[derive(Debug, Clone)]
449pub struct DeviceUnregisterResult {
450 pub device_index: usize,
452 pub kernels_to_migrate: Vec<KernelMigrationPlan>,
454 pub orphaned_kernels: Vec<KernelId>,
456 pub success: bool,
458}
459
460#[derive(Debug, Clone)]
462pub struct KernelMigrationPlan {
463 pub kernel_id: KernelId,
465 pub source_device: usize,
467 pub target_device: usize,
469 pub priority: MigrationPriority,
471}
472
473#[derive(Debug, Clone, Copy, PartialEq, Eq)]
475pub enum MigrationPriority {
476 Low,
478 Normal,
480 High,
482 Critical,
484}
485
486pub struct MultiGpuCoordinator {
488 config: MultiGpuConfig,
490 devices: RwLock<Vec<DeviceInfo>>,
492 kernel_device_map: RwLock<HashMap<KernelId, usize>>,
494 device_kernel_counts: RwLock<Vec<AtomicUsize>>,
496 round_robin_counter: AtomicUsize,
498 total_kernels: AtomicU64,
500 #[allow(clippy::type_complexity)]
502 custom_selector:
503 RwLock<Option<Arc<dyn Fn(&[DeviceStatus], &LaunchOptions) -> usize + Send + Sync>>>,
504 topology: RwLock<Option<GpuTopology>>,
506}
507
508impl MultiGpuCoordinator {
509 pub fn new(config: MultiGpuConfig) -> Arc<Self> {
511 Arc::new(Self {
512 config,
513 devices: RwLock::new(Vec::new()),
514 kernel_device_map: RwLock::new(HashMap::new()),
515 device_kernel_counts: RwLock::new(Vec::new()),
516 round_robin_counter: AtomicUsize::new(0),
517 total_kernels: AtomicU64::new(0),
518 custom_selector: RwLock::new(None),
519 topology: RwLock::new(None),
520 })
521 }
522
523 pub fn register_device(&self, device: DeviceInfo) {
525 let index = device.index;
526 let mut devices = self.devices.write();
527 let mut counts = self.device_kernel_counts.write();
528
529 let mut current_len = devices.len();
531 while current_len <= index {
532 devices.push(DeviceInfo::new(
533 current_len,
534 "Unknown".to_string(),
535 Backend::Cpu,
536 ));
537 counts.push(AtomicUsize::new(0));
538 current_len += 1;
539 }
540
541 devices[index] = device;
542 }
543
544 pub fn unregister_device(&self, index: usize) -> DeviceUnregisterResult {
556 let devices = self.devices.read();
557
558 if index >= devices.len() {
560 return DeviceUnregisterResult {
561 device_index: index,
562 kernels_to_migrate: Vec::new(),
563 orphaned_kernels: Vec::new(),
564 success: false,
565 };
566 }
567
568 let kernels_on_device = self.kernels_on_device(index);
570
571 let available_targets: Vec<usize> = devices
573 .iter()
574 .enumerate()
575 .filter(|(i, _)| *i != index)
576 .map(|(i, _)| i)
577 .collect();
578
579 drop(devices); let mut kernels_to_migrate = Vec::new();
582 let mut orphaned_kernels = Vec::new();
583
584 if available_targets.is_empty() {
585 orphaned_kernels = kernels_on_device;
587 } else {
588 for kernel_id in kernels_on_device {
590 if let Some(target) = self.select_migration_target(&available_targets) {
592 let priority = self.calculate_migration_priority(&kernel_id);
593 kernels_to_migrate.push(KernelMigrationPlan {
594 kernel_id,
595 source_device: index,
596 target_device: target,
597 priority,
598 });
599 } else {
600 orphaned_kernels.push(kernel_id);
601 }
602 }
603 }
604
605 {
607 let mut kernel_map = self.kernel_device_map.write();
608 let counts = self.device_kernel_counts.read();
609
610 for plan in &kernels_to_migrate {
611 kernel_map.insert(plan.kernel_id.clone(), plan.target_device);
613
614 if index < counts.len() {
616 counts[index].fetch_sub(1, Ordering::Relaxed);
617 }
618 if plan.target_device < counts.len() {
619 counts[plan.target_device].fetch_add(1, Ordering::Relaxed);
620 }
621 }
622
623 for kernel_id in &orphaned_kernels {
625 kernel_map.remove(kernel_id);
626 if index < counts.len() {
627 counts[index].fetch_sub(1, Ordering::Relaxed);
628 }
629 }
630 }
631
632 {
634 let mut devices = self.devices.write();
635 if index < devices.len() {
636 devices[index].available_memory = 0;
637 devices[index].name = format!("{} (unregistered)", devices[index].name);
638 }
639 }
640
641 DeviceUnregisterResult {
642 device_index: index,
643 kernels_to_migrate,
644 orphaned_kernels,
645 success: true,
646 }
647 }
648
649 fn select_migration_target(&self, candidates: &[usize]) -> Option<usize> {
651 if candidates.is_empty() {
652 return None;
653 }
654
655 let counts = self.device_kernel_counts.read();
656
657 candidates
659 .iter()
660 .filter_map(|&idx| {
661 if idx < counts.len() {
662 Some((idx, counts[idx].load(Ordering::Relaxed)))
663 } else {
664 None
665 }
666 })
667 .min_by_key(|(_, count)| *count)
668 .map(|(idx, _)| idx)
669 }
670
671 fn calculate_migration_priority(&self, _kernel_id: &KernelId) -> MigrationPriority {
673 MigrationPriority::Normal
679 }
680
681 pub fn devices(&self) -> Vec<DeviceInfo> {
683 self.devices.read().clone()
684 }
685
686 pub fn device(&self, index: usize) -> Option<DeviceInfo> {
688 self.devices.read().get(index).cloned()
689 }
690
691 pub fn device_count(&self) -> usize {
693 self.devices.read().len()
694 }
695
696 pub fn select_device(&self, options: &LaunchOptions) -> Result<usize> {
698 let devices = self.devices.read();
699 if devices.is_empty() {
700 return Err(RingKernelError::BackendUnavailable(
701 "No GPU devices available".to_string(),
702 ));
703 }
704
705 let status = self.get_all_status();
707
708 if self.config.load_balancing == LoadBalancingStrategy::Custom {
710 if let Some(selector) = &*self.custom_selector.read() {
711 return Ok(selector(&status, options));
712 }
713 }
714
715 let candidates: Vec<_> = if !self.config.preferred_devices.is_empty() {
717 status
718 .into_iter()
719 .filter(|s| self.config.preferred_devices.contains(&s.info.index))
720 .collect()
721 } else {
722 status
723 };
724
725 if candidates.is_empty() {
726 return Err(RingKernelError::BackendUnavailable(
727 "No suitable GPU device available".to_string(),
728 ));
729 }
730
731 let selected = match self.config.load_balancing {
732 LoadBalancingStrategy::FirstAvailable => {
733 candidates.first().map(|s| s.info.index).unwrap_or(0)
734 }
735 LoadBalancingStrategy::LeastLoaded => candidates
736 .iter()
737 .filter(|s| s.available && s.kernel_count < self.config.max_kernels_per_device)
738 .min_by(|a, b| a.kernel_count.cmp(&b.kernel_count))
739 .map(|s| s.info.index)
740 .unwrap_or(0),
741 LoadBalancingStrategy::RoundRobin => {
742 let available: Vec<_> = candidates.iter().filter(|s| s.available).collect();
743
744 if available.is_empty() {
745 candidates.first().map(|s| s.info.index).unwrap_or(0)
746 } else {
747 let idx =
748 self.round_robin_counter.fetch_add(1, Ordering::Relaxed) % available.len();
749 available[idx].info.index
750 }
751 }
752 LoadBalancingStrategy::MemoryBased => candidates
753 .iter()
754 .filter(|s| s.available)
755 .max_by(|a, b| a.info.available_memory.cmp(&b.info.available_memory))
756 .map(|s| s.info.index)
757 .unwrap_or(0),
758 LoadBalancingStrategy::ComputeCapability => candidates
759 .iter()
760 .filter(|s| s.available)
761 .max_by(|a, b| {
762 let a_cap = a.info.compute_capability.unwrap_or((0, 0));
763 let b_cap = b.info.compute_capability.unwrap_or((0, 0));
764 a_cap.cmp(&b_cap)
765 })
766 .map(|s| s.info.index)
767 .unwrap_or(0),
768 LoadBalancingStrategy::Custom => {
769 0
771 }
772 };
773
774 Ok(selected)
775 }
776
777 pub fn assign_kernel(&self, kernel_id: KernelId, device_index: usize) {
779 self.kernel_device_map
780 .write()
781 .insert(kernel_id, device_index);
782
783 let counts = self.device_kernel_counts.read();
784 if device_index < counts.len() {
785 counts[device_index].fetch_add(1, Ordering::Relaxed);
786 }
787
788 self.total_kernels.fetch_add(1, Ordering::Relaxed);
789 }
790
791 pub fn remove_kernel(&self, kernel_id: &KernelId) {
793 if let Some(device_index) = self.kernel_device_map.write().remove(kernel_id) {
794 let counts = self.device_kernel_counts.read();
795 if device_index < counts.len() {
796 counts[device_index].fetch_sub(1, Ordering::Relaxed);
797 }
798 }
799 }
800
801 pub fn get_kernel_device(&self, kernel_id: &KernelId) -> Option<usize> {
803 self.kernel_device_map.read().get(kernel_id).copied()
804 }
805
806 pub fn kernels_on_device(&self, device_index: usize) -> Vec<KernelId> {
808 self.kernel_device_map
809 .read()
810 .iter()
811 .filter(|(_, &idx)| idx == device_index)
812 .map(|(k, _)| k.clone())
813 .collect()
814 }
815
816 pub fn get_all_status(&self) -> Vec<DeviceStatus> {
818 let devices = self.devices.read();
819 let kernel_map = self.kernel_device_map.read();
820 let counts = self.device_kernel_counts.read();
821
822 devices
823 .iter()
824 .enumerate()
825 .map(|(idx, info)| {
826 let kernel_count = if idx < counts.len() {
827 counts[idx].load(Ordering::Relaxed)
828 } else {
829 0
830 };
831
832 let kernels: Vec<_> = kernel_map
833 .iter()
834 .filter(|(_, &dev_idx)| dev_idx == idx)
835 .map(|(k, _)| k.clone())
836 .collect();
837
838 let load = kernel_count as f64 / self.config.max_kernels_per_device as f64;
839 let available = kernel_count < self.config.max_kernels_per_device;
840
841 DeviceStatus {
842 info: info.clone(),
843 kernel_count,
844 kernels,
845 available,
846 load,
847 }
848 })
849 .collect()
850 }
851
852 pub fn get_device_status(&self, device_index: usize) -> Option<DeviceStatus> {
854 self.get_all_status().into_iter().nth(device_index)
855 }
856
857 pub fn set_custom_selector<F>(&self, selector: F)
859 where
860 F: Fn(&[DeviceStatus], &LaunchOptions) -> usize + Send + Sync + 'static,
861 {
862 *self.custom_selector.write() = Some(Arc::new(selector));
863 }
864
865 pub fn stats(&self) -> MultiGpuStats {
867 let status = self.get_all_status();
868 let total_kernels: usize = status.iter().map(|s| s.kernel_count).sum();
869 let total_memory: u64 = status.iter().map(|s| s.info.total_memory).sum();
870 let available_memory: u64 = status.iter().map(|s| s.info.available_memory).sum();
871
872 MultiGpuStats {
873 device_count: status.len(),
874 total_kernels,
875 total_memory,
876 available_memory,
877 kernels_launched: self.total_kernels.load(Ordering::Relaxed),
878 }
879 }
880
881 pub fn can_p2p(&self, device_a: usize, device_b: usize) -> bool {
883 if !self.config.enable_p2p {
884 return false;
885 }
886
887 let devices = self.devices.read();
888 if let (Some(a), Some(b)) = (devices.get(device_a), devices.get(device_b)) {
889 a.p2p_capable && b.p2p_capable
890 } else {
891 false
892 }
893 }
894
895 pub fn update_device_memory(&self, device_index: usize, available_memory: u64) {
897 let mut devices = self.devices.write();
898 if let Some(device) = devices.get_mut(device_index) {
899 device.available_memory = available_memory;
900 }
901 }
902
903 pub fn discover_topology(&self) -> GpuTopology {
909 let devices = self.devices.read();
910 let device_count = devices.len();
911
912 if device_count == 0 {
913 return GpuTopology::new(0);
914 }
915
916 let mut topo = GpuTopology::new(device_count);
917
918 for (i, dev_i) in devices.iter().enumerate() {
920 for (j, dev_j) in devices.iter().enumerate() {
921 if i == j {
922 continue;
923 }
924
925 let interconnect = if dev_i.p2p_capable && dev_j.p2p_capable {
927 if dev_i.backend == dev_j.backend {
929 match dev_i.backend {
930 Backend::Cuda => {
931 let cc_i = dev_i.compute_capability.unwrap_or((0, 0));
933 let cc_j = dev_j.compute_capability.unwrap_or((0, 0));
934
935 if cc_i.0 >= 8 && cc_j.0 >= 8 {
937 InterconnectType::NvLink
938 } else {
939 InterconnectType::Pcie
940 }
941 }
942 _ => InterconnectType::Pcie,
943 }
944 } else {
945 InterconnectType::None
946 }
947 } else {
948 InterconnectType::None
949 };
950
951 topo.set_connection(GpuConnection::new(i, j, interconnect));
952 }
953 }
954
955 *self.topology.write() = Some(topo.clone());
957
958 topo
959 }
960
961 pub fn topology(&self) -> GpuTopology {
963 {
964 let topo = self.topology.read();
965 if let Some(ref t) = *topo {
966 return t.clone();
967 }
968 }
969 self.discover_topology()
970 }
971
972 pub fn set_topology(&self, topology: GpuTopology) {
974 *self.topology.write() = Some(topology);
975 }
976
977 pub fn select_device_for_k2k(&self, source_kernel: &KernelId) -> Result<usize> {
979 let source_idx = match self.get_kernel_device(source_kernel) {
980 Some(idx) => idx,
981 None => return self.select_device(&LaunchOptions::default()),
982 };
983 let topo = self.topology();
984 let status = self.get_all_status();
985
986 let neighbors = topo.neighbors(source_idx);
988
989 if neighbors.is_empty() {
990 return self.select_device(&LaunchOptions::default());
992 }
993
994 let best = neighbors
996 .iter()
997 .filter_map(|&dev_idx| {
998 status.iter().find(|s| s.info.index == dev_idx).map(|s| {
999 let conn = topo.get_connection(source_idx, dev_idx);
1000 let bandwidth = conn.map(|c| c.bandwidth_gbps).unwrap_or(1.0);
1001 let score = bandwidth / (s.load + 0.1);
1002 (dev_idx, score)
1003 })
1004 })
1005 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
1006 .map(|(idx, _)| idx);
1007
1008 best.ok_or_else(|| {
1009 RingKernelError::BackendUnavailable("No suitable K2K device found".to_string())
1010 })
1011 }
1012
1013 pub fn request_migration(
1019 &self,
1020 kernel_id: &KernelId,
1021 target_device: usize,
1022 ) -> Result<MigrationRequest> {
1023 let source_device = self
1024 .get_kernel_device(kernel_id)
1025 .ok_or_else(|| RingKernelError::KernelNotFound(kernel_id.as_str().to_string()))?;
1026
1027 if source_device == target_device {
1028 return Err(RingKernelError::InvalidConfig(
1029 "Cannot migrate to same device".to_string(),
1030 ));
1031 }
1032
1033 let devices = self.devices.read();
1034 if target_device >= devices.len() {
1035 return Err(RingKernelError::DeviceNotAvailable(format!(
1036 "Device {} not available",
1037 target_device
1038 )));
1039 }
1040
1041 let topo = self.topology();
1042 let path = topo.best_path(source_device, target_device);
1043 let connection = topo.get_connection(source_device, target_device);
1044
1045 Ok(MigrationRequest {
1046 kernel_id: kernel_id.clone(),
1047 source_device,
1048 target_device,
1049 path,
1050 estimated_bandwidth_gbps: connection.map(|c| c.bandwidth_gbps).unwrap_or(16.0),
1051 estimated_latency_us: connection.map(|c| c.latency_us).unwrap_or(10.0),
1052 state: MigrationState::Pending,
1053 started_at: None,
1054 })
1055 }
1056
1057 pub fn complete_migration(&self, request: &MigrationRequest) -> Result<()> {
1059 {
1061 let mut map = self.kernel_device_map.write();
1062 if let Some(dev) = map.get_mut(&request.kernel_id) {
1063 *dev = request.target_device;
1064 }
1065 }
1066
1067 {
1069 let counts = self.device_kernel_counts.read();
1070 if request.source_device < counts.len() {
1071 counts[request.source_device].fetch_sub(1, Ordering::Relaxed);
1072 }
1073 if request.target_device < counts.len() {
1074 counts[request.target_device].fetch_add(1, Ordering::Relaxed);
1075 }
1076 }
1077
1078 Ok(())
1079 }
1080}
1081
1082#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1088pub enum MigrationState {
1089 Pending,
1091 Quiescing,
1093 Checkpointing,
1095 Transferring,
1097 Restoring,
1099 Completed,
1101 Failed,
1103 Cancelled,
1105}
1106
1107#[derive(Debug, Clone)]
1109pub struct MigrationRequest {
1110 pub kernel_id: KernelId,
1112 pub source_device: usize,
1114 pub target_device: usize,
1116 pub path: Vec<usize>,
1118 pub estimated_bandwidth_gbps: f64,
1120 pub estimated_latency_us: f64,
1122 pub state: MigrationState,
1124 pub started_at: Option<Instant>,
1126}
1127
1128impl MigrationRequest {
1129 pub fn estimate_transfer_time(&self, state_size_bytes: usize) -> Duration {
1131 let size_gb = state_size_bytes as f64 / 1_000_000_000.0;
1133 let transfer_time_s = size_gb / self.estimated_bandwidth_gbps;
1134 let total_us = (transfer_time_s * 1_000_000.0) + self.estimated_latency_us;
1135 Duration::from_micros(total_us as u64)
1136 }
1137}
1138
1139pub struct CrossGpuK2KRouter {
1145 coordinator: Arc<MultiGpuCoordinator>,
1147 pending_queues: RwLock<HashMap<(usize, usize), Vec<PendingK2KMessage>>>,
1149 stats: CrossGpuRouterStats,
1151}
1152
1153#[derive(Debug, Clone)]
1155pub struct PendingK2KMessage {
1156 pub source_kernel: KernelId,
1158 pub dest_kernel: KernelId,
1160 pub message: K2KMessage,
1162 pub queued_at: Instant,
1164 pub hops: u32,
1166}
1167
1168#[derive(Debug, Default)]
1170pub struct CrossGpuRouterStats {
1171 messages_routed: AtomicU64,
1173 bytes_transferred: AtomicU64,
1175 messages_pending: AtomicUsize,
1177 total_latency_us: AtomicU64,
1179 routing_failures: AtomicU64,
1181}
1182
1183impl CrossGpuK2KRouter {
1184 pub fn new(coordinator: Arc<MultiGpuCoordinator>) -> Arc<Self> {
1186 Arc::new(Self {
1187 coordinator,
1188 pending_queues: RwLock::new(HashMap::new()),
1189 stats: CrossGpuRouterStats::default(),
1190 })
1191 }
1192
1193 pub fn route_message(
1195 &self,
1196 source_kernel: &KernelId,
1197 dest_kernel: &KernelId,
1198 message: K2KMessage,
1199 ) -> Result<RoutingDecision> {
1200 let source_device = self
1201 .coordinator
1202 .get_kernel_device(source_kernel)
1203 .ok_or_else(|| {
1204 RingKernelError::K2KDestinationNotFound(source_kernel.as_str().to_string())
1205 })?;
1206
1207 let dest_device = self
1208 .coordinator
1209 .get_kernel_device(dest_kernel)
1210 .ok_or_else(|| {
1211 RingKernelError::K2KDestinationNotFound(dest_kernel.as_str().to_string())
1212 })?;
1213
1214 if source_device == dest_device {
1216 return Ok(RoutingDecision::SameDevice);
1217 }
1218
1219 let topo = self.coordinator.topology();
1221 let path = topo.best_path(source_device, dest_device);
1222
1223 if let Some(conn) = topo.get_connection(source_device, dest_device) {
1225 if conn.interconnect.supports_p2p() {
1226 let pending = PendingK2KMessage {
1228 source_kernel: source_kernel.clone(),
1229 dest_kernel: dest_kernel.clone(),
1230 message,
1231 queued_at: Instant::now(),
1232 hops: 1,
1233 };
1234
1235 self.enqueue_pending(source_device, dest_device, pending);
1236 self.stats.messages_pending.fetch_add(1, Ordering::Relaxed);
1237
1238 return Ok(RoutingDecision::DirectP2P {
1239 source_device,
1240 dest_device,
1241 bandwidth_gbps: conn.bandwidth_gbps,
1242 });
1243 }
1244 }
1245
1246 if path.len() > 2 {
1248 let pending = PendingK2KMessage {
1249 source_kernel: source_kernel.clone(),
1250 dest_kernel: dest_kernel.clone(),
1251 message,
1252 queued_at: Instant::now(),
1253 hops: (path.len() - 1) as u32,
1254 };
1255
1256 self.enqueue_pending(source_device, path[1], pending);
1258 self.stats.messages_pending.fetch_add(1, Ordering::Relaxed);
1259
1260 return Ok(RoutingDecision::MultiHop {
1261 path: path.clone(),
1262 total_hops: (path.len() - 1) as u32,
1263 });
1264 }
1265
1266 let pending = PendingK2KMessage {
1268 source_kernel: source_kernel.clone(),
1269 dest_kernel: dest_kernel.clone(),
1270 message,
1271 queued_at: Instant::now(),
1272 hops: 2, };
1274
1275 self.enqueue_pending(source_device, dest_device, pending);
1276 self.stats.messages_pending.fetch_add(1, Ordering::Relaxed);
1277
1278 Ok(RoutingDecision::HostMediated {
1279 source_device,
1280 dest_device,
1281 })
1282 }
1283
1284 pub fn drain_pending(&self, source: usize, dest: usize) -> Vec<PendingK2KMessage> {
1286 let mut queues = self.pending_queues.write();
1287 let messages = queues.remove(&(source, dest)).unwrap_or_default();
1288 self.stats
1289 .messages_pending
1290 .fetch_sub(messages.len(), Ordering::Relaxed);
1291 messages
1292 }
1293
1294 pub fn record_delivery(&self, message: &PendingK2KMessage, payload_size: usize) {
1296 self.stats.messages_routed.fetch_add(1, Ordering::Relaxed);
1297 self.stats
1298 .bytes_transferred
1299 .fetch_add(payload_size as u64, Ordering::Relaxed);
1300
1301 let latency = message.queued_at.elapsed().as_micros() as u64;
1302 self.stats
1303 .total_latency_us
1304 .fetch_add(latency, Ordering::Relaxed);
1305 }
1306
1307 pub fn record_failure(&self) {
1309 self.stats.routing_failures.fetch_add(1, Ordering::Relaxed);
1310 }
1311
1312 pub fn stats(&self) -> CrossGpuRouterStatsSnapshot {
1314 let messages_routed = self.stats.messages_routed.load(Ordering::Relaxed);
1315 let total_latency = self.stats.total_latency_us.load(Ordering::Relaxed);
1316
1317 CrossGpuRouterStatsSnapshot {
1318 messages_routed,
1319 bytes_transferred: self.stats.bytes_transferred.load(Ordering::Relaxed),
1320 messages_pending: self.stats.messages_pending.load(Ordering::Relaxed),
1321 avg_latency_us: if messages_routed > 0 {
1322 total_latency as f64 / messages_routed as f64
1323 } else {
1324 0.0
1325 },
1326 routing_failures: self.stats.routing_failures.load(Ordering::Relaxed),
1327 }
1328 }
1329
1330 fn enqueue_pending(&self, source: usize, dest: usize, message: PendingK2KMessage) {
1331 let mut queues = self.pending_queues.write();
1332 queues.entry((source, dest)).or_default().push(message);
1333 }
1334}
1335
1336#[derive(Debug, Clone)]
1338pub struct CrossGpuRouterStatsSnapshot {
1339 pub messages_routed: u64,
1341 pub bytes_transferred: u64,
1343 pub messages_pending: usize,
1345 pub avg_latency_us: f64,
1347 pub routing_failures: u64,
1349}
1350
1351#[derive(Debug, Clone)]
1353pub enum RoutingDecision {
1354 SameDevice,
1356 DirectP2P {
1358 source_device: usize,
1360 dest_device: usize,
1362 bandwidth_gbps: f64,
1364 },
1365 MultiHop {
1367 path: Vec<usize>,
1369 total_hops: u32,
1371 },
1372 HostMediated {
1374 source_device: usize,
1376 dest_device: usize,
1378 },
1379}
1380
1381#[derive(Debug, Clone, Default)]
1383pub struct MultiGpuStats {
1384 pub device_count: usize,
1386 pub total_kernels: usize,
1388 pub total_memory: u64,
1390 pub available_memory: u64,
1392 pub kernels_launched: u64,
1394}
1395
1396pub struct MultiGpuBuilder {
1398 config: MultiGpuConfig,
1399}
1400
1401impl MultiGpuBuilder {
1402 pub fn new() -> Self {
1404 Self {
1405 config: MultiGpuConfig::default(),
1406 }
1407 }
1408
1409 pub fn load_balancing(mut self, strategy: LoadBalancingStrategy) -> Self {
1411 self.config.load_balancing = strategy;
1412 self
1413 }
1414
1415 pub fn auto_select_device(mut self, enable: bool) -> Self {
1417 self.config.auto_select_device = enable;
1418 self
1419 }
1420
1421 pub fn max_kernels_per_device(mut self, max: usize) -> Self {
1423 self.config.max_kernels_per_device = max;
1424 self
1425 }
1426
1427 pub fn enable_p2p(mut self, enable: bool) -> Self {
1429 self.config.enable_p2p = enable;
1430 self
1431 }
1432
1433 pub fn preferred_devices(mut self, devices: Vec<usize>) -> Self {
1435 self.config.preferred_devices = devices;
1436 self
1437 }
1438
1439 pub fn build(self) -> Arc<MultiGpuCoordinator> {
1441 MultiGpuCoordinator::new(self.config)
1442 }
1443}
1444
1445impl Default for MultiGpuBuilder {
1446 fn default() -> Self {
1447 Self::new()
1448 }
1449}
1450
1451pub struct CrossDeviceTransfer {
1453 pub source_device: usize,
1455 pub dest_device: usize,
1457 pub size: usize,
1459 pub use_p2p: bool,
1461}
1462
1463impl CrossDeviceTransfer {
1464 pub fn new(source: usize, dest: usize, size: usize) -> Self {
1466 Self {
1467 source_device: source,
1468 dest_device: dest,
1469 size,
1470 use_p2p: true,
1471 }
1472 }
1473
1474 pub fn without_p2p(mut self) -> Self {
1476 self.use_p2p = false;
1477 self
1478 }
1479}
1480
1481use crate::checkpoint::{CheckpointStorage, CheckpointableKernel, MemoryStorage};
1486
1487pub struct KernelMigrator {
1504 coordinator: Arc<MultiGpuCoordinator>,
1506 storage: Arc<dyn CheckpointStorage>,
1508 stats: MigrationStats,
1510}
1511
1512#[derive(Debug, Default)]
1514pub struct MigrationStats {
1515 pub successful_migrations: AtomicU64,
1517 pub failed_migrations: AtomicU64,
1519 pub bytes_transferred: AtomicU64,
1521 pub checkpoint_time_us: AtomicU64,
1523 pub restore_time_us: AtomicU64,
1525}
1526
1527#[derive(Debug, Clone)]
1529pub struct MigrationResult {
1530 pub kernel_id: KernelId,
1532 pub source_device: usize,
1534 pub target_device: usize,
1536 pub checkpoint_size: usize,
1538 pub checkpoint_duration: Duration,
1540 pub transfer_duration: Duration,
1542 pub restore_duration: Duration,
1544 pub total_duration: Duration,
1546}
1547
1548impl KernelMigrator {
1549 pub fn new(coordinator: Arc<MultiGpuCoordinator>) -> Self {
1551 Self {
1552 coordinator,
1553 storage: Arc::new(MemoryStorage::new()),
1554 stats: MigrationStats::default(),
1555 }
1556 }
1557
1558 pub fn with_storage(
1560 coordinator: Arc<MultiGpuCoordinator>,
1561 storage: Arc<dyn CheckpointStorage>,
1562 ) -> Self {
1563 Self {
1564 coordinator,
1565 storage,
1566 stats: MigrationStats::default(),
1567 }
1568 }
1569
1570 pub fn migrate_with_checkpoint<K: CheckpointableKernel>(
1579 &self,
1580 kernel: &K,
1581 request: &mut MigrationRequest,
1582 ) -> Result<MigrationResult> {
1583 let start_time = Instant::now();
1584 request.started_at = Some(start_time);
1585
1586 request.state = MigrationState::Quiescing;
1588 request.state = MigrationState::Checkpointing;
1593 let checkpoint_start = Instant::now();
1594 let checkpoint = kernel.create_checkpoint().map_err(|e| {
1595 self.stats.failed_migrations.fetch_add(1, Ordering::Relaxed);
1596 request.state = MigrationState::Failed;
1597 RingKernelError::MigrationFailed(format!("Checkpoint creation failed: {}", e))
1598 })?;
1599 let checkpoint_duration = checkpoint_start.elapsed();
1600 let checkpoint_size = checkpoint.total_size();
1601
1602 self.stats
1603 .checkpoint_time_us
1604 .fetch_add(checkpoint_duration.as_micros() as u64, Ordering::Relaxed);
1605
1606 request.state = MigrationState::Transferring;
1608 let transfer_start = Instant::now();
1609
1610 let checkpoint_name = format!(
1612 "migration_{}_{}_{}",
1613 request.kernel_id.as_str(),
1614 request.source_device,
1615 request.target_device
1616 );
1617 self.storage
1618 .save(&checkpoint, &checkpoint_name)
1619 .map_err(|e| {
1620 self.stats.failed_migrations.fetch_add(1, Ordering::Relaxed);
1621 request.state = MigrationState::Failed;
1622 RingKernelError::MigrationFailed(format!("Checkpoint transfer failed: {}", e))
1623 })?;
1624
1625 let transfer_duration = transfer_start.elapsed();
1626 self.stats
1627 .bytes_transferred
1628 .fetch_add(checkpoint_size as u64, Ordering::Relaxed);
1629
1630 request.state = MigrationState::Restoring;
1632 let restore_start = Instant::now();
1633
1634 let _restored = self.storage.load(&checkpoint_name).map_err(|e| {
1636 self.stats.failed_migrations.fetch_add(1, Ordering::Relaxed);
1637 request.state = MigrationState::Failed;
1638 RingKernelError::MigrationFailed(format!("Checkpoint restore failed: {}", e))
1639 })?;
1640
1641 let restore_duration = restore_start.elapsed();
1642 self.stats
1643 .restore_time_us
1644 .fetch_add(restore_duration.as_micros() as u64, Ordering::Relaxed);
1645
1646 request.state = MigrationState::Completed;
1648 self.coordinator.complete_migration(request)?;
1649
1650 let _ = self.storage.delete(&checkpoint_name);
1652
1653 self.stats
1654 .successful_migrations
1655 .fetch_add(1, Ordering::Relaxed);
1656
1657 Ok(MigrationResult {
1658 kernel_id: request.kernel_id.clone(),
1659 source_device: request.source_device,
1660 target_device: request.target_device,
1661 checkpoint_size,
1662 checkpoint_duration,
1663 transfer_duration,
1664 restore_duration,
1665 total_duration: start_time.elapsed(),
1666 })
1667 }
1668
1669 pub fn coordinator(&self) -> &Arc<MultiGpuCoordinator> {
1671 &self.coordinator
1672 }
1673
1674 pub fn stats(&self) -> MigrationStatsSnapshot {
1676 let successful = self.stats.successful_migrations.load(Ordering::Relaxed);
1677 let failed = self.stats.failed_migrations.load(Ordering::Relaxed);
1678 let total = successful + failed;
1679 let checkpoint_us = self.stats.checkpoint_time_us.load(Ordering::Relaxed);
1680 let restore_us = self.stats.restore_time_us.load(Ordering::Relaxed);
1681
1682 MigrationStatsSnapshot {
1683 successful_migrations: successful,
1684 failed_migrations: failed,
1685 bytes_transferred: self.stats.bytes_transferred.load(Ordering::Relaxed),
1686 avg_checkpoint_time: checkpoint_us
1687 .checked_div(total)
1688 .map(Duration::from_micros)
1689 .unwrap_or(Duration::ZERO),
1690 avg_restore_time: restore_us
1691 .checked_div(total)
1692 .map(Duration::from_micros)
1693 .unwrap_or(Duration::ZERO),
1694 }
1695 }
1696}
1697
1698#[derive(Debug, Clone)]
1700pub struct MigrationStatsSnapshot {
1701 pub successful_migrations: u64,
1703 pub failed_migrations: u64,
1705 pub bytes_transferred: u64,
1707 pub avg_checkpoint_time: Duration,
1709 pub avg_restore_time: Duration,
1711}
1712
1713pub trait MigratableKernel: CheckpointableKernel {
1715 fn prepare_for_migration(&mut self) -> Result<()>;
1717
1718 fn cancel_migration(&mut self) -> Result<()>;
1720
1721 fn is_quiescent(&self) -> bool;
1723
1724 fn estimated_state_size(&self) -> usize;
1726}
1727
1728#[derive(Debug, Clone)]
1734pub struct HotReloadConfig {
1735 pub enabled: bool,
1737 pub reload_timeout: Duration,
1739 pub preserve_state: bool,
1741 pub max_retries: u32,
1743 pub retry_backoff: Duration,
1745 pub validate_before_swap: bool,
1747 pub keep_fallback: bool,
1749 pub max_rule_history: usize,
1752}
1753
1754impl Default for HotReloadConfig {
1755 fn default() -> Self {
1756 Self {
1757 enabled: true,
1758 reload_timeout: Duration::from_secs(30),
1759 preserve_state: true,
1760 max_retries: 3,
1761 retry_backoff: Duration::from_millis(500),
1762 validate_before_swap: true,
1763 keep_fallback: true,
1764 max_rule_history: 5,
1765 }
1766 }
1767}
1768
1769impl HotReloadConfig {
1770 pub fn new() -> Self {
1772 Self::default()
1773 }
1774
1775 pub fn with_enabled(mut self, enabled: bool) -> Self {
1777 self.enabled = enabled;
1778 self
1779 }
1780
1781 pub fn with_timeout(mut self, timeout: Duration) -> Self {
1783 self.reload_timeout = timeout;
1784 self
1785 }
1786
1787 pub fn with_preserve_state(mut self, preserve: bool) -> Self {
1789 self.preserve_state = preserve;
1790 self
1791 }
1792
1793 pub fn with_max_retries(mut self, retries: u32) -> Self {
1795 self.max_retries = retries;
1796 self
1797 }
1798}
1799
1800#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1802pub enum HotReloadState {
1803 Idle,
1805 Draining,
1807 Checkpointing,
1809 Compiling,
1811 Validating,
1813 Swapping,
1815 Restoring,
1817 Completed,
1819 Failed,
1821 RollingBack,
1823}
1824
1825#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1827pub enum KernelCodeFormat {
1828 Ptx,
1830 Cubin,
1832 SpirV,
1834 Wgsl,
1836 Msl,
1838 MetalLib,
1840 Source,
1842}
1843
1844#[derive(Debug, Clone)]
1846pub struct KernelCodeSource {
1847 pub version_id: u64,
1849 pub format: KernelCodeFormat,
1851 pub code: Vec<u8>,
1853 pub entry_point: String,
1855 pub metadata: HashMap<String, String>,
1857 pub created_at: Instant,
1859 pub hash: [u8; 32],
1861}
1862
1863impl KernelCodeSource {
1864 pub fn new(format: KernelCodeFormat, code: Vec<u8>, entry_point: impl Into<String>) -> Self {
1866 let hash = Self::compute_hash(&code);
1867 Self {
1868 version_id: 0,
1869 format,
1870 code,
1871 entry_point: entry_point.into(),
1872 metadata: HashMap::new(),
1873 created_at: Instant::now(),
1874 hash,
1875 }
1876 }
1877
1878 pub fn from_ptx(ptx: &str, entry_point: impl Into<String>) -> Self {
1880 Self::new(KernelCodeFormat::Ptx, ptx.as_bytes().to_vec(), entry_point)
1881 }
1882
1883 pub fn from_wgsl(wgsl: &str, entry_point: impl Into<String>) -> Self {
1885 Self::new(
1886 KernelCodeFormat::Wgsl,
1887 wgsl.as_bytes().to_vec(),
1888 entry_point,
1889 )
1890 }
1891
1892 pub fn from_msl(msl: &str, entry_point: impl Into<String>) -> Self {
1894 Self::new(KernelCodeFormat::Msl, msl.as_bytes().to_vec(), entry_point)
1895 }
1896
1897 pub fn with_version(mut self, version: u64) -> Self {
1899 self.version_id = version;
1900 self
1901 }
1902
1903 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
1905 self.metadata.insert(key.into(), value.into());
1906 self
1907 }
1908
1909 fn compute_hash(data: &[u8]) -> [u8; 32] {
1910 use std::hash::{Hash, Hasher};
1911 let mut hasher = std::collections::hash_map::DefaultHasher::new();
1912 data.hash(&mut hasher);
1913 let h1 = hasher.finish();
1914 h1.hash(&mut hasher);
1915 let h2 = hasher.finish();
1916 h1.hash(&mut hasher);
1917 let h3 = hasher.finish();
1918 h1.hash(&mut hasher);
1919 let h4 = hasher.finish();
1920
1921 let mut hash = [0u8; 32];
1922 hash[0..8].copy_from_slice(&h1.to_le_bytes());
1923 hash[8..16].copy_from_slice(&h2.to_le_bytes());
1924 hash[16..24].copy_from_slice(&h3.to_le_bytes());
1925 hash[24..32].copy_from_slice(&h4.to_le_bytes());
1926 hash
1927 }
1928
1929 pub fn as_str(&self) -> Option<&str> {
1931 match self.format {
1932 KernelCodeFormat::Ptx
1933 | KernelCodeFormat::Wgsl
1934 | KernelCodeFormat::Msl
1935 | KernelCodeFormat::Source => std::str::from_utf8(&self.code).ok(),
1936 _ => None,
1937 }
1938 }
1939
1940 pub fn size(&self) -> usize {
1942 self.code.len()
1943 }
1944}
1945
1946#[derive(Debug)]
1948pub struct HotReloadRequest {
1949 pub kernel_id: KernelId,
1951 pub new_code: KernelCodeSource,
1953 pub state: HotReloadState,
1955 pub created_at: Instant,
1957 pub started_at: Option<Instant>,
1959 pub retry_count: u32,
1961 pub error: Option<String>,
1963 checkpoint_data: Option<Vec<u8>>,
1965}
1966
1967impl HotReloadRequest {
1968 pub fn new(kernel_id: KernelId, new_code: KernelCodeSource) -> Self {
1970 Self {
1971 kernel_id,
1972 new_code,
1973 state: HotReloadState::Idle,
1974 created_at: Instant::now(),
1975 started_at: None,
1976 retry_count: 0,
1977 error: None,
1978 checkpoint_data: None,
1979 }
1980 }
1981
1982 pub fn is_in_progress(&self) -> bool {
1984 !matches!(
1985 self.state,
1986 HotReloadState::Idle | HotReloadState::Completed | HotReloadState::Failed
1987 )
1988 }
1989
1990 pub fn is_completed(&self) -> bool {
1992 self.state == HotReloadState::Completed
1993 }
1994
1995 pub fn is_failed(&self) -> bool {
1997 self.state == HotReloadState::Failed
1998 }
1999
2000 pub fn elapsed(&self) -> Duration {
2002 self.created_at.elapsed()
2003 }
2004
2005 pub fn reload_elapsed(&self) -> Option<Duration> {
2007 self.started_at.map(|s| s.elapsed())
2008 }
2009}
2010
2011#[derive(Debug, Clone)]
2013pub struct HotReloadResult {
2014 pub kernel_id: KernelId,
2016 pub old_version: u64,
2018 pub new_version: u64,
2020 pub state_preserved: bool,
2022 pub checkpoint_size: usize,
2024 pub drain_duration: Duration,
2026 pub checkpoint_duration: Duration,
2028 pub compile_duration: Duration,
2030 pub swap_duration: Duration,
2032 pub restore_duration: Duration,
2034 pub total_duration: Duration,
2036}
2037
2038#[derive(Debug, Default)]
2040struct HotReloadStats {
2041 successful_reloads: AtomicU64,
2042 failed_reloads: AtomicU64,
2043 rollbacks: AtomicU64,
2044 total_drain_time_us: AtomicU64,
2045 total_compile_time_us: AtomicU64,
2046 total_swap_time_us: AtomicU64,
2047 state_preserved_count: AtomicU64,
2048}
2049
2050#[derive(Debug, Clone)]
2052pub struct HotReloadStatsSnapshot {
2053 pub successful_reloads: u64,
2055 pub failed_reloads: u64,
2057 pub rollbacks: u64,
2059 pub avg_drain_time: Duration,
2061 pub avg_compile_time: Duration,
2063 pub avg_swap_time: Duration,
2065 pub state_preserved_count: u64,
2067}
2068
2069pub struct HotReloadManager {
2099 config: HotReloadConfig,
2101 kernels: RwLock<HashMap<KernelId, KernelCodeSource>>,
2103 fallbacks: RwLock<HashMap<KernelId, KernelCodeSource>>,
2105 active_requests: RwLock<HashMap<KernelId, HotReloadRequest>>,
2107 version_counter: AtomicU64,
2109 stats: HotReloadStats,
2111 rule_registry: Arc<crate::rules::RuleRegistry>,
2115}
2116
2117impl HotReloadManager {
2118 pub fn new(config: HotReloadConfig) -> Arc<Self> {
2120 Self::with_rule_backend(config, Arc::new(crate::rules::NoopSwapBackend))
2121 }
2122
2123 pub fn with_rule_backend(
2128 config: HotReloadConfig,
2129 rule_backend: Arc<dyn crate::rules::RuleSwapBackend>,
2130 ) -> Arc<Self> {
2131 let rule_registry = Arc::new(crate::rules::RuleRegistry::new(
2132 config.max_rule_history,
2133 rule_backend,
2134 ));
2135 Arc::new(Self {
2136 config,
2137 kernels: RwLock::new(HashMap::new()),
2138 fallbacks: RwLock::new(HashMap::new()),
2139 active_requests: RwLock::new(HashMap::new()),
2140 version_counter: AtomicU64::new(1),
2141 stats: HotReloadStats::default(),
2142 rule_registry,
2143 })
2144 }
2145
2146 pub fn rule_registry(&self) -> &Arc<crate::rules::RuleRegistry> {
2148 &self.rule_registry
2149 }
2150
2151 pub fn with_defaults() -> Arc<Self> {
2153 Self::new(HotReloadConfig::default())
2154 }
2155
2156 pub fn is_enabled(&self) -> bool {
2158 self.config.enabled
2159 }
2160
2161 pub fn register_kernel(&self, kernel_id: &KernelId, code: KernelCodeSource) {
2163 let version = self.version_counter.fetch_add(1, Ordering::Relaxed);
2164 let code = code.with_version(version);
2165 self.kernels.write().insert(kernel_id.clone(), code);
2166 }
2167
2168 pub fn unregister_kernel(&self, kernel_id: &KernelId) {
2170 self.kernels.write().remove(kernel_id);
2171 self.fallbacks.write().remove(kernel_id);
2172 self.active_requests.write().remove(kernel_id);
2173 }
2174
2175 pub fn get_current_version(&self, kernel_id: &KernelId) -> Option<u64> {
2177 self.kernels.read().get(kernel_id).map(|c| c.version_id)
2178 }
2179
2180 pub fn get_current_code(&self, kernel_id: &KernelId) -> Option<KernelCodeSource> {
2182 self.kernels.read().get(kernel_id).cloned()
2183 }
2184
2185 pub fn request_reload(
2187 &self,
2188 kernel_id: &KernelId,
2189 new_code: KernelCodeSource,
2190 ) -> Result<HotReloadRequest> {
2191 if !self.config.enabled {
2192 return Err(RingKernelError::ValidationError(
2193 "Hot reload is disabled".to_string(),
2194 ));
2195 }
2196
2197 if !self.kernels.read().contains_key(kernel_id) {
2199 return Err(RingKernelError::KernelNotFound(
2200 kernel_id.as_str().to_string(),
2201 ));
2202 }
2203
2204 {
2206 let active = self.active_requests.read();
2207 if let Some(existing) = active.get(kernel_id) {
2208 if existing.is_in_progress() {
2209 return Err(RingKernelError::ValidationError(
2210 "Hot reload already in progress for this kernel".to_string(),
2211 ));
2212 }
2213 }
2214 }
2215
2216 let version = self.version_counter.fetch_add(1, Ordering::Relaxed);
2218 let new_code = new_code.with_version(version);
2219
2220 let request = HotReloadRequest::new(kernel_id.clone(), new_code);
2221 self.active_requests.write().insert(
2222 kernel_id.clone(),
2223 HotReloadRequest::new(kernel_id.clone(), request.new_code.clone()),
2224 );
2225
2226 Ok(request)
2227 }
2228
2229 pub fn execute_reload<K: CheckpointableKernel>(
2238 &self,
2239 request: &mut HotReloadRequest,
2240 kernel: &K,
2241 ) -> Result<HotReloadResult> {
2242 let start_time = Instant::now();
2243 request.started_at = Some(start_time);
2244
2245 let old_version = self
2247 .kernels
2248 .read()
2249 .get(&request.kernel_id)
2250 .map(|c| c.version_id)
2251 .unwrap_or(0);
2252
2253 request.state = HotReloadState::Draining;
2255 let drain_start = Instant::now();
2256 std::thread::sleep(Duration::from_micros(10));
2258 let drain_duration = drain_start.elapsed();
2259 self.stats
2260 .total_drain_time_us
2261 .fetch_add(drain_duration.as_micros() as u64, Ordering::Relaxed);
2262
2263 request.state = HotReloadState::Checkpointing;
2265 let checkpoint_start = Instant::now();
2266 let checkpoint_size = if self.config.preserve_state {
2267 let checkpoint = kernel.create_checkpoint()?;
2268 let data = checkpoint.to_bytes();
2269 request.checkpoint_data = Some(data.clone());
2270 data.len()
2271 } else {
2272 0
2273 };
2274 let checkpoint_duration = checkpoint_start.elapsed();
2275
2276 request.state = HotReloadState::Validating;
2278 if self.config.validate_before_swap {
2279 self.validate_code(&request.new_code)?;
2280 }
2281
2282 request.state = HotReloadState::Compiling;
2284 let compile_start = Instant::now();
2285 std::thread::sleep(Duration::from_micros(10));
2287 let compile_duration = compile_start.elapsed();
2288 self.stats
2289 .total_compile_time_us
2290 .fetch_add(compile_duration.as_micros() as u64, Ordering::Relaxed);
2291
2292 request.state = HotReloadState::Swapping;
2294 let swap_start = Instant::now();
2295
2296 if self.config.keep_fallback {
2298 if let Some(old_code) = self.kernels.read().get(&request.kernel_id).cloned() {
2299 self.fallbacks
2300 .write()
2301 .insert(request.kernel_id.clone(), old_code);
2302 }
2303 }
2304
2305 self.kernels
2307 .write()
2308 .insert(request.kernel_id.clone(), request.new_code.clone());
2309 let swap_duration = swap_start.elapsed();
2310 self.stats
2311 .total_swap_time_us
2312 .fetch_add(swap_duration.as_micros() as u64, Ordering::Relaxed);
2313
2314 request.state = HotReloadState::Restoring;
2316 let restore_start = Instant::now();
2317 let restore_duration = restore_start.elapsed();
2319
2320 request.state = HotReloadState::Completed;
2322 self.stats
2323 .successful_reloads
2324 .fetch_add(1, Ordering::Relaxed);
2325 if self.config.preserve_state && checkpoint_size > 0 {
2326 self.stats
2327 .state_preserved_count
2328 .fetch_add(1, Ordering::Relaxed);
2329 }
2330
2331 self.active_requests.write().remove(&request.kernel_id);
2333
2334 Ok(HotReloadResult {
2335 kernel_id: request.kernel_id.clone(),
2336 old_version,
2337 new_version: request.new_code.version_id,
2338 state_preserved: self.config.preserve_state && checkpoint_size > 0,
2339 checkpoint_size,
2340 drain_duration,
2341 checkpoint_duration,
2342 compile_duration,
2343 swap_duration,
2344 restore_duration,
2345 total_duration: start_time.elapsed(),
2346 })
2347 }
2348
2349 pub fn rollback(&self, kernel_id: &KernelId) -> Result<()> {
2351 let fallback =
2352 self.fallbacks.write().remove(kernel_id).ok_or_else(|| {
2353 RingKernelError::ValidationError("No fallback available".to_string())
2354 })?;
2355
2356 self.kernels.write().insert(kernel_id.clone(), fallback);
2357 self.stats.rollbacks.fetch_add(1, Ordering::Relaxed);
2358
2359 if let Some(request) = self.active_requests.write().get_mut(kernel_id) {
2361 request.state = HotReloadState::RollingBack;
2362 }
2363
2364 Ok(())
2365 }
2366
2367 fn validate_code(&self, code: &KernelCodeSource) -> Result<()> {
2369 if code.code.is_empty() {
2371 return Err(RingKernelError::ValidationError(
2372 "Kernel code is empty".to_string(),
2373 ));
2374 }
2375
2376 if code.entry_point.is_empty() {
2377 return Err(RingKernelError::ValidationError(
2378 "Entry point is empty".to_string(),
2379 ));
2380 }
2381
2382 match code.format {
2384 KernelCodeFormat::Ptx => {
2385 if let Some(text) = code.as_str() {
2387 if !text.contains(".version") && !text.contains(".target") {
2388 return Err(RingKernelError::ValidationError(
2389 "PTX code missing version/target directive".to_string(),
2390 ));
2391 }
2392 }
2393 }
2394 KernelCodeFormat::Wgsl => {
2395 if let Some(text) = code.as_str() {
2397 if !text.contains("@compute") && !text.contains("fn ") {
2398 return Err(RingKernelError::ValidationError(
2399 "WGSL code missing compute shader or function".to_string(),
2400 ));
2401 }
2402 }
2403 }
2404 KernelCodeFormat::Msl => {
2405 if let Some(text) = code.as_str() {
2407 if !text.contains("kernel ") {
2408 return Err(RingKernelError::ValidationError(
2409 "MSL code missing kernel function".to_string(),
2410 ));
2411 }
2412 }
2413 }
2414 _ => {}
2415 }
2416
2417 Ok(())
2418 }
2419
2420 pub fn stats(&self) -> HotReloadStatsSnapshot {
2422 let successful = self.stats.successful_reloads.load(Ordering::Relaxed);
2423 let failed = self.stats.failed_reloads.load(Ordering::Relaxed);
2424 let total = successful.max(1);
2425
2426 HotReloadStatsSnapshot {
2427 successful_reloads: successful,
2428 failed_reloads: failed,
2429 rollbacks: self.stats.rollbacks.load(Ordering::Relaxed),
2430 avg_drain_time: Duration::from_micros(
2431 self.stats.total_drain_time_us.load(Ordering::Relaxed) / total,
2432 ),
2433 avg_compile_time: Duration::from_micros(
2434 self.stats.total_compile_time_us.load(Ordering::Relaxed) / total,
2435 ),
2436 avg_swap_time: Duration::from_micros(
2437 self.stats.total_swap_time_us.load(Ordering::Relaxed) / total,
2438 ),
2439 state_preserved_count: self.stats.state_preserved_count.load(Ordering::Relaxed),
2440 }
2441 }
2442
2443 pub fn list_kernels(&self) -> Vec<KernelId> {
2445 self.kernels.read().keys().cloned().collect()
2446 }
2447
2448 pub fn is_registered(&self, kernel_id: &KernelId) -> bool {
2450 self.kernels.read().contains_key(kernel_id)
2451 }
2452
2453 pub fn is_reload_in_progress(&self, kernel_id: &KernelId) -> bool {
2455 self.active_requests
2456 .read()
2457 .get(kernel_id)
2458 .map(|r| r.is_in_progress())
2459 .unwrap_or(false)
2460 }
2461
2462 pub fn config(&self) -> &HotReloadConfig {
2464 &self.config
2465 }
2466}
2467
2468pub trait HotReloadableKernel: CheckpointableKernel {
2470 fn prepare_for_reload(&mut self) -> Result<()>;
2472
2473 fn apply_code(&mut self, code: &KernelCodeSource) -> Result<()>;
2475
2476 fn resume_after_reload(&mut self) -> Result<()>;
2478
2479 fn is_ready_for_reload(&self) -> bool;
2481}
2482
2483#[cfg(test)]
2484mod tests {
2485 use super::*;
2486
2487 #[test]
2488 fn test_device_info() {
2489 let info = DeviceInfo::new(0, "Test GPU".to_string(), Backend::Cuda);
2490 assert_eq!(info.index, 0);
2491 assert_eq!(info.name, "Test GPU");
2492 assert_eq!(info.memory_utilization(), 0.0);
2493 }
2494
2495 #[test]
2496 fn test_coordinator_registration() {
2497 let coord = MultiGpuBuilder::new().build();
2498
2499 let device = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2500 coord.register_device(device);
2501
2502 assert_eq!(coord.device_count(), 1);
2503 assert!(coord.device(0).is_some());
2504 }
2505
2506 #[test]
2507 fn test_kernel_assignment() {
2508 let coord = MultiGpuBuilder::new().build();
2509
2510 let device = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2511 coord.register_device(device);
2512
2513 let kernel_id = KernelId::new("test_kernel");
2514 coord.assign_kernel(kernel_id.clone(), 0);
2515
2516 assert_eq!(coord.get_kernel_device(&kernel_id), Some(0));
2517 assert_eq!(coord.kernels_on_device(0).len(), 1);
2518 }
2519
2520 #[test]
2521 fn test_load_balancing_least_loaded() {
2522 let coord = MultiGpuBuilder::new()
2523 .load_balancing(LoadBalancingStrategy::LeastLoaded)
2524 .build();
2525
2526 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2528 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2529
2530 coord.assign_kernel(KernelId::new("k1"), 0);
2532
2533 let selected = coord.select_device(&LaunchOptions::default()).unwrap();
2535 assert_eq!(selected, 1);
2536 }
2537
2538 #[test]
2539 fn test_round_robin() {
2540 let coord = MultiGpuBuilder::new()
2541 .load_balancing(LoadBalancingStrategy::RoundRobin)
2542 .build();
2543
2544 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2545 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2546
2547 let d1 = coord.select_device(&LaunchOptions::default()).unwrap();
2548 let d2 = coord.select_device(&LaunchOptions::default()).unwrap();
2549 let d3 = coord.select_device(&LaunchOptions::default()).unwrap();
2550
2551 assert_ne!(d1, d2);
2553 assert_eq!(d1, d3);
2554 }
2555
2556 #[test]
2561 fn test_interconnect_bandwidth() {
2562 assert!(
2563 InterconnectType::NvLink.estimated_bandwidth_gbps()
2564 > InterconnectType::Pcie.estimated_bandwidth_gbps()
2565 );
2566 assert!(
2567 InterconnectType::Pcie.estimated_bandwidth_gbps()
2568 > InterconnectType::None.estimated_bandwidth_gbps()
2569 );
2570 assert!(
2571 InterconnectType::SameDevice.estimated_bandwidth_gbps()
2572 > InterconnectType::NvLink.estimated_bandwidth_gbps()
2573 );
2574 }
2575
2576 #[test]
2577 fn test_interconnect_p2p_support() {
2578 assert!(!InterconnectType::None.supports_p2p());
2579 assert!(InterconnectType::Pcie.supports_p2p());
2580 assert!(InterconnectType::NvLink.supports_p2p());
2581 assert!(InterconnectType::NvSwitch.supports_p2p());
2582 }
2583
2584 #[test]
2585 fn test_gpu_topology_creation() {
2586 let topo = GpuTopology::new(4);
2587 assert_eq!(topo.device_count, 4);
2588
2589 for i in 0..4 {
2591 let conn = topo.get_connection(i, i);
2592 assert!(conn.is_some());
2593 assert_eq!(conn.unwrap().interconnect, InterconnectType::SameDevice);
2594 }
2595 }
2596
2597 #[test]
2598 fn test_gpu_topology_set_connection() {
2599 let mut topo = GpuTopology::new(4);
2600
2601 topo.set_connection(GpuConnection::new(0, 1, InterconnectType::NvLink));
2603
2604 let conn_01 = topo.get_connection(0, 1);
2605 assert!(conn_01.is_some());
2606 assert_eq!(conn_01.unwrap().interconnect, InterconnectType::NvLink);
2607
2608 let conn_10 = topo.get_connection(1, 0);
2610 assert!(conn_10.is_some());
2611 assert_eq!(conn_10.unwrap().interconnect, InterconnectType::NvLink);
2612 }
2613
2614 #[test]
2615 fn test_gpu_topology_neighbors() {
2616 let mut topo = GpuTopology::new(4);
2617
2618 topo.set_connection(GpuConnection::new(0, 1, InterconnectType::NvLink));
2620 topo.set_connection(GpuConnection::new(1, 2, InterconnectType::NvLink));
2621 topo.set_connection(GpuConnection::new(2, 3, InterconnectType::NvLink));
2622 topo.set_connection(GpuConnection::new(3, 0, InterconnectType::NvLink));
2623
2624 let neighbors_0 = topo.neighbors(0);
2625 assert_eq!(neighbors_0.len(), 2);
2626 assert!(neighbors_0.contains(&1));
2627 assert!(neighbors_0.contains(&3));
2628 }
2629
2630 #[test]
2631 fn test_gpu_topology_best_path() {
2632 let mut topo = GpuTopology::new(4);
2633
2634 topo.set_connection(GpuConnection::new(0, 1, InterconnectType::NvLink));
2636 topo.set_connection(GpuConnection::new(1, 2, InterconnectType::NvLink));
2637 topo.set_connection(GpuConnection::new(2, 3, InterconnectType::NvLink));
2638 topo.set_connection(GpuConnection::new(0, 3, InterconnectType::None)); let path_01 = topo.best_path(0, 1);
2642 assert_eq!(path_01, vec![0, 1]);
2643
2644 let path_00 = topo.best_path(0, 0);
2646 assert_eq!(path_00, vec![0]);
2647 }
2648
2649 #[test]
2650 fn test_gpu_topology_fully_connected() {
2651 let mut topo = GpuTopology::new(3);
2652
2653 assert!(!topo.is_fully_connected());
2655
2656 topo.set_connection(GpuConnection::new(0, 1, InterconnectType::NvLink));
2658 topo.set_connection(GpuConnection::new(0, 2, InterconnectType::NvLink));
2659 topo.set_connection(GpuConnection::new(1, 2, InterconnectType::NvLink));
2660
2661 assert!(topo.is_fully_connected());
2662 }
2663
2664 #[test]
2665 fn test_gpu_topology_numa() {
2666 let mut topo = GpuTopology::new(4);
2667
2668 topo.set_numa_node(0, 0);
2670 topo.set_numa_node(1, 0);
2671 topo.set_numa_node(2, 1);
2672 topo.set_numa_node(3, 1);
2673
2674 let numa_neighbors_0 = topo.numa_neighbors(0);
2675 assert_eq!(numa_neighbors_0, vec![1]);
2676
2677 let numa_neighbors_2 = topo.numa_neighbors(2);
2678 assert_eq!(numa_neighbors_2, vec![3]);
2679 }
2680
2681 #[test]
2686 fn test_coordinator_topology_discovery() {
2687 let coord = MultiGpuBuilder::new().enable_p2p(true).build();
2688
2689 let mut dev0 = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2691 dev0.p2p_capable = true;
2692 dev0.compute_capability = Some((8, 0)); let mut dev1 = DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda);
2695 dev1.p2p_capable = true;
2696 dev1.compute_capability = Some((8, 6)); coord.register_device(dev0);
2699 coord.register_device(dev1);
2700
2701 let topo = coord.discover_topology();
2702
2703 assert_eq!(topo.device_count, 2);
2704
2705 let conn = topo.get_connection(0, 1);
2707 assert!(conn.is_some());
2708 assert_eq!(conn.unwrap().interconnect, InterconnectType::NvLink);
2709 }
2710
2711 #[test]
2716 fn test_migration_request() {
2717 let coord = MultiGpuBuilder::new().build();
2718
2719 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2720 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2721
2722 let kernel_id = KernelId::new("migrating_kernel");
2723 coord.assign_kernel(kernel_id.clone(), 0);
2724
2725 let request = coord.request_migration(&kernel_id, 1).unwrap();
2726
2727 assert_eq!(request.source_device, 0);
2728 assert_eq!(request.target_device, 1);
2729 assert_eq!(request.state, MigrationState::Pending);
2730 }
2731
2732 #[test]
2733 fn test_migration_same_device_error() {
2734 let coord = MultiGpuBuilder::new().build();
2735
2736 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2737
2738 let kernel_id = KernelId::new("kernel");
2739 coord.assign_kernel(kernel_id.clone(), 0);
2740
2741 let result = coord.request_migration(&kernel_id, 0);
2742 assert!(result.is_err());
2743 }
2744
2745 #[test]
2746 fn test_migration_complete() {
2747 let coord = MultiGpuBuilder::new().build();
2748
2749 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2750 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2751
2752 let kernel_id = KernelId::new("migrating_kernel");
2753 coord.assign_kernel(kernel_id.clone(), 0);
2754
2755 assert_eq!(coord.get_kernel_device(&kernel_id), Some(0));
2756
2757 let request = coord.request_migration(&kernel_id, 1).unwrap();
2758 coord.complete_migration(&request).unwrap();
2759
2760 assert_eq!(coord.get_kernel_device(&kernel_id), Some(1));
2761 }
2762
2763 #[test]
2764 fn test_migration_transfer_time_estimate() {
2765 let request = MigrationRequest {
2766 kernel_id: KernelId::new("test"),
2767 source_device: 0,
2768 target_device: 1,
2769 path: vec![0, 1],
2770 estimated_bandwidth_gbps: 300.0, estimated_latency_us: 1.0,
2772 state: MigrationState::Pending,
2773 started_at: None,
2774 };
2775
2776 let time = request.estimate_transfer_time(1_000_000_000);
2778 assert!(time.as_micros() > 3000);
2779 assert!(time.as_micros() < 4000);
2780 }
2781
2782 use crate::hlc::HlcTimestamp;
2787 use crate::message::MessageEnvelope;
2788
2789 fn make_test_k2k_message(source: &KernelId, dest: &KernelId) -> K2KMessage {
2790 let timestamp = HlcTimestamp::now(42);
2791 let envelope = MessageEnvelope::empty(1, 2, timestamp);
2792 K2KMessage::new(source.clone(), dest.clone(), envelope, timestamp)
2793 }
2794
2795 #[test]
2796 fn test_router_same_device() {
2797 let coord = MultiGpuBuilder::new().build();
2798 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2799
2800 let k1 = KernelId::new("k1");
2801 let k2 = KernelId::new("k2");
2802 coord.assign_kernel(k1.clone(), 0);
2803 coord.assign_kernel(k2.clone(), 0);
2804
2805 let router = CrossGpuK2KRouter::new(coord);
2806
2807 let msg = make_test_k2k_message(&k1, &k2);
2808 let decision = router.route_message(&k1, &k2, msg).unwrap();
2809
2810 matches!(decision, RoutingDecision::SameDevice);
2811 }
2812
2813 #[test]
2814 fn test_router_cross_device() {
2815 let coord = MultiGpuBuilder::new().enable_p2p(true).build();
2816
2817 let mut dev0 = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2818 dev0.p2p_capable = true;
2819 let mut dev1 = DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda);
2820 dev1.p2p_capable = true;
2821
2822 coord.register_device(dev0);
2823 coord.register_device(dev1);
2824
2825 let k1 = KernelId::new("k1");
2826 let k2 = KernelId::new("k2");
2827 coord.assign_kernel(k1.clone(), 0);
2828 coord.assign_kernel(k2.clone(), 1);
2829
2830 let router = CrossGpuK2KRouter::new(coord);
2831
2832 let msg = make_test_k2k_message(&k1, &k2);
2833 let decision = router.route_message(&k1, &k2, msg).unwrap();
2834
2835 match decision {
2836 RoutingDecision::DirectP2P {
2837 source_device,
2838 dest_device,
2839 ..
2840 } => {
2841 assert_eq!(source_device, 0);
2842 assert_eq!(dest_device, 1);
2843 }
2844 _ => panic!("Expected DirectP2P routing"),
2845 }
2846 }
2847
2848 #[test]
2849 fn test_router_pending_messages() {
2850 let coord = MultiGpuBuilder::new().enable_p2p(true).build();
2851
2852 let mut dev0 = DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda);
2853 dev0.p2p_capable = true;
2854 let mut dev1 = DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda);
2855 dev1.p2p_capable = true;
2856
2857 coord.register_device(dev0);
2858 coord.register_device(dev1);
2859
2860 let k1 = KernelId::new("k1");
2861 let k2 = KernelId::new("k2");
2862 coord.assign_kernel(k1.clone(), 0);
2863 coord.assign_kernel(k2.clone(), 1);
2864
2865 let router = CrossGpuK2KRouter::new(coord);
2866
2867 for _ in 0..3 {
2869 let msg = make_test_k2k_message(&k1, &k2);
2870 router.route_message(&k1, &k2, msg).unwrap();
2871 }
2872
2873 assert_eq!(router.stats().messages_pending, 3);
2874
2875 let pending = router.drain_pending(0, 1);
2877 assert_eq!(pending.len(), 3);
2878 assert_eq!(router.stats().messages_pending, 0);
2879 }
2880
2881 #[test]
2882 fn test_router_stats() {
2883 let coord = MultiGpuBuilder::new().build();
2884 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2885
2886 let k1 = KernelId::new("k1");
2887 let k2 = KernelId::new("k2");
2888 coord.assign_kernel(k1.clone(), 0);
2889 coord.assign_kernel(k2.clone(), 0);
2890
2891 let router = CrossGpuK2KRouter::new(coord);
2892
2893 let stats = router.stats();
2894 assert_eq!(stats.messages_routed, 0);
2895 assert_eq!(stats.bytes_transferred, 0);
2896 assert_eq!(stats.routing_failures, 0);
2897 }
2898
2899 use crate::checkpoint::{Checkpoint, CheckpointBuilder};
2904
2905 struct MockCheckpointableKernel {
2907 kernel_id: String,
2908 kernel_type: String,
2909 state_data: Vec<u8>,
2910 step: u64,
2911 }
2912
2913 impl MockCheckpointableKernel {
2914 fn new(kernel_id: &str, state_size: usize) -> Self {
2915 Self {
2916 kernel_id: kernel_id.to_string(),
2917 kernel_type: "mock_kernel".to_string(),
2918 state_data: vec![0xAB; state_size],
2919 step: 1000,
2920 }
2921 }
2922 }
2923
2924 impl CheckpointableKernel for MockCheckpointableKernel {
2925 fn create_checkpoint(&self) -> Result<Checkpoint> {
2926 let checkpoint = CheckpointBuilder::new(&self.kernel_id, &self.kernel_type)
2927 .step(self.step)
2928 .grid_size(64, 64, 64)
2929 .control_block(vec![1, 2, 3, 4])
2930 .device_memory("state", self.state_data.clone())
2931 .build();
2932 Ok(checkpoint)
2933 }
2934
2935 fn restore_from_checkpoint(&mut self, checkpoint: &Checkpoint) -> Result<()> {
2936 self.step = checkpoint.metadata.current_step;
2937 Ok(())
2938 }
2939
2940 fn checkpoint_kernel_id(&self) -> &str {
2941 &self.kernel_id
2942 }
2943
2944 fn checkpoint_kernel_type(&self) -> &str {
2945 &self.kernel_type
2946 }
2947 }
2948
2949 #[test]
2950 fn test_migrator_creation() {
2951 let coord = MultiGpuBuilder::new().build();
2952 let migrator = KernelMigrator::new(coord);
2953
2954 let stats = migrator.stats();
2955 assert_eq!(stats.successful_migrations, 0);
2956 assert_eq!(stats.failed_migrations, 0);
2957 assert_eq!(stats.bytes_transferred, 0);
2958 }
2959
2960 #[test]
2961 fn test_migrator_with_custom_storage() {
2962 let coord = MultiGpuBuilder::new().build();
2963 let storage = Arc::new(MemoryStorage::new());
2964 let migrator = KernelMigrator::with_storage(coord.clone(), storage);
2965
2966 assert!(Arc::ptr_eq(migrator.coordinator(), &coord));
2968 }
2969
2970 #[test]
2971 fn test_successful_migration() {
2972 let coord = MultiGpuBuilder::new().build();
2973
2974 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
2976 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
2977
2978 let kernel_id = KernelId::new("migratable_kernel");
2980 coord.assign_kernel(kernel_id.clone(), 0);
2981
2982 let migrator = KernelMigrator::new(coord.clone());
2983
2984 let kernel = MockCheckpointableKernel::new("migratable_kernel", 1024);
2986
2987 let mut request = coord.request_migration(&kernel_id, 1).unwrap();
2989 assert_eq!(request.state, MigrationState::Pending);
2990
2991 let result = migrator
2993 .migrate_with_checkpoint(&kernel, &mut request)
2994 .unwrap();
2995
2996 assert_eq!(result.kernel_id.as_str(), "migratable_kernel");
2998 assert_eq!(result.source_device, 0);
2999 assert_eq!(result.target_device, 1);
3000 assert!(result.checkpoint_size > 0);
3001 assert!(result.total_duration > Duration::ZERO);
3002
3003 assert_eq!(coord.get_kernel_device(&kernel_id), Some(1));
3005
3006 let stats = migrator.stats();
3008 assert_eq!(stats.successful_migrations, 1);
3009 assert_eq!(stats.failed_migrations, 0);
3010 assert!(stats.bytes_transferred > 0);
3011 }
3012
3013 #[test]
3014 fn test_migration_result_fields() {
3015 let coord = MultiGpuBuilder::new().build();
3016
3017 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3018 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
3019
3020 let kernel_id = KernelId::new("test_kernel");
3021 coord.assign_kernel(kernel_id.clone(), 0);
3022
3023 let migrator = KernelMigrator::new(coord.clone());
3024 let kernel = MockCheckpointableKernel::new("test_kernel", 4096);
3025 let mut request = coord.request_migration(&kernel_id, 1).unwrap();
3026
3027 let result = migrator
3028 .migrate_with_checkpoint(&kernel, &mut request)
3029 .unwrap();
3030
3031 assert!(result.checkpoint_duration >= Duration::ZERO);
3033 assert!(result.transfer_duration >= Duration::ZERO);
3034 assert!(result.restore_duration >= Duration::ZERO);
3035
3036 assert!(result.total_duration >= result.checkpoint_duration);
3038 }
3039
3040 #[test]
3041 fn test_migration_stats_accumulate() {
3042 let coord = MultiGpuBuilder::new().build();
3043
3044 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3045 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
3046
3047 let migrator = KernelMigrator::new(coord.clone());
3048
3049 let k1 = KernelId::new("k1");
3051 coord.assign_kernel(k1.clone(), 0);
3052 let kernel1 = MockCheckpointableKernel::new("k1", 1000);
3053 let mut req1 = coord.request_migration(&k1, 1).unwrap();
3054 migrator
3055 .migrate_with_checkpoint(&kernel1, &mut req1)
3056 .unwrap();
3057
3058 let k2 = KernelId::new("k2");
3060 coord.assign_kernel(k2.clone(), 0);
3061 let kernel2 = MockCheckpointableKernel::new("k2", 2000);
3062 let mut req2 = coord.request_migration(&k2, 1).unwrap();
3063 migrator
3064 .migrate_with_checkpoint(&kernel2, &mut req2)
3065 .unwrap();
3066
3067 let stats = migrator.stats();
3068 assert_eq!(stats.successful_migrations, 2);
3069 assert_eq!(stats.failed_migrations, 0);
3070 assert!(stats.bytes_transferred > 0);
3072 }
3073
3074 #[test]
3079 fn test_unregister_device_no_kernels() {
3080 let coord = MultiGpuBuilder::new().build();
3081
3082 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3083 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
3084
3085 let result = coord.unregister_device(0);
3086
3087 assert!(result.success);
3088 assert_eq!(result.device_index, 0);
3089 assert!(result.kernels_to_migrate.is_empty());
3090 assert!(result.orphaned_kernels.is_empty());
3091 }
3092
3093 #[test]
3094 fn test_unregister_device_with_kernels() {
3095 let coord = MultiGpuBuilder::new().build();
3096
3097 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3098 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
3099
3100 let k1 = KernelId::new("k1");
3102 let k2 = KernelId::new("k2");
3103 coord.assign_kernel(k1.clone(), 0);
3104 coord.assign_kernel(k2.clone(), 0);
3105
3106 let result = coord.unregister_device(0);
3107
3108 assert!(result.success);
3109 assert_eq!(result.kernels_to_migrate.len(), 2);
3110 assert!(result.orphaned_kernels.is_empty());
3111
3112 for plan in &result.kernels_to_migrate {
3114 assert_eq!(plan.source_device, 0);
3115 assert_eq!(plan.target_device, 1);
3116 }
3117
3118 assert_eq!(coord.get_kernel_device(&k1), Some(1));
3120 assert_eq!(coord.get_kernel_device(&k2), Some(1));
3121 }
3122
3123 #[test]
3124 fn test_unregister_single_device_orphans_kernels() {
3125 let coord = MultiGpuBuilder::new().build();
3126
3127 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3128
3129 let k1 = KernelId::new("k1");
3131 coord.assign_kernel(k1.clone(), 0);
3132
3133 let result = coord.unregister_device(0);
3134
3135 assert!(result.success);
3136 assert!(result.kernels_to_migrate.is_empty());
3137 assert_eq!(result.orphaned_kernels.len(), 1);
3138 assert_eq!(result.orphaned_kernels[0], k1);
3139
3140 assert!(coord.get_kernel_device(&k1).is_none());
3142 }
3143
3144 #[test]
3145 fn test_unregister_nonexistent_device() {
3146 let coord = MultiGpuBuilder::new().build();
3147
3148 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3149
3150 let result = coord.unregister_device(99);
3151
3152 assert!(!result.success);
3153 assert_eq!(result.device_index, 99);
3154 }
3155
3156 #[test]
3157 fn test_unregister_distributes_to_least_loaded() {
3158 let coord = MultiGpuBuilder::new().build();
3159
3160 coord.register_device(DeviceInfo::new(0, "GPU 0".to_string(), Backend::Cuda));
3161 coord.register_device(DeviceInfo::new(1, "GPU 1".to_string(), Backend::Cuda));
3162 coord.register_device(DeviceInfo::new(2, "GPU 2".to_string(), Backend::Cuda));
3163
3164 coord.assign_kernel(KernelId::new("pre1"), 1);
3166 coord.assign_kernel(KernelId::new("pre2"), 1);
3167 coord.assign_kernel(KernelId::new("pre3"), 1);
3168
3169 let k1 = KernelId::new("migrate_me");
3171 coord.assign_kernel(k1.clone(), 0);
3172
3173 let result = coord.unregister_device(0);
3174
3175 assert!(result.success);
3176 assert_eq!(result.kernels_to_migrate.len(), 1);
3177
3178 let plan = &result.kernels_to_migrate[0];
3180 assert_eq!(plan.target_device, 2);
3181 }
3182
3183 #[test]
3184 fn test_migration_priority_enum() {
3185 let low = MigrationPriority::Low;
3186 let normal = MigrationPriority::Normal;
3187 let high = MigrationPriority::High;
3188 let critical = MigrationPriority::Critical;
3189
3190 assert_ne!(low, normal);
3191 assert_ne!(normal, high);
3192 assert_ne!(high, critical);
3193 assert_eq!(low, MigrationPriority::Low);
3194 }
3195
3196 #[test]
3199 fn test_hot_reload_config_default() {
3200 let config = HotReloadConfig::default();
3201 assert!(config.enabled);
3202 assert!(config.preserve_state);
3203 assert!(config.validate_before_swap);
3204 assert!(config.keep_fallback);
3205 assert_eq!(config.max_retries, 3);
3206 }
3207
3208 #[test]
3209 fn test_hot_reload_config_builder() {
3210 let config = HotReloadConfig::new()
3211 .with_enabled(false)
3212 .with_preserve_state(false)
3213 .with_max_retries(5)
3214 .with_timeout(Duration::from_secs(60));
3215
3216 assert!(!config.enabled);
3217 assert!(!config.preserve_state);
3218 assert_eq!(config.max_retries, 5);
3219 assert_eq!(config.reload_timeout, Duration::from_secs(60));
3220 }
3221
3222 #[test]
3223 fn test_kernel_code_source_ptx() {
3224 let ptx = ".version 7.0\n.target sm_80\nkernel: ret;";
3225 let code = KernelCodeSource::from_ptx(ptx, "kernel");
3226
3227 assert_eq!(code.format, KernelCodeFormat::Ptx);
3228 assert_eq!(code.entry_point, "kernel");
3229 assert_eq!(code.as_str(), Some(ptx));
3230 assert_eq!(code.size(), ptx.len());
3231 }
3232
3233 #[test]
3234 fn test_kernel_code_source_wgsl() {
3235 let wgsl = "@compute fn main() {}";
3236 let code = KernelCodeSource::from_wgsl(wgsl, "main");
3237
3238 assert_eq!(code.format, KernelCodeFormat::Wgsl);
3239 assert_eq!(code.entry_point, "main");
3240 assert_eq!(code.as_str(), Some(wgsl));
3241 }
3242
3243 #[test]
3244 fn test_kernel_code_source_msl() {
3245 let msl = "kernel void my_kernel() {}";
3246 let code = KernelCodeSource::from_msl(msl, "my_kernel");
3247
3248 assert_eq!(code.format, KernelCodeFormat::Msl);
3249 assert_eq!(code.entry_point, "my_kernel");
3250 assert_eq!(code.as_str(), Some(msl));
3251 }
3252
3253 #[test]
3254 fn test_hot_reload_manager_creation() {
3255 let manager = HotReloadManager::with_defaults();
3256 assert!(manager.is_enabled());
3257 assert!(manager.list_kernels().is_empty());
3258 }
3259
3260 #[test]
3261 fn test_hot_reload_manager_register_kernel() {
3262 let manager = HotReloadManager::with_defaults();
3263 let kernel_id = KernelId::new("test_kernel");
3264 let code = KernelCodeSource::from_ptx(".version 7.0", "kernel");
3265
3266 manager.register_kernel(&kernel_id, code);
3267
3268 assert!(manager.is_registered(&kernel_id));
3269 assert!(!manager.is_reload_in_progress(&kernel_id));
3270 assert!(manager.get_current_version(&kernel_id).is_some());
3271 }
3272
3273 #[test]
3274 fn test_hot_reload_request_states() {
3275 let kernel_id = KernelId::new("test");
3276 let code = KernelCodeSource::from_ptx(".version 7.0", "kernel");
3277 let request = HotReloadRequest::new(kernel_id, code);
3278
3279 assert_eq!(request.state, HotReloadState::Idle);
3280 assert!(!request.is_in_progress());
3281 assert!(!request.is_completed());
3282 assert!(!request.is_failed());
3283 }
3284
3285 #[test]
3286 fn test_hot_reload_disabled() {
3287 let config = HotReloadConfig::new().with_enabled(false);
3288 let manager = HotReloadManager::new(config);
3289 let kernel_id = KernelId::new("test");
3290 let code = KernelCodeSource::from_ptx(".version 7.0", "kernel");
3291
3292 manager.register_kernel(&kernel_id, code.clone());
3293 let result = manager.request_reload(&kernel_id, code);
3294 assert!(result.is_err());
3295 }
3296
3297 #[test]
3298 fn test_hot_reload_stats() {
3299 let manager = HotReloadManager::with_defaults();
3300 let stats = manager.stats();
3301
3302 assert_eq!(stats.successful_reloads, 0);
3303 assert_eq!(stats.failed_reloads, 0);
3304 assert_eq!(stats.rollbacks, 0);
3305 }
3306
3307 #[test]
3308 fn test_hot_reload_code_formats() {
3309 let formats = [
3310 KernelCodeFormat::Ptx,
3311 KernelCodeFormat::Cubin,
3312 KernelCodeFormat::SpirV,
3313 KernelCodeFormat::Wgsl,
3314 KernelCodeFormat::Msl,
3315 KernelCodeFormat::MetalLib,
3316 KernelCodeFormat::Source,
3317 ];
3318
3319 for (i, f1) in formats.iter().enumerate() {
3321 for (j, f2) in formats.iter().enumerate() {
3322 if i != j {
3323 assert_ne!(f1, f2);
3324 }
3325 }
3326 }
3327 }
3328
3329 #[test]
3330 fn test_hot_reload_state_transitions() {
3331 let states = [
3332 HotReloadState::Idle,
3333 HotReloadState::Draining,
3334 HotReloadState::Checkpointing,
3335 HotReloadState::Compiling,
3336 HotReloadState::Validating,
3337 HotReloadState::Swapping,
3338 HotReloadState::Restoring,
3339 HotReloadState::Completed,
3340 HotReloadState::Failed,
3341 HotReloadState::RollingBack,
3342 ];
3343
3344 for (i, s1) in states.iter().enumerate() {
3346 for (j, s2) in states.iter().enumerate() {
3347 if i != j {
3348 assert_ne!(s1, s2);
3349 }
3350 }
3351 }
3352 }
3353
3354 #[test]
3355 fn test_hot_reload_execute() {
3356 let manager = HotReloadManager::with_defaults();
3357 let kernel_id = KernelId::new("test_kernel");
3358
3359 let initial_code = KernelCodeSource::from_ptx(".version 7.0\n.target sm_80", "kernel");
3360 manager.register_kernel(&kernel_id, initial_code);
3361
3362 let new_code = KernelCodeSource::from_ptx(".version 8.0\n.target sm_90", "kernel");
3363 let mut request = manager.request_reload(&kernel_id, new_code).unwrap();
3364
3365 let mock_kernel = MockCheckpointableKernel::new("test_kernel", 512);
3367
3368 let result = manager.execute_reload(&mut request, &mock_kernel).unwrap();
3369
3370 assert!(request.is_completed());
3371 assert_eq!(result.kernel_id.as_str(), "test_kernel");
3372 assert!(result.state_preserved);
3373 assert!(result.checkpoint_size > 0);
3374 assert!(result.total_duration > Duration::ZERO);
3375
3376 let stats = manager.stats();
3378 assert_eq!(stats.successful_reloads, 1);
3379 }
3380
3381 #[test]
3382 fn test_hot_reload_list_kernels() {
3383 let manager = HotReloadManager::with_defaults();
3384
3385 let k1 = KernelId::new("kernel1");
3386 let k2 = KernelId::new("kernel2");
3387 let k3 = KernelId::new("kernel3");
3388
3389 manager.register_kernel(&k1, KernelCodeSource::from_ptx(".version 7.0", "k1"));
3390 manager.register_kernel(&k2, KernelCodeSource::from_ptx(".version 7.0", "k2"));
3391 manager.register_kernel(&k3, KernelCodeSource::from_ptx(".version 7.0", "k3"));
3392
3393 let kernels = manager.list_kernels();
3394 assert_eq!(kernels.len(), 3);
3395 assert!(kernels.contains(&k1));
3396 assert!(kernels.contains(&k2));
3397 assert!(kernels.contains(&k3));
3398 }
3399}