#[cfg(feature = "complex")]
fn main() {
eprintln!("mpi_parallel_demo.rs is unavailable when built with --features complex");
}
use kryst::parallel::Comm;
use kryst::*;
#[cfg(not(feature = "complex"))]
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("Kryst MPI-Parallel Communication Demo");
println!("=====================================");
example_basic_comm_operations()?;
example_parallel_inner_products()?;
example_communicator_splitting()?;
Ok(())
}
#[cfg(not(feature = "complex"))]
fn example_basic_comm_operations() -> Result<(), Box<dyn std::error::Error>> {
println!("\n1. Basic Communicator Operations");
println!("---------------------------------");
#[cfg(not(any(feature = "mpi", feature = "rayon")))]
let comm = parallel::NoComm;
#[cfg(all(feature = "rayon", not(feature = "mpi")))]
let comm = parallel::RayonComm::new();
#[cfg(feature = "mpi")]
let comm = parallel::MpiComm::new();
println!("Communicator rank: {}", comm.rank());
println!("Communicator size: {}", comm.size());
let local_value = 1.0;
let global_sum = comm.all_reduce_f64(local_value);
println!("Local value: {}, Global sum: {}", local_value, global_sum);
println!(
"Expected: {} (local_value * comm_size)",
local_value * comm.size() as f64
);
comm.barrier();
println!("✓ Barrier synchronization completed");
Ok(())
}
#[cfg(not(feature = "complex"))]
fn example_parallel_inner_products() -> Result<(), Box<dyn std::error::Error>> {
println!("\n2. Parallel Inner Products");
println!("---------------------------");
let n = 5;
let x = vec![1.0; n]; let y = vec![2.0; n];
#[cfg(not(any(feature = "mpi", feature = "rayon")))]
let comm = parallel::NoComm;
#[cfg(all(feature = "rayon", not(feature = "mpi")))]
let comm = parallel::RayonComm::new();
#[cfg(feature = "mpi")]
let comm = parallel::MpiComm::new();
println!("Vector x: {:?}", x);
println!("Vector y: {:?}", y);
println!("Expected dot product: {} (per process)", 10.0);
println!(
"Expected global dot product: {} (total across {} processes)",
10.0 * comm.size() as f64,
comm.size()
);
let local_dot: f64 = x.iter().zip(y.iter()).map(|(xi, yi)| xi * yi).sum();
let global_dot = comm.all_reduce_f64(local_dot);
println!("Manual calculation:");
println!(" Local dot product: {}", local_dot);
println!(" Global dot product: {}", global_dot);
println!("✓ Parallel reduction working correctly");
Ok(())
}
#[cfg(not(feature = "complex"))]
fn example_communicator_splitting() -> Result<(), Box<dyn std::error::Error>> {
println!("\n3. Communicator Splitting");
println!("--------------------------");
#[cfg(not(any(feature = "mpi", feature = "rayon")))]
let main_comm = parallel::UniverseComm::Serial;
#[cfg(all(feature = "rayon", not(feature = "mpi")))]
let main_comm = parallel::UniverseComm::Rayon(parallel::RayonComm::new());
#[cfg(feature = "mpi")]
let main_comm = parallel::UniverseComm::Mpi(std::sync::Arc::new(parallel::MpiComm::new()));
println!(
"Main communicator - Rank: {}, Size: {}",
main_comm.rank(),
main_comm.size()
);
let color = main_comm.rank() as i32 % 2; let key = main_comm.rank() as i32;
let sub_comm = main_comm.split(color, key);
println!(
"Sub-communicator (color {}) - Rank: {}, Size: {}",
color,
sub_comm.rank(),
sub_comm.size()
);
let local_value = main_comm.rank() as f64 + 1.0; let sub_sum = sub_comm.all_reduce_f64(local_value);
let main_sum = main_comm.all_reduce_f64(local_value);
println!("Local value: {}", local_value);
println!("Sub-communicator sum: {}", sub_sum);
println!("Main communicator sum: {}", main_sum);
println!("✓ Communicator splitting demonstration completed");
Ok(())
}