use std::fmt;
use crate::{Communicator, Error, Mpi, Result, ThreadLevel};
pub struct TopologyInfo {
library_version: String,
standard_version: String,
thread_level: ThreadLevel,
size: i32,
hosts: Vec<HostEntry>,
#[cfg(feature = "numa")]
slurm: Option<SlurmInfo>,
}
#[derive(Debug, Clone)]
pub struct HostEntry {
pub hostname: String,
pub ranks: Vec<i32>,
}
#[cfg(feature = "numa")]
#[derive(Debug, Clone)]
pub struct SlurmInfo {
pub job_id: String,
pub node_list: Option<String>,
pub cpus_per_task: Option<i32>,
}
impl TopologyInfo {
pub fn hosts(&self) -> &[HostEntry] {
&self.hosts
}
pub fn library_version(&self) -> &str {
&self.library_version
}
pub fn standard_version(&self) -> &str {
&self.standard_version
}
pub fn thread_level(&self) -> ThreadLevel {
self.thread_level
}
pub fn size(&self) -> i32 {
self.size
}
pub fn num_hosts(&self) -> usize {
self.hosts.len()
}
#[cfg(feature = "numa")]
pub fn slurm(&self) -> Option<&SlurmInfo> {
self.slurm.as_ref()
}
}
const HOSTNAME_BUF_LEN: usize = 256;
pub(crate) fn gather_topology(comm: &Communicator, mpi: &Mpi) -> Result<TopologyInfo> {
let size = comm.size();
let rank = comm.rank();
let name = comm.processor_name()?;
let mut local_buf = [0u8; HOSTNAME_BUF_LEN];
let name_bytes = name.as_bytes();
let copy_len = name_bytes.len().min(HOSTNAME_BUF_LEN);
local_buf[..copy_len].copy_from_slice(&name_bytes[..copy_len]);
let mut all_bufs = vec![0u8; HOSTNAME_BUF_LEN * size as usize];
comm.allgather(&local_buf, &mut all_bufs)?;
let mut host_map: Vec<(String, Vec<i32>)> = Vec::new();
for r in 0..size {
let start = r as usize * HOSTNAME_BUF_LEN;
let end = start + HOSTNAME_BUF_LEN;
let raw = &all_bufs[start..end];
let nul_pos = raw.iter().position(|&b| b == 0).unwrap_or(HOSTNAME_BUF_LEN);
let hostname = std::str::from_utf8(&raw[..nul_pos])
.map_err(|_| Error::Internal("Invalid UTF-8 in gathered hostname".into()))?
.to_string();
if let Some(entry) = host_map.iter_mut().find(|(h, _)| *h == hostname) {
entry.1.push(r);
} else {
host_map.push((hostname, vec![r]));
}
}
let hosts: Vec<HostEntry> = host_map
.into_iter()
.map(|(hostname, ranks)| HostEntry { hostname, ranks })
.collect();
let library_version = if rank == 0 {
Mpi::library_version()?
} else {
String::new()
};
let standard_version = if rank == 0 {
Mpi::version()?
} else {
String::new()
};
let library_version = broadcast_string(comm, &library_version, 0)?;
let standard_version = broadcast_string(comm, &standard_version, 0)?;
let thread_level = mpi.thread_level();
#[cfg(feature = "numa")]
let slurm = if crate::slurm::is_slurm_job() {
Some(SlurmInfo {
job_id: crate::slurm::job_id().unwrap_or_default(),
node_list: crate::slurm::node_list(),
cpus_per_task: crate::slurm::cpus_per_task(),
})
} else {
None
};
Ok(TopologyInfo {
library_version,
standard_version,
thread_level,
size,
hosts,
#[cfg(feature = "numa")]
slurm,
})
}
fn broadcast_string(comm: &Communicator, s: &str, root: i32) -> Result<String> {
const BUF_LEN: usize = 512;
let mut buf = [0u8; BUF_LEN];
if comm.rank() == root {
let bytes = s.as_bytes();
let copy_len = bytes.len().min(BUF_LEN);
buf[..copy_len].copy_from_slice(&bytes[..copy_len]);
}
comm.broadcast(&mut buf, root)?;
let nul_pos = buf.iter().position(|&b| b == 0).unwrap_or(BUF_LEN);
let result = std::str::from_utf8(&buf[..nul_pos])
.map_err(|_| Error::Internal("Invalid UTF-8 in broadcast string".into()))?;
Ok(result.to_string())
}
impl fmt::Display for TopologyInfo {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "================ MPI Topology ================")?;
writeln!(f, "Library: {}", self.library_version)?;
writeln!(f, "Standard: {}", self.standard_version)?;
writeln!(f, "Threads: {:?}", self.thread_level)?;
let node_word = if self.hosts.len() == 1 {
"node"
} else {
"nodes"
};
writeln!(
f,
"Processes: {} across {} {}",
self.size,
self.hosts.len(),
node_word,
)?;
#[cfg(feature = "numa")]
if let Some(ref slurm) = self.slurm {
writeln!(f, "SLURM Job: {}", slurm.job_id)?;
if let Some(ref nl) = slurm.node_list {
writeln!(f, "Nodes: {}", nl)?;
}
if let Some(cpt) = slurm.cpus_per_task {
writeln!(f, "CPUs/Task: {}", cpt)?;
}
}
writeln!(f)?;
for entry in &self.hosts {
let ranks_str: Vec<String> = entry.ranks.iter().map(|r| r.to_string()).collect();
let proc_word = if entry.ranks.len() == 1 {
"process"
} else {
"processes"
};
writeln!(
f,
" {}: ranks {} ({} {})",
entry.hostname,
ranks_str.join(", "),
entry.ranks.len(),
proc_word,
)?;
}
write!(f, "==============================================")?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_topology() -> TopologyInfo {
TopologyInfo {
library_version: "Open MPI v4.1.6".to_string(),
standard_version: "MPI 4.0".to_string(),
thread_level: ThreadLevel::Funneled,
size: 8,
hosts: vec![
HostEntry {
hostname: "compute-01".to_string(),
ranks: vec![0, 1, 2, 3],
},
HostEntry {
hostname: "compute-02".to_string(),
ranks: vec![4, 5, 6, 7],
},
],
#[cfg(feature = "numa")]
slurm: None,
}
}
#[test]
fn display_contains_library_version() {
let topo = sample_topology();
let output = format!("{topo}");
assert!(output.contains("Open MPI v4.1.6"));
}
#[test]
fn display_contains_standard_version() {
let topo = sample_topology();
let output = format!("{topo}");
assert!(output.contains("MPI 4.0"));
}
#[test]
fn display_contains_thread_level() {
let topo = sample_topology();
let output = format!("{topo}");
assert!(output.contains("Funneled"));
}
#[test]
fn display_contains_process_count() {
let topo = sample_topology();
let output = format!("{topo}");
assert!(output.contains("8 across 2 nodes"));
}
#[test]
fn display_contains_host_entries() {
let topo = sample_topology();
let output = format!("{topo}");
assert!(output.contains("compute-01: ranks 0, 1, 2, 3 (4 processes)"));
assert!(output.contains("compute-02: ranks 4, 5, 6, 7 (4 processes)"));
}
#[test]
fn display_single_node() {
let topo = TopologyInfo {
library_version: "MPICH v4.1".to_string(),
standard_version: "MPI 4.0".to_string(),
thread_level: ThreadLevel::Single,
size: 4,
hosts: vec![HostEntry {
hostname: "localhost".to_string(),
ranks: vec![0, 1, 2, 3],
}],
#[cfg(feature = "numa")]
slurm: None,
};
let output = format!("{topo}");
assert!(output.contains("4 across 1 node"));
assert!(!output.contains("nodes"));
}
#[test]
fn display_single_process() {
let topo = TopologyInfo {
library_version: "MPICH v4.1".to_string(),
standard_version: "MPI 4.0".to_string(),
thread_level: ThreadLevel::Single,
size: 1,
hosts: vec![HostEntry {
hostname: "localhost".to_string(),
ranks: vec![0],
}],
#[cfg(feature = "numa")]
slurm: None,
};
let output = format!("{topo}");
assert!(output.contains("1 process)"));
assert!(!output.contains("processes"));
}
#[test]
fn accessors_return_expected_values() {
let topo = sample_topology();
assert_eq!(topo.library_version(), "Open MPI v4.1.6");
assert_eq!(topo.standard_version(), "MPI 4.0");
assert_eq!(topo.thread_level(), ThreadLevel::Funneled);
assert_eq!(topo.size(), 8);
assert_eq!(topo.num_hosts(), 2);
assert_eq!(topo.hosts().len(), 2);
assert_eq!(topo.hosts()[0].hostname, "compute-01");
assert_eq!(topo.hosts()[0].ranks, vec![0, 1, 2, 3]);
}
#[cfg(feature = "numa")]
#[test]
fn display_with_slurm_info() {
let topo = TopologyInfo {
library_version: "Open MPI v4.1.6".to_string(),
standard_version: "MPI 4.0".to_string(),
thread_level: ThreadLevel::Multiple,
size: 8,
hosts: vec![
HostEntry {
hostname: "compute-01".to_string(),
ranks: vec![0, 1, 2, 3],
},
HostEntry {
hostname: "compute-02".to_string(),
ranks: vec![4, 5, 6, 7],
},
],
slurm: Some(SlurmInfo {
job_id: "123456".to_string(),
node_list: Some("compute-[01-02]".to_string()),
cpus_per_task: Some(4),
}),
};
let output = format!("{topo}");
assert!(output.contains("SLURM Job: 123456"));
assert!(output.contains("Nodes: compute-[01-02]"));
assert!(output.contains("CPUs/Task: 4"));
}
#[cfg(feature = "numa")]
#[test]
fn slurm_accessor_none_when_absent() {
let topo = sample_topology();
assert!(topo.slurm().is_none());
}
}