use mpi::traits::*;
use std::fs;
use std::io::Write;
use std::mem;
fn create_input_file() -> Result<String, Box<dyn std::error::Error>> {
let input_content = r#"
&GLOBAL
PROJECT H2O_MPI_RUST
RUN_TYPE ENERGY_FORCE
PRINT_LEVEL MEDIUM
&END GLOBAL
&FORCE_EVAL
METHOD Quickstep
&DFT
BASIS_SET_FILE_NAME BASIS_MOLOPT
POTENTIAL_FILE_NAME GTH_POTENTIALS
&MGRID
CUTOFF 400
REL_CUTOFF 50
&END MGRID
&QS
EPS_DEFAULT 1.0E-10
&END QS
&SCF
SCF_GUESS ATOMIC
EPS_SCF 1.0E-6
MAX_SCF 50
ADDED_MOS 10
&DIAGONALIZATION
ALGORITHM STANDARD
&END DIAGONALIZATION
&SMEAR OFF
&END SMEAR
&PRINT
&RESTART OFF
&END RESTART
&END PRINT
&END SCF
&XC
&XC_FUNCTIONAL PBE
&END XC_FUNCTIONAL
&END XC
&END DFT
&SUBSYS
&CELL
ABC 8.0 8.0 8.0
PERIODIC NONE
&END CELL
&COORD
O 0.000000 0.000000 0.000000
H 0.000000 0.000000 1.000000
H 0.942809 0.000000 -0.333333
&END COORD
&KIND H
BASIS_SET DZVP-MOLOPT-GTH
POTENTIAL GTH-PBE-q1
&END KIND
&KIND O
BASIS_SET DZVP-MOLOPT-GTH
POTENTIAL GTH-PBE-q6
&END KIND
&END SUBSYS
&END FORCE_EVAL
"#;
let filename = "h2o_mpi_rust.inp";
let mut file = fs::File::create(filename)?;
file.write_all(input_content.as_bytes())?;
Ok(filename.to_string())
}
fn main() {
let universe = mpi::initialize().expect("Failed to initialize MPI");
let world = universe.world();
let rank = world.rank();
let size = world.size();
if rank == 0 {
println!("{}", "=".repeat(70));
println!("CP2K-RS MPI Example (Rust)");
println!("{}", "=".repeat(70));
println!("MPI Configuration:");
println!(" Total ranks: {}", size);
if let Ok(omp_threads) = std::env::var("OMP_NUM_THREADS") {
println!(" OMP_NUM_THREADS: {}", omp_threads);
println!(
" Total parallelism: {} ranks × {} threads",
size, omp_threads
);
} else {
println!(" OMP_NUM_THREADS: not set (using default)");
}
println!("{}", "=".repeat(70));
}
world.barrier();
if let Err(e) = cp2k_rs::init() {
eprintln!("✗ Rank {} failed to initialize CP2K: {}", rank, e);
world.abort(1);
}
if rank == 0 {
match cp2k_rs::get_version() {
Ok(version) => {
println!("✓ CP2K initialized with MPI support");
println!(" Version: {}", version);
}
Err(e) => {
eprintln!("✗ Failed to get CP2K version: {}", e);
}
}
}
world.barrier();
let input_file = if rank == 0 {
match create_input_file() {
Ok(filename) => {
println!("✓ Created input file: {}", filename);
filename
}
Err(e) => {
eprintln!("✗ Failed to create input file: {}", e);
world.abort(1);
}
}
} else {
"h2o_mpi_rust.inp".to_string()
};
world.barrier();
let output_file = if rank == 0 {
"h2o_mpi_rust.out".to_string()
} else {
format!("h2o_mpi_rust_rank{}.out", rank)
};
let mut force_env = match cp2k_rs::ForceEnv::new_with_mpi(&input_file, &output_file, &world) {
Ok(env) => {
if rank == 0 {
println!("✓ Force environment created on all {} ranks", size);
}
env
}
Err(e) => {
eprintln!("✗ Rank {} failed to create force environment: {}", rank, e);
world.abort(1);
}
};
world.barrier();
match (force_env.get_natom(), force_env.get_nparticle()) {
(Ok(natom), Ok(nparticle)) => {
if rank == 0 {
println!("\nSystem Information:");
println!(" Number of atoms: {}", natom);
println!(" Number of particles: {}", nparticle);
}
}
_ => {
eprintln!("✗ Rank {} failed to get system info", rank);
world.abort(1);
}
}
world.barrier();
if rank == 0 {
println!("\n{}", "=".repeat(70));
println!("Running parallel DFT calculation on {} ranks...", size);
println!("{}", "=".repeat(70));
}
world.barrier();
if let Err(e) = force_env.calc_energy_force() {
eprintln!("✗ Rank {} calculation failed: {}", rank, e);
world.abort(1);
}
if rank == 0 {
println!("✓ Energy and forces calculated");
}
world.barrier();
if rank == 0 {
match (
force_env.get_potential_energy(),
force_env.get_forces(),
force_env.get_positions(),
) {
(Ok(energy), Ok(forces), Ok(positions)) => {
println!("\nResults:");
println!(" Total Energy: {:.10} Ha", energy);
println!(" Forces (Ha/Bohr):");
let forces_slice = forces.as_slice().unwrap();
for (i, force) in forces_slice.chunks(3).enumerate() {
println!(
" Atom {}: [{:12.6}, {:12.6}, {:12.6}]",
i, force[0], force[1], force[2]
);
}
println!("\n Positions (Bohr):");
let positions_slice = positions.as_slice().unwrap();
for (i, pos) in positions_slice.chunks(3).enumerate() {
println!(
" Atom {}: [{:12.6}, {:12.6}, {:12.6}]",
i, pos[0], pos[1], pos[2]
);
}
let max_force = forces_slice
.chunks(3)
.map(|f| (f[0].powi(2) + f[1].powi(2) + f[2].powi(2)).sqrt())
.fold(0.0f64, f64::max);
println!("\n Maximum force magnitude: {:.6} Ha/Bohr", max_force);
}
_ => {
eprintln!("✗ Rank {} failed to retrieve results", rank);
}
}
}
world.barrier();
#[cfg(feature = "extended")]
{
if rank == 0 {
println!("\n{}", "=".repeat(70));
println!("Extended DFT properties:");
println!("{}", "=".repeat(70));
match force_env.get_homo_lumo(1) {
Ok((homo, lumo, homo_idx, lumo_idx)) => {
let band_gap = lumo - homo;
println!(
" HOMO: {:.6} Ha ({:.3} eV) [MO #{}]",
homo,
homo * 27.211386,
homo_idx
);
println!(
" LUMO: {:.6} Ha ({:.3} eV) [MO #{}]",
lumo,
lumo * 27.211386,
lumo_idx
);
println!(
" Band gap: {:.6} Ha ({:.3} eV)",
band_gap,
band_gap * 27.211386
);
}
Err(e) => {
println!(" Note: Could not retrieve HOMO/LUMO: {}", e);
}
}
match force_env.get_scf_info() {
Ok((nsteps, converged, energy_change)) => {
println!("\n SCF Convergence:");
println!(" Converged: {}", if converged { "Yes" } else { "No" });
println!(" Steps: {}", nsteps);
println!(" Final energy change: {:.2e} Ha", energy_change);
}
Err(e) => {
println!(" Note: Could not retrieve SCF info: {}", e);
}
}
}
}
world.barrier();
if rank == 0 {
println!("\n{}", "=".repeat(70));
println!("Calculation complete!");
println!("{}", "=".repeat(70));
println!("\nParallel Efficiency Notes:");
println!(
" - The calculation was distributed across {} MPI ranks",
size
);
println!(" - Each rank used OpenMP threads (OMP_NUM_THREADS)");
println!(" - Total output files: {} (one per rank)", size);
println!(" - Main output: {}", output_file);
}
world.barrier();
drop(force_env);
if rank == 0 {
println!("\n✓ Force environment cleaned up");
}
world.barrier();
if let Err(e) = cp2k_rs::finalize() {
eprintln!("✗ Rank {} failed to finalize CP2K: {}", rank, e);
world.abort(1);
}
if rank == 0 {
println!("✓ CP2K finalized successfully");
println!("\nAll ranks completed successfully!");
println!("{}", "=".repeat(70));
}
unsafe { mpi::ffi::MPI_Finalize() };
mem::forget(universe);
}