use std::io::{Read, Write};
use std::os::unix::net::UnixListener;
use mpi::traits::*;
use cp2k_rs::worker_protocol::{Command, Payload, Request, Response};
use cp2k_rs::{finalize, init, ForceEnv};
fn read_msg<R: Read>(stream: &mut R) -> std::io::Result<Vec<u8>> {
let mut len_buf = [0u8; 4];
stream.read_exact(&mut len_buf)?;
let len = u32::from_le_bytes(len_buf) as usize;
let mut buf = vec![0u8; len];
stream.read_exact(&mut buf)?;
Ok(buf)
}
fn write_msg<W: Write>(stream: &mut W, data: &[u8]) -> std::io::Result<()> {
let len = data.len() as u32;
stream.write_all(&len.to_le_bytes())?;
stream.write_all(data)?;
stream.flush()
}
const TAG_LEN: mpi::Tag = 1;
const TAG_CMD: mpi::Tag = 2;
fn broadcast_command(world: &mpi::topology::SimpleCommunicator, bytes: &[u8]) {
let rank = world.rank();
if rank == 0 {
let len = bytes.len() as u32;
for r in 1..world.size() {
world.process_at_rank(r).send_with_tag(&[len], TAG_LEN);
world.process_at_rank(r).send_with_tag(bytes, TAG_CMD);
}
} else {
}
}
fn receive_command(world: &mpi::topology::SimpleCommunicator) -> Vec<u8> {
let (len_buf, _) = world
.process_at_rank(0)
.receive_vec_with_tag::<u32>(TAG_LEN);
let len = len_buf[0] as usize;
let mut buf = vec![0u8; len];
let (data, _) = world.process_at_rank(0).receive_vec_with_tag::<u8>(TAG_CMD);
buf.copy_from_slice(&data[..len]);
buf
}
fn dispatch(req: &Request, force_env: &mut Option<ForceEnv>) -> Response {
let id = req.request_id;
macro_rules! need_env {
($fe:expr) => {
match $fe {
Some(fe) => fe,
None => return Response::error(id, "No force environment initialized"),
}
};
}
match &req.command {
Command::InitForceEnv { input, output } => match ForceEnv::new(input, output) {
Ok(fe) => {
*force_env = Some(fe);
Response::ok(id, Payload::Empty)
}
Err(e) => Response::error(id, format!("{e}")),
},
Command::CalcEnergyForce => {
let fe = need_env!(force_env);
match fe.calc_energy_force() {
Ok(()) => Response::ok(id, Payload::Empty),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::CalcEnergy => {
let fe = need_env!(force_env);
match fe.calc_energy() {
Ok(()) => Response::ok(id, Payload::Empty),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::GetNatom => {
let fe = need_env!(force_env);
match fe.get_natom() {
Ok(n) => Response::ok(id, Payload::UInt(n as u64)),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::GetNparticle => {
let fe = need_env!(force_env);
match fe.get_nparticle() {
Ok(n) => Response::ok(id, Payload::UInt(n as u64)),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::GetPositions => {
let fe = need_env!(force_env);
match fe.get_positions() {
Ok(arr) => Response::ok(id, Payload::Array1(arr.into_raw_vec_and_offset().0)),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::GetForces => {
let fe = need_env!(force_env);
match fe.get_forces() {
Ok(arr) => Response::ok(id, Payload::Array1(arr.into_raw_vec_and_offset().0)),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::GetPotentialEnergy => {
let fe = need_env!(force_env);
match fe.get_potential_energy() {
Ok(e) => Response::ok(id, Payload::Float(e)),
Err(err) => Response::error(id, format!("{err}")),
}
}
Command::GetCell => {
let fe = need_env!(force_env);
match fe.get_cell() {
Ok(arr) => {
let data: Vec<f64> = arr.into_raw_vec_and_offset().0;
Response::ok(
id,
Payload::Array2 {
rows: 3,
cols: 3,
data,
},
)
}
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::GetQmmmCell => {
let fe = need_env!(force_env);
match fe.get_qmmm_cell() {
Ok(arr) => {
let data: Vec<f64> = arr.into_raw_vec_and_offset().0;
Response::ok(
id,
Payload::Array2 {
rows: 3,
cols: 3,
data,
},
)
}
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::SetPositions { data } => {
let fe = need_env!(force_env);
match fe.set_positions(data) {
Ok(()) => Response::ok(id, Payload::Empty),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::SetVelocities { data } => {
let fe = need_env!(force_env);
match fe.set_velocities(data) {
Ok(()) => Response::ok(id, Payload::Empty),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::SetCell { data } => {
let fe = need_env!(force_env);
if data.len() != 9 {
return Response::error(id, "SetCell: expected 9 floats");
}
let mut cell = [[0.0f64; 3]; 3];
for i in 0..3 {
for j in 0..3 {
cell[i][j] = data[i * 3 + j];
}
}
match fe.set_cell(&cell) {
Ok(()) => Response::ok(id, Payload::Empty),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::GetMoCount => {
let fe = need_env!(force_env);
match fe.get_mo_count() {
Ok(n) => Response::ok(id, Payload::Int(n as i64)),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::IsQuickstep => {
let fe = need_env!(force_env);
Response::ok(id, Payload::Bool(fe.is_quickstep()))
}
#[cfg(feature = "extended")]
Command::GetStressTensor => {
let fe = need_env!(force_env);
match fe.get_stress_tensor() {
Ok(arr) => {
let data: Vec<f64> = arr.into_raw_vec_and_offset().0;
Response::ok(
id,
Payload::Array2 {
rows: 3,
cols: 3,
data,
},
)
}
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetVirialTensor => {
let fe = need_env!(force_env);
match fe.get_virial_tensor() {
Ok(arr) => {
let data: Vec<f64> = arr.into_raw_vec_and_offset().0;
Response::ok(
id,
Payload::Array2 {
rows: 3,
cols: 3,
data,
},
)
}
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetNmo { spin } => {
let fe = need_env!(force_env);
match fe.get_nmo(*spin) {
Ok(n) => Response::ok(id, Payload::UInt(n as u64)),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetEigenvalues { spin } => {
let fe = need_env!(force_env);
match fe.get_eigenvalues(*spin) {
Ok(arr) => Response::ok(id, Payload::Array1(arr.into_raw_vec_and_offset().0)),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetOccupationNumbers { spin } => {
let fe = need_env!(force_env);
match fe.get_occupation_numbers(*spin) {
Ok(arr) => Response::ok(id, Payload::Array1(arr.into_raw_vec_and_offset().0)),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetHomoLumo { spin } => {
let fe = need_env!(force_env);
match fe.get_homo_lumo(*spin) {
Ok((homo, lumo, homo_idx, lumo_idx)) => Response::ok(
id,
Payload::HomoLumo {
homo,
lumo,
homo_idx,
lumo_idx,
},
),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetMullikenCharges => {
let fe = need_env!(force_env);
match fe.get_mulliken_charges() {
Ok(arr) => Response::ok(id, Payload::Array1(arr.into_raw_vec_and_offset().0)),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetHirshfeldCharges => {
let fe = need_env!(force_env);
match fe.get_hirshfeld_charges() {
Ok(arr) => Response::ok(id, Payload::Array1(arr.into_raw_vec_and_offset().0)),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetDipoleMoment => {
let fe = need_env!(force_env);
match fe.get_dipole_moment() {
Ok(arr) => Response::ok(id, Payload::Array1(arr.into_raw_vec_and_offset().0)),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetScfInfo => {
let fe = need_env!(force_env);
match fe.get_scf_info() {
Ok((nsteps, converged, energy_change)) => Response::ok(
id,
Payload::ScfInfo {
nsteps,
converged,
energy_change,
},
),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetEnergyComponents => {
let fe = need_env!(force_env);
match fe.get_energy_components() {
Ok((e_kin, e_hartree, e_xc, e_core, e_total)) => Response::ok(
id,
Payload::EnergyComponents {
e_kin,
e_hartree,
e_xc,
e_core,
e_total,
},
),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetNelectron => {
let fe = need_env!(force_env);
match fe.get_nelectron() {
Ok(n) => Response::ok(id, Payload::Int(n as i64)),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetFermiEnergy => {
let fe = need_env!(force_env);
match fe.get_fermi_energy() {
Ok(e) => Response::ok(id, Payload::Float(e)),
Err(err) => Response::error(id, format!("{err}")),
}
}
#[cfg(feature = "extended")]
Command::GetTotalSpin => {
let fe = need_env!(force_env);
match fe.get_total_spin() {
Ok(s) => Response::ok(id, Payload::Float(s)),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::GetVersion => match cp2k_rs::get_version() {
Ok(v) => Response::ok(id, Payload::String(v)),
Err(e) => Response::error(id, format!("{e}")),
},
Command::Shutdown => Response::ok(id, Payload::Empty),
}
}
fn main() {
let universe = mpi::initialize().expect("MPI initialization failed");
let world = universe.world();
let rank = world.rank();
if let Err(e) = init() {
if rank == 0 {
eprintln!("[worker] CP2K init failed: {e}");
}
return;
}
if rank == 0 {
rank0_server(&world);
} else {
other_rank_loop(&world);
}
if let Err(e) = finalize() {
eprintln!("[worker rank {rank}] CP2K finalize failed: {e}");
}
}
fn rank0_server(world: &mpi::topology::SimpleCommunicator) {
let socket_file = std::env::var("CP2K_WORKER_SOCKET_FILE")
.unwrap_or_else(|_| format!("/tmp/cp2k_worker_{}.sock", std::process::id()));
let _ = std::fs::remove_file(&socket_file);
let listener = UnixListener::bind(&socket_file).expect("Failed to bind Unix socket");
use std::os::unix::fs::PermissionsExt;
std::fs::set_permissions(&socket_file, std::fs::Permissions::from_mode(0o600))
.expect("Failed to set socket permissions");
let ready_file = format!("{socket_file}.ready");
std::fs::write(&ready_file, &socket_file).expect("Failed to write ready file");
eprintln!("[worker] Listening on {socket_file}");
let (mut stream, _) = listener.accept().expect("Failed to accept connection");
let mut force_env: Option<ForceEnv> = None;
let mut shutdown = false;
loop {
let raw = match read_msg(&mut stream) {
Ok(b) => b,
Err(e) => {
eprintln!("[worker] Read error: {e}");
break;
}
};
let req: Request = match bincode::deserialize(&raw) {
Ok(r) => r,
Err(e) => {
eprintln!("[worker] Deserialize error: {e}");
break;
}
};
let is_shutdown = matches!(req.command, Command::Shutdown);
broadcast_command(world, &raw);
let resp = dispatch(&req, &mut force_env);
let resp_bytes = bincode::serialize(&resp).expect("Serialize response failed");
if let Err(e) = write_msg(&mut stream, &resp_bytes) {
eprintln!("[worker] Write error: {e}");
break;
}
if is_shutdown {
shutdown = true;
break;
}
}
if !shutdown {
let shutdown_req = Request {
request_id: u64::MAX,
command: Command::Shutdown,
};
let shutdown_bytes = bincode::serialize(&shutdown_req).expect("serialize shutdown");
broadcast_command(world, &shutdown_bytes);
}
let _ = std::fs::remove_file(&socket_file);
let _ = std::fs::remove_file(&ready_file);
}
fn other_rank_loop(world: &mpi::topology::SimpleCommunicator) {
let mut force_env: Option<ForceEnv> = None;
loop {
let raw = receive_command(world);
let req: Request = match bincode::deserialize(&raw) {
Ok(r) => r,
Err(e) => {
eprintln!("[worker rank {}] Deserialize error: {e}", world.rank());
break;
}
};
let is_shutdown = matches!(req.command, Command::Shutdown);
let _ = dispatch(&req, &mut force_env);
if is_shutdown {
break;
}
}
}