use std::io::{Read, Write};
use std::os::unix::net::UnixListener;
use mpi::traits::*;
use cp2k_rs::worker_protocol::{Command, Payload, Request, Response};
use cp2k_rs::{finalize, init, ForceEnv};
fn read_msg<R: Read>(stream: &mut R) -> std::io::Result<Vec<u8>> {
let mut len_buf = [0u8; 4];
stream.read_exact(&mut len_buf)?;
let len = u32::from_le_bytes(len_buf) as usize;
let mut buf = vec![0u8; len];
stream.read_exact(&mut buf)?;
Ok(buf)
}
fn write_msg<W: Write>(stream: &mut W, data: &[u8]) -> std::io::Result<()> {
let len = data.len() as u32;
stream.write_all(&len.to_le_bytes())?;
stream.write_all(data)?;
stream.flush()
}
const TAG_LEN: mpi::Tag = 1;
const TAG_CMD: mpi::Tag = 2;
const TAG_PATCH_LEN: mpi::Tag = 3;
const TAG_PATCH_PATH: mpi::Tag = 4;
fn broadcast_command(world: &mpi::topology::SimpleCommunicator, bytes: &[u8]) {
let rank = world.rank();
if rank == 0 {
let len = bytes.len() as u32;
for r in 1..world.size() {
world.process_at_rank(r).send_with_tag(&[len], TAG_LEN);
world.process_at_rank(r).send_with_tag(bytes, TAG_CMD);
}
} else {
}
}
fn receive_command(world: &mpi::topology::SimpleCommunicator) -> Vec<u8> {
let (len_buf, _) = world
.process_at_rank(0)
.receive_vec_with_tag::<u32>(TAG_LEN);
let len = len_buf[0] as usize;
let mut buf = vec![0u8; len];
let (data, _) = world.process_at_rank(0).receive_vec_with_tag::<u8>(TAG_CMD);
buf.copy_from_slice(&data[..len]);
buf
}
const BOHR: f64 = 0.52917721067;
fn patch_input_with_geometry(
input_path: &str,
symbols: &[String],
positions_angstrom: &[f64],
cell_angstrom: &[f64],
periodic: &str,
) -> Result<std::path::PathBuf, String> {
if positions_angstrom.len() != symbols.len() * 3 {
return Err(format!(
"positions_angstrom length {} != symbols.len() * 3 = {}",
positions_angstrom.len(),
symbols.len() * 3
));
}
if cell_angstrom.len() != 9 {
return Err(format!("cell_angstrom length {} != 9", cell_angstrom.len()));
}
let content = std::fs::read_to_string(input_path)
.map_err(|e| format!("Failed to read input file '{}': {e}", input_path))?;
let cell_block = format!(
"&CELL\n A {:.10} {:.10} {:.10}\n B {:.10} {:.10} {:.10}\n C {:.10} {:.10} {:.10}\n PERIODIC {}\n&END CELL\n",
cell_angstrom[0], cell_angstrom[1], cell_angstrom[2],
cell_angstrom[3], cell_angstrom[4], cell_angstrom[5],
cell_angstrom[6], cell_angstrom[7], cell_angstrom[8],
periodic,
);
let mut coord_block = "&COORD\n".to_string();
for (i, sym) in symbols.iter().enumerate() {
coord_block.push_str(&format!(
" {} {:.10} {:.10} {:.10}\n",
sym,
positions_angstrom[i * 3],
positions_angstrom[i * 3 + 1],
positions_angstrom[i * 3 + 2],
));
}
coord_block.push_str("&END COORD\n");
let mut kept_lines: Vec<&str> = Vec::new();
let mut inside_subsys = false;
let mut skipping = false;
let mut skip_depth: u32 = 0;
for line in content.lines() {
let trimmed_upper = line.trim().to_uppercase();
if skipping {
if trimmed_upper.starts_with('&') && !trimmed_upper.starts_with("&END") {
skip_depth += 1;
} else if trimmed_upper.starts_with("&END") {
if skip_depth == 0 {
skipping = false;
} else {
skip_depth -= 1;
}
}
} else if inside_subsys {
let is_cell = trimmed_upper == "&CELL"
|| trimmed_upper.starts_with("&CELL ")
|| trimmed_upper.starts_with("&CELL\t");
let is_coord = trimmed_upper == "&COORD"
|| trimmed_upper.starts_with("&COORD ")
|| trimmed_upper.starts_with("&COORD\t");
let is_end_subsys = trimmed_upper == "&END SUBSYS" || trimmed_upper == "&END";
if is_cell || is_coord {
skipping = true;
skip_depth = 0;
} else if is_end_subsys {
inside_subsys = false;
kept_lines.push(line);
} else {
kept_lines.push(line);
}
} else {
let is_subsys = trimmed_upper == "&SUBSYS"
|| trimmed_upper.starts_with("&SUBSYS ")
|| trimmed_upper.starts_with("&SUBSYS\t");
if is_subsys {
inside_subsys = true;
}
kept_lines.push(line);
}
}
let injection = format!("{}{}", cell_block, coord_block);
let mut output = String::new();
let subsys_idx = kept_lines.iter().position(|l| {
let u = l.trim().to_uppercase();
u == "&SUBSYS" || u.starts_with("&SUBSYS ") || u.starts_with("&SUBSYS\t")
});
if let Some(idx) = subsys_idx {
for (i, line) in kept_lines.iter().enumerate() {
output.push_str(line);
output.push('\n');
if i == idx {
output.push_str(&injection);
}
}
} else {
let new_subsys = format!("&SUBSYS\n{}&END SUBSYS\n", injection);
let force_eval_end_idx = kept_lines.iter().position(|l| {
let u = l.trim().to_uppercase();
u == "&END FORCE_EVAL" || u.starts_with("&END FORCE_EVAL")
});
if let Some(idx) = force_eval_end_idx {
for (i, line) in kept_lines.iter().enumerate() {
if i == idx {
output.push_str(&new_subsys);
}
output.push_str(line);
output.push('\n');
}
} else {
for line in &kept_lines {
output.push_str(line);
output.push('\n');
}
output.push_str(&new_subsys);
}
}
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.subsec_nanos())
.unwrap_or(0);
let temp_name = format!("cp2k_worker_patched_{}_{}.inp", std::process::id(), nanos);
let temp_path = std::path::Path::new("/tmp").join(temp_name);
std::fs::write(&temp_path, &output).map_err(|e| {
format!(
"Failed to write patched input file '{}': {e}",
temp_path.display()
)
})?;
Ok(temp_path)
}
fn dispatch(
req: &Request,
force_env: &mut Option<ForceEnv>,
world: &mpi::topology::SimpleCommunicator,
) -> Response {
let id = req.request_id;
macro_rules! need_env {
($fe:expr) => {
match $fe {
Some(fe) => fe,
None => {
return Response::error(
id,
"No force environment is initialized. \
Call calc.calculate(atoms) at least once before \
querying properties such as HOMO/LUMO or SCF info.",
)
}
}
};
}
match &req.command {
Command::InitForceEnv { input, output } => match ForceEnv::new(input, output) {
Ok(fe) => {
*force_env = Some(fe);
Response::ok(id, Payload::Empty)
}
Err(e) => Response::error(id, format!("{e}")),
},
Command::InitForceEnvWithGeometry {
input,
output,
symbols,
positions_angstrom,
cell_angstrom,
periodic,
} => {
let rank = world.rank();
let patched_path_str: String = if rank == 0 {
let patched_path = match patch_input_with_geometry(
input,
symbols,
positions_angstrom,
cell_angstrom,
periodic,
) {
Ok(p) => p,
Err(e) => {
let empty: &[u8] = &[];
let zero: u32 = 0;
for r in 1..world.size() {
world
.process_at_rank(r)
.send_with_tag(&[zero], TAG_PATCH_LEN);
world
.process_at_rank(r)
.send_with_tag(empty, TAG_PATCH_PATH);
}
return Response::error(id, format!("Failed to patch input file: {e}"));
}
};
let path_str = patched_path.to_string_lossy().into_owned();
let path_bytes = path_str.as_bytes();
let path_len = path_bytes.len() as u32;
for r in 1..world.size() {
world
.process_at_rank(r)
.send_with_tag(&[path_len], TAG_PATCH_LEN);
world
.process_at_rank(r)
.send_with_tag(path_bytes, TAG_PATCH_PATH);
}
path_str
} else {
let (len_buf, _) = world
.process_at_rank(0)
.receive_vec_with_tag::<u32>(TAG_PATCH_LEN);
let len = len_buf[0] as usize;
if len == 0 {
return Response::ok(id, Payload::Empty);
}
let (path_bytes, _) = world
.process_at_rank(0)
.receive_vec_with_tag::<u8>(TAG_PATCH_PATH);
String::from_utf8(path_bytes[..len].to_vec()).unwrap_or_default()
};
if rank == 0 && !patched_path_str.is_empty() {
let output_path = std::path::Path::new(output.as_str());
let input_copy_path = match (output_path.parent(), output_path.file_stem()) {
(Some(dir), Some(stem)) => dir.join(format!("{}.inp", stem.to_string_lossy())),
_ => output_path.with_extension("inp"),
};
if let Err(e) = std::fs::copy(&patched_path_str, &input_copy_path) {
eprintln!(
"[worker] Warning: could not write input copy to '{}': {e}",
input_copy_path.display()
);
}
}
let result = ForceEnv::new(&patched_path_str, output);
if rank == 0 {
let _ = std::fs::remove_file(&patched_path_str);
}
match result {
Ok(fe) => {
*force_env = Some(fe);
Response::ok(id, Payload::Empty)
}
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::CalcEnergyForce => {
let fe = need_env!(force_env);
match fe.calc_energy_force() {
Ok(()) => Response::ok(id, Payload::Empty),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::CalcEnergy => {
let fe = need_env!(force_env);
match fe.calc_energy() {
Ok(()) => Response::ok(id, Payload::Empty),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::GetNatom => {
let fe = need_env!(force_env);
match fe.get_natom() {
Ok(n) => Response::ok(id, Payload::UInt(n as u64)),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::GetNparticle => {
let fe = need_env!(force_env);
match fe.get_nparticle() {
Ok(n) => Response::ok(id, Payload::UInt(n as u64)),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::GetPositions => {
let fe = need_env!(force_env);
match fe.get_positions() {
Ok(arr) => Response::ok(id, Payload::Array1(arr.into_raw_vec_and_offset().0)),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::GetForces => {
let fe = need_env!(force_env);
match fe.get_forces() {
Ok(arr) => Response::ok(id, Payload::Array1(arr.into_raw_vec_and_offset().0)),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::GetPotentialEnergy => {
let fe = need_env!(force_env);
match fe.get_potential_energy() {
Ok(e) => Response::ok(id, Payload::Float(e)),
Err(err) => Response::error(id, format!("{err}")),
}
}
Command::GetCell => {
let fe = need_env!(force_env);
match fe.get_cell() {
Ok(arr) => {
let data: Vec<f64> = arr.into_raw_vec_and_offset().0;
Response::ok(
id,
Payload::Array2 {
rows: 3,
cols: 3,
data,
},
)
}
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::GetQmmmCell => {
let fe = need_env!(force_env);
match fe.get_qmmm_cell() {
Ok(arr) => {
let data: Vec<f64> = arr.into_raw_vec_and_offset().0;
Response::ok(
id,
Payload::Array2 {
rows: 3,
cols: 3,
data,
},
)
}
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::SetPositions { data } => {
let fe = need_env!(force_env);
let data_bohr: Vec<f64> = data.iter().map(|x| x / BOHR).collect();
match fe.set_positions(&data_bohr) {
Ok(()) => Response::ok(id, Payload::Empty),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::SetVelocities { data } => {
let fe = need_env!(force_env);
match fe.set_velocities(data) {
Ok(()) => Response::ok(id, Payload::Empty),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::SetCell { data } => {
let fe = need_env!(force_env);
if data.len() != 9 {
return Response::error(id, "SetCell: expected 9 floats");
}
let mut cell = [[0.0f64; 3]; 3];
for i in 0..3 {
for j in 0..3 {
cell[i][j] = data[i * 3 + j] / BOHR;
}
}
match fe.set_cell(&cell) {
Ok(()) => Response::ok(id, Payload::Empty),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::GetMoCount => {
let fe = need_env!(force_env);
match fe.get_mo_count() {
Ok(n) => Response::ok(id, Payload::Int(n as i64)),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::IsQuickstep => {
let fe = need_env!(force_env);
Response::ok(id, Payload::Bool(fe.is_quickstep()))
}
#[cfg(feature = "extended")]
Command::GetStressTensor => {
let fe = need_env!(force_env);
match fe.get_stress_tensor() {
Ok(arr) => {
let data: Vec<f64> = arr.into_raw_vec_and_offset().0;
Response::ok(
id,
Payload::Array2 {
rows: 3,
cols: 3,
data,
},
)
}
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetVirialTensor => {
let fe = need_env!(force_env);
match fe.get_virial_tensor() {
Ok(arr) => {
let data: Vec<f64> = arr.into_raw_vec_and_offset().0;
Response::ok(
id,
Payload::Array2 {
rows: 3,
cols: 3,
data,
},
)
}
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetNmo { spin } => {
let fe = need_env!(force_env);
match fe.get_nmo(*spin) {
Ok(n) => Response::ok(id, Payload::UInt(n as u64)),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetEigenvalues { spin } => {
let fe = need_env!(force_env);
match fe.get_eigenvalues(*spin) {
Ok(arr) => Response::ok(id, Payload::Array1(arr.into_raw_vec_and_offset().0)),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetOccupationNumbers { spin } => {
let fe = need_env!(force_env);
match fe.get_occupation_numbers(*spin) {
Ok(arr) => Response::ok(id, Payload::Array1(arr.into_raw_vec_and_offset().0)),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetHomoLumo { spin } => {
let fe = need_env!(force_env);
match fe.get_homo_lumo(*spin) {
Ok((homo, lumo, homo_idx, lumo_idx)) => Response::ok(
id,
Payload::HomoLumo {
homo,
lumo,
homo_idx,
lumo_idx,
},
),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetMullikenCharges => {
let fe = need_env!(force_env);
match fe.get_mulliken_charges() {
Ok(arr) => Response::ok(id, Payload::Array1(arr.into_raw_vec_and_offset().0)),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetHirshfeldCharges => {
let fe = need_env!(force_env);
match fe.get_hirshfeld_charges() {
Ok(arr) => Response::ok(id, Payload::Array1(arr.into_raw_vec_and_offset().0)),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetDipoleMoment => {
let fe = need_env!(force_env);
match fe.get_dipole_moment() {
Ok(arr) => Response::ok(id, Payload::Array1(arr.into_raw_vec_and_offset().0)),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetScfInfo => {
let fe = need_env!(force_env);
match fe.get_scf_info() {
Ok((nsteps, converged, energy_change)) => Response::ok(
id,
Payload::ScfInfo {
nsteps,
converged,
energy_change,
},
),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetEnergyComponents => {
let fe = need_env!(force_env);
match fe.get_energy_components() {
Ok((e_kin, e_hartree, e_xc, e_core, e_total)) => Response::ok(
id,
Payload::EnergyComponents {
e_kin,
e_hartree,
e_xc,
e_core,
e_total,
},
),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetNelectron => {
let fe = need_env!(force_env);
match fe.get_nelectron() {
Ok(n) => Response::ok(id, Payload::Int(n as i64)),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetFermiEnergy => {
let fe = need_env!(force_env);
match fe.get_fermi_energy() {
Ok(e) => Response::ok(id, Payload::Float(e)),
Err(err) => Response::error(id, format!("{err}")),
}
}
#[cfg(feature = "extended")]
Command::GetTotalSpin => {
let fe = need_env!(force_env);
match fe.get_total_spin() {
Ok(s) => Response::ok(id, Payload::Float(s)),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::GetVersion => match cp2k_rs::get_version() {
Ok(v) => Response::ok(id, Payload::String(v)),
Err(e) => Response::error(id, format!("{e}")),
},
Command::Shutdown => Response::ok(id, Payload::Empty),
}
}
fn main() {
let universe = mpi::initialize().expect("MPI initialization failed");
let world = universe.world();
let rank = world.rank();
if let Err(e) = init() {
if rank == 0 {
eprintln!(
"[worker] CP2K init failed: {e}\n\
Hint: ensure CP2K_DATA_DIR is set and points to the CP2K data/ \
directory, e.g.:\n\
\texport CP2K_DATA_DIR=/path/to/cp2k/data\n\
When using mpirun, pass it to all ranks:\n\
\tmpirun -n 4 -x CP2K_DATA_DIR ./cp2k_rs_worker"
);
}
return;
}
if rank == 0 {
rank0_server(&world);
} else {
other_rank_loop(&world);
}
if let Err(e) = finalize() {
eprintln!("[worker rank {rank}] CP2K finalize failed: {e}");
}
}
fn rank0_server(world: &mpi::topology::SimpleCommunicator) {
let socket_file = std::env::var("CP2K_WORKER_SOCKET_FILE")
.unwrap_or_else(|_| format!("/tmp/cp2k_worker_{}.sock", std::process::id()));
let _ = std::fs::remove_file(&socket_file);
let listener = UnixListener::bind(&socket_file).expect("Failed to bind Unix socket");
use std::os::unix::fs::PermissionsExt;
std::fs::set_permissions(&socket_file, std::fs::Permissions::from_mode(0o600))
.expect("Failed to set socket permissions");
let ready_file = format!("{socket_file}.ready");
std::fs::write(&ready_file, &socket_file).expect("Failed to write ready file");
eprintln!("[worker] Listening on {socket_file}");
let (mut stream, _) = listener.accept().expect("Failed to accept connection");
let mut force_env: Option<ForceEnv> = None;
let mut shutdown = false;
loop {
let raw = match read_msg(&mut stream) {
Ok(b) => b,
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
eprintln!(
"[worker] Client disconnected without calling stop().\n\
\n\
Always shut down the calculator gracefully to avoid this message.\n\
\n\
Python (fastatomstruct / cp2k_rs):\n\
\n\
\twith fs.CP2kCalculator(\"input.inp\", \"out.txt\") as calc:\n\
\t result = calc.calculate(atoms)\n\
\n\
\t# or explicitly:\n\
\tcalc = fs.CP2kCalculator(\"input.inp\", \"out.txt\")\n\
\ttry:\n\
\t result = calc.calculate(atoms)\n\
\tfinally:\n\
\t calc.stop()\n\
\n\
Rust:\n\
\n\
\tcp2k_rs::worker::stop_worker()?;"
);
break;
}
Err(e) => {
eprintln!("[worker] Read error: {e}");
break;
}
};
let req: Request = match bincode::deserialize(&raw) {
Ok(r) => r,
Err(e) => {
eprintln!(
"[worker] Deserialize error: {e}\n\
This usually means the client and worker binaries were built \
from different versions of cp2k-rs. Reinstall both the cp2k_rs \
Python package and fastatomstruct from the same source."
);
break;
}
};
let is_shutdown = matches!(req.command, Command::Shutdown);
broadcast_command(world, &raw);
let resp = dispatch(&req, &mut force_env, world);
let resp_bytes = bincode::serialize(&resp).expect("Serialize response failed");
if let Err(e) = write_msg(&mut stream, &resp_bytes) {
eprintln!("[worker] Write error: {e}");
break;
}
if is_shutdown {
shutdown = true;
break;
}
}
if !shutdown {
let shutdown_req = Request {
request_id: u64::MAX,
command: Command::Shutdown,
};
let shutdown_bytes = bincode::serialize(&shutdown_req).expect("serialize shutdown");
broadcast_command(world, &shutdown_bytes);
}
let _ = std::fs::remove_file(&socket_file);
let _ = std::fs::remove_file(&ready_file);
}
fn other_rank_loop(world: &mpi::topology::SimpleCommunicator) {
let mut force_env: Option<ForceEnv> = None;
loop {
let raw = receive_command(world);
let req: Request = match bincode::deserialize(&raw) {
Ok(r) => r,
Err(e) => {
eprintln!("[worker rank {}] Deserialize error: {e}", world.rank());
break;
}
};
let is_shutdown = matches!(req.command, Command::Shutdown);
let _ = dispatch(&req, &mut force_env, world);
if is_shutdown {
break;
}
}
}