tokitai-operator 0.1.0

Verified DL kernel compiler: formally-checked GEMM, p-adic, sheaf, contract-carrying ops. Paper-artifact grade.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
//! Persistent kernel server for HIP backend kernels.
//!
//! ## Why this exists
//!
//! The HIP backend kernels (fp16 GEMM, LayerNorm, Softmax, etc.) are
//! shipped as separate binaries built by `hipcc`. The Rust wrappers
//! historically invoked each kernel with
//! `Command::new(&executable).spawn()` on every call. On the RX 7800
//! XT that costs ~400-500 ms of pure process-spawn + HIP-runtime-init
//! overhead, even when the kernel binary is hot in the page cache.
//! A 5-step Nano training (195K params, ~266 kernel calls per step)
//! therefore spent ~30 s/step in process-spawn overhead alone, almost
//! the same as the 85M-param Tiny run.
//!
//! ## What this module does
//!
//! It maintains a process-wide pool of long-lived child processes,
//! one per `kernel_type` (e.g. `"hip-gemm-f16-fwd"`,
//! `"hip-layernorm-fwd-bwd"`, `"hip-softmax-fwd-grad"`). The first
//! call to `run` for a given `kernel_type` spawns the child with the
//! `--server` flag; subsequent calls reuse the same child via a
//! length-prefixed binary protocol on the child's stdin/stdout.
//!
//! ## Protocol
//!
//! Per call, the client writes:
//! ```text
//! [u32: payload_len (little-endian)]
//! [payload_len bytes: text payload (the same payload the legacy
//!                      one-shot text mode would have read from stdin)]
//! ```
//! and reads back:
//! ```text
//! [u32: response_len (little-endian)]
//! [response_len bytes: text response (the same response the legacy
//!                       one-shot text mode would have written to stdout)]
//! ```
//!
//! This protocol is robust against line-buffering issues because it
//! uses exact byte counts, not line delimiters.
//!
//! ## Backward compatibility
//!
//! The kernel binaries still support the legacy one-shot text mode
//! (when invoked without `--server`). The `KernelServer::oneshot`
//! helper drives that path for tests that want full process
//! isolation.

use std::collections::HashMap;
use std::io::{Read, Write};
use std::path::{Path, PathBuf};
use std::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command, Stdio};
use std::sync::{Arc, Mutex, OnceLock};

use crate::{Error, Result};

/// Process-wide singleton kernel server. The pool is shared across
/// all HIP backend calls in the running process; the first call for a
/// kernel type pays the ~400 ms spawn cost, every subsequent call
/// costs only the I/O round-trip.
static GLOBAL_KERNEL_SERVER: OnceLock<KernelServer> = OnceLock::new();

/// Handle to one persistent kernel child. Holds the child process
/// and the stdin/stdout pipes; the pipes are behind a `Mutex` so the
/// server can serialize concurrent callers (the existing callers
/// share a global server, but the per-kernel pipe is single-threaded
/// by design — the per-call round-trip is ~1 ms, so the contention
/// overhead is negligible).
struct KernelHandle {
    child: Child,
    stdin: ChildStdin,
    stdout: ChildStdout,
    // Stderr is intentionally not drained in the hot path: the
    // kernel writes to stderr only on internal errors (which we
    // already detect via the response markers), and the 64KB
    // Linux pipe buffer is plenty for typical workloads. Draining
    // on every call would require either a background thread or
    // a non-blocking fd, both of which add complexity for
    // negligible benefit. The field is captured at spawn time
    // only to drop the pipe in `Child`'s destructor when the
    // persistent child is torn down.
    _stderr_dropper: Option<ChildStderr>,
}

/// Pool of long-lived kernel children, keyed by `kernel_type`.
pub struct KernelServer {
    pool: Mutex<HashMap<String, Arc<Mutex<KernelHandle>>>>,
}

