use ferrompi::Mpi;
trait TestValue: Copy + PartialEq + std::fmt::Debug {
fn from_rank(rank: i32) -> Self;
fn from_rank_indexed(rank: i32, index: usize) -> Self;
fn type_name() -> &'static str;
}
impl TestValue for i32 {
fn from_rank(rank: i32) -> Self {
rank * 11 + 3
}
fn from_rank_indexed(rank: i32, index: usize) -> Self {
rank * 100 + index as i32
}
fn type_name() -> &'static str {
"i32"
}
}
impl TestValue for f32 {
fn from_rank(rank: i32) -> Self {
rank as f32 * 7.5 + 1.0
}
fn from_rank_indexed(rank: i32, index: usize) -> Self {
rank as f32 * 100.0 + index as f32
}
fn type_name() -> &'static str {
"f32"
}
}
impl TestValue for u8 {
fn from_rank(rank: i32) -> Self {
(rank * 13 + 5) as u8
}
fn from_rank_indexed(rank: i32, index: usize) -> Self {
((rank * 10 + index as i32) % 256) as u8
}
fn type_name() -> &'static str {
"u8"
}
}
impl TestValue for u32 {
fn from_rank(rank: i32) -> Self {
(rank * 17 + 7) as u32
}
fn from_rank_indexed(rank: i32, index: usize) -> Self {
(rank * 1000 + index as i32) as u32
}
fn type_name() -> &'static str {
"u32"
}
}
impl TestValue for u64 {
fn from_rank(rank: i32) -> Self {
(rank as u64) * 1_000_001 + 42
}
fn from_rank_indexed(rank: i32, index: usize) -> Self {
(rank as u64) * 10_000 + index as u64
}
fn type_name() -> &'static str {
"u64"
}
}
impl TestValue for i64 {
fn from_rank(rank: i32) -> Self {
(rank as i64) * -500_003 + 99
}
fn from_rank_indexed(rank: i32, index: usize) -> Self {
(rank as i64) * 10_000 + index as i64
}
fn type_name() -> &'static str {
"i64"
}
}
fn verify_data<T: TestValue>(recv_data: &[T], partner: i32, rank: i32, op_name: &str) {
for (i, val) in recv_data.iter().enumerate() {
assert_eq!(
*val,
T::from_rank_indexed(partner, i),
"{}: {op_name} data mismatch at index {i} on rank {rank}",
T::type_name()
);
}
}
fn test_send_recv<T: ferrompi::MpiDatatype + TestValue>(
world: &ferrompi::Communicator,
rank: i32,
size: i32,
tag_base: i32,
) {
let tag = tag_base;
let buf_len = 4;
if rank % 2 == 0 {
let partner = (rank + 1) % size;
let send_data: Vec<T> = (0..buf_len)
.map(|i| T::from_rank_indexed(rank, i))
.collect();
world
.send(&send_data, partner, tag)
.expect("send failed in send/recv test");
let mut recv_data = vec![T::from_rank(0); buf_len];
let (src, actual_tag, count) = world
.recv(&mut recv_data, partner, tag)
.expect("recv failed in send/recv test");
assert_eq!(
src,
partner,
"{}: send/recv source mismatch: got {src}, expected {partner}",
T::type_name()
);
assert_eq!(
actual_tag,
tag,
"{}: send/recv tag mismatch",
T::type_name()
);
assert_eq!(
count,
buf_len as i64,
"{}: send/recv count mismatch",
T::type_name()
);
verify_data(&recv_data, partner, rank, "send/recv");
} else {
let partner = (rank + size - 1) % size;
let mut recv_data = vec![T::from_rank(0); buf_len];
let (src, actual_tag, count) = world
.recv(&mut recv_data, partner, tag)
.expect("recv failed in send/recv test");
assert_eq!(
src,
partner,
"{}: send/recv source mismatch: got {src}, expected {partner}",
T::type_name()
);
assert_eq!(
actual_tag,
tag,
"{}: send/recv tag mismatch",
T::type_name()
);
assert_eq!(
count,
buf_len as i64,
"{}: send/recv count mismatch",
T::type_name()
);
verify_data(&recv_data, partner, rank, "send/recv");
let send_data: Vec<T> = (0..buf_len)
.map(|i| T::from_rank_indexed(rank, i))
.collect();
world
.send(&send_data, partner, tag)
.expect("send failed in send/recv test");
}
}
fn test_sendrecv<T: ferrompi::MpiDatatype + TestValue>(
world: &ferrompi::Communicator,
rank: i32,
size: i32,
tag_base: i32,
) {
let next = (rank + 1) % size;
let prev = (rank + size - 1) % size;
let tag = tag_base;
let buf_len = 5;
let send_buf: Vec<T> = (0..buf_len)
.map(|i| T::from_rank_indexed(rank, i))
.collect();
let mut recv_buf = vec![T::from_rank(0); buf_len];
let (source, actual_tag, count) = world
.sendrecv(&send_buf, next, tag, &mut recv_buf, prev, tag)
.expect("sendrecv failed");
assert_eq!(
source,
prev,
"{}: sendrecv source mismatch: got {source}, expected {prev}",
T::type_name()
);
assert_eq!(actual_tag, tag, "{}: sendrecv tag mismatch", T::type_name());
assert_eq!(
count,
buf_len as i64,
"{}: sendrecv count mismatch",
T::type_name()
);
verify_data(&recv_buf, prev, rank, "sendrecv");
}
fn test_isend_irecv<T: ferrompi::MpiDatatype + TestValue>(
world: &ferrompi::Communicator,
rank: i32,
size: i32,
tag_base: i32,
) {
let next = (rank + 1) % size;
let prev = (rank + size - 1) % size;
let tag = tag_base;
let buf_len = 6;
let send_data: Vec<T> = (0..buf_len)
.map(|i| T::from_rank_indexed(rank, i))
.collect();
let mut recv_data = vec![T::from_rank(0); buf_len];
let recv_req = world
.irecv(&mut recv_data, prev, tag)
.expect("irecv failed");
let send_req = world.isend(&send_data, next, tag).expect("isend failed");
send_req.wait().expect("isend wait failed");
recv_req.wait().expect("irecv wait failed");
verify_data(&recv_data, prev, rank, "isend/irecv");
}
fn run_type_tests<T: ferrompi::MpiDatatype + TestValue>(
world: &ferrompi::Communicator,
rank: i32,
size: i32,
tag_offset: i32,
test_counter: &mut i32,
) {
test_send_recv::<T>(world, rank, size, tag_offset + 10);
world.barrier().expect("barrier after send/recv failed");
*test_counter += 1;
if rank == 0 {
println!("PASS: send/recv <{}>", T::type_name());
}
test_sendrecv::<T>(world, rank, size, tag_offset + 20);
world.barrier().expect("barrier after sendrecv failed");
*test_counter += 1;
if rank == 0 {
println!("PASS: sendrecv <{}>", T::type_name());
}
test_isend_irecv::<T>(world, rank, size, tag_offset + 30);
world.barrier().expect("barrier after isend/irecv failed");
*test_counter += 1;
if rank == 0 {
println!("PASS: isend/irecv <{}>", T::type_name());
}
}
fn main() {
let mpi = Mpi::init().expect("MPI init failed");
let world = mpi.world();
let rank = world.rank();
let size = world.size();
assert!(
size >= 2,
"test_p2p_extra requires at least 2 processes, got {size}"
);
let mut test_count: i32 = 0;
run_type_tests::<i32>(&world, rank, size, 1000, &mut test_count);
run_type_tests::<f32>(&world, rank, size, 1100, &mut test_count);
run_type_tests::<u8>(&world, rank, size, 1200, &mut test_count);
run_type_tests::<u32>(&world, rank, size, 1300, &mut test_count);
run_type_tests::<u64>(&world, rank, size, 1400, &mut test_count);
run_type_tests::<i64>(&world, rank, size, 1500, &mut test_count);
{
let tag = 2000;
if rank == 0 && size >= 2 {
let data = vec![42i32; 7];
world.send(&data, 1, tag).expect("probe test: send failed");
} else if rank == 1 {
let status = world.probe::<i32>(0, tag).expect("probe::<i32> failed");
assert_eq!(status.source, 0, "probe::<i32> source mismatch");
assert_eq!(status.tag, tag, "probe::<i32> tag mismatch");
assert_eq!(status.count, 7, "probe::<i32> count mismatch");
let mut buf = vec![0i32; status.count as usize];
world
.recv(&mut buf, 0, tag)
.expect("recv after probe::<i32> failed");
assert!(
buf.iter().all(|&v| v == 42),
"probe::<i32> + recv data mismatch"
);
}
world.barrier().expect("barrier after probe::<i32> failed");
test_count += 1;
if rank == 0 {
println!("PASS: probe::<i32>");
}
}
{
let tag = 2100;
if rank == 0 && size >= 2 {
let data = vec![99i32; 3];
world.send(&data, 1, tag).expect("iprobe test: send failed");
} else if rank == 1 {
let mut status = None;
for _ in 0..100_000 {
if let Some(s) = world.iprobe::<i32>(0, tag).expect("iprobe::<i32> failed") {
status = Some(s);
break;
}
std::hint::spin_loop();
}
let status = status.expect("iprobe::<i32>: message never arrived");
assert_eq!(status.source, 0, "iprobe::<i32> source mismatch");
assert_eq!(status.tag, tag, "iprobe::<i32> tag mismatch");
assert_eq!(status.count, 3, "iprobe::<i32> count mismatch");
let mut buf = vec![0i32; status.count as usize];
world
.recv(&mut buf, 0, tag)
.expect("recv after iprobe::<i32> failed");
assert!(
buf.iter().all(|&v| v == 99),
"iprobe::<i32> + recv data mismatch"
);
}
world.barrier().expect("barrier after iprobe::<i32> failed");
test_count += 1;
if rank == 0 {
println!("PASS: iprobe::<i32>");
}
}
world.barrier().expect("final barrier failed");
if rank == 0 {
println!("\n========================================");
println!("All multi-type P2P tests passed! ({test_count} tests)");
println!("========================================");
}
}