cp2k-rs 0.2.0

Rust bindings for CP2K with Python interface
Documentation
//! MPI-parallel CP2K example in Rust
//!
//! This example demonstrates how to run CP2K calculations in parallel using MPI
//! from Rust. It shows proper initialization, communicator management, and
//! finalization.
//!
//! # Usage
//!
//! ```bash
//! # Build with MPI and extended support
//! module load mpi/openmpi-x86_64
//! cargo build --release --example rust_mpi_example --features mpi,extended,build-cp2k
//!
//! # Run with mpirun - pass CP2K_DATA_DIR to all ranks with -x
//! export CP2K_DATA_DIR=$(find target/release/build -name "data" -type d | head -1)
//! export OMP_NUM_THREADS=2
//! mpirun -np 4 -x CP2K_DATA_DIR -x OMP_NUM_THREADS ./target/release/examples/rust_mpi_example
//! ```
//!
//! # Requirements
//!
//! - MPI implementation (OpenMPI, MPICH, Intel MPI, etc.)
//! - CP2K built with MPI support (PSMP)
//! - Rust `mpi` crate

use mpi::traits::*;
use std::fs;
use std::io::Write;
use std::mem;

// Note: This example assumes cp2k-rs is built without MPI feature by default
// For MPI support, you would need to add the mpi feature to cp2k-rs

fn create_input_file() -> Result<String, Box<dyn std::error::Error>> {
    let input_content = r#"
&GLOBAL
  PROJECT H2O_MPI_RUST
  RUN_TYPE ENERGY_FORCE
  PRINT_LEVEL MEDIUM
&END GLOBAL

&FORCE_EVAL
  METHOD Quickstep
  &DFT
    BASIS_SET_FILE_NAME BASIS_MOLOPT
    POTENTIAL_FILE_NAME GTH_POTENTIALS

    &MGRID
      CUTOFF 400
      REL_CUTOFF 50
    &END MGRID

    &QS
      EPS_DEFAULT 1.0E-10
    &END QS

    &SCF
      SCF_GUESS ATOMIC
      EPS_SCF 1.0E-6
      MAX_SCF 50
      ADDED_MOS 10
      &DIAGONALIZATION
        ALGORITHM STANDARD
      &END DIAGONALIZATION
      &SMEAR OFF
      &END SMEAR
      &PRINT
        &RESTART OFF
        &END RESTART
      &END PRINT
    &END SCF

    &XC
      &XC_FUNCTIONAL PBE
      &END XC_FUNCTIONAL
    &END XC
  &END DFT

  &SUBSYS
    &CELL
      ABC 8.0 8.0 8.0
      PERIODIC NONE
    &END CELL

    &COORD
      O   0.000000    0.000000    0.000000
      H   0.000000    0.000000    1.000000
      H   0.942809    0.000000   -0.333333
    &END COORD

    &KIND H
      BASIS_SET DZVP-MOLOPT-GTH
      POTENTIAL GTH-PBE-q1
    &END KIND

    &KIND O
      BASIS_SET DZVP-MOLOPT-GTH
      POTENTIAL GTH-PBE-q6
    &END KIND
  &END SUBSYS
&END FORCE_EVAL
"#;

    let filename = "h2o_mpi_rust.inp";
    let mut file = fs::File::create(filename)?;
    file.write_all(input_content.as_bytes())?;
    Ok(filename.to_string())
}

