cp2k-rs 0.1.3

Rust bindings for CP2K with Python interface
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
//! Pure-Rust worker management: spawning, IPC, and lifecycle.
//!
//! All public items in this module are GIL-free. When calling from Python,
//! the caller should release the GIL (e.g., with `py.detach(...)`) around
//! any blocking calls (`start_worker`, `stop_worker`, `ipc_call`).

use std::collections::HashMap;
use std::io::{Read, Write};
use std::os::unix::net::UnixStream;
use std::process::{Child, Command as StdCommand, Stdio};
use std::sync::Mutex;
use std::time::Duration;

use thiserror::Error;

use crate::worker_protocol::{Command, Payload, Request, Response, Status};

// ─── error type ──────────────────────────────────────────────────────────────

/// Errors that can occur during worker operations.
#[derive(Debug, Error)]
pub enum WorkerError {
    #[error("worker mutex poisoned")]
    MutexPoisoned,
    #[error("CP2K worker is not running; call start_worker() first")]
    NotRunning,
    #[error("a CP2K worker is already running; call stop_worker() first")]
    AlreadyRunning,
    #[error(
        "cp2k_rs_worker binary not found; \
             set CP2K_WORKER_BIN or ensure the binary is on PATH"
    )]
    BinaryNotFound,
    #[error("IPC I/O error: {0}")]
    Io(#[from] std::io::Error),
    #[error("serialization error: {0}")]
    Serialize(String),
    #[error("CP2K error: {0}")]
    Cp2kError(String),
    #[error("{0}")]
    Other(String),
}

// ─── global worker state ─────────────────────────────────────────────────────

struct WorkerState {
    child: Child,
    stream: UnixStream,
    next_id: u64,
    socket_path: String,
    ready_path: String,
}

static WORKER: Mutex<Option<WorkerState>> = Mutex::new(None);

// ─── IPC helpers (private) ───────────────────────────────────────────────────

fn read_msg(stream: &mut UnixStream) -> std::io::Result<Vec<u8>> {
    let mut len_buf = [0u8; 4];
    stream.read_exact(&mut len_buf)?;
    let len = u32::from_le_bytes(len_buf) as usize;
    let mut buf = vec![0u8; len];
    stream.read_exact(&mut buf)?;
    Ok(buf)
}

fn write_msg(stream: &mut UnixStream, data: &[u8]) -> std::io::Result<()> {
    let len = data.len() as u32;
    stream.write_all(&len.to_le_bytes())?;
    stream.write_all(data)?;
    stream.flush()
}

// ─── public IPC entry point ───────────────────────────────────────────────────

/// Send a command and receive the response.
///
/// Blocks until the response arrives. GIL-free: safe to call inside
/// `py.detach(...)`.
pub fn ipc_call(command: Command) -> Result<Payload, WorkerError> {
    let mut guard = WORKER.lock().map_err(|_| WorkerError::MutexPoisoned)?;
    let state = guard.as_mut().ok_or(WorkerError::NotRunning)?;

    let req = Request {
        request_id: state.next_id,
        command,
    };
    state.next_id += 1;

    let bytes = bincode::serialize(&req).map_err(|e| WorkerError::Serialize(e.to_string()))?;
    write_msg(&mut state.stream, &bytes)?;

    let raw = read_msg(&mut state.stream).map_err(|e| {
        if e.kind() == std::io::ErrorKind::UnexpectedEof {
            WorkerError::Other(
                "CP2K worker process died unexpectedly during a request \
                 (connection closed mid-read). Check the worker's stderr output \
                 for CP2K error messages."
                    .into(),
            )
        } else {
            WorkerError::Io(e)
        }
    })?;
    let resp: Response =
        bincode::deserialize(&raw).map_err(|e| WorkerError::Serialize(e.to_string()))?;

    match resp.status {
        Status::Ok => Ok(resp.payload),
        Status::Error(msg) => Err(WorkerError::Cp2kError(msg)),
    }
}

// ─── binary discovery ─────────────────────────────────────────────────────────

/// Find the `cp2k_rs_worker` binary.
///
/// Search order:
///   1. `CP2K_WORKER_BIN` environment variable
///   2. Same directory as the current executable
///   3. `PATH`
///
/// Python callers may additionally check the installed `cp2k_rs` Python
/// package directory and pass the found path directly to [`start_worker`].
pub fn find_worker_binary() -> Option<std::path::PathBuf> {
    // 1. Explicit override
    if let Ok(p) = std::env::var("CP2K_WORKER_BIN") {
        let path = std::path::PathBuf::from(p);
        if path.exists() {
            return Some(path);
        }
    }

    // 2. Next to the calling executable
    if let Ok(exe) = std::env::current_exe() {
        if let Some(dir) = exe.parent() {
            let candidate = dir.join("cp2k_rs_worker");
            if candidate.exists() {
                return Some(candidate);
            }
        }
    }

    // 3. PATH lookup
    if let Ok(path_var) = std::env::var("PATH") {
        for dir in std::env::split_paths(&path_var) {
            let candidate = dir.join("cp2k_rs_worker");
            if candidate.exists() {
                return Some(candidate);
            }
        }
    }

    None
}

