cp2k-rs 0.1.1

Rust bindings for CP2K with Python interface
Documentation
//! Pure-Rust worker management: spawning, IPC, and lifecycle.
//!
//! All public items in this module are GIL-free. When calling from Python,
//! the caller should release the GIL (e.g., with `py.detach(...)`) around
//! any blocking calls (`start_worker`, `stop_worker`, `ipc_call`).

use std::collections::HashMap;
use std::io::{Read, Write};
use std::os::unix::net::UnixStream;
use std::process::{Child, Command as StdCommand, Stdio};
use std::sync::Mutex;
use std::time::Duration;

use thiserror::Error;

use crate::worker_protocol::{Command, Payload, Request, Response, Status};

// ─── error type ──────────────────────────────────────────────────────────────

/// Errors that can occur during worker operations.
#[derive(Debug, Error)]
pub enum WorkerError {
    #[error("worker mutex poisoned")]
    MutexPoisoned,
    #[error("CP2K worker is not running; call start_worker() first")]
    NotRunning,
    #[error("a CP2K worker is already running; call stop_worker() first")]
    AlreadyRunning,
    #[error(
        "cp2k_rs_worker binary not found; \
             set CP2K_WORKER_BIN or ensure the binary is on PATH"
    )]
    BinaryNotFound,
    #[error("IPC I/O error: {0}")]
    Io(#[from] std::io::Error),
    #[error("serialization error: {0}")]
    Serialize(String),
    #[error("CP2K error: {0}")]
    Cp2kError(String),
    #[error("{0}")]
    Other(String),
}

// ─── global worker state ─────────────────────────────────────────────────────

struct WorkerState {
    child: Child,
    stream: UnixStream,
    next_id: u64,
    socket_path: String,
    ready_path: String,
}

static WORKER: Mutex<Option<WorkerState>> = Mutex::new(None);

// ─── IPC helpers (private) ───────────────────────────────────────────────────

fn read_msg(stream: &mut UnixStream) -> std::io::Result<Vec<u8>> {
    let mut len_buf = [0u8; 4];
    stream.read_exact(&mut len_buf)?;
    let len = u32::from_le_bytes(len_buf) as usize;
    let mut buf = vec![0u8; len];
    stream.read_exact(&mut buf)?;
    Ok(buf)
}

fn write_msg(stream: &mut UnixStream, data: &[u8]) -> std::io::Result<()> {
    let len = data.len() as u32;
    stream.write_all(&len.to_le_bytes())?;
    stream.write_all(data)?;
    stream.flush()
}

// ─── public IPC entry point ───────────────────────────────────────────────────

/// Send a command and receive the response.
///
/// Blocks until the response arrives. GIL-free: safe to call inside
/// `py.detach(...)`.
pub fn ipc_call(command: Command) -> Result<Payload, WorkerError> {
    let mut guard = WORKER.lock().map_err(|_| WorkerError::MutexPoisoned)?;
    let state = guard.as_mut().ok_or(WorkerError::NotRunning)?;

    let req = Request {
        request_id: state.next_id,
        command,
    };
    state.next_id += 1;

    let bytes = bincode::serialize(&req).map_err(|e| WorkerError::Serialize(e.to_string()))?;
    write_msg(&mut state.stream, &bytes)?;

    let raw = read_msg(&mut state.stream).map_err(|e| {
        if e.kind() == std::io::ErrorKind::UnexpectedEof {
            WorkerError::Other(
                "CP2K worker process died unexpectedly during a request \
                 (connection closed mid-read). Check the worker's stderr output \
                 for CP2K error messages."
                    .into(),
            )
        } else {
            WorkerError::Io(e)
        }
    })?;
    let resp: Response =
        bincode::deserialize(&raw).map_err(|e| WorkerError::Serialize(e.to_string()))?;

    match resp.status {
        Status::Ok => Ok(resp.payload),
        Status::Error(msg) => Err(WorkerError::Cp2kError(msg)),
    }
}

// ─── binary discovery ─────────────────────────────────────────────────────────

/// Find the `cp2k_rs_worker` binary.
///
/// Search order:
///   1. `CP2K_WORKER_BIN` environment variable
///   2. Same directory as the current executable
///   3. `PATH`
///
/// Python callers may additionally check the installed `cp2k_rs` Python
/// package directory and pass the found path directly to [`start_worker`].
pub fn find_worker_binary() -> Option<std::path::PathBuf> {
    // 1. Explicit override
    if let Ok(p) = std::env::var("CP2K_WORKER_BIN") {
        let path = std::path::PathBuf::from(p);
        if path.exists() {
            return Some(path);
        }
    }

    // 2. Next to the calling executable
    if let Ok(exe) = std::env::current_exe() {
        if let Some(dir) = exe.parent() {
            let candidate = dir.join("cp2k_rs_worker");
            if candidate.exists() {
                return Some(candidate);
            }
        }
    }

    // 3. PATH lookup
    if let Ok(path_var) = std::env::var("PATH") {
        for dir in std::env::split_paths(&path_var) {
            let candidate = dir.join("cp2k_rs_worker");
            if candidate.exists() {
                return Some(candidate);
            }
        }
    }

    None
}