impl KernelServer {
    /// Returns the process-wide singleton kernel server. The first
    /// call lazily constructs the pool; subsequent calls return the
    /// same instance.
    pub fn global() -> &'static Self {
        GLOBAL_KERNEL_SERVER.get_or_init(|| Self {
            pool: Mutex::new(HashMap::new()),
        })
    }

    /// Send a text `payload` to the persistent child for the given
    /// `kernel_type` and return its text `response`. Spawns the child
    /// on first use for that `kernel_type`. The `executable` is used
    /// only on the first call for each `kernel_type`; later calls
    /// reuse the same child.
    ///
    /// ## Crash recovery
    ///
    /// If the persistent child has died (e.g. the kernel's `check()`
    /// helper called `std::exit(10)` on a HIP error), the broken
    /// handle is detected via `Child::try_wait()` and **evicted** from
    /// the pool. A fresh child is spawned in its place. Without this,
    /// a single kernel crash would wedge the entire training run —
    /// every subsequent call to the same `kernel_type` would
    /// repeatedly hit the dead child and return the same
    /// `UnexpectedEof` error. The review (agent afb0561c1cd6e90af,
    /// session 206321a2) flagged this as a real [BLOCKER] for the
    /// 0.7B MoE training on RX 7800 XT.
    pub fn run(&self, kernel_type: &str, executable: &Path, payload: &str) -> Result<String> {
        // Step 1: get-or-respawn the child for this kernel type. We
        // hold the pool lock across spawn to keep the
        // get-or-insert critical section atomic, but the per-call
        // hot path (existing child) only briefly locks the pool to
        // clone the Arc, then drops the lock before taking the
        // per-kernel handle lock.
        let handle = {
            let mut pool = self
                .pool
                .lock()
                .map_err(|err| Error::backend(format!("kernel server pool poisoned: {err}")))?;
            match pool.get(kernel_type) {
                Some(existing) if handle_is_alive(existing) => Arc::clone(existing),
                Some(dead) => {
                    // The previous child has died. Best-effort:
                    // take the per-handle lock briefly (so the
                    // previous caller, if any, is forced to
                    // observe the eviction) and kill the
                    // lingering child to release the file
                    // descriptor. Then respawn.
                    if let Ok(mut h) = dead.lock() {
                        let _ = h.child.kill();
                        let _ = h.child.wait();
                    }
                    pool.remove(kernel_type);
                    spawn_and_insert(&mut pool, kernel_type, executable)?
                }
                None => spawn_and_insert(&mut pool, kernel_type, executable)?,
            }
        };

        // Step 2: serialize the per-kernel pipe I/O. We hold the
        // per-kernel mutex for the whole round-trip; this is fine
        // because the round-trip is ~1 ms and concurrent callers
        // would not be faster even with parallelism (the kernel
        // itself is GPU-bound and cannot overlap with another
        // call on the same kernel type anyway).
        let mut h = handle
            .lock()
            .map_err(|err| Error::backend(format!("kernel handle poisoned: {err}")))?;
        send_payload(&mut h.stdin, payload.as_bytes())?;
        let response = read_response(&mut h.stdout)?;
        // We do not drain stderr in the hot path: the kernel writes
        // to stderr only on internal errors (which we already detect
        // via the response markers), and the 64KB Linux pipe buffer
        // is plenty for typical workloads. Draining on every call
        // would require either a background thread or a non-blocking
        // fd, both of which add complexity for negligible benefit.
        Ok(response)
    }

    /// Spawn a fresh child for the given executable, run a single
    /// call, and tear it down. This is the legacy one-shot path,
    /// kept for tests that want full process isolation (e.g. to
    /// verify a kernel doesn't leak state between runs).
    pub fn oneshot(executable: &Path, payload: &str) -> Result<String> {
        let mut child = Command::new(executable)
            .stdin(Stdio::piped())
            .stdout(Stdio::piped())
            .stderr(Stdio::piped())
            .spawn()
            .map_err(|err| Error::backend(format!("failed to spawn HIP kernel one-shot: {err}")))?;
        if let Some(stdin) = child.stdin.as_mut() {
            stdin.write_all(payload.as_bytes()).map_err(|err| {
                Error::backend(format!("failed to write HIP kernel one-shot stdin: {err}"))
            })?;
        }
        let run = child
            .wait_with_output()
            .map_err(|err| Error::backend(format!("failed to run HIP kernel one-shot: {err}")))?;
        if !run.status.success() {
            return Err(Error::backend(format!(
                "HIP kernel one-shot failed: {}{}",
                String::from_utf8_lossy(&run.stderr),
                String::from_utf8_lossy(&run.stdout)
            )));
        }
        Ok(String::from_utf8_lossy(&run.stdout).into_owned())
    }

    /// Binary-protocol counterpart of [`run`]. The wire format is
    /// identical (length-prefixed little-endian) but the payload
    /// and response are raw bytes, so the kernel-side parser sees
    /// the bytes the host wrote (no UTF-8 round-trip). The kernel
    /// itself is responsible for dispatching to a binary or text
    /// parser based on the first 4 bytes of the payload.
    ///
    /// The pool key (`kernel_type`) is shared with the text
    /// `run` method: a kernel that speaks both protocols uses the
    /// same persistent child, and the kernel decides per-request
    /// which path to take.
    pub fn run_binary(
        &self,
        kernel_type: &str,
        executable: &Path,
        payload: &[u8],
    ) -> Result<Vec<u8>> {
        let handle = {
            let mut pool = self
                .pool
                .lock()
                .map_err(|err| Error::backend(format!("kernel server pool poisoned: {err}")))?;
            match pool.get(kernel_type) {
                Some(existing) if handle_is_alive(existing) => Arc::clone(existing),
                Some(dead) => {
                    if let Ok(mut h) = dead.lock() {
                        let _ = h.child.kill();
                        let _ = h.child.wait();
                    }
                    pool.remove(kernel_type);
                    spawn_and_insert(&mut pool, kernel_type, executable)?
                }
                None => spawn_and_insert(&mut pool, kernel_type, executable)?,
            }
        };
        let mut h = handle
            .lock()
            .map_err(|err| Error::backend(format!("kernel handle poisoned: {err}")))?;
        send_payload(&mut h.stdin, payload)?;
        read_response_bytes(&mut h.stdout)
    }
}

