Skip to main content

oxicuda_driver/
nvlink_topology.rs

1//! NVLink/NVSwitch topology-aware communication.
2//!
3//! This module provides topology discovery and routing primitives for
4//! multi-GPU systems connected via NVLink, NVSwitch, or PCIe.  It enables
5//! topology-aware collective communication scheduling (ring AllReduce,
6//! tree broadcast/reduce) and task placement that minimises inter-GPU
7//! communication cost.
8//!
9//! On macOS, where NVIDIA GPUs are not available, this module returns
10//! synthetic topology data for a 4-GPU NVLink mesh so that algorithm
11//! logic can be tested without hardware.
12//!
13//! # Example
14//!
15//! ```rust,no_run
16//! use oxicuda_driver::nvlink_topology::GpuTopology;
17//!
18//! let topo = GpuTopology::discover()?;
19//! println!("topology type: {:?}", topo.topology_type());
20//! let ring = topo.optimal_ring_order()?;
21//! println!("ring order: {ring:?}");
22//! # Ok::<(), oxicuda_driver::error::CudaError>(())
23//! ```
24
25use std::cmp::Reverse;
26use std::collections::{BinaryHeap, HashMap};
27
28use crate::error::{CudaError, CudaResult};
29
30// ---------------------------------------------------------------------------
31// NVLink version & status enums
32// ---------------------------------------------------------------------------
33
34/// NVLink generation/version.
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
36pub enum NvLinkVersion {
37    /// NVLink 1.0 — Pascal architecture (P100), 20 GB/s per link.
38    V1,
39    /// NVLink 2.0 — Volta architecture (V100), 25 GB/s per link.
40    V2,
41    /// NVLink 3.0 — Ampere architecture (A100), 25 GB/s per link (wider).
42    V3,
43    /// NVLink 4.0 — Hopper architecture (H100), 25 GB/s per sub-link.
44    V4,
45    /// NVSwitch-mediated all-to-all fabric.
46    NvSwitch,
47}
48
49impl NvLinkVersion {
50    /// Per-link bandwidth in GB/s for this NVLink generation.
51    #[inline]
52    pub fn per_link_bandwidth_gbps(self) -> f64 {
53        match self {
54            Self::V1 => 20.0,
55            Self::V2 => 25.0,
56            Self::V3 => 25.0,
57            Self::V4 => 25.0,
58            Self::NvSwitch => 25.0,
59        }
60    }
61}
62
63/// Status of an NVLink connection between two devices.
64#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
65pub enum NvLinkStatus {
66    /// Link is active and usable.
67    Active,
68    /// Link exists but is currently inactive.
69    Inactive,
70    /// NVLink not supported between these devices.
71    Unsupported,
72}
73
74// ---------------------------------------------------------------------------
75// NVLink info per peer
76// ---------------------------------------------------------------------------
77
78/// Information about the NVLink connections to a specific peer device.
79#[derive(Debug, Clone)]
80pub struct NvLinkInfo {
81    /// NVLink generation.
82    pub version: NvLinkVersion,
83    /// Aggregate bidirectional bandwidth in GB/s across all active links.
84    pub bandwidth_gbps: f64,
85    /// Number of active NVLink connections.
86    pub link_count: u32,
87    /// Ordinal of the peer device.
88    pub peer_device_id: i32,
89}
90
91// ---------------------------------------------------------------------------
92// Link type between devices
93// ---------------------------------------------------------------------------
94
95/// The physical interconnect type between two devices.
96#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
97pub enum LinkType {
98    /// Direct NVLink connection.
99    NvLink,
100    /// PCIe connection (possibly through a switch).
101    PCIe,
102    /// NVSwitch-mediated connection (full bisection bandwidth).
103    NvSwitch,
104}
105
106// ---------------------------------------------------------------------------
107// Topology link
108// ---------------------------------------------------------------------------
109
110/// A directed link between two GPU devices in the topology graph.
111#[derive(Debug, Clone)]
112pub struct TopologyLink {
113    /// Source device ordinal.
114    pub from_device: i32,
115    /// Destination device ordinal.
116    pub to_device: i32,
117    /// Physical interconnect type.
118    pub link_type: LinkType,
119    /// Aggregate bandwidth in GB/s.
120    pub bandwidth_gbps: f64,
121    /// Estimated one-way latency in nanoseconds.
122    pub latency_ns: f64,
123    /// Number of hops (1 for direct NVLink, 1 for NVSwitch, 2+ for PCIe).
124    pub hop_count: u32,
125}
126
127// ---------------------------------------------------------------------------
128// Topology classification
129// ---------------------------------------------------------------------------
130
131/// High-level classification of the detected GPU topology.
132#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
133pub enum TopologyType {
134    /// Only one GPU detected.
135    SingleGpu,
136    /// Two GPUs connected by NVLink.
137    NvLinkPair,
138    /// GPUs form a ring via NVLink.
139    NvLinkRing,
140    /// All GPUs connected to all others via NVLink (full mesh).
141    NvLinkMesh,
142    /// NVSwitch fabric connecting all GPUs.
143    NvSwitchFabric,
144    /// GPUs connected only via PCIe.
145    PcieOnly,
146}
147
148// ---------------------------------------------------------------------------
149// Topology tree (for broadcast / reduce)
150// ---------------------------------------------------------------------------
151
152/// A tree structure rooted at a single device, used for broadcast/reduce
153/// collectives.
154#[derive(Debug, Clone)]
155pub struct TopologyTree {
156    /// Root device ordinal.
157    pub root: i32,
158    /// Maps each device to its list of child devices.
159    pub children: HashMap<i32, Vec<i32>>,
160    /// Maps each device to its parent (root has no entry).
161    pub parent: HashMap<i32, i32>,
162}
163
164impl TopologyTree {
165    /// Returns all devices in the tree.
166    pub fn devices(&self) -> Vec<i32> {
167        let mut devs: Vec<i32> = self.children.keys().copied().collect();
168        devs.sort_unstable();
169        devs
170    }
171
172    /// Returns the depth of the tree.
173    pub fn depth(&self) -> usize {
174        fn walk_depth(node: i32, children: &HashMap<i32, Vec<i32>>) -> usize {
175            match children.get(&node) {
176                Some(kids) if !kids.is_empty() => {
177                    1 + kids
178                        .iter()
179                        .map(|&c| walk_depth(c, children))
180                        .max()
181                        .unwrap_or(0)
182                }
183                _ => 1,
184            }
185        }
186        walk_depth(self.root, &self.children)
187    }
188}
189
190// ---------------------------------------------------------------------------
191// Communication schedule
192// ---------------------------------------------------------------------------
193
194/// A single data transfer in a communication schedule.
195#[derive(Debug, Clone)]
196pub struct Transfer {
197    /// Source device ordinal.
198    pub src: i32,
199    /// Destination device ordinal.
200    pub dst: i32,
201    /// Data size in bytes.
202    pub data_size: u64,
203}
204
205/// An ordered list of transfers optimised for the topology.
206#[derive(Debug, Clone)]
207pub struct CommunicationSchedule {
208    /// Ordered transfers.
209    pub transfers: Vec<Transfer>,
210    /// Estimated total time in microseconds.
211    pub estimated_time_us: f64,
212}
213
214// ---------------------------------------------------------------------------
215// Task placement
216// ---------------------------------------------------------------------------
217
218/// A communication demand between two logical tasks.
219#[derive(Debug, Clone)]
220pub struct TaskCommunication {
221    /// First task index.
222    pub task_a: usize,
223    /// Second task index.
224    pub task_b: usize,
225    /// Communication volume in bytes.
226    pub volume_bytes: u64,
227}
228
229/// Result of topology-aware task placement.
230#[derive(Debug, Clone)]
231pub struct TopologyAwarePlacement {
232    /// Maps task index to device ordinal.
233    pub assignment: HashMap<usize, i32>,
234    /// Estimated total communication cost (lower is better).
235    pub total_cost: f64,
236}
237
238// ---------------------------------------------------------------------------
239// GpuTopology
240// ---------------------------------------------------------------------------
241
242/// Complete GPU topology graph with adjacency information.
243///
244/// Holds all devices, inter-device links, and a dense adjacency matrix for
245/// fast bandwidth/latency lookups.
246#[derive(Debug, Clone)]
247pub struct GpuTopology {
248    /// Device ordinals present in the topology.
249    pub devices: Vec<i32>,
250    /// All directed links between devices.
251    pub links: Vec<TopologyLink>,
252    /// Dense adjacency matrix: `adj[i][j]` = bandwidth in GB/s from
253    /// device `devices[i]` to device `devices[j]`.  Diagonal is `f64::INFINITY`.
254    adj_bandwidth: Vec<Vec<f64>>,
255    /// Dense latency matrix in nanoseconds.
256    adj_latency: Vec<Vec<f64>>,
257}
258
259impl GpuTopology {
260    // -- Discovery ----------------------------------------------------------
261
262    /// Discover the GPU topology of the current system.
263    ///
264    /// On Linux/Windows with NVIDIA drivers, this queries the driver for
265    /// device count and peer-access capabilities.  On macOS, a synthetic
266    /// 4-GPU NVLink mesh is returned for testing purposes.
267    ///
268    /// # Errors
269    ///
270    /// Returns [`CudaError::NoDevice`] if no devices are found.
271    pub fn discover() -> CudaResult<Self> {
272        #[cfg(target_os = "macos")]
273        {
274            Self::synthetic_mesh(4)
275        }
276        #[cfg(not(target_os = "macos"))]
277        {
278            Self::discover_real()
279        }
280    }
281
282    /// Build a synthetic N-GPU NVLink mesh (used on macOS and in tests).
283    #[cfg(any(target_os = "macos", test))]
284    fn synthetic_mesh(n: usize) -> CudaResult<Self> {
285        if n == 0 {
286            return Err(CudaError::NoDevice);
287        }
288        let devices: Vec<i32> = (0..n as i32).collect();
289        let mut links = Vec::new();
290        let mut adj_bandwidth = vec![vec![0.0; n]; n];
291        let mut adj_latency = vec![vec![f64::MAX; n]; n];
292
293        for i in 0..n {
294            adj_bandwidth[i][i] = f64::INFINITY;
295            adj_latency[i][i] = 0.0;
296        }
297
298        for i in 0..n {
299            for j in 0..n {
300                if i == j {
301                    continue;
302                }
303                let bw = 100.0; // Synthetic 100 GB/s NVLink
304                let lat = 500.0; // 500 ns latency
305                links.push(TopologyLink {
306                    from_device: i as i32,
307                    to_device: j as i32,
308                    link_type: LinkType::NvLink,
309                    bandwidth_gbps: bw,
310                    latency_ns: lat,
311                    hop_count: 1,
312                });
313                adj_bandwidth[i][j] = bw;
314                adj_latency[i][j] = lat;
315            }
316        }
317
318        Ok(Self {
319            devices,
320            links,
321            adj_bandwidth,
322            adj_latency,
323        })
324    }
325
326    /// Real topology discovery using the CUDA driver.
327    #[cfg(not(target_os = "macos"))]
328    fn discover_real() -> CudaResult<Self> {
329        use crate::device::Device;
330
331        crate::init()?;
332        let count = Device::count()?;
333        if count <= 0 {
334            return Err(CudaError::NoDevice);
335        }
336        let n = count as usize;
337        let devices: Vec<i32> = (0..count).collect();
338        let mut links = Vec::new();
339        let mut adj_bandwidth = vec![vec![0.0; n]; n];
340        let mut adj_latency = vec![vec![f64::MAX; n]; n];
341
342        for i in 0..n {
343            adj_bandwidth[i][i] = f64::INFINITY;
344            adj_latency[i][i] = 0.0;
345        }
346
347        for i in 0..n {
348            for j in 0..n {
349                if i == j {
350                    continue;
351                }
352                // Check peer access to determine link type
353                let can_peer = crate::device::can_access_peer(
354                    &Device::get(i as i32)?,
355                    &Device::get(j as i32)?,
356                )?;
357
358                let (lt, bw, lat, hops) = if can_peer {
359                    // Peer access available — likely NVLink
360                    (LinkType::NvLink, 50.0, 800.0, 1)
361                } else {
362                    // Fallback to PCIe
363                    (LinkType::PCIe, 16.0, 2000.0, 2)
364                };
365
366                links.push(TopologyLink {
367                    from_device: i as i32,
368                    to_device: j as i32,
369                    link_type: lt,
370                    bandwidth_gbps: bw,
371                    latency_ns: lat,
372                    hop_count: hops,
373                });
374                adj_bandwidth[i][j] = bw;
375                adj_latency[i][j] = lat;
376            }
377        }
378
379        Ok(Self {
380            devices,
381            links,
382            adj_bandwidth,
383            adj_latency,
384        })
385    }
386
387    // -- Topology classification --------------------------------------------
388
389    /// Classify the topology into a high-level category.
390    pub fn topology_type(&self) -> TopologyType {
391        let n = self.devices.len();
392        if n == 0 {
393            return TopologyType::SingleGpu;
394        }
395        if n == 1 {
396            return TopologyType::SingleGpu;
397        }
398
399        let has_nvswitch = self.links.iter().any(|l| l.link_type == LinkType::NvSwitch);
400        if has_nvswitch {
401            return TopologyType::NvSwitchFabric;
402        }
403
404        let nvlink_links: Vec<&TopologyLink> = self
405            .links
406            .iter()
407            .filter(|l| l.link_type == LinkType::NvLink)
408            .collect();
409
410        if nvlink_links.is_empty() {
411            return TopologyType::PcieOnly;
412        }
413
414        // Count NVLink neighbors per device
415        let mut nvlink_neighbors: HashMap<i32, Vec<i32>> = HashMap::new();
416        for link in &nvlink_links {
417            nvlink_neighbors
418                .entry(link.from_device)
419                .or_default()
420                .push(link.to_device);
421        }
422
423        // Check if all devices have NVLink neighbors
424        let all_have_nvlink = self
425            .devices
426            .iter()
427            .all(|d| nvlink_neighbors.contains_key(d));
428
429        if !all_have_nvlink {
430            // Some devices only on PCIe
431            if n == 2 {
432                return TopologyType::NvLinkPair;
433            }
434            return TopologyType::PcieOnly;
435        }
436
437        // Check for full mesh: every device connected to every other
438        let is_full_mesh = self.devices.iter().all(|d| {
439            let neighbors = nvlink_neighbors.get(d).map(|v| v.len()).unwrap_or(0);
440            neighbors == n - 1
441        });
442
443        if is_full_mesh {
444            return TopologyType::NvLinkMesh;
445        }
446
447        if n == 2 {
448            return TopologyType::NvLinkPair;
449        }
450
451        // Check for ring: every device has exactly 2 NVLink neighbors
452        // and they form a single cycle
453        let is_ring = self
454            .devices
455            .iter()
456            .all(|d| nvlink_neighbors.get(d).map(|v| v.len()).unwrap_or(0) == 2);
457
458        if is_ring && self.verify_ring(&nvlink_neighbors) {
459            return TopologyType::NvLinkRing;
460        }
461
462        // Default: partial NVLink connectivity
463        TopologyType::NvLinkMesh
464    }
465
466    /// Verify that the given adjacency forms a single Hamiltonian cycle.
467    fn verify_ring(&self, neighbors: &HashMap<i32, Vec<i32>>) -> bool {
468        if self.devices.is_empty() {
469            return false;
470        }
471        let start = self.devices[0];
472        let mut visited = vec![false; self.devices.len()];
473        let mut current = start;
474        let mut prev = -1_i32;
475
476        for step in 0..self.devices.len() {
477            let idx = match self.device_index(current) {
478                Some(i) => i,
479                None => return false,
480            };
481            if visited[idx] {
482                return false;
483            }
484            visited[idx] = true;
485
486            let nbrs = match neighbors.get(&current) {
487                Some(v) => v,
488                None => return false,
489            };
490
491            if step < self.devices.len() - 1 {
492                // Move to the neighbor that isn't `prev`
493                let next = nbrs.iter().find(|&&n| n != prev);
494                match next {
495                    Some(&n) => {
496                        prev = current;
497                        current = n;
498                    }
499                    None => return false,
500                }
501            } else {
502                // Last step: must connect back to start
503                return nbrs.contains(&start);
504            }
505        }
506        false
507    }
508
509    // -- Path finding -------------------------------------------------------
510
511    /// Find the fastest path between two devices (maximising bandwidth).
512    ///
513    /// Uses Dijkstra on inverse-bandwidth weights so that the path with the
514    /// highest minimum-bandwidth bottleneck is found.
515    ///
516    /// # Errors
517    ///
518    /// Returns [`CudaError::InvalidDevice`] if either device is not in
519    /// the topology.
520    pub fn best_path(&self, src: i32, dst: i32) -> CudaResult<Vec<i32>> {
521        let src_idx = self.device_index(src).ok_or(CudaError::InvalidDevice)?;
522        let dst_idx = self.device_index(dst).ok_or(CudaError::InvalidDevice)?;
523
524        if src_idx == dst_idx {
525            return Ok(vec![src]);
526        }
527
528        let n = self.devices.len();
529        // Use inverse bandwidth as weight for Dijkstra
530        let mut dist = vec![f64::INFINITY; n];
531        let mut prev: Vec<Option<usize>> = vec![None; n];
532        dist[src_idx] = 0.0;
533
534        let mut heap: BinaryHeap<Reverse<(OrderedF64, usize)>> = BinaryHeap::new();
535        heap.push(Reverse((OrderedF64(0.0), src_idx)));
536
537        while let Some(Reverse((OrderedF64(cost), u))) = heap.pop() {
538            if u == dst_idx {
539                break;
540            }
541            if cost > dist[u] {
542                continue;
543            }
544            for v in 0..n {
545                if u == v {
546                    continue;
547                }
548                let bw = self.adj_bandwidth[u][v];
549                if bw <= 0.0 {
550                    continue;
551                }
552                let edge_cost = 1.0 / bw;
553                let new_dist = dist[u] + edge_cost;
554                if new_dist < dist[v] {
555                    dist[v] = new_dist;
556                    prev[v] = Some(u);
557                    heap.push(Reverse((OrderedF64(new_dist), v)));
558                }
559            }
560        }
561
562        // Reconstruct path
563        if prev[dst_idx].is_none() {
564            return Err(CudaError::InvalidValue);
565        }
566
567        let mut path = Vec::new();
568        let mut cur = dst_idx;
569        while let Some(p) = prev[cur] {
570            path.push(self.devices[cur]);
571            cur = p;
572        }
573        path.push(self.devices[src_idx]);
574        path.reverse();
575        Ok(path)
576    }
577
578    // -- Bandwidth & latency queries ----------------------------------------
579
580    /// Returns the direct bandwidth between two devices in GB/s.
581    ///
582    /// Returns 0.0 if the devices are not directly connected, and
583    /// `f64::INFINITY` for same-device queries.
584    pub fn bandwidth_between(&self, src: i32, dst: i32) -> f64 {
585        let src_idx = match self.device_index(src) {
586            Some(i) => i,
587            None => return 0.0,
588        };
589        let dst_idx = match self.device_index(dst) {
590            Some(i) => i,
591            None => return 0.0,
592        };
593        self.adj_bandwidth[src_idx][dst_idx]
594    }
595
596    /// Returns the estimated latency between two devices in microseconds.
597    ///
598    /// Returns 0.0 for same-device and `f64::MAX` (converted to us) if
599    /// no connection exists.
600    pub fn latency_between(&self, src: i32, dst: i32) -> f64 {
601        let src_idx = match self.device_index(src) {
602            Some(i) => i,
603            None => return f64::MAX / 1000.0,
604        };
605        let dst_idx = match self.device_index(dst) {
606            Some(i) => i,
607            None => return f64::MAX / 1000.0,
608        };
609        self.adj_latency[src_idx][dst_idx] / 1000.0 // ns -> us
610    }
611
612    // -- Ring order for AllReduce -------------------------------------------
613
614    /// Find an optimal ring ordering for AllReduce collective.
615    ///
616    /// Uses a greedy nearest-bandwidth-neighbor heuristic: starting from
617    /// device 0, always pick the unvisited neighbor with the highest
618    /// bandwidth.
619    ///
620    /// # Errors
621    ///
622    /// Returns [`CudaError::NoDevice`] if the topology has no devices.
623    pub fn optimal_ring_order(&self) -> CudaResult<Vec<i32>> {
624        let n = self.devices.len();
625        if n == 0 {
626            return Err(CudaError::NoDevice);
627        }
628        if n == 1 {
629            return Ok(vec![self.devices[0]]);
630        }
631
632        let mut visited = vec![false; n];
633        let mut ring = Vec::with_capacity(n);
634
635        // Start from device 0
636        let mut current = 0_usize;
637        visited[current] = true;
638        ring.push(self.devices[current]);
639
640        for _ in 1..n {
641            // Pick unvisited neighbor with highest bandwidth
642            let mut best_bw = -1.0_f64;
643            let mut best_idx = None;
644            for (j, &is_visited) in visited.iter().enumerate() {
645                if is_visited {
646                    continue;
647                }
648                let bw = self.adj_bandwidth[current][j];
649                if bw > best_bw {
650                    best_bw = bw;
651                    best_idx = Some(j);
652                }
653            }
654            match best_idx {
655                Some(idx) => {
656                    visited[idx] = true;
657                    ring.push(self.devices[idx]);
658                    current = idx;
659                }
660                None => {
661                    // No reachable unvisited device — shouldn't happen in connected graph
662                    return Err(CudaError::InvalidValue);
663                }
664            }
665        }
666
667        Ok(ring)
668    }
669
670    // -- Optimal tree for broadcast/reduce ----------------------------------
671
672    /// Build an optimal spanning tree for broadcast/reduce collectives.
673    ///
674    /// Constructs a maximum-bandwidth spanning tree rooted at the device
675    /// with the best aggregate bandwidth to all others (Prim-style).
676    ///
677    /// # Errors
678    ///
679    /// Returns [`CudaError::NoDevice`] if the topology has no devices.
680    pub fn optimal_tree(&self) -> CudaResult<TopologyTree> {
681        let n = self.devices.len();
682        if n == 0 {
683            return Err(CudaError::NoDevice);
684        }
685
686        // Pick root: device with highest total outgoing bandwidth
687        let root_idx = (0..n)
688            .max_by(|&a, &b| {
689                let sum_a: f64 = (0..n)
690                    .filter(|&j| j != a)
691                    .map(|j| {
692                        let bw = self.adj_bandwidth[a][j];
693                        if bw.is_infinite() { 0.0 } else { bw }
694                    })
695                    .sum();
696                let sum_b: f64 = (0..n)
697                    .filter(|&j| j != b)
698                    .map(|j| {
699                        let bw = self.adj_bandwidth[b][j];
700                        if bw.is_infinite() { 0.0 } else { bw }
701                    })
702                    .sum();
703                sum_a
704                    .partial_cmp(&sum_b)
705                    .unwrap_or(std::cmp::Ordering::Equal)
706            })
707            .unwrap_or(0);
708
709        let root = self.devices[root_idx];
710        let mut children: HashMap<i32, Vec<i32>> = HashMap::new();
711        let mut parent: HashMap<i32, i32> = HashMap::new();
712
713        // Initialize children for all devices
714        for &d in &self.devices {
715            children.insert(d, Vec::new());
716        }
717
718        if n == 1 {
719            return Ok(TopologyTree {
720                root,
721                children,
722                parent,
723            });
724        }
725
726        // Prim's algorithm for maximum spanning tree
727        let mut in_tree = vec![false; n];
728        let mut best_edge: Vec<(f64, Option<usize>)> = vec![(0.0, None); n];
729        in_tree[root_idx] = true;
730
731        // Initialize edges from root
732        for (j, edge) in best_edge.iter_mut().enumerate() {
733            if j != root_idx {
734                let bw = self.adj_bandwidth[root_idx][j];
735                let bw_val = if bw.is_infinite() { 0.0 } else { bw };
736                *edge = (bw_val, Some(root_idx));
737            }
738        }
739
740        for _ in 1..n {
741            // Find the not-in-tree node with maximum bandwidth edge
742            let mut best_bw = -1.0_f64;
743            let mut best_node = None;
744            for j in 0..n {
745                if !in_tree[j] && best_edge[j].0 > best_bw {
746                    best_bw = best_edge[j].0;
747                    best_node = Some(j);
748                }
749            }
750
751            let node = match best_node {
752                Some(j) => j,
753                None => break,
754            };
755
756            in_tree[node] = true;
757            let parent_idx = match best_edge[node].1 {
758                Some(p) => p,
759                None => continue,
760            };
761
762            let parent_dev = self.devices[parent_idx];
763            let child_dev = self.devices[node];
764            parent.insert(child_dev, parent_dev);
765            children.entry(parent_dev).or_default().push(child_dev);
766
767            // Update best edges for remaining nodes
768            for j in 0..n {
769                if !in_tree[j] {
770                    let bw = self.adj_bandwidth[node][j];
771                    let bw_val = if bw.is_infinite() { 0.0 } else { bw };
772                    if bw_val > best_edge[j].0 {
773                        best_edge[j] = (bw_val, Some(node));
774                    }
775                }
776            }
777        }
778
779        Ok(TopologyTree {
780            root,
781            children,
782            parent,
783        })
784    }
785
786    // -- Task placement -----------------------------------------------------
787
788    /// Assign tasks to devices minimising total communication cost.
789    ///
790    /// For small task counts (<=8), tries all permutations. For larger counts,
791    /// uses a greedy assignment heuristic.
792    ///
793    /// # Errors
794    ///
795    /// Returns [`CudaError::InvalidValue`] if there are more tasks than devices,
796    /// or if the topology has no devices.
797    pub fn optimal_placement(
798        &self,
799        num_tasks: usize,
800        communications: &[TaskCommunication],
801    ) -> CudaResult<TopologyAwarePlacement> {
802        let n = self.devices.len();
803        if num_tasks == 0 || n == 0 {
804            return Err(CudaError::InvalidValue);
805        }
806        if num_tasks > n {
807            return Err(CudaError::InvalidValue);
808        }
809
810        if num_tasks <= 8 && n <= 8 {
811            self.placement_exhaustive(num_tasks, communications)
812        } else {
813            self.placement_greedy(num_tasks, communications)
814        }
815    }
816
817    /// Exhaustive (permutation-based) placement for small instances.
818    fn placement_exhaustive(
819        &self,
820        num_tasks: usize,
821        communications: &[TaskCommunication],
822    ) -> CudaResult<TopologyAwarePlacement> {
823        let n = self.devices.len();
824        let mut best_cost = f64::INFINITY;
825        let mut best_assignment: HashMap<usize, i32> = HashMap::new();
826
827        // Generate permutations of device indices (up to num_tasks)
828        let mut perm: Vec<usize> = (0..n).collect();
829        let mut found = false;
830
831        // Heap's algorithm for permutations
832        Self::for_each_permutation(&mut perm, n, &mut |p| {
833            let assignment: HashMap<usize, i32> =
834                (0..num_tasks).map(|t| (t, self.devices[p[t]])).collect();
835            let cost = Self::compute_placement_cost(
836                &assignment,
837                communications,
838                &self.adj_bandwidth,
839                &self.devices,
840            );
841            if cost < best_cost {
842                best_cost = cost;
843                best_assignment = assignment;
844                found = true;
845            }
846        });
847
848        if !found {
849            // Fallback: identity mapping
850            best_assignment = (0..num_tasks).map(|t| (t, self.devices[t])).collect();
851            best_cost = Self::compute_placement_cost(
852                &best_assignment,
853                communications,
854                &self.adj_bandwidth,
855                &self.devices,
856            );
857        }
858
859        Ok(TopologyAwarePlacement {
860            assignment: best_assignment,
861            total_cost: best_cost,
862        })
863    }
864
865    /// Heap's algorithm to iterate over all permutations.
866    fn for_each_permutation(arr: &mut [usize], k: usize, callback: &mut impl FnMut(&[usize])) {
867        if k == 1 {
868            callback(arr);
869            return;
870        }
871        Self::for_each_permutation(arr, k - 1, callback);
872        for i in 0..k - 1 {
873            if k % 2 == 0 {
874                arr.swap(i, k - 1);
875            } else {
876                arr.swap(0, k - 1);
877            }
878            Self::for_each_permutation(arr, k - 1, callback);
879        }
880    }
881
882    /// Greedy placement for larger instances.
883    fn placement_greedy(
884        &self,
885        num_tasks: usize,
886        communications: &[TaskCommunication],
887    ) -> CudaResult<TopologyAwarePlacement> {
888        // Sort tasks by total communication volume (descending)
889        let mut task_volume: Vec<(usize, u64)> = (0..num_tasks).map(|t| (t, 0u64)).collect();
890        for comm in communications {
891            if comm.task_a < num_tasks {
892                task_volume[comm.task_a].1 += comm.volume_bytes;
893            }
894            if comm.task_b < num_tasks {
895                task_volume[comm.task_b].1 += comm.volume_bytes;
896            }
897        }
898        task_volume.sort_by_key(|entry| std::cmp::Reverse(entry.1));
899
900        let mut assignment: HashMap<usize, i32> = HashMap::new();
901        let mut used_devices: Vec<bool> = vec![false; self.devices.len()];
902
903        for &(task, _) in &task_volume {
904            // Find best device for this task given current assignments
905            let mut best_cost = f64::INFINITY;
906            let mut best_dev_idx = None;
907
908            for (dev_idx, &used) in used_devices.iter().enumerate() {
909                if used {
910                    continue;
911                }
912                // Compute cost of placing this task on this device
913                let mut cost = 0.0_f64;
914                for comm in communications {
915                    let (peer_task, is_relevant) = if comm.task_a == task {
916                        (comm.task_b, true)
917                    } else if comm.task_b == task {
918                        (comm.task_a, true)
919                    } else {
920                        (0, false)
921                    };
922                    if !is_relevant {
923                        continue;
924                    }
925                    if let Some(&peer_dev) = assignment.get(&peer_task) {
926                        let peer_dev_idx = match self.device_index(peer_dev) {
927                            Some(i) => i,
928                            None => continue,
929                        };
930                        let bw = self.adj_bandwidth[dev_idx][peer_dev_idx];
931                        let bw_val = if bw.is_infinite() || bw <= 0.0 {
932                            1e-9
933                        } else {
934                            bw
935                        };
936                        cost += comm.volume_bytes as f64 / bw_val;
937                    }
938                }
939                if cost < best_cost {
940                    best_cost = cost;
941                    best_dev_idx = Some(dev_idx);
942                }
943            }
944
945            let dev_idx =
946                best_dev_idx.unwrap_or_else(|| used_devices.iter().position(|&u| !u).unwrap_or(0));
947            used_devices[dev_idx] = true;
948            assignment.insert(task, self.devices[dev_idx]);
949        }
950
951        let total_cost = Self::compute_placement_cost(
952            &assignment,
953            communications,
954            &self.adj_bandwidth,
955            &self.devices,
956        );
957
958        Ok(TopologyAwarePlacement {
959            assignment,
960            total_cost,
961        })
962    }
963
964    /// Compute total communication cost for a given placement.
965    fn compute_placement_cost(
966        assignment: &HashMap<usize, i32>,
967        communications: &[TaskCommunication],
968        adj_bandwidth: &[Vec<f64>],
969        devices: &[i32],
970    ) -> f64 {
971        let dev_to_idx: HashMap<i32, usize> =
972            devices.iter().enumerate().map(|(i, &d)| (d, i)).collect();
973
974        let mut cost = 0.0_f64;
975        for comm in communications {
976            let dev_a = match assignment.get(&comm.task_a) {
977                Some(&d) => d,
978                None => continue,
979            };
980            let dev_b = match assignment.get(&comm.task_b) {
981                Some(&d) => d,
982                None => continue,
983            };
984            if dev_a == dev_b {
985                continue; // Same device, no cost
986            }
987            let idx_a = match dev_to_idx.get(&dev_a) {
988                Some(&i) => i,
989                None => continue,
990            };
991            let idx_b = match dev_to_idx.get(&dev_b) {
992                Some(&i) => i,
993                None => continue,
994            };
995            let bw = adj_bandwidth[idx_a][idx_b];
996            let bw_val = if bw.is_infinite() || bw <= 0.0 {
997                1e-9
998            } else {
999                bw
1000            };
1001            cost += comm.volume_bytes as f64 / bw_val;
1002        }
1003        cost
1004    }
1005
1006    // -- Communication schedule ---------------------------------------------
1007
1008    /// Build an optimised communication schedule for a set of transfers.
1009    ///
1010    /// Sorts transfers by bandwidth (highest-bandwidth links first) so that
1011    /// high-bandwidth transfers are scheduled early.
1012    pub fn build_schedule(&self, transfers: &[(i32, i32, u64)]) -> CommunicationSchedule {
1013        let mut entries: Vec<(f64, Transfer)> = transfers
1014            .iter()
1015            .map(|&(src, dst, size)| {
1016                let bw = self.bandwidth_between(src, dst);
1017                let bw_val = if bw.is_infinite() || bw <= 0.0 {
1018                    1e-9
1019                } else {
1020                    bw
1021                };
1022                (
1023                    bw_val,
1024                    Transfer {
1025                        src,
1026                        dst,
1027                        data_size: size,
1028                    },
1029                )
1030            })
1031            .collect();
1032
1033        // Sort by bandwidth descending (schedule high-bandwidth transfers first)
1034        entries.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
1035
1036        let mut total_time_us = 0.0_f64;
1037        let transfers_out: Vec<Transfer> = entries
1038            .into_iter()
1039            .map(|(bw, t)| {
1040                // time = size_bytes / (bw_gbps * 1e9) * 1e6 = size / (bw * 1000)
1041                total_time_us += t.data_size as f64 / (bw * 1000.0);
1042                t
1043            })
1044            .collect();
1045
1046        CommunicationSchedule {
1047            transfers: transfers_out,
1048            estimated_time_us: total_time_us,
1049        }
1050    }
1051
1052    // -- NVLink info query --------------------------------------------------
1053
1054    /// Query NVLink information for a specific device pair.
1055    pub fn nvlink_info(&self, device: i32, peer: i32) -> Option<NvLinkInfo> {
1056        let link = self
1057            .links
1058            .iter()
1059            .find(|l| l.from_device == device && l.to_device == peer)?;
1060
1061        if link.link_type == LinkType::PCIe {
1062            return None;
1063        }
1064
1065        let version = match link.link_type {
1066            LinkType::NvSwitch => NvLinkVersion::NvSwitch,
1067            _ => NvLinkVersion::V3, // Default assumption
1068        };
1069
1070        Some(NvLinkInfo {
1071            version,
1072            bandwidth_gbps: link.bandwidth_gbps,
1073            link_count: if link.bandwidth_gbps > 0.0 {
1074                (link.bandwidth_gbps / version.per_link_bandwidth_gbps()).ceil() as u32
1075            } else {
1076                0
1077            },
1078            peer_device_id: peer,
1079        })
1080    }
1081
1082    // -- Helpers ------------------------------------------------------------
1083
1084    /// Find the index of a device ordinal in `self.devices`.
1085    fn device_index(&self, device_id: i32) -> Option<usize> {
1086        self.devices.iter().position(|&d| d == device_id)
1087    }
1088
1089    /// Return the adjacency bandwidth matrix (read-only).
1090    pub fn adjacency_matrix(&self) -> &Vec<Vec<f64>> {
1091        &self.adj_bandwidth
1092    }
1093}
1094
1095// ---------------------------------------------------------------------------
1096// OrderedF64 — wrapper for use in BinaryHeap
1097// ---------------------------------------------------------------------------
1098
1099/// Wrapper around f64 that implements Ord for use in priority queues.
1100///
1101/// NaN values are treated as equal and greater than all finite values.
1102#[derive(Debug, Clone, Copy, PartialEq)]
1103struct OrderedF64(f64);
1104
1105impl Eq for OrderedF64 {}
1106
1107impl PartialOrd for OrderedF64 {
1108    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
1109        Some(self.cmp(other))
1110    }
1111}
1112
1113impl Ord for OrderedF64 {
1114    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
1115        self.0
1116            .partial_cmp(&other.0)
1117            .unwrap_or(std::cmp::Ordering::Equal)
1118    }
1119}
1120
1121// ---------------------------------------------------------------------------
1122// Tests
1123// ---------------------------------------------------------------------------
1124
1125#[cfg(test)]
1126mod tests {
1127    use super::*;
1128
1129    fn make_synthetic_4gpu() -> GpuTopology {
1130        GpuTopology::synthetic_mesh(4).expect("synthetic mesh should not fail")
1131    }
1132
1133    fn make_synthetic_1gpu() -> GpuTopology {
1134        GpuTopology::synthetic_mesh(1).expect("single GPU mesh should not fail")
1135    }
1136
1137    #[test]
1138    fn test_topology_discovery_macos_synthetic() {
1139        let topo = GpuTopology::discover();
1140        // On macOS this should succeed with synthetic data
1141        #[cfg(target_os = "macos")]
1142        {
1143            let topo = topo.expect("discover should succeed on macOS");
1144            assert_eq!(topo.devices.len(), 4);
1145            // 4 devices * 3 peers = 12 links
1146            assert_eq!(topo.links.len(), 12);
1147        }
1148        // On other platforms, result depends on hardware
1149        #[cfg(not(target_os = "macos"))]
1150        {
1151            let _ = topo; // May fail if no GPU
1152        }
1153    }
1154
1155    #[test]
1156    fn test_topology_type_single_gpu() {
1157        let topo = make_synthetic_1gpu();
1158        assert_eq!(topo.topology_type(), TopologyType::SingleGpu);
1159    }
1160
1161    #[test]
1162    fn test_topology_type_nvlink_mesh() {
1163        let topo = make_synthetic_4gpu();
1164        assert_eq!(topo.topology_type(), TopologyType::NvLinkMesh);
1165    }
1166
1167    #[test]
1168    fn test_topology_type_pcie_only() {
1169        let mut topo = make_synthetic_4gpu();
1170        // Replace all NVLink with PCIe
1171        for link in &mut topo.links {
1172            link.link_type = LinkType::PCIe;
1173        }
1174        assert_eq!(topo.topology_type(), TopologyType::PcieOnly);
1175    }
1176
1177    #[test]
1178    fn test_topology_type_nvlink_pair() {
1179        let topo = GpuTopology::synthetic_mesh(2).expect("pair mesh");
1180        assert_eq!(topo.topology_type(), TopologyType::NvLinkMesh);
1181        // But with only 2 and partial connectivity:
1182        let mut partial = GpuTopology {
1183            devices: vec![0, 1],
1184            links: vec![TopologyLink {
1185                from_device: 0,
1186                to_device: 1,
1187                link_type: LinkType::NvLink,
1188                bandwidth_gbps: 50.0,
1189                latency_ns: 500.0,
1190                hop_count: 1,
1191            }],
1192            adj_bandwidth: vec![vec![f64::INFINITY, 50.0], vec![0.0, f64::INFINITY]],
1193            adj_latency: vec![vec![0.0, 500.0], vec![f64::MAX, 0.0]],
1194        };
1195        // Only one direction — not full mesh
1196        assert_eq!(partial.topology_type(), TopologyType::NvLinkPair);
1197
1198        // Add reverse link to make it full mesh
1199        partial.links.push(TopologyLink {
1200            from_device: 1,
1201            to_device: 0,
1202            link_type: LinkType::NvLink,
1203            bandwidth_gbps: 50.0,
1204            latency_ns: 500.0,
1205            hop_count: 1,
1206        });
1207        partial.adj_bandwidth[1][0] = 50.0;
1208        partial.adj_latency[1][0] = 500.0;
1209        assert_eq!(partial.topology_type(), TopologyType::NvLinkMesh);
1210    }
1211
1212    #[test]
1213    fn test_best_path_direct() {
1214        let topo = make_synthetic_4gpu();
1215        let path = topo.best_path(0, 3).expect("path should exist");
1216        // In a full mesh, direct path is optimal
1217        assert_eq!(path.len(), 2);
1218        assert_eq!(path[0], 0);
1219        assert_eq!(path[1], 3);
1220    }
1221
1222    #[test]
1223    fn test_best_path_same_device() {
1224        let topo = make_synthetic_4gpu();
1225        let path = topo.best_path(1, 1).expect("same device path");
1226        assert_eq!(path, vec![1]);
1227    }
1228
1229    #[test]
1230    fn test_best_path_invalid_device() {
1231        let topo = make_synthetic_4gpu();
1232        let result = topo.best_path(0, 99);
1233        assert!(result.is_err());
1234    }
1235
1236    #[test]
1237    fn test_bandwidth_between() {
1238        let topo = make_synthetic_4gpu();
1239        let bw = topo.bandwidth_between(0, 1);
1240        assert!((bw - 100.0).abs() < 1e-6);
1241
1242        // Same device
1243        let bw_self = topo.bandwidth_between(0, 0);
1244        assert!(bw_self.is_infinite());
1245
1246        // Invalid device
1247        let bw_invalid = topo.bandwidth_between(0, 99);
1248        assert!((bw_invalid - 0.0).abs() < 1e-6);
1249    }
1250
1251    #[test]
1252    fn test_latency_between() {
1253        let topo = make_synthetic_4gpu();
1254        // 500 ns = 0.5 us
1255        let lat = topo.latency_between(0, 1);
1256        assert!((lat - 0.5).abs() < 1e-6);
1257
1258        // Same device: 0.0
1259        let lat_self = topo.latency_between(0, 0);
1260        assert!((lat_self - 0.0).abs() < 1e-6);
1261    }
1262
1263    #[test]
1264    fn test_ring_order() {
1265        let topo = make_synthetic_4gpu();
1266        let ring = topo.optimal_ring_order().expect("ring order");
1267        // All devices present
1268        assert_eq!(ring.len(), 4);
1269        let mut sorted = ring.clone();
1270        sorted.sort_unstable();
1271        assert_eq!(sorted, vec![0, 1, 2, 3]);
1272    }
1273
1274    #[test]
1275    fn test_ring_order_single() {
1276        let topo = make_synthetic_1gpu();
1277        let ring = topo.optimal_ring_order().expect("ring order");
1278        assert_eq!(ring, vec![0]);
1279    }
1280
1281    #[test]
1282    fn test_tree_construction() {
1283        let topo = make_synthetic_4gpu();
1284        let tree = topo.optimal_tree().expect("tree");
1285        assert!(topo.devices.contains(&tree.root));
1286        // All non-root devices have a parent
1287        assert_eq!(tree.parent.len(), 3);
1288        // All devices are in the tree
1289        let devs = tree.devices();
1290        assert_eq!(devs.len(), 4);
1291        // Tree depth is at least 1
1292        assert!(tree.depth() >= 1);
1293    }
1294
1295    #[test]
1296    fn test_adjacency_matrix_correctness() {
1297        let topo = make_synthetic_4gpu();
1298        let adj = topo.adjacency_matrix();
1299        assert_eq!(adj.len(), 4);
1300        for (i, adj_row) in adj.iter().enumerate().take(4) {
1301            assert!(adj_row[i].is_infinite()); // Self-bandwidth is infinity
1302            for (j, adj_val) in adj_row.iter().enumerate().take(4) {
1303                if i != j {
1304                    assert!((*adj_val - 100.0).abs() < 1e-6);
1305                }
1306            }
1307        }
1308    }
1309
1310    #[test]
1311    fn test_path_symmetry() {
1312        let topo = make_synthetic_4gpu();
1313        for i in 0..4_i32 {
1314            for j in 0..4_i32 {
1315                if i == j {
1316                    continue;
1317                }
1318                let bw_ij = topo.bandwidth_between(i, j);
1319                let bw_ji = topo.bandwidth_between(j, i);
1320                assert!(
1321                    (bw_ij - bw_ji).abs() < 1e-6,
1322                    "bandwidth {i}->{j} ({bw_ij}) != {j}->{i} ({bw_ji})"
1323                );
1324            }
1325        }
1326    }
1327
1328    #[test]
1329    fn test_communication_schedule() {
1330        let topo = make_synthetic_4gpu();
1331        let transfers = vec![(0, 1, 1_000_000u64), (2, 3, 2_000_000), (0, 3, 500_000)];
1332        let schedule = topo.build_schedule(&transfers);
1333        assert_eq!(schedule.transfers.len(), 3);
1334        assert!(schedule.estimated_time_us > 0.0);
1335    }
1336
1337    #[test]
1338    fn test_placement_optimization() {
1339        let topo = make_synthetic_4gpu();
1340        let comms = vec![
1341            TaskCommunication {
1342                task_a: 0,
1343                task_b: 1,
1344                volume_bytes: 1_000_000,
1345            },
1346            TaskCommunication {
1347                task_a: 1,
1348                task_b: 2,
1349                volume_bytes: 2_000_000,
1350            },
1351        ];
1352        let placement = topo
1353            .optimal_placement(3, &comms)
1354            .expect("placement should succeed");
1355        assert_eq!(placement.assignment.len(), 3);
1356        // All tasks assigned to different devices
1357        let mut devs: Vec<i32> = placement.assignment.values().copied().collect();
1358        devs.sort_unstable();
1359        devs.dedup();
1360        assert_eq!(devs.len(), 3);
1361    }
1362
1363    #[test]
1364    fn test_nvswitch_fabric_detection() {
1365        let mut topo = make_synthetic_4gpu();
1366        for link in &mut topo.links {
1367            link.link_type = LinkType::NvSwitch;
1368        }
1369        assert_eq!(topo.topology_type(), TopologyType::NvSwitchFabric);
1370    }
1371
1372    #[test]
1373    fn test_empty_topology_errors() {
1374        let result = GpuTopology::synthetic_mesh(0);
1375        assert!(result.is_err());
1376    }
1377
1378    #[test]
1379    fn test_nvlink_info_link_count_correct() {
1380        // Synthetic 4-GPU mesh has 100 GB/s links using NvLink V3 (25 GB/s per link).
1381        // Expected link_count = ceil(100 / 25) = 4.
1382        let topo = make_synthetic_4gpu();
1383        let info = topo
1384            .nvlink_info(0, 1)
1385            .expect("NVLink info should exist for NvLink topology");
1386        assert_eq!(info.peer_device_id, 1);
1387        assert!((info.bandwidth_gbps - 100.0).abs() < 1e-6);
1388        let per_link = info.version.per_link_bandwidth_gbps();
1389        let expected_count = (100.0_f64 / per_link).ceil() as u32;
1390        assert_eq!(
1391            info.link_count, expected_count,
1392            "link count should match bandwidth / per-link rate"
1393        );
1394    }
1395
1396    #[test]
1397    fn test_nvlink_info_returns_none_for_pcie_link() {
1398        let mut topo = make_synthetic_4gpu();
1399        // Force all links to PCIe
1400        for link in &mut topo.links {
1401            link.link_type = LinkType::PCIe;
1402        }
1403        // nvlink_info returns None for PCIe links
1404        let info = topo.nvlink_info(0, 1);
1405        assert!(info.is_none(), "should return None for PCIe links");
1406    }
1407
1408    #[test]
1409    fn test_nvlink_info_nonexistent_peer() {
1410        let topo = make_synthetic_4gpu();
1411        // Device 99 doesn't exist
1412        let info = topo.nvlink_info(0, 99);
1413        assert!(info.is_none(), "should return None for nonexistent peer");
1414    }
1415
1416    #[test]
1417    fn test_best_path_finds_intermediate_hop() {
1418        // Build a chain topology: 0 -- 1 -- 2 (no direct 0->2 link)
1419        let bw_direct = 100.0_f64;
1420        let lat = 500.0_f64;
1421        let topo = GpuTopology {
1422            devices: vec![0, 1, 2],
1423            links: vec![
1424                TopologyLink {
1425                    from_device: 0,
1426                    to_device: 1,
1427                    link_type: LinkType::NvLink,
1428                    bandwidth_gbps: bw_direct,
1429                    latency_ns: lat,
1430                    hop_count: 1,
1431                },
1432                TopologyLink {
1433                    from_device: 1,
1434                    to_device: 0,
1435                    link_type: LinkType::NvLink,
1436                    bandwidth_gbps: bw_direct,
1437                    latency_ns: lat,
1438                    hop_count: 1,
1439                },
1440                TopologyLink {
1441                    from_device: 1,
1442                    to_device: 2,
1443                    link_type: LinkType::NvLink,
1444                    bandwidth_gbps: bw_direct,
1445                    latency_ns: lat,
1446                    hop_count: 1,
1447                },
1448                TopologyLink {
1449                    from_device: 2,
1450                    to_device: 1,
1451                    link_type: LinkType::NvLink,
1452                    bandwidth_gbps: bw_direct,
1453                    latency_ns: lat,
1454                    hop_count: 1,
1455                },
1456            ],
1457            adj_bandwidth: vec![
1458                vec![f64::INFINITY, bw_direct, 0.0],
1459                vec![bw_direct, f64::INFINITY, bw_direct],
1460                vec![0.0, bw_direct, f64::INFINITY],
1461            ],
1462            adj_latency: vec![
1463                vec![0.0, lat, f64::MAX],
1464                vec![lat, 0.0, lat],
1465                vec![f64::MAX, lat, 0.0],
1466            ],
1467        };
1468
1469        // No direct 0->2 link: path must go through 1
1470        let path = topo.best_path(0, 2).expect("path should exist via hop");
1471        assert_eq!(path.len(), 3, "chain topology requires intermediate hop");
1472        assert_eq!(path[0], 0);
1473        assert_eq!(path[1], 1);
1474        assert_eq!(path[2], 2);
1475    }
1476
1477    #[test]
1478    fn test_build_schedule_sorted_by_bandwidth_desc() {
1479        let topo = make_synthetic_4gpu();
1480        // All links have equal 100 GB/s bandwidth in synthetic mesh,
1481        // so verify the schedule is built for all transfers and time is positive.
1482        let transfers = vec![(0, 1, 1_000_000_u64), (1, 2, 500_000), (2, 3, 2_000_000)];
1483        let schedule = topo.build_schedule(&transfers);
1484        assert_eq!(
1485            schedule.transfers.len(),
1486            3,
1487            "all transfers should appear in schedule"
1488        );
1489        assert!(
1490            schedule.estimated_time_us > 0.0,
1491            "total time should be positive"
1492        );
1493        // The 2 MB transfer should take longer than the 0.5 MB one at the same BW
1494        let time_2mb = 2_000_000.0_f64 / (100.0 * 1000.0);
1495        let time_1mb = 1_000_000.0_f64 / (100.0 * 1000.0);
1496        let time_0_5mb = 500_000.0_f64 / (100.0 * 1000.0);
1497        let expected_total = time_2mb + time_1mb + time_0_5mb;
1498        assert!(
1499            (schedule.estimated_time_us - expected_total).abs() < 1e-6,
1500            "estimated time should match sum of individual transfer times"
1501        );
1502    }
1503
1504    #[test]
1505    fn test_adjacency_matrix_chain_topology() {
1506        // 3-node chain: 0-1-2 with no direct 0-2 link
1507        let topo = GpuTopology {
1508            devices: vec![0, 1, 2],
1509            links: vec![],
1510            adj_bandwidth: vec![
1511                vec![f64::INFINITY, 50.0, 0.0],
1512                vec![50.0, f64::INFINITY, 50.0],
1513                vec![0.0, 50.0, f64::INFINITY],
1514            ],
1515            adj_latency: vec![
1516                vec![0.0, 500.0, f64::MAX],
1517                vec![500.0, 0.0, 500.0],
1518                vec![f64::MAX, 500.0, 0.0],
1519            ],
1520        };
1521        let adj = topo.adjacency_matrix();
1522        assert!(
1523            (adj[0][1] - 50.0).abs() < 1e-6,
1524            "0->1 bandwidth should be 50 GB/s"
1525        );
1526        assert!((adj[0][2] - 0.0).abs() < 1e-6, "0->2 has no direct link");
1527        assert!(adj[1][1].is_infinite(), "self-bandwidth is infinite");
1528    }
1529
1530    #[test]
1531    fn test_optimal_placement_all_same_device_no_cost() {
1532        // With only 1 task, placement should succeed with no communication cost.
1533        let topo = make_synthetic_4gpu();
1534        let comms = vec![];
1535        let placement = topo
1536            .optimal_placement(1, &comms)
1537            .expect("single-task placement should succeed");
1538        assert_eq!(placement.assignment.len(), 1);
1539        assert_eq!(placement.total_cost, 0.0, "no comms means no cost");
1540    }
1541
1542    #[test]
1543    fn test_topology_tree_single_device() {
1544        let topo = make_synthetic_1gpu();
1545        let tree = topo.optimal_tree().expect("single-device tree");
1546        assert_eq!(tree.root, 0);
1547        assert_eq!(tree.parent.len(), 0, "root has no parent");
1548        assert_eq!(tree.depth(), 1, "single-node tree has depth 1");
1549    }
1550}