use std::collections::HashMap;
use std::io::{Read, Write};
use std::os::unix::net::UnixStream;
use std::process::{Child, Command as StdCommand, Stdio};
use std::sync::Mutex;
use std::time::Duration;
use thiserror::Error;
use crate::worker_protocol::{Command, Payload, Request, Response, Status};
#[derive(Debug, Error)]
pub enum WorkerError {
#[error("worker mutex poisoned")]
MutexPoisoned,
#[error("CP2K worker is not running; call start_worker() first")]
NotRunning,
#[error("a CP2K worker is already running; call stop_worker() first")]
AlreadyRunning,
#[error(
"cp2k_rs_worker binary not found; \
set CP2K_WORKER_BIN or ensure the binary is on PATH"
)]
BinaryNotFound,
#[error("IPC I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("serialization error: {0}")]
Serialize(String),
#[error("CP2K error: {0}")]
Cp2kError(String),
#[error("{0}")]
Other(String),
}
struct WorkerState {
child: Child,
stream: UnixStream,
next_id: u64,
socket_path: String,
socket_file_path: String,
ready_path: String,
}
static WORKER: Mutex<Option<WorkerState>> = Mutex::new(None);
fn read_msg(stream: &mut UnixStream) -> std::io::Result<Vec<u8>> {
let mut len_buf = [0u8; 4];
stream.read_exact(&mut len_buf)?;
let len = u32::from_le_bytes(len_buf) as usize;
let mut buf = vec![0u8; len];
stream.read_exact(&mut buf)?;
Ok(buf)
}
fn write_msg(stream: &mut UnixStream, data: &[u8]) -> std::io::Result<()> {
let len = data.len() as u32;
stream.write_all(&len.to_le_bytes())?;
stream.write_all(data)?;
stream.flush()
}
pub fn ipc_call(command: Command) -> Result<Payload, WorkerError> {
let mut guard = WORKER.lock().map_err(|_| WorkerError::MutexPoisoned)?;
let state = guard.as_mut().ok_or(WorkerError::NotRunning)?;
let req = Request {
request_id: state.next_id,
command,
};
state.next_id += 1;
let bytes = bincode::serialize(&req).map_err(|e| WorkerError::Serialize(e.to_string()))?;
write_msg(&mut state.stream, &bytes)?;
let raw = read_msg(&mut state.stream).map_err(|e| {
if e.kind() == std::io::ErrorKind::UnexpectedEof {
WorkerError::Other(
"CP2K worker process died unexpectedly during a request \
(connection closed mid-read). Check the worker's stderr output \
for CP2K error messages."
.into(),
)
} else {
WorkerError::Io(e)
}
})?;
let resp: Response =
bincode::deserialize(&raw).map_err(|e| WorkerError::Serialize(e.to_string()))?;
match resp.status {
Status::Ok => Ok(resp.payload),
Status::Error(msg) => Err(WorkerError::Cp2kError(msg)),
}
}
pub fn find_worker_binary() -> Option<std::path::PathBuf> {
if let Ok(p) = std::env::var("CP2K_WORKER_BIN") {
let path = std::path::PathBuf::from(p);
if path.exists() {
return Some(path);
}
}
if let Ok(exe) = std::env::current_exe()
&& let Some(dir) = exe.parent()
{
let candidate = dir.join("cp2k_rs_worker");
if candidate.exists() {
return Some(candidate);
}
}
if let Ok(path_var) = std::env::var("PATH") {
for dir in std::env::split_paths(&path_var) {
let candidate = dir.join("cp2k_rs_worker");
if candidate.exists() {
return Some(candidate);
}
}
}
None
}
fn lookup_env(var: &str, extra_env: Option<&HashMap<String, String>>) -> Option<String> {
extra_env
.and_then(|env| env.get(var).cloned())
.or_else(|| std::env::var(var).ok())
}
pub fn default_launcher(
nproc: Option<u32>,
extra_env: Option<&HashMap<String, String>>,
) -> Vec<String> {
if std::env::var("SLURM_JOB_ID").is_ok() {
match nproc {
Some(n) => vec!["srun".to_string(), "-n".to_string(), n.to_string()],
None => vec!["srun".to_string()],
}
} else {
let mut v = vec!["mpirun".to_string()];
if let Some(n) = nproc {
v.push("-n".to_string());
v.push(n.to_string());
}
let passthrough_vars: &[&str] = &[
"OMPI_ALLOW_RUN_AS_ROOT",
"OMPI_ALLOW_RUN_AS_ROOT_CONFIRM",
"CP2K_DATA_DIR",
"OMP_NUM_THREADS",
"OMP_WAIT_POLICY",
"RUST_BACKTRACE",
];
for var in passthrough_vars {
if let Some(val) = lookup_env(var, extra_env) {
v.push("-x".to_string());
v.push(format!("{var}={val}"));
}
}
let mca_vars: &[&str] = &[
"OMPI_MCA_pml",
"OMPI_MCA_btl",
"OMPI_MCA_osc",
"OMPI_MCA_mtl",
"UCX_ERROR_SIGNALS",
];
let user_set_mca = mca_vars.iter().any(|v| lookup_env(v, extra_env).is_some());
if user_set_mca {
for var in mca_vars {
if let Some(val) = lookup_env(var, extra_env) {
v.push("-x".to_string());
v.push(format!("{var}={val}"));
}
}
} else if !has_infiniband() {
let safe_defaults: &[(&str, &str)] = &[
("OMPI_MCA_pml", "ob1"),
("OMPI_MCA_btl", "tcp,self"),
("OMPI_MCA_osc", "pt2pt"),
("OMPI_MCA_mtl", "^ofi,psm,psm2"),
("UCX_ERROR_SIGNALS", ""),
];
for (var, default) in safe_defaults {
v.push("-x".to_string());
v.push(format!("{var}={default}"));
}
}
v
}
}
fn has_infiniband() -> bool {
std::fs::read_dir("/sys/class/infiniband")
.map(|mut d| d.next().is_some())
.unwrap_or(false)
}
pub fn start_worker(
worker_bin: std::path::PathBuf,
nproc: Option<u32>,
launcher_cmd: Option<Vec<String>>,
env: Option<HashMap<String, String>>,
working_dir: Option<String>,
connect_timeout: f64,
) -> Result<(), WorkerError> {
{
let guard = WORKER.lock().map_err(|_| WorkerError::MutexPoisoned)?;
if guard.is_some() {
return Err(WorkerError::AlreadyRunning);
}
}
let mut extra_env = env.unwrap_or_default();
extra_env
.entry("OMP_WAIT_POLICY".to_string())
.or_insert_with(|| "passive".to_string());
if let Some(parent) = worker_bin.parent() {
let candidate = parent.join("data");
if candidate.exists() {
extra_env.insert(
"CP2K_DATA_DIR".to_string(),
candidate.to_string_lossy().into_owned(),
);
}
}
let mut cmd_parts = launcher_cmd.unwrap_or_else(|| default_launcher(nproc, Some(&extra_env)));
cmd_parts.push(worker_bin.to_string_lossy().into_owned());
let socket_path = format!("/tmp/cp2k_worker_{}_{}.sock", std::process::id(), {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.subsec_nanos()
});
let socket_file_path = format!("{socket_path}.path");
let ready_path = format!("{socket_path}.ready");
let _ = std::fs::remove_file(&socket_path);
let _ = std::fs::remove_file(&socket_file_path);
let _ = std::fs::remove_file(&ready_path);
if let Err(e) = std::fs::write(&socket_file_path, &socket_path) {
return Err(WorkerError::Other(format!(
"Failed to write CP2K socket path file '{socket_file_path}': {e}"
)));
}
let mut cmd = StdCommand::new(&cmd_parts[0]);
cmd.args(&cmd_parts[1..]);
cmd.env("CP2K_WORKER_SOCKET_FILE", &socket_file_path);
cmd.env("CP2K_WORKER_READY_FILE", &ready_path);
cmd.stdout(Stdio::inherit());
cmd.stderr(Stdio::inherit());
if !extra_env.is_empty() {
for (k, v) in extra_env {
cmd.env(k, v);
}
}
if let Some(dir) = working_dir {
cmd.current_dir(dir);
}
let mut child = cmd
.spawn()
.map_err(|e| WorkerError::Other(format!("Failed to spawn cp2k_rs_worker: {e}")))?;
let timeout = Duration::from_secs_f64(connect_timeout);
let start = std::time::Instant::now();
loop {
if std::path::Path::new(&ready_path).exists() {
break;
}
if start.elapsed() > timeout {
let default_ready = "/tmp/cp2k_rs_worker.ready";
let default_ready_exists = std::path::Path::new(default_ready).exists();
let child_status = match child.try_wait() {
Ok(Some(status)) => format!("worker exited with status {status}"),
Ok(None) => "worker still running".to_string(),
Err(e) => format!("worker status check failed: {e}"),
};
let extra = if default_ready_exists {
format!(
" Default ready file exists at {default_ready} (possible handshake mismatch)."
)
} else {
String::new()
};
return Err(WorkerError::Other(format!(
"Timed out waiting for cp2k_rs_worker to become ready ({connect_timeout}s). \
The worker process may have crashed during CP2K initialization. \
Check that CP2K_DATA_DIR is set (it should point to the CP2K data/ directory) \
and that the MPI launcher is available on PATH. Socket: {socket_path}. {child_status}.{extra}"
)));
}
std::thread::sleep(Duration::from_millis(50));
}
let deadline = std::time::Instant::now() + timeout;
let stream = loop {
match UnixStream::connect(&socket_path) {
Ok(s) => break s,
Err(_) if std::time::Instant::now() < deadline => {
std::thread::sleep(Duration::from_millis(50));
}
Err(e) => return Err(WorkerError::Other(format!("Connect failed: {e}"))),
}
};
stream.set_read_timeout(None)?;
let mut guard = WORKER.lock().map_err(|_| WorkerError::MutexPoisoned)?;
*guard = Some(WorkerState {
child,
stream,
next_id: 0,
socket_path,
socket_file_path,
ready_path,
});
Ok(())
}
pub fn read_shared_array3(
shm_name: &str,
_dims: [usize; 3],
byte_size: usize,
) -> Result<Vec<f64>, WorkerError> {
crate::shm::read_shared_array3(shm_name, _dims, byte_size)
.map_err(|e| WorkerError::Other(format!("shm open failed: {e}")))
}
pub fn read_shared_array2(shm_name: &str, byte_size: usize) -> Result<Vec<f64>, WorkerError> {
crate::shm::read_shared_array2(shm_name, byte_size)
.map_err(|e| WorkerError::Other(format!("shm open failed: {e}")))
}
pub fn stop_worker() -> Result<(), WorkerError> {
let _ = ipc_call(Command::Shutdown);
let mut guard = WORKER.lock().map_err(|_| WorkerError::MutexPoisoned)?;
if let Some(mut state) = guard.take() {
let grace = Duration::from_secs(10);
let start = std::time::Instant::now();
loop {
match state.child.try_wait() {
Ok(Some(_)) => break,
Ok(None) if start.elapsed() < grace => {
std::thread::sleep(Duration::from_millis(100));
}
_ => {
let _ = state.child.kill();
break;
}
}
}
let _ = std::fs::remove_file(&state.socket_path);
let _ = std::fs::remove_file(&state.socket_file_path);
let _ = std::fs::remove_file(&state.ready_path);
}
Ok(())
}