/// Detect the default MPI launcher based on the current environment.
pub fn default_launcher(nproc: u32) -> Vec<String> {
    if std::env::var("SLURM_JOB_ID").is_ok() {
        vec!["srun".to_string()]
    } else if std::env::var("PBS_JOBID").is_ok() || std::env::var("LSB_JOBID").is_ok() {
        vec!["mpirun".to_string(), "-n".to_string(), nproc.to_string()]
    } else {
        vec!["mpirun".to_string(), "-n".to_string(), nproc.to_string()]
    }
}

// ─── worker lifecycle ─────────────────────────────────────────────────────────

/// Start the MPI worker process and wait until its socket is ready.
///
/// * `worker_bin`       – Path to the `cp2k_rs_worker` binary.
/// * `nproc`            – MPI rank count (ignored when `launcher_cmd` is given).
/// * `launcher_cmd`     – Custom launcher prefix, e.g. `["srun", "-n", "8"]`.
/// * `env`              – Extra environment variables for the worker.
/// * `working_dir`      – Working directory for the worker process.
/// * `connect_timeout`  – Seconds to wait for the socket to become ready.
///
/// GIL-free: safe to call inside `py.detach(...)`.
pub fn start_worker(
    worker_bin: std::path::PathBuf,
    nproc: u32,
    launcher_cmd: Option<Vec<String>>,
    env: Option<HashMap<String, String>>,
    working_dir: Option<String>,
    connect_timeout: f64,
) -> Result<(), WorkerError> {
    {
        let guard = WORKER.lock().map_err(|_| WorkerError::MutexPoisoned)?;
        if guard.is_some() {
            return Err(WorkerError::AlreadyRunning);
        }
    }

    let mut cmd_parts = launcher_cmd.unwrap_or_else(|| default_launcher(nproc));
    cmd_parts.push(worker_bin.to_string_lossy().into_owned());

    let socket_path = format!("/tmp/cp2k_worker_{}_{}.sock", std::process::id(), {
        use std::time::{SystemTime, UNIX_EPOCH};
        SystemTime::now()
            .duration_since(UNIX_EPOCH)
            .unwrap_or_default()
            .subsec_nanos()
    });
    let ready_path = format!("{socket_path}.ready");

    let _ = std::fs::remove_file(&socket_path);
    let _ = std::fs::remove_file(&ready_path);

    let mut cmd = StdCommand::new(&cmd_parts[0]);
    cmd.args(&cmd_parts[1..]);
    cmd.env("CP2K_WORKER_SOCKET_FILE", &socket_path);
    cmd.stdout(Stdio::inherit());
    cmd.stderr(Stdio::inherit());

    if let Some(extra_env) = env {
        for (k, v) in extra_env {
            cmd.env(k, v);
        }
    }
    if let Some(dir) = working_dir {
        cmd.current_dir(dir);
    }

    let child = cmd
        .spawn()
        .map_err(|e| WorkerError::Other(format!("Failed to spawn cp2k_rs_worker: {e}")))?;

    let timeout = Duration::from_secs_f64(connect_timeout);
    let start = std::time::Instant::now();
    loop {
        if std::path::Path::new(&ready_path).exists() {
            break;
        }
        if start.elapsed() > timeout {
            return Err(WorkerError::Other(format!(
                "Timed out waiting for cp2k_rs_worker to become ready ({connect_timeout}s). \
                 The worker process may have crashed during CP2K initialization. \
                 Check that CP2K_DATA_DIR is set (it should point to the CP2K data/ directory) \
                 and that the MPI launcher is available on PATH. Socket: {socket_path}"
            )));
        }
        std::thread::sleep(Duration::from_millis(50));
    }

    let deadline = std::time::Instant::now() + timeout;
    let stream = loop {
        match UnixStream::connect(&socket_path) {
            Ok(s) => break s,
            Err(_) if std::time::Instant::now() < deadline => {
                std::thread::sleep(Duration::from_millis(50));
            }
            Err(e) => return Err(WorkerError::Other(format!("Connect failed: {e}"))),
        }
    };

    stream.set_read_timeout(None)?;

    let mut guard = WORKER.lock().map_err(|_| WorkerError::MutexPoisoned)?;
    *guard = Some(WorkerState {
        child,
        stream,
        next_id: 0,
        socket_path,
        ready_path,
    });

    Ok(())
}

/// Shut down the worker process gracefully.
///
/// GIL-free: safe to call inside `py.detach(...)`.
pub fn stop_worker() -> Result<(), WorkerError> {
    // Best-effort shutdown: ignore errors (worker may already be gone).
    let _ = ipc_call(Command::Shutdown);

    let mut guard = WORKER.lock().map_err(|_| WorkerError::MutexPoisoned)?;
    if let Some(mut state) = guard.take() {
        let grace = Duration::from_secs(10);
        let start = std::time::Instant::now();
        loop {
            match state.child.try_wait() {
                Ok(Some(_)) => break,
                Ok(None) if start.elapsed() < grace => {
                    std::thread::sleep(Duration::from_millis(100));
                }
                _ => {
                    let _ = state.child.kill();
                    break;
                }
            }
        }
        let _ = std::fs::remove_file(&state.socket_path);
        let _ = std::fs::remove_file(&state.ready_path);
    }

    Ok(())
}