fn main() {
    // Step 1: Initialize MPI
    // CRITICAL: This must be done before CP2K initialization
    let universe = mpi::initialize().expect("Failed to initialize MPI");
    let world = universe.world();
    let rank = world.rank();
    let size = world.size();

    // Print MPI configuration (rank 0 only)
    if rank == 0 {
        println!("{}", "=".repeat(70));
        println!("CP2K-RS MPI Example (Rust)");
        println!("{}", "=".repeat(70));
        println!("MPI Configuration:");
        println!("  Total ranks: {}", size);
        if let Ok(omp_threads) = std::env::var("OMP_NUM_THREADS") {
            println!("  OMP_NUM_THREADS: {}", omp_threads);
            println!(
                "  Total parallelism: {} ranks × {} threads",
                size, omp_threads
            );
        } else {
            println!("  OMP_NUM_THREADS: not set (using default)");
        }
        println!("{}", "=".repeat(70));
    }

    // Synchronize all ranks
    world.barrier();

    // Step 2: Initialize CP2K with MPI support
    // All ranks must call this
    if let Err(e) = cp2k_rs::init() {
        eprintln!("✗ Rank {} failed to initialize CP2K: {}", rank, e);
        world.abort(1);
    }

    if rank == 0 {
        match cp2k_rs::get_version() {
            Ok(version) => {
                println!("✓ CP2K initialized with MPI support");
                println!("  Version: {}", version);
            }
            Err(e) => {
                eprintln!("✗ Failed to get CP2K version: {}", e);
            }
        }
    }

    world.barrier();

    // Step 3: Create input file (rank 0 only, all ranks will read)
    let input_file = if rank == 0 {
        match create_input_file() {
            Ok(filename) => {
                println!("✓ Created input file: {}", filename);
                filename
            }
            Err(e) => {
                eprintln!("✗ Failed to create input file: {}", e);
                world.abort(1);
            }
        }
    } else {
        "h2o_mpi_rust.inp".to_string()
    };

    // Wait for file to be written
    world.barrier();

    // Step 4: Create force environment on all ranks using the MPI communicator
    // Each rank participates in the parallel calculation
    let output_file = if rank == 0 {
        "h2o_mpi_rust.out".to_string()
    } else {
        format!("h2o_mpi_rust_rank{}.out", rank)
    };

    let mut force_env = match cp2k_rs::ForceEnv::new_with_mpi(&input_file, &output_file, &world) {
        Ok(env) => {
            if rank == 0 {
                println!("✓ Force environment created on all {} ranks", size);
            }
            env
        }
        Err(e) => {
            eprintln!("✗ Rank {} failed to create force environment: {}", rank, e);
            world.abort(1);
        }
    };

    world.barrier();

    // Step 5: Get system information
    match (force_env.get_natom(), force_env.get_nparticle()) {
        (Ok(natom), Ok(nparticle)) => {
            if rank == 0 {
                println!("\nSystem Information:");
                println!("  Number of atoms: {}", natom);
                println!("  Number of particles: {}", nparticle);
            }
        }
        _ => {
            eprintln!("✗ Rank {} failed to get system info", rank);
            world.abort(1);
        }
    }

    world.barrier();

    // Step 6: Run the parallel calculation
    if rank == 0 {
        println!("\n{}", "=".repeat(70));
        println!("Running parallel DFT calculation on {} ranks...", size);
        println!("{}", "=".repeat(70));
    }

    world.barrier();

    if let Err(e) = force_env.calc_energy_force() {
        eprintln!("✗ Rank {} calculation failed: {}", rank, e);
        world.abort(1);
    }

    if rank == 0 {
        println!("✓ Energy and forces calculated");
    }

    world.barrier();

    // Step 7: Retrieve results (typically rank 0 only, but all ranks can access)
    if rank == 0 {
        match (
            force_env.get_potential_energy(),
            force_env.get_forces(),
            force_env.get_positions(),
        ) {
            (Ok(energy), Ok(forces), Ok(positions)) => {
                println!("\nResults:");
                println!("  Total Energy: {:.10} Ha", energy);
                println!("  Forces (Ha/Bohr):");
                let forces_slice = forces.as_slice().unwrap();
                for (i, force) in forces_slice.chunks(3).enumerate() {
                    println!(
                        "    Atom {}: [{:12.6}, {:12.6}, {:12.6}]",
                        i, force[0], force[1], force[2]
                    );
                }

                println!("\n  Positions (Bohr):");
                let positions_slice = positions.as_slice().unwrap();
                for (i, pos) in positions_slice.chunks(3).enumerate() {
                    println!(
                        "    Atom {}: [{:12.6}, {:12.6}, {:12.6}]",
                        i, pos[0], pos[1], pos[2]
                    );
                }

                // Calculate maximum force magnitude
                let max_force = forces_slice
                    .chunks(3)
                    .map(|f| (f[0].powi(2) + f[1].powi(2) + f[2].powi(2)).sqrt())
                    .fold(0.0f64, f64::max);
                println!("\n  Maximum force magnitude: {:.6} Ha/Bohr", max_force);
            }
            _ => {
                eprintln!("✗ Rank {} failed to retrieve results", rank);
            }
        }
    }

    world.barrier();

    // Step 8: Extended properties (if available with 'extended' feature)
    #[cfg(feature = "extended")]
    {
        if rank == 0 {
            println!("\n{}", "=".repeat(70));
            println!("Extended DFT properties:");
            println!("{}", "=".repeat(70));

            match force_env.get_homo_lumo(1) {
                Ok((homo, lumo, homo_idx, lumo_idx)) => {
                    let band_gap = lumo - homo;
                    println!(
                        "  HOMO: {:.6} Ha ({:.3} eV) [MO #{}]",
                        homo,
                        homo * 27.211386,
                        homo_idx
                    );
                    println!(
                        "  LUMO: {:.6} Ha ({:.3} eV) [MO #{}]",
                        lumo,
                        lumo * 27.211386,
                        lumo_idx
                    );
                    println!(
                        "  Band gap: {:.6} Ha ({:.3} eV)",
                        band_gap,
                        band_gap * 27.211386
                    );
                }
                Err(e) => {
                    println!("  Note: Could not retrieve HOMO/LUMO: {}", e);
                }
            }

            match force_env.get_scf_info() {
                Ok((nsteps, converged, energy_change)) => {
                    println!("\n  SCF Convergence:");
                    println!("    Converged: {}", if converged { "Yes" } else { "No" });
                    println!("    Steps: {}", nsteps);
                    println!("    Final energy change: {:.2e} Ha", energy_change);
                }
                Err(e) => {
                    println!("  Note: Could not retrieve SCF info: {}", e);
                }
            }
        }
    }

    world.barrier();

    // Step 9: Performance summary
    if rank == 0 {
        println!("\n{}", "=".repeat(70));
        println!("Calculation complete!");
        println!("{}", "=".repeat(70));
        println!("\nParallel Efficiency Notes:");
        println!(
            "  - The calculation was distributed across {} MPI ranks",
            size
        );
        println!("  - Each rank used OpenMP threads (OMP_NUM_THREADS)");
        println!("  - Total output files: {} (one per rank)", size);
        println!("  - Main output: {}", output_file);
    }

    world.barrier();

    // Step 10: Cleanup
    // CRITICAL: Drop force environment before finalizing CP2K
    drop(force_env);

    if rank == 0 {
        println!("\n✓ Force environment cleaned up");
    }

    world.barrier();

    // Step 11: Finalize CP2K
    // CRITICAL: Must be called before MPI finalization.
    // cp2k_rs::finalize() calls finalize_cp2k(.FALSE.) which tears down CP2K's
    // internal state but does NOT call MPI_Finalize.
    if let Err(e) = cp2k_rs::finalize() {
        eprintln!("✗ Rank {} failed to finalize CP2K: {}", rank, e);
        world.abort(1);
    }

    if rank == 0 {
        println!("✓ CP2K finalized successfully");
        println!("\nAll ranks completed successfully!");
        println!("{}", "=".repeat(70));
    }

    // Step 12: MPI finalization
    // We must call MPI_Finalize exactly once.  The mpi::Universe RAII guard
    // calls MPI_Finalize on drop, but CP2K's Fortran runtime may also register
    // atexit handlers that touch MPI state after the drop.  Forgetting the
    // universe prevents a second MPI_Finalize / UCX shutdown that would race
    // with those atexit handlers and cause a SIGSEGV in libucs.
    // Instead we call MPI_Finalize explicitly here, after all CP2K cleanup,
    // before any atexit handlers can interfere.
    unsafe { mpi::ffi::MPI_Finalize() };
    mem::forget(universe);
}