use std::collections::{BTreeSet, HashMap};
use std::io::{Read, Write};
use std::os::unix::net::UnixListener;
use std::path::Path;
use mpi::traits::*;
use cp2k_rs::shm::ShmRegion;
use cp2k_rs::worker_protocol::{Command, Cp2kInputSpec, Payload, Request, Response};
use cp2k_rs::{ForceEnv, finalize, init};
fn read_msg<R: Read>(stream: &mut R) -> std::io::Result<Vec<u8>> {
let mut len_buf = [0u8; 4];
stream.read_exact(&mut len_buf)?;
let len = u32::from_le_bytes(len_buf) as usize;
let mut buf = vec![0u8; len];
stream.read_exact(&mut buf)?;
Ok(buf)
}
fn write_msg<W: Write>(stream: &mut W, data: &[u8]) -> std::io::Result<()> {
let len = data.len() as u32;
stream.write_all(&len.to_le_bytes())?;
stream.write_all(data)?;
stream.flush()
}
#[cfg(feature = "worker-debug")]
fn log_rank(world: &mpi::topology::SimpleCommunicator, msg: &str) {
if let Ok(dir) = std::env::var("CP2K_WORKER_LOG_DIR")
&& !dir.is_empty()
{
let _ = std::fs::create_dir_all(&dir);
let path =
std::path::Path::new(&dir).join(format!("cp2k_rs_worker_rank{}.log", world.rank()));
if let Ok(mut f) = std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(path)
{
let _ = writeln!(f, "{msg}");
return;
}
}
eprintln!("{msg}");
}
#[cfg(not(feature = "worker-debug"))]
fn log_rank(_world: &mpi::topology::SimpleCommunicator, _msg: &str) {}
fn read_gth_potentials_map(data_dir: &Path) -> Result<HashMap<String, Vec<String>>, String> {
let path = data_dir.join("GTH_POTENTIALS");
let content = std::fs::read_to_string(&path).map_err(|e| {
format!(
"Failed to read CP2K potentials file '{}': {e}",
path.display()
)
})?;
let mut map: HashMap<String, Vec<String>> = HashMap::new();
for line in content.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() < 2 {
continue;
}
let element = parts[0];
let first_potential = parts[1];
if !first_potential.starts_with("GTH-") {
continue;
}
let entry = map.entry(element.to_string()).or_default();
for potential in &parts[1..] {
if !entry.iter().any(|v| v == potential) {
entry.push((*potential).to_string());
}
}
}
Ok(map)
}
fn xc_potential_prefixes(functional: Option<&str>) -> Vec<&'static str> {
let f = functional.unwrap_or("PBE").to_uppercase();
if f.contains("BLYP") {
vec!["GTH-BLYP", "GTH-LDA"]
} else if f.contains("PBE") {
vec!["GTH-PBE", "GTH-LDA"]
} else if f.contains("PADE") {
vec!["GTH-PADE", "GTH-LDA"]
} else {
vec!["GTH-LDA"]
}
}
fn pick_gth_potential(potentials: &[String], prefixes: &[&str]) -> Option<String> {
for prefix in prefixes {
if let Some(p) = potentials
.iter()
.find(|v| v.starts_with(prefix) && v.contains("-q") && !v.ends_with("_old"))
{
return Some(p.to_string());
}
if let Some(p) = potentials
.iter()
.find(|v| v.starts_with(prefix) && v.contains("-q"))
{
return Some(p.to_string());
}
if let Some(p) = potentials.iter().find(|v| v.starts_with(prefix)) {
return Some(p.to_string());
}
}
None
}
fn auto_fill_kinds_if_missing(
settings: &mut cp2k_rs::worker_protocol::Cp2kSettings,
symbols: &[String],
) -> Result<(), String> {
let subsys = settings
.force_eval
.subsys
.get_or_insert_with(Default::default);
if !subsys.kinds.is_empty() {
return Ok(());
}
let mut elements = BTreeSet::new();
for s in symbols {
elements.insert(s.to_string());
}
if elements.is_empty() {
return Ok(());
}
let data_dir = std::env::var("CP2K_DATA_DIR").map_err(|_| {
"CP2K_DATA_DIR is not set; cannot auto-generate &KIND blocks. \
Please set force_eval.subsys.kinds explicitly."
.to_string()
})?;
let pot_map = read_gth_potentials_map(Path::new(&data_dir))?;
let functional = settings
.force_eval
.dft
.as_ref()
.and_then(|d| d.xc.as_ref())
.and_then(|x| x.functional.as_deref());
let prefixes = xc_potential_prefixes(functional);
for element in elements {
let potentials = pot_map.get(&element).ok_or_else(|| {
format!(
"No GTH potentials found for element '{}' in {}/GTH_POTENTIALS. \
Please set force_eval.subsys.kinds explicitly.",
element, data_dir
)
})?;
let potential = pick_gth_potential(potentials, &prefixes).ok_or_else(|| {
format!(
"No GTH potential found for element '{}' matching XC functional '{:?}' in {}/GTH_POTENTIALS. \
Please set force_eval.subsys.kinds explicitly.",
element, functional, data_dir
)
})?;
subsys
.kinds
.push(cp2k_rs::worker_protocol::Cp2kKindSettings {
element,
basis_set: Some("DZVP-MOLOPT-SR-GTH".to_string()),
potential: Some(potential),
extra_keywords: Vec::new(),
extra_sections: Vec::new(),
});
}
Ok(())
}
fn resolve_input_spec_to_path(
input: &Cp2kInputSpec,
symbols: &[String],
) -> Result<std::path::PathBuf, String> {
match input {
Cp2kInputSpec::FromFile { path } => Ok(std::path::PathBuf::from(path)),
Cp2kInputSpec::Generated { settings } => {
let mut settings = settings.clone();
auto_fill_kinds_if_missing(&mut settings, symbols)?;
let content = render_cp2k_input_from_settings(&settings)?;
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.subsec_nanos())
.unwrap_or(0);
let temp_name = format!("cp2k_worker_generated_{}_{}.inp", std::process::id(), nanos);
let temp_path = std::path::Path::new("/tmp").join(temp_name);
std::fs::write(&temp_path, &content).map_err(|e| {
format!(
"Failed to write generated input file '{}': {e}",
temp_path.display()
)
})?;
Ok(temp_path)
}
}
}
fn render_cp2k_input_from_settings(
settings: &cp2k_rs::worker_protocol::Cp2kSettings,
) -> Result<String, String> {
use cp2k_rs::worker_protocol::{
Cp2kDftSettings, Cp2kKeyword, Cp2kKindSettings, Cp2kSection, Cp2kSettings,
Cp2kSubsysSettings, Cp2kXcSettings,
};
fn push_kw(out: &mut String, indent: usize, kw: &Cp2kKeyword) {
let pad = " ".repeat(indent);
if let Some(unit) = kw.unit.as_deref() {
out.push_str(&format!("{pad}{} [{}] {}\n", kw.name, unit, kw.value));
} else {
out.push_str(&format!("{pad}{} {}\n", kw.name, kw.value));
}
}
fn push_section(out: &mut String, indent: usize, sec: &Cp2kSection) {
let pad = " ".repeat(indent);
out.push_str(&pad);
out.push('&');
out.push_str(&sec.name);
if let Some(param) = sec.parameter.as_deref() {
out.push(' ');
out.push_str(param);
}
out.push('\n');
let inner_indent = indent + 2;
for kw in &sec.keywords {
push_kw(out, inner_indent, kw);
}
for s in &sec.sections {
push_section(out, inner_indent, s);
}
out.push_str(&pad);
out.push_str("&END ");
out.push_str(&sec.name);
out.push('\n');
}
fn ensure_global_defaults(gs: &mut cp2k_rs::worker_protocol::Cp2kGlobalSettings) {
if gs.project.is_none() {
gs.project = Some("generated".to_string());
}
}
fn render_global(out: &mut String, s: &Cp2kSettings) {
let mut global = s.global.clone().unwrap_or_default();
ensure_global_defaults(&mut global);
out.push_str("&GLOBAL\n");
if let Some(p) = global.project.as_deref() {
out.push_str(&format!(" PROJECT {p}\n"));
}
if let Some(pl) = global.print_level.as_deref() {
out.push_str(&format!(" PRINT_LEVEL {pl}\n"));
}
if let Some(rt) = global.run_type.as_deref() {
out.push_str(&format!(" RUN_TYPE {rt}\n"));
}
if let Some(wt) = global.walltime.as_deref() {
out.push_str(&format!(" WALLTIME {wt}\n"));
}
if let Some(seed) = global.seed.as_deref() {
if seed.is_empty() {
out.push_str(" SEED 2000\n");
} else if seed.len() == 1 {
out.push_str(&format!(" SEED {}\n", seed[0]));
} else {
let seed_str = seed
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(" ");
out.push_str(&format!(" SEED {seed_str}\n"));
}
}
if let Some(true) = global.echo_input {
out.push_str(" ECHO_INPUT YES\n");
}
out.push_str("&END GLOBAL\n\n");
}
fn render_dft_block(out: &mut String, dft: &Cp2kDftSettings) {
out.push_str(" &DFT\n");
let basis = dft.basis_set_file_name.as_deref().unwrap_or("BASIS_MOLOPT");
out.push_str(&format!(" BASIS_SET_FILE_NAME {basis}\n"));
let pot = dft
.potential_file_name
.as_deref()
.unwrap_or("GTH_POTENTIALS");
out.push_str(&format!(" POTENTIAL_FILE_NAME {pot}\n"));
if let Some(qs) = dft.qs.as_ref() {
out.push_str(" &QS\n");
if let Some(eps) = qs.eps_default {
out.push_str(&format!(" EPS_DEFAULT {eps:.6E}\n"));
}
out.push_str(" &END QS\n");
}
if let Some(mg) = dft.mgrid.as_ref() {
out.push_str(" &MGRID\n");
if let Some(cut) = mg.cutoff {
out.push_str(&format!(" CUTOFF {cut}\n"));
}
if let Some(rel) = mg.rel_cutoff {
out.push_str(&format!(" REL_CUTOFF {rel}\n"));
}
out.push_str(" &END MGRID\n");
}
if let Some(scf) = dft.scf.as_ref() {
out.push_str(" &SCF\n");
if let Some(guess) = scf.scf_guess.as_deref() {
out.push_str(&format!(" SCF_GUESS {guess}\n"));
}
if let Some(ms) = scf.max_scf {
out.push_str(&format!(" MAX_SCF {ms}\n"));
}
if let Some(eps) = scf.eps_scf {
out.push_str(&format!(" EPS_SCF {eps:.6E}\n"));
}
if let Some(added) = scf.added_mos {
out.push_str(&format!(" ADDED_MOS {added}\n"));
}
for kw in &scf.extra_keywords {
push_kw(out, 6, kw); }
for sec in &scf.extra_sections {
push_section(out, 6, sec);
}
if let Some(smear) = scf.smear.as_ref()
&& smear.enabled
{
out.push_str(" &SMEAR ON\n");
if let Some(method) = smear.method.as_deref() {
out.push_str(&format!(" METHOD {method}\n"));
}
if let Some(tk) = smear.electronic_temperature_k {
out.push_str(&format!(" ELECTRONIC_TEMPERATURE [K] {tk}\n"));
}
out.push_str(" &END SMEAR\n");
}
if let Some(diag) = scf.diagonalization.as_ref()
&& diag.enabled
{
out.push_str(" &DIAGONALIZATION ON\n");
if let Some(algo) = diag.algorithm.as_deref() {
out.push_str(&format!(" ALGORITHM {algo}\n"));
}
out.push_str(" &END DIAGONALIZATION\n");
}
if let Some(mix) = scf.mixing.as_ref() {
out.push_str(" &MIXING\n");
if let Some(m) = mix.method.as_deref() {
out.push_str(&format!(" METHOD {m}\n"));
}
if let Some(a) = mix.alpha {
out.push_str(&format!(" ALPHA {a}\n"));
}
if let Some(n) = mix.nbroyden {
out.push_str(&format!(" NBROYDEN {n}\n"));
}
out.push_str(" &END MIXING\n");
}
out.push_str(" &END SCF\n");
}
fn render_xc(out: &mut String, xc: &Cp2kXcSettings) {
out.push_str(" &XC\n");
let fnctl = xc.functional.as_deref().unwrap_or("PBE");
out.push_str(&format!(" &XC_FUNCTIONAL {fnctl}\n"));
out.push_str(" &END XC_FUNCTIONAL\n");
for kw in &xc.extra_keywords {
push_kw(out, 6, kw);
}
for sec in &xc.extra_sections {
push_section(out, 6, sec);
}
out.push_str(" &END XC\n");
}
if let Some(xc) = dft.xc.as_ref() {
render_xc(out, xc);
}
for kw in &dft.extra_keywords {
push_kw(out, 4, kw); }
for sec in &dft.extra_sections {
push_section(out, 4, sec); }
out.push_str(" &END DFT\n");
}
fn render_kind_block(out: &mut String, kind: &Cp2kKindSettings) {
out.push_str(&format!(" &KIND {}\n", kind.element));
if let Some(basis) = kind.basis_set.as_deref() {
out.push_str(&format!(" BASIS_SET {basis}\n"));
}
if let Some(pot) = kind.potential.as_deref() {
out.push_str(&format!(" POTENTIAL {pot}\n"));
}
for kw in &kind.extra_keywords {
push_kw(out, 6, kw);
}
for sec in &kind.extra_sections {
push_section(out, 6, sec);
}
out.push_str(" &END KIND\n");
}
fn render_subsys_block(out: &mut String, subsys: &Cp2kSubsysSettings) {
out.push_str(" &SUBSYS\n");
for kw in &subsys.extra_keywords {
push_kw(out, 4, kw);
}
for kind in &subsys.kinds {
render_kind_block(out, kind);
}
for sec in &subsys.extra_sections {
push_section(out, 4, sec);
}
out.push_str(" &END SUBSYS\n");
}
let mut out = String::new();
render_global(&mut out, settings);
out.push_str("&FORCE_EVAL\n");
out.push_str(&format!(" METHOD {}\n", settings.force_eval.method));
if let Some(st) = settings.force_eval.stress_tensor.as_deref() {
out.push_str(&format!(" STRESS_TENSOR {st}\n"));
}
if let Some(dft) = settings.force_eval.dft.as_ref() {
render_dft_block(&mut out, dft);
}
if let Some(subsys) = settings.force_eval.subsys.as_ref() {
render_subsys_block(&mut out, subsys);
} else {
out.push_str(" &SUBSYS\n");
out.push_str(" &END SUBSYS\n");
}
out.push_str("&END FORCE_EVAL\n");
Ok(out)
}
const TAG_LEN: mpi::Tag = 1;
const TAG_CMD: mpi::Tag = 2;
const TAG_PATCH_LEN: mpi::Tag = 3;
const TAG_PATCH_PATH: mpi::Tag = 4;
fn broadcast_command(world: &mpi::topology::SimpleCommunicator, bytes: &[u8]) {
let rank = world.rank();
if rank == 0 {
let len = bytes.len() as u32;
for r in 1..world.size() {
world.process_at_rank(r).send_with_tag(&[len], TAG_LEN);
world.process_at_rank(r).send_with_tag(bytes, TAG_CMD);
}
} else {
}
}
fn receive_command(world: &mpi::topology::SimpleCommunicator) -> Vec<u8> {
let (len_buf, _) = world
.process_at_rank(0)
.receive_vec_with_tag::<u32>(TAG_LEN);
let len = len_buf[0] as usize;
let (mut data, _) = world.process_at_rank(0).receive_vec_with_tag::<u8>(TAG_CMD);
data.truncate(len);
data
}
const BOHR: f64 = 0.5291772105828846;
fn patch_global_block(
content: &str,
project_label: Option<&str>,
print_level: Option<&str>,
run_type: Option<&str>,
walltime: Option<&str>,
seed: Option<&[i32]>,
echo_input: Option<bool>,
) -> String {
let has_overrides = project_label.is_some()
|| print_level.is_some()
|| run_type.is_some()
|| walltime.is_some()
|| seed.is_some()
|| matches!(echo_input, Some(true));
if !has_overrides {
return content.to_string();
}
let lines: Vec<&str> = content.lines().collect();
let mut result = Vec::new();
let mut inside_global = false;
let mut global_depth = 0;
let mut found_global = false;
for line in lines.iter() {
let trimmed = line.trim_start();
let upper = trimmed.to_uppercase();
if !found_global && (upper == "&GLOBAL" || upper.starts_with("&GLOBAL ")) {
found_global = true;
inside_global = true;
global_depth = 0;
result.push(line.to_string());
continue;
}
if inside_global {
if upper.starts_with('&') && !upper.starts_with("&END") {
global_depth += 1;
result.push(line.to_string());
} else if upper.starts_with("&END GLOBAL")
|| (global_depth == 0 && upper.starts_with("&END"))
{
if let Some(label) = project_label {
result.push(format!(" PROJECT {}", label));
}
if let Some(level) = print_level {
result.push(format!(" PRINT_LEVEL {}", level));
}
if let Some(rt) = run_type {
result.push(format!(" RUN_TYPE {}", rt));
}
if let Some(wt) = walltime {
result.push(format!(" WALLTIME {}", wt));
}
if let Some(s) = seed {
if s.is_empty() {
result.push(" SEED 2000".to_string());
} else if s.len() == 1 {
result.push(format!(" SEED {}", s[0]));
} else {
let seed_str = s
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(" ");
result.push(format!(" SEED {}", seed_str));
}
}
if let Some(true) = echo_input {
result.push(" ECHO_INPUT YES".to_string());
}
result.push(line.to_string());
inside_global = false;
} else {
result.push(line.to_string());
}
} else {
result.push(line.to_string());
}
}
if !found_global {
let mut has_force_eval = false;
let mut force_eval_idx = 0;
for (idx, line) in lines.iter().enumerate() {
let upper = line.trim_start().to_uppercase();
if upper.starts_with("&FORCE_EVAL") {
has_force_eval = true;
force_eval_idx = idx;
break;
}
}
let mut new_result = Vec::new();
for (idx, line) in result.iter().enumerate() {
if has_force_eval && idx == force_eval_idx {
new_result.push("&GLOBAL".to_string());
if let Some(label) = project_label {
new_result.push(format!(" PROJECT {}", label));
}
if let Some(level) = print_level {
new_result.push(format!(" PRINT_LEVEL {}", level));
}
if let Some(rt) = run_type {
new_result.push(format!(" RUN_TYPE {}", rt));
}
if let Some(wt) = walltime {
new_result.push(format!(" WALLTIME {}", wt));
}
if let Some(s) = seed {
if s.is_empty() {
new_result.push(" SEED 2000".to_string());
} else if s.len() == 1 {
new_result.push(format!(" SEED {}", s[0]));
} else {
let seed_str = s
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(" ");
new_result.push(format!(" SEED {}", seed_str));
}
}
if let Some(true) = echo_input {
new_result.push(" ECHO_INPUT YES".to_string());
}
new_result.push("&END GLOBAL".to_string());
}
new_result.push(line.clone());
}
result = new_result;
}
result.join("\n")
}
fn patch_dft_block(
content: &str,
charge: Option<i32>,
multiplicity: Option<i32>,
uks: Option<bool>,
roks: Option<bool>,
wfn_restart_file: Option<&str>,
) -> String {
if charge.is_none()
&& multiplicity.is_none()
&& uks.is_none()
&& roks.is_none()
&& wfn_restart_file.is_none()
{
return content.to_string();
}
let lines: Vec<&str> = content.lines().collect();
let mut result = Vec::new();
let mut inside_dft = false;
let mut dft_depth = 0;
let mut found_dft = false;
let mut block_depth: i32 = 0;
for line in lines.iter() {
let trimmed = line.trim_start();
let upper = trimmed.to_uppercase();
if !found_dft && block_depth == 1 && (upper == "&DFT" || upper.starts_with("&DFT ")) {
found_dft = true;
inside_dft = true;
dft_depth = 0;
result.push(line.to_string());
continue;
}
if inside_dft {
if upper.starts_with('&') && !upper.starts_with("&END") {
dft_depth += 1;
result.push(line.to_string());
} else if upper.starts_with("&END DFT") || (dft_depth == 0 && upper.starts_with("&END"))
{
if let Some(q) = charge {
result.push(format!(" CHARGE {}", q));
}
if let Some(m) = multiplicity {
result.push(format!(" MULTIPLICITY {}", m));
}
if let Some(true) = uks {
result.push(" UKS YES".to_string());
}
if let Some(true) = roks {
result.push(" ROKS YES".to_string());
}
if let Some(file) = wfn_restart_file {
result.push(format!(" WFN_RESTART_FILE_NAME {}", file));
}
result.push(line.to_string());
inside_dft = false;
} else {
result.push(line.to_string());
}
} else {
if upper.starts_with('&') && !upper.starts_with("&END") {
block_depth += 1;
} else if upper.starts_with("&END") {
block_depth = block_depth.saturating_sub(1);
}
result.push(line.to_string());
}
}
if !found_dft {
let mut new_result = Vec::new();
let mut inserted = false;
for line in result.iter() {
let upper = line.trim_start().to_uppercase();
new_result.push(line.clone());
if !inserted && (upper == "&FORCE_EVAL" || upper.starts_with("&FORCE_EVAL ")) {
new_result.push("&DFT".to_string());
if let Some(q) = charge {
new_result.push(format!(" CHARGE {}", q));
}
if let Some(m) = multiplicity {
new_result.push(format!(" MULTIPLICITY {}", m));
}
if let Some(true) = uks {
new_result.push(" UKS YES".to_string());
}
if let Some(true) = roks {
new_result.push(" ROKS YES".to_string());
}
if let Some(file) = wfn_restart_file {
new_result.push(format!(" WFN_RESTART_FILE_NAME {}", file));
}
new_result.push("&END DFT".to_string());
inserted = true;
}
}
result = new_result;
}
result.join("\n")
}
fn patch_kpoints_block(
content: &str,
kpoint_config: Option<&cp2k_rs::worker_protocol::KPointConfig>,
) -> String {
if kpoint_config.is_none() {
return content.to_string();
}
let config = kpoint_config.unwrap();
let lines: Vec<&str> = content.lines().collect();
let mut result = Vec::new();
let mut inside_dft = false;
let mut inside_kpoints = false;
let mut found_kpoints = false;
for line in lines.iter() {
let trimmed = line.trim_start();
let upper = trimmed.to_uppercase();
if (upper == "&DFT" || upper.starts_with("&DFT ")) && !inside_dft {
inside_dft = true;
result.push(line.to_string());
continue;
}
if inside_dft {
if (upper == "&KPOINTS" || upper.starts_with("&KPOINTS ")) && !inside_kpoints {
inside_kpoints = true;
found_kpoints = true;
result.push(line.to_string());
continue;
}
if inside_kpoints {
if upper.starts_with("&END KPOINTS") {
generate_kpoints_lines(config, &mut result);
result.push(line.to_string());
inside_kpoints = false;
}
} else if upper.starts_with("&END DFT") {
if !found_kpoints {
result.push(" &KPOINTS".to_string());
generate_kpoints_lines(config, &mut result);
result.push(" &END KPOINTS".to_string());
}
result.push(line.to_string());
inside_dft = false;
} else {
result.push(line.to_string());
}
} else {
result.push(line.to_string());
}
}
result.join("\n")
}
fn generate_kpoints_lines(
config: &cp2k_rs::worker_protocol::KPointConfig,
result: &mut Vec<String>,
) {
use cp2k_rs::worker_protocol::KPointScheme;
match &config.scheme {
KPointScheme::None => {
result.push(" SCHEME NONE".to_string());
}
KPointScheme::Gamma => {
result.push(" SCHEME GAMMA".to_string());
}
KPointScheme::MonkhorstPack(n1, n2, n3) => {
result.push(format!(" SCHEME MONKHORST-PACK {} {} {}", n1, n2, n3));
}
KPointScheme::ExplicitList(kpoints) => {
result.push(" SCHEME GENERAL".to_string());
for kpt in kpoints {
result.push(format!(
" KPOINT {:.8} {:.8} {:.8} {:.8}",
kpt[0], kpt[1], kpt[2], kpt[3]
));
}
}
}
if config.use_symmetry {
result.push(" SYMMETRY TRUE".to_string());
}
if config.verbose {
result.push(" VERBOSE TRUE".to_string());
}
if config.full_grid {
result.push(" FULL_GRID TRUE".to_string());
}
}
fn patch_dos_block(
content: &str,
dos_config: Option<&cp2k_rs::worker_protocol::DosConfig>,
) -> String {
let config = match dos_config {
Some(c) if c.enabled => c,
_ => return content.to_string(),
};
let dos_inner = match (config.erange_ev, config.npoints) {
(Some((e_min, e_max)), Some(npts)) => {
format!(" EMIN {e_min:.6}\n EMAX {e_max:.6}\n NPOINTS {npts}\n")
}
(Some((e_min, e_max)), None) => format!(" EMIN {e_min:.6}\n EMAX {e_max:.6}\n"),
_ => String::new(),
};
let dos_block = format!(" &PRINT\n &DOS\n{dos_inner} &END DOS\n &END PRINT\n");
if let Some(pos) = content.rfind("&END DFT") {
let mut result = content[..pos].to_string();
result.push_str(&dos_block);
result.push_str(&content[pos..]);
result
} else {
content.to_string()
}
}
struct PatchInputOverrides<'a> {
project_label: Option<&'a str>,
print_level: Option<&'a str>,
run_type: Option<&'a str>,
charge: Option<i32>,
multiplicity: Option<i32>,
walltime: Option<&'a str>,
seed: Option<&'a [i32]>,
echo_input: Option<bool>,
uks: Option<bool>,
roks: Option<bool>,
wfn_restart_file: Option<&'a str>,
kpoint_config: Option<&'a cp2k_rs::worker_protocol::KPointConfig>,
dos_config: Option<&'a cp2k_rs::worker_protocol::DosConfig>,
}
fn patch_input_with_geometry(
input_path: &str,
symbols: &[String],
positions_angstrom: &[f64],
cell_angstrom: &[f64],
periodic: &str,
overrides: PatchInputOverrides<'_>,
) -> Result<std::path::PathBuf, String> {
if positions_angstrom.len() != symbols.len() * 3 {
return Err(format!(
"positions_angstrom length {} != symbols.len() * 3 = {}",
positions_angstrom.len(),
symbols.len() * 3
));
}
if cell_angstrom.len() != 9 {
return Err(format!("cell_angstrom length {} != 9", cell_angstrom.len()));
}
let content = std::fs::read_to_string(input_path)
.map_err(|e| format!("Failed to read input file '{}': {e}", input_path))?;
let content = patch_global_block(
&content,
overrides.project_label,
overrides.print_level,
overrides.run_type,
overrides.walltime,
overrides.seed,
overrides.echo_input,
);
let content = patch_dft_block(
&content,
overrides.charge,
overrides.multiplicity,
overrides.uks,
overrides.roks,
overrides.wfn_restart_file,
);
let content = patch_kpoints_block(&content, overrides.kpoint_config);
let content = patch_dos_block(&content, overrides.dos_config);
let cell_block = format!(
"&CELL\n A {:.10} {:.10} {:.10}\n B {:.10} {:.10} {:.10}\n C {:.10} {:.10} {:.10}\n PERIODIC {}\n&END CELL\n",
cell_angstrom[0],
cell_angstrom[1],
cell_angstrom[2],
cell_angstrom[3],
cell_angstrom[4],
cell_angstrom[5],
cell_angstrom[6],
cell_angstrom[7],
cell_angstrom[8],
periodic,
);
let mut coord_block = "&COORD\n".to_string();
for (i, sym) in symbols.iter().enumerate() {
coord_block.push_str(&format!(
" {} {:.10} {:.10} {:.10}\n",
sym,
positions_angstrom[i * 3],
positions_angstrom[i * 3 + 1],
positions_angstrom[i * 3 + 2],
));
}
coord_block.push_str("&END COORD\n");
let mut kept_lines: Vec<&str> = Vec::new();
let mut inside_subsys = false;
let mut skipping = false;
let mut skip_depth: u32 = 0;
for line in content.lines() {
let trimmed_upper = line.trim().to_uppercase();
if skipping {
if trimmed_upper.starts_with('&') && !trimmed_upper.starts_with("&END") {
skip_depth += 1;
} else if trimmed_upper.starts_with("&END") {
if skip_depth == 0 {
skipping = false;
} else {
skip_depth -= 1;
}
}
} else if inside_subsys {
let is_cell = trimmed_upper == "&CELL"
|| trimmed_upper.starts_with("&CELL ")
|| trimmed_upper.starts_with("&CELL\t");
let is_coord = trimmed_upper == "&COORD"
|| trimmed_upper.starts_with("&COORD ")
|| trimmed_upper.starts_with("&COORD\t");
let is_end_subsys = trimmed_upper == "&END SUBSYS" || trimmed_upper == "&END";
if is_cell || is_coord {
skipping = true;
skip_depth = 0;
} else if is_end_subsys {
inside_subsys = false;
kept_lines.push(line);
} else {
kept_lines.push(line);
}
} else {
let is_subsys = trimmed_upper == "&SUBSYS"
|| trimmed_upper.starts_with("&SUBSYS ")
|| trimmed_upper.starts_with("&SUBSYS\t");
if is_subsys {
inside_subsys = true;
}
kept_lines.push(line);
}
}
let injection = format!("{}{}", cell_block, coord_block);
let mut output = String::new();
let subsys_idx = kept_lines.iter().position(|l| {
let u = l.trim().to_uppercase();
u == "&SUBSYS" || u.starts_with("&SUBSYS ") || u.starts_with("&SUBSYS\t")
});
if let Some(idx) = subsys_idx {
for (i, line) in kept_lines.iter().enumerate() {
output.push_str(line);
output.push('\n');
if i == idx {
output.push_str(&injection);
}
}
} else {
let new_subsys = format!("&SUBSYS\n{}&END SUBSYS\n", injection);
let force_eval_end_idx = kept_lines.iter().position(|l| {
let u = l.trim().to_uppercase();
u == "&END FORCE_EVAL" || u.starts_with("&END FORCE_EVAL")
});
if let Some(idx) = force_eval_end_idx {
for (i, line) in kept_lines.iter().enumerate() {
if i == idx {
output.push_str(&new_subsys);
}
output.push_str(line);
output.push('\n');
}
} else {
for line in &kept_lines {
output.push_str(line);
output.push('\n');
}
output.push_str(&new_subsys);
}
}
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.subsec_nanos())
.unwrap_or(0);
let temp_name = format!("cp2k_worker_patched_{}_{}.inp", std::process::id(), nanos);
let temp_path = std::path::Path::new("/tmp").join(temp_name);
std::fs::write(&temp_path, &output).map_err(|e| {
format!(
"Failed to write patched input file '{}': {e}",
temp_path.display()
)
})?;
Ok(temp_path)
}
fn dispatch(
req: &Request,
force_env: &mut Option<ForceEnv>,
world: &mpi::topology::SimpleCommunicator,
) -> Response {
let id = req.request_id;
macro_rules! need_env {
($fe:expr) => {
match $fe {
Some(fe) => fe,
None => {
return Response::error(
id,
"No force environment is initialized. \
Call calc.calculate(atoms) at least once before \
querying properties such as HOMO/LUMO or SCF info.",
)
}
}
};
}
match &req.command {
Command::InitForceEnv { input, output } => {
let result = ForceEnv::new_with_mpi(input, output, world);
match result {
Ok(fe) => {
*force_env = Some(fe);
Response::ok(id, Payload::Empty)
}
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::InitForceEnvWithGeometry {
input,
output,
symbols,
positions_angstrom,
cell_angstrom,
periodic,
project_label,
print_level,
run_type,
charge,
multiplicity,
walltime,
seed,
echo_input,
uks,
roks,
wfn_restart_file,
kpoint_config,
dos_config,
} => {
let rank = world.rank();
log_rank(
world,
&format!(
"[worker rank {}] InitForceEnvWithGeometry: start (input spec + geometry)",
rank
),
);
let patched_path_str: String = if rank == 0 {
let base_input_path = match resolve_input_spec_to_path(input, symbols) {
Ok(p) => p,
Err(e) => {
let empty: &[u8] = &[];
let zero: u32 = 0;
for r in 1..world.size() {
world
.process_at_rank(r)
.send_with_tag(&[zero], TAG_PATCH_LEN);
world
.process_at_rank(r)
.send_with_tag(empty, TAG_PATCH_PATH);
}
return Response::error(id, e);
}
};
let base_input_str = match base_input_path.to_str() {
Some(s) => s,
None => {
let empty: &[u8] = &[];
let zero: u32 = 0;
for r in 1..world.size() {
world
.process_at_rank(r)
.send_with_tag(&[zero], TAG_PATCH_LEN);
world
.process_at_rank(r)
.send_with_tag(empty, TAG_PATCH_PATH);
}
return Response::error(
id,
format!(
"Resolved input path is not valid UTF-8: {}",
base_input_path.display()
),
);
}
};
let patched_path = match patch_input_with_geometry(
base_input_str,
symbols,
positions_angstrom,
cell_angstrom,
periodic,
PatchInputOverrides {
project_label: project_label.as_deref(),
print_level: print_level.as_deref(),
run_type: run_type.as_deref(),
charge: *charge,
multiplicity: *multiplicity,
walltime: walltime.as_deref(),
seed: seed.as_deref(),
echo_input: *echo_input,
uks: *uks,
roks: *roks,
wfn_restart_file: wfn_restart_file.as_deref(),
kpoint_config: kpoint_config.as_ref(),
dos_config: dos_config.as_ref(),
},
) {
Ok(p) => p,
Err(e) => {
let empty: &[u8] = &[];
let zero: u32 = 0;
for r in 1..world.size() {
world
.process_at_rank(r)
.send_with_tag(&[zero], TAG_PATCH_LEN);
world
.process_at_rank(r)
.send_with_tag(empty, TAG_PATCH_PATH);
}
return Response::error(id, format!("Failed to patch input file: {e}"));
}
};
let s = patched_path.to_string_lossy().to_string();
let bytes = s.as_bytes();
let len = bytes.len() as u32;
for r in 1..world.size() {
world
.process_at_rank(r)
.send_with_tag(&[len], TAG_PATCH_LEN);
world
.process_at_rank(r)
.send_with_tag(bytes, TAG_PATCH_PATH);
}
s
} else {
let (len_buf, _) = world
.process_at_rank(0)
.receive_vec_with_tag::<u32>(TAG_PATCH_LEN);
let len = len_buf[0] as usize;
let (data, _) = world
.process_at_rank(0)
.receive_vec_with_tag::<u8>(TAG_PATCH_PATH);
if len == 0 {
return Response::error(
id,
"Failed to create patched input file on rank 0".to_string(),
);
}
String::from_utf8_lossy(&data[..len]).to_string()
};
log_rank(
world,
&format!(
"[worker rank {}] InitForceEnvWithGeometry: calling ForceEnv::new (input='{}', output='{}')",
world.rank(),
patched_path_str,
output
),
);
let result = ForceEnv::new_with_mpi(&patched_path_str, output, world);
log_rank(
world,
&format!(
"[worker rank {}] InitForceEnvWithGeometry: ForceEnv::new returned",
world.rank()
),
);
if world.rank() == 0 {
let _ = std::fs::remove_file(&patched_path_str);
if patched_path_str.contains("cp2k_worker_patched_") {
}
}
match result {
Ok(fe) => {
*force_env = Some(fe);
Response::ok(id, Payload::Empty)
}
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::CalcEnergyForce => {
let fe = need_env!(force_env.as_mut());
match fe.calc_energy_force() {
Ok(()) => Response::ok(id, Payload::Empty),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::CalcEnergy => {
let fe = need_env!(force_env.as_mut());
match fe.calc_energy() {
Ok(()) => Response::ok(id, Payload::Empty),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::GetNatom => {
let fe = need_env!(force_env.as_ref());
match fe.get_natom() {
Ok(n) => Response::ok(id, Payload::Int(n as i64)),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::GetNparticle => {
let fe = need_env!(force_env.as_ref());
match fe.get_nparticle() {
Ok(n) => Response::ok(id, Payload::Int(n as i64)),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::GetPositions => {
let fe = need_env!(force_env.as_ref());
match fe.get_positions() {
Ok(arr) => Response::ok(id, Payload::Array1(arr.into_raw_vec_and_offset().0)),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::GetForces => {
let fe = need_env!(force_env.as_ref());
match fe.get_forces() {
Ok(arr) => Response::ok(id, Payload::Array1(arr.into_raw_vec_and_offset().0)),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::GetPotentialEnergy => {
let fe = need_env!(force_env.as_ref());
match fe.get_potential_energy() {
Ok(e) => Response::ok(id, Payload::Float(e)),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::GetCell => {
let fe = need_env!(force_env.as_ref());
match fe.get_cell() {
Ok(arr) => Response::ok(
id,
Payload::Array2 {
rows: 3,
cols: 3,
data: arr.into_raw_vec_and_offset().0,
},
),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::GetQmmmCell => {
let fe = need_env!(force_env.as_ref());
match fe.get_qmmm_cell() {
Ok(arr) => Response::ok(
id,
Payload::Array2 {
rows: 3,
cols: 3,
data: arr.into_raw_vec_and_offset().0,
},
),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::SetPositions { data } => {
let fe = need_env!(force_env.as_mut());
if data.len() % 3 != 0 {
return Response::error(id, "SetPositions expects flat 3N array".to_string());
}
let mut pos_bohr = data.clone();
for x in &mut pos_bohr {
*x /= BOHR;
}
match fe.set_positions(&pos_bohr) {
Ok(()) => Response::ok(id, Payload::Empty),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::SetVelocities { data } => {
let fe = need_env!(force_env.as_mut());
match fe.set_velocities(data) {
Ok(()) => Response::ok(id, Payload::Empty),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::SetCell { data } => {
let fe = need_env!(force_env.as_mut());
if data.len() != 9 {
return Response::error(id, "SetCell expects 9 floats (row-major)".to_string());
}
let mut cell = [[0.0f64; 3]; 3];
for i in 0..3 {
for j in 0..3 {
cell[i][j] = data[i * 3 + j] / BOHR;
}
}
match fe.set_cell(&cell) {
Ok(()) => Response::ok(id, Payload::Empty),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::GetMoCount => {
let fe = need_env!(force_env.as_ref());
match fe.get_mo_count() {
Ok(n) => Response::ok(id, Payload::Int(n as i64)),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::IsQuickstep => {
let fe = need_env!(force_env.as_ref());
Response::ok(id, Payload::Bool(fe.is_quickstep()))
}
#[cfg(feature = "extended")]
Command::GetStressTensor => {
let fe = need_env!(force_env.as_ref());
match fe.get_stress_tensor() {
Ok(arr) => Response::ok(
id,
Payload::Array2 {
rows: 3,
cols: 3,
data: arr.into_raw_vec_and_offset().0,
},
),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetVirialTensor => {
let fe = need_env!(force_env.as_ref());
match fe.get_virial_tensor() {
Ok(arr) => Response::ok(
id,
Payload::Array2 {
rows: 3,
cols: 3,
data: arr.into_raw_vec_and_offset().0,
},
),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetNmo { spin } => {
let fe = need_env!(force_env.as_ref());
match fe.get_nmo(*spin) {
Ok(n) => Response::ok(id, Payload::Int(n as i64)),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetEigenvalues { spin } => {
let fe = need_env!(force_env.as_ref());
match fe.get_eigenvalues(*spin) {
Ok(arr) => Response::ok(id, Payload::Array1(arr.into_raw_vec_and_offset().0)),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetOccupationNumbers { spin } => {
let fe = need_env!(force_env.as_ref());
match fe.get_occupation_numbers(*spin) {
Ok(arr) => Response::ok(id, Payload::Array1(arr.into_raw_vec_and_offset().0)),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetHomoLumo { spin } => {
let fe = need_env!(force_env.as_ref());
match fe.get_homo_lumo(*spin) {
Ok((h, l, hi, li)) => Response::ok(
id,
Payload::HomoLumo {
homo: h,
lumo: l,
homo_idx: hi,
lumo_idx: li,
},
),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetMullikenCharges => {
let fe = need_env!(force_env.as_ref());
match fe.get_mulliken_charges() {
Ok(arr) => Response::ok(id, Payload::Array1(arr.into_raw_vec_and_offset().0)),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetHirshfeldCharges => {
let fe = need_env!(force_env.as_ref());
match fe.get_hirshfeld_charges() {
Ok(arr) => Response::ok(id, Payload::Array1(arr.into_raw_vec_and_offset().0)),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetDipoleMoment => {
let fe = need_env!(force_env.as_ref());
match fe.get_dipole_moment() {
Ok(arr) => Response::ok(id, Payload::Array1(arr.into_raw_vec_and_offset().0)),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetScfInfo => {
let fe = need_env!(force_env.as_ref());
match fe.get_scf_info() {
Ok((nsteps, converged, ediff)) => Response::ok(
id,
Payload::ScfInfo {
nsteps,
converged,
energy_change: ediff,
},
),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetEnergyComponents => {
let fe = need_env!(force_env.as_ref());
match fe.get_energy_components() {
Ok((e_kin, e_hartree, e_xc, e_core, e_total)) => Response::ok(
id,
Payload::EnergyComponents {
e_kin,
e_hartree,
e_xc,
e_core,
e_total,
},
),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetNelectron => {
let fe = need_env!(force_env.as_ref());
match fe.get_nelectron() {
Ok(n) => Response::ok(id, Payload::Int(n as i64)),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetFermiEnergy => {
let fe = need_env!(force_env.as_ref());
match fe.get_fermi_energy() {
Ok(e) => Response::ok(id, Payload::Float(e)),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetTotalSpin => {
let fe = need_env!(force_env.as_ref());
match fe.get_total_spin() {
Ok(s) => Response::ok(id, Payload::Float(s)),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetGridInfo { spin } => {
let fe = need_env!(force_env.as_ref());
match fe.get_grid_info(*spin) {
Ok(info) => Response::ok(
id,
Payload::GridInfo {
npts: info.npts,
origin: info.origin,
dh: info.dh,
},
),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetElectronDensity { spin } => {
let fe = need_env!(force_env.as_ref());
match fe.get_electron_density(*spin) {
Ok((info, density)) => {
let n1 = info.npts[0] as usize;
let n2 = info.npts[1] as usize;
let n3 = info.npts[2] as usize;
let n_total = n1 * n2 * n3;
let byte_size = n_total * std::mem::size_of::<f64>();
let data = match density.as_slice_memory_order() {
Some(slice) => slice,
None => {
return Response::error(
id,
"Electron density buffer is not contiguous in memory".to_string(),
);
}
};
if world.rank() == 0 {
match ShmRegion::create(byte_size) {
Ok(mut shm) => {
let ptr = shm.as_mut_ptr() as *mut f64;
let slice = unsafe { std::slice::from_raw_parts_mut(ptr, n_total) };
slice.copy_from_slice(data);
let shm_name = shm.name().to_string();
Response::ok(
id,
Payload::SharedArray3 {
shm_name,
dims: [n1, n2, n3],
byte_size,
},
)
}
Err(e) => Response::error(id, format!("shm create failed: {e}")),
}
} else {
Response::ok(
id,
Payload::SharedArray3 {
shm_name: String::new(),
dims: [0, 0, 0],
byte_size: 0,
},
)
}
}
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetMoCoeffInfo { spin } => {
let fe = need_env!(force_env.as_ref());
match fe.get_mo_coeff_info(*spin) {
Ok((nao, nmo)) => Response::ok(id, Payload::MoCoeffInfo { nao, nmo }),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetMoCoefficients { spin } => {
let fe = need_env!(force_env.as_ref());
match fe.get_mo_coefficients(*spin) {
Ok(arr) => {
let rows = arr.nrows();
let cols = arr.ncols();
let n_total = rows * cols;
let byte_size = n_total * std::mem::size_of::<f64>();
let data = match arr.as_slice_memory_order() {
Some(slice) => slice,
None => {
return Response::error(
id,
"MO coefficient buffer is not contiguous in memory".to_string(),
);
}
};
if world.rank() == 0 {
match ShmRegion::create(byte_size) {
Ok(mut shm) => {
let ptr = shm.as_mut_ptr() as *mut f64;
let slice = unsafe { std::slice::from_raw_parts_mut(ptr, n_total) };
slice.copy_from_slice(data);
let shm_name = shm.name().to_string();
Response::ok(
id,
Payload::SharedArray2 {
shm_name,
rows,
cols,
byte_size,
},
)
}
Err(e) => Response::error(id, format!("shm create failed: {e}")),
}
} else {
Response::ok(
id,
Payload::SharedArray2 {
shm_name: String::new(),
rows: 0,
cols: 0,
byte_size: 0,
},
)
}
}
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetNkpoints => {
let fe = need_env!(force_env.as_ref());
match fe.get_nkpoints() {
Ok(n) => Response::ok(id, Payload::Int(n as i64)),
Err(e) => Response::error(id, format!("{e}")),
}
}
#[cfg(feature = "extended")]
Command::GetKpointEigenvalues { kpt_idx, spin } => {
let fe = need_env!(force_env.as_ref());
match fe.get_kpoint_eigenvalues(*kpt_idx, *spin) {
Ok(arr) => Response::ok(id, Payload::Array1(arr.into_raw_vec_and_offset().0)),
Err(e) => Response::error(id, format!("{e}")),
}
}
Command::GetVersion => match cp2k_rs::get_version() {
Ok(v) => Response::ok(id, Payload::String(v)),
Err(e) => Response::error(id, format!("{e}")),
},
Command::Shutdown => {
*force_env = None;
Response::ok(id, Payload::Empty)
}
}
}
fn main() {
if std::env::var("OMP_WAIT_POLICY").is_err() {
unsafe { std::env::set_var("OMP_WAIT_POLICY", "passive") };
}
let universe = mpi::initialize().expect("Failed to initialize MPI");
let world = universe.world();
if let Err(e) = init() {
eprintln!("CP2K initialization failed: {e}");
std::process::exit(1);
}
let rank = world.rank();
let mut force_env: Option<ForceEnv> = None;
if rank == 0 {
let socket_file = std::env::var("CP2K_WORKER_SOCKET_FILE")
.expect("CP2K_WORKER_SOCKET_FILE env-var must be set");
let socket_path = std::fs::read_to_string(&socket_file)
.unwrap_or_else(|_| "/tmp/cp2k_rs_worker.sock".to_string())
.trim()
.to_string();
let _ = std::fs::remove_file(&socket_path);
let listener = UnixListener::bind(&socket_path).expect("Failed to bind unix socket");
use std::os::unix::fs::PermissionsExt;
let _ = std::fs::set_permissions(&socket_path, std::fs::Permissions::from_mode(0o700));
let ready_file = std::env::var("CP2K_WORKER_READY_FILE")
.unwrap_or_else(|_| "/tmp/cp2k_rs_worker.ready".to_string());
let _ = std::fs::write(&ready_file, &socket_path);
let (mut stream, _) = listener.accept().expect("Failed to accept connection");
loop {
let msg = match read_msg(&mut stream) {
Ok(m) => m,
Err(e) => {
eprintln!("IPC read error: {e}");
break;
}
};
broadcast_command(&world, &msg);
let req: Request = match bincode::deserialize(&msg) {
Ok(r) => r,
Err(e) => {
let resp = Response::error(0, format!("Failed to deserialize request: {e}"));
let bytes = bincode::serialize(&resp).unwrap();
let _ = write_msg(&mut stream, &bytes);
continue;
}
};
let resp = dispatch(&req, &mut force_env, &world);
let bytes = bincode::serialize(&resp).unwrap();
if let Err(e) = write_msg(&mut stream, &bytes) {
eprintln!("IPC write error: {e}");
break;
}
if matches!(req.command, Command::Shutdown) {
break;
}
}
let _ = std::fs::remove_file(&socket_path);
let _ = std::fs::remove_file(&ready_file);
} else {
loop {
let msg = receive_command(&world);
let req: Request = match bincode::deserialize(&msg) {
Ok(r) => r,
Err(e) => {
eprintln!("Failed to deserialize request: {e}");
std::process::exit(1);
}
};
let resp = dispatch(&req, &mut force_env, &world);
drop(resp);
if matches!(req.command, Command::Shutdown) {
break;
}
}
}
let _ = finalize();
}