/// Check whether the kernel child behind a pool entry is still
/// running. Returns `false` if the child has exited (e.g. the
/// kernel crashed and called `std::exit(10)` on a HIP error) or
/// if the per-handle mutex is poisoned (treat as dead — the
/// training run is broken regardless). The pool lock MUST NOT be
/// held when this is called.
fn handle_is_alive(handle: &Arc<Mutex<KernelHandle>>) -> bool {
    let Ok(mut h) = handle.lock() else {
        return false;
    };
    // `try_wait` returns `Ok(None)` if the child is still running,
    // `Ok(Some(status))` if it has exited, and `Err(e)` on a kernel
    // query failure. Treat any non-running case as "dead" so the
    // caller evicts and respawns.
    match h.child.try_wait() {
        Ok(None) => true,
        Ok(Some(_status)) => false,
        Err(_err) => false,
    }
}

/// Spawn a new persistent kernel child and insert it into `pool`
/// under `kernel_type`. Caller MUST hold the pool lock.
fn spawn_and_insert(
    pool: &mut HashMap<String, Arc<Mutex<KernelHandle>>>,
    kernel_type: &str,
    executable: &Path,
) -> Result<Arc<Mutex<KernelHandle>>> {
    let h = spawn_persistent_kernel(executable)?;
    let arc = Arc::new(Mutex::new(h));
    pool.insert(kernel_type.to_string(), Arc::clone(&arc));
    Ok(arc)
}

/// Spawn a persistent child with `--server` and capture the
/// three pipes. Fails the same way as the legacy one-shot path so
/// the host can surface a clear error.
fn spawn_persistent_kernel(executable: &Path) -> Result<KernelHandle> {
    let mut child = Command::new(executable)
        .arg("--server")
        .stdin(Stdio::piped())
        .stdout(Stdio::piped())
        .stderr(Stdio::piped())
        .spawn()
        .map_err(|err| {
            Error::backend(format!(
                "failed to spawn persistent HIP kernel {}: {err}",
                executable.display()
            ))
        })?;
    let stdin = child.stdin.take().ok_or_else(|| {
        Error::backend(format!(
            "persistent HIP kernel {} had no stdin pipe",
            executable.display()
        ))
    })?;
    let stdout = child.stdout.take().ok_or_else(|| {
        Error::backend(format!(
            "persistent HIP kernel {} had no stdout pipe",
            executable.display()
        ))
    })?;
    let stderr = child.stderr.take();
    Ok(KernelHandle {
        child,
        stdin,
        stdout,
        _stderr_dropper: stderr,
    })
}

/// Write a length-prefixed binary payload. Uses exact byte counts
/// so the kernel's `std::cin.read()` on the C++ side can demux
/// requests without depending on line-buffering behavior.
fn send_payload(stdin: &mut ChildStdin, payload: &[u8]) -> Result<()> {
    let len = u32::try_from(payload.len())
        .map_err(|_| Error::backend("kernel payload too large (>4 GiB)"))?;
    let len_bytes = len.to_le_bytes();
    stdin
        .write_all(&len_bytes)
        .map_err(|err| Error::backend(format!("failed to write kernel payload length: {err}")))?;
    if !payload.is_empty() {
        stdin
            .write_all(payload)
            .map_err(|err| Error::backend(format!("failed to write kernel payload body: {err}")))?;
    }
    stdin
        .flush()
        .map_err(|err| Error::backend(format!("failed to flush kernel payload: {err}")))?;
    Ok(())
}