/// Detect the default MPI launcher based on the current environment.
///
/// When `nproc` is `None` the launcher is returned without a `-n` flag,
/// letting the scheduler or MPI implementation decide the process count.
///
/// For `mpirun` (OpenMPI), critical environment variables are forwarded
/// explicitly via `-x` flags.  OpenMPI does not guarantee that every
/// variable present in the parent shell is propagated to remote ranks; in
/// particular the `OMPI_MCA_*` transport-selection variables must be
/// forwarded explicitly, otherwise UCX may be selected and crash with
/// SIGILL on CI runners that lack the required hardware (InfiniBand, AVX-512
/// instructions used by PSM3, …).
pub fn default_launcher(nproc: Option<u32>) -> Vec<String> {
    if std::env::var("SLURM_JOB_ID").is_ok() {
        match nproc {
            Some(n) => vec!["srun".to_string(), "-n".to_string(), n.to_string()],
            None => vec!["srun".to_string()],
        }
    } else {
        let mut v = vec!["mpirun".to_string()];
        if let Some(n) = nproc {
            v.push("-n".to_string());
            v.push(n.to_string());
        }

        // Forward environment variables that OpenMPI does not always propagate
        // automatically.  Using `-x VAR=val` hard-codes the value; using `-x VAR`
        // (no `=`) forwards whatever the parent exported.
        //
        // ── Transport / UCX suppression ───────────────────────────────────────
        // On AlmaLinux/RHEL 8 the system OpenMPI is built with UCX support.
        // UCX probes hardware at initialisation time using CPU instructions
        // (e.g. CPUID-gated AVX-512 paths in PSM3/UD verbs) that are absent on
        // plain VMs and CI runners, causing SIGILL inside libucs.so.  This
        // happens even when pml=ob1 is set because UCX is *also* loaded as the
        // one-sided communication (OSC) component, independently of the PML.
        //
        // Safe-default injection policy:
        //   • If the user has already set *any* OMPI_MCA_* variable we assume
        //     they know their MPI environment and we forward only what they set —
        //     no defaults are injected.  This preserves full performance on real
        //     HPC clusters with InfiniBand / RDMA hardware.
        //   • If no OMPI_MCA_* variable is set AND we cannot detect any
        //     InfiniBand hardware (no /sys/class/infiniband entries) we inject
        //     the conservative TCP-only defaults that prevent the UCX SIGILL.
        //   • UCX_ERROR_SIGNALS="" is always forwarded if the parent set it;
        //     it is injected as a default only together with the other UCX vars.
        //   • OMPI_ALLOW_RUN_AS_ROOT is forwarded if the parent set it but is
        //     never injected as a default — it is a security-relevant setting
        //     that the user should opt into explicitly.

        // Always-forward variables (only when the parent has exported them).
        let passthrough_vars: &[&str] = &[
            "OMPI_ALLOW_RUN_AS_ROOT",
            "OMPI_ALLOW_RUN_AS_ROOT_CONFIRM",
            "CP2K_DATA_DIR",
            "OMP_NUM_THREADS",
            "RUST_BACKTRACE",
        ];
        for var in passthrough_vars {
            if let Ok(val) = std::env::var(var) {
                v.push("-x".to_string());
                v.push(format!("{var}={val}"));
            }
        }

        // The MCA transport variables require more careful handling.
        let mca_vars: &[&str] = &[
            "OMPI_MCA_pml",
            "OMPI_MCA_btl",
            "OMPI_MCA_osc",
            "OMPI_MCA_mtl",
            "UCX_ERROR_SIGNALS",
        ];

        // Check whether the parent already has an opinion about MCA transport.
        let user_set_mca = mca_vars.iter().any(|v| std::env::var(v).is_ok());

        if user_set_mca {
            // User knows their environment: forward exactly what they set,
            // inject nothing extra.
            for var in mca_vars {
                if let Ok(val) = std::env::var(var) {
                    v.push("-x".to_string());
                    v.push(format!("{var}={val}"));
                }
            }
        } else if !has_infiniband() {
            // No user configuration and no IB hardware detected: inject the
            // conservative defaults that prevent UCX from crashing.
            let safe_defaults: &[(&str, &str)] = &[
                ("OMPI_MCA_pml", "ob1"),
                ("OMPI_MCA_btl", "tcp,self"),
                ("OMPI_MCA_osc", "pt2pt"),
                ("OMPI_MCA_mtl", "^ofi,psm,psm2"),
                ("UCX_ERROR_SIGNALS", ""),
            ];
            for (var, default) in safe_defaults {
                v.push("-x".to_string());
                v.push(format!("{var}={default}"));
            }
        }
        // else: IB hardware present and no explicit MCA config — let OpenMPI
        // auto-detect; don't constrain the transport choice.

        v
    }
}