/// Read a length-prefixed binary response. Returns the response
/// bytes as a `String` (the kernel's text output is always valid
/// UTF-8 in our kernels, so we accept the lossy round-trip just in
/// case).
fn read_response(stdout: &mut ChildStdout) -> Result<String> {
    let mut len_bytes = [0u8; 4];
    stdout
        .read_exact(&mut len_bytes)
        .map_err(|err| Error::backend(format!("failed to read kernel response length: {err}")))?;
    let len = u32::from_le_bytes(len_bytes) as usize;
    let mut buf = vec![0u8; len];
    if len > 0 {
        stdout
            .read_exact(&mut buf)
            .map_err(|err| Error::backend(format!("failed to read kernel response body: {err}")))?;
    }
    Ok(String::from_utf8_lossy(&buf).into_owned())
}

/// Read a length-prefixed binary response, returning the raw
/// bytes. Used by [`KernelServer::run_binary`] for kernels that
/// speak a binary wire format (e.g. the AdamW step v2 protocol).
fn read_response_bytes(stdout: &mut ChildStdout) -> Result<Vec<u8>> {
    let mut len_bytes = [0u8; 4];
    stdout
        .read_exact(&mut len_bytes)
        .map_err(|err| Error::backend(format!("failed to read kernel response length: {err}")))?;
    let len = u32::from_le_bytes(len_bytes) as usize;
    let mut buf = vec![0u8; len];
    if len > 0 {
        stdout
            .read_exact(&mut buf)
            .map_err(|err| Error::backend(format!("failed to read kernel response body: {err}")))?;
    }
    Ok(buf)
}

/// Best-effort shutdown of all persistent children. Called from
/// `Drop` on the global server when the process exits. We use
/// `child.kill()` to terminate immediately; the OS reaps the
/// processes.
impl Drop for KernelServer {
    fn drop(&mut self) {
        if let Ok(mut pool) = self.pool.lock() {
            for (_kernel_type, handle) in pool.drain() {
                if let Ok(mut h) = handle.lock() {
                    let _ = h.child.kill();
                    let _ = h.child.wait();
                }
            }
        }
    }
}

/// Convenience wrapper used by the per-kernel Rust wrappers. The
/// wrapper passes its own `kernel_type` label and the path to the
/// compiled executable; this forwards to the global
/// `KernelServer::run` so each call site is a single line.
pub fn run_persistent(kernel_type: &str, executable: &Path, payload: &str) -> Result<String> {
    KernelServer::global().run(kernel_type, executable, payload)
}

/// Binary-protocol convenience wrapper. Same shape as
/// [`run_persistent`] but the payload is raw bytes and the response
/// is returned as raw bytes (no UTF-8 lossy round-trip). Used by
/// the AdamW step kernel after the v2 binary I/O upgrade; the
/// gemm/layernorm/softmax paths still use the text protocol.
pub fn run_persistent_binary(
    kernel_type: &str,
    executable: &Path,
    payload: &[u8],
) -> Result<Vec<u8>> {
    KernelServer::global().run_binary(kernel_type, executable, payload)
}

/// Compute the cached executable path for a kernel. This is just
/// a thin helper so the per-kernel Rust wrappers can stay short.
/// Returns the path the kernel's compile helper would have written
/// to, so callers can `hipcc_recheck_artifact` it before passing it
/// to the server.
pub fn cached_executable_for(cache_dir: &Path, source_fingerprint: &str, suffix: &str) -> PathBuf {
    cache_dir.join(format!("{source_fingerprint}-{suffix}"))
}

#[cfg(test)]
mod tests {

    /// The protocol is byte-exact little-endian. This test pins the
    /// byte order so a future refactor that accidentally switches
    /// to big-endian fails the test rather than silently corrupting
    /// the protocol on real hardware.
    #[test]
    fn protocol_is_little_endian() {
        let len: u32 = 0x0102_0304;
        assert_eq!(len.to_le_bytes(), [0x04, 0x03, 0x02, 0x01]);
    }
}