/// Return `true` if at least one InfiniBand port is visible in sysfs.
///
/// Presence of `/sys/class/infiniband` with at least one entry is a reliable
/// indicator that UCX/RDMA transports are meaningful on this host.  The check
/// is intentionally conservative: if sysfs is unavailable or the read fails
/// for any reason we return `false` (i.e. assume no IB, inject safe defaults).
fn has_infiniband() -> bool {
    std::fs::read_dir("/sys/class/infiniband")
        .map(|mut d| d.next().is_some())
        .unwrap_or(false)
}

// ─── worker lifecycle ─────────────────────────────────────────────────────────

/// Start the MPI worker process and wait until its socket is ready.
///
/// * `worker_bin`       – Path to the `cp2k_rs_worker` binary.
/// * `nproc`            – MPI rank count. `None` omits `-n` and lets the
///                        scheduler decide. Ignored when `launcher_cmd` is given.
/// * `launcher_cmd`     – Custom launcher prefix, e.g. `["srun", "-n", "8"]`.
/// * `env`              – Extra environment variables for the worker.
/// * `working_dir`      – Working directory for the worker process.
/// * `connect_timeout`  – Seconds to wait for the socket to become ready.
///
/// GIL-free: safe to call inside `py.detach(...)`.
pub fn start_worker(
    worker_bin: std::path::PathBuf,
    nproc: Option<u32>,
    launcher_cmd: Option<Vec<String>>,
    env: Option<HashMap<String, String>>,
    working_dir: Option<String>,
    connect_timeout: f64,
) -> Result<(), WorkerError> {
    {
        let guard = WORKER.lock().map_err(|_| WorkerError::MutexPoisoned)?;
        if guard.is_some() {
            return Err(WorkerError::AlreadyRunning);
        }
    }

    let mut cmd_parts = launcher_cmd.unwrap_or_else(|| default_launcher(nproc));
    cmd_parts.push(worker_bin.to_string_lossy().into_owned());

    let socket_path = format!("/tmp/cp2k_worker_{}_{}.sock", std::process::id(), {
        use std::time::{SystemTime, UNIX_EPOCH};
        SystemTime::now()
            .duration_since(UNIX_EPOCH)
            .unwrap_or_default()
            .subsec_nanos()
    });
    let ready_path = format!("{socket_path}.ready");

    let _ = std::fs::remove_file(&socket_path);
    let _ = std::fs::remove_file(&ready_path);

    let mut cmd = StdCommand::new(&cmd_parts[0]);
    cmd.args(&cmd_parts[1..]);
    cmd.env("CP2K_WORKER_SOCKET_FILE", &socket_path);
    cmd.stdout(Stdio::inherit());
    cmd.stderr(Stdio::inherit());

    if let Some(extra_env) = env {
        for (k, v) in extra_env {
            cmd.env(k, v);
        }
    }
    if let Some(dir) = working_dir {
        cmd.current_dir(dir);
    }

    let child = cmd
        .spawn()
        .map_err(|e| WorkerError::Other(format!("Failed to spawn cp2k_rs_worker: {e}")))?;

    let timeout = Duration::from_secs_f64(connect_timeout);
    let start = std::time::Instant::now();
    loop {
        if std::path::Path::new(&ready_path).exists() {
            break;
        }
        if start.elapsed() > timeout {
            return Err(WorkerError::Other(format!(
                "Timed out waiting for cp2k_rs_worker to become ready ({connect_timeout}s). \
                 The worker process may have crashed during CP2K initialization. \
                 Check that CP2K_DATA_DIR is set (it should point to the CP2K data/ directory) \
                 and that the MPI launcher is available on PATH. Socket: {socket_path}"
            )));
        }
        std::thread::sleep(Duration::from_millis(50));
    }

    let deadline = std::time::Instant::now() + timeout;
    let stream = loop {
        match UnixStream::connect(&socket_path) {
            Ok(s) => break s,
            Err(_) if std::time::Instant::now() < deadline => {
                std::thread::sleep(Duration::from_millis(50));
            }
            Err(e) => return Err(WorkerError::Other(format!("Connect failed: {e}"))),
        }
    };

    stream.set_read_timeout(None)?;

    let mut guard = WORKER.lock().map_err(|_| WorkerError::MutexPoisoned)?;
    *guard = Some(WorkerState {
        child,
        stream,
        next_id: 0,
        socket_path,
        ready_path,
    });

    Ok(())
}

/// Shut down the worker process gracefully.
///
/// GIL-free: safe to call inside `py.detach(...)`.
pub fn stop_worker() -> Result<(), WorkerError> {
    // Best-effort shutdown: ignore errors (worker may already be gone).
    let _ = ipc_call(Command::Shutdown);

    let mut guard = WORKER.lock().map_err(|_| WorkerError::MutexPoisoned)?;
    if let Some(mut state) = guard.take() {
        let grace = Duration::from_secs(10);
        let start = std::time::Instant::now();
        loop {
            match state.child.try_wait() {
                Ok(Some(_)) => break,
                Ok(None) if start.elapsed() < grace => {
                    std::thread::sleep(Duration::from_millis(100));
                }
                _ => {
                    let _ = state.child.kill();
                    break;
                }
            }
        }
        let _ = std::fs::remove_file(&state.socket_path);
        let _ = std::fs::remove_file(&state.ready_path);
    }

    Ok(())
}