use std::collections::HashMap;
use std::io::{Read, Write};
use std::path::{Path, PathBuf};
use std::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command, Stdio};
use std::sync::{Arc, Mutex, OnceLock};
use crate::{Error, Result};
static GLOBAL_KERNEL_SERVER: OnceLock<KernelServer> = OnceLock::new();
struct KernelHandle {
child: Child,
stdin: ChildStdin,
stdout: ChildStdout,
_stderr_dropper: Option<ChildStderr>,
}
pub struct KernelServer {
pool: Mutex<HashMap<String, Arc<Mutex<KernelHandle>>>>,
}
impl KernelServer {
pub fn global() -> &'static Self {
GLOBAL_KERNEL_SERVER.get_or_init(|| Self {
pool: Mutex::new(HashMap::new()),
})
}
pub fn run(&self, kernel_type: &str, executable: &Path, payload: &str) -> Result<String> {
let handle = {
let mut pool = self
.pool
.lock()
.map_err(|err| Error::backend(format!("kernel server pool poisoned: {err}")))?;
match pool.get(kernel_type) {
Some(existing) if handle_is_alive(existing) => Arc::clone(existing),
Some(dead) => {
if let Ok(mut h) = dead.lock() {
let _ = h.child.kill();
let _ = h.child.wait();
}
pool.remove(kernel_type);
spawn_and_insert(&mut pool, kernel_type, executable)?
}
None => spawn_and_insert(&mut pool, kernel_type, executable)?,
}
};
let mut h = handle
.lock()
.map_err(|err| Error::backend(format!("kernel handle poisoned: {err}")))?;
send_payload(&mut h.stdin, payload.as_bytes())?;
let response = read_response(&mut h.stdout)?;
Ok(response)
}
pub fn oneshot(executable: &Path, payload: &str) -> Result<String> {
let mut child = Command::new(executable)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|err| Error::backend(format!("failed to spawn HIP kernel one-shot: {err}")))?;
if let Some(stdin) = child.stdin.as_mut() {
stdin.write_all(payload.as_bytes()).map_err(|err| {
Error::backend(format!("failed to write HIP kernel one-shot stdin: {err}"))
})?;
}
let run = child
.wait_with_output()
.map_err(|err| Error::backend(format!("failed to run HIP kernel one-shot: {err}")))?;
if !run.status.success() {
return Err(Error::backend(format!(
"HIP kernel one-shot failed: {}{}",
String::from_utf8_lossy(&run.stderr),
String::from_utf8_lossy(&run.stdout)
)));
}
Ok(String::from_utf8_lossy(&run.stdout).into_owned())
}
pub fn run_binary(
&self,
kernel_type: &str,
executable: &Path,
payload: &[u8],
) -> Result<Vec<u8>> {
let handle = {
let mut pool = self
.pool
.lock()
.map_err(|err| Error::backend(format!("kernel server pool poisoned: {err}")))?;
match pool.get(kernel_type) {
Some(existing) if handle_is_alive(existing) => Arc::clone(existing),
Some(dead) => {
if let Ok(mut h) = dead.lock() {
let _ = h.child.kill();
let _ = h.child.wait();
}
pool.remove(kernel_type);
spawn_and_insert(&mut pool, kernel_type, executable)?
}
None => spawn_and_insert(&mut pool, kernel_type, executable)?,
}
};
let mut h = handle
.lock()
.map_err(|err| Error::backend(format!("kernel handle poisoned: {err}")))?;
send_payload(&mut h.stdin, payload)?;
read_response_bytes(&mut h.stdout)
}
}
fn handle_is_alive(handle: &Arc<Mutex<KernelHandle>>) -> bool {
let Ok(mut h) = handle.lock() else {
return false;
};
match h.child.try_wait() {
Ok(None) => true,
Ok(Some(_status)) => false,
Err(_err) => false,
}
}
fn spawn_and_insert(
pool: &mut HashMap<String, Arc<Mutex<KernelHandle>>>,
kernel_type: &str,
executable: &Path,
) -> Result<Arc<Mutex<KernelHandle>>> {
let h = spawn_persistent_kernel(executable)?;
let arc = Arc::new(Mutex::new(h));
pool.insert(kernel_type.to_string(), Arc::clone(&arc));
Ok(arc)
}
fn spawn_persistent_kernel(executable: &Path) -> Result<KernelHandle> {
let mut child = Command::new(executable)
.arg("--server")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|err| {
Error::backend(format!(
"failed to spawn persistent HIP kernel {}: {err}",
executable.display()
))
})?;
let stdin = child.stdin.take().ok_or_else(|| {
Error::backend(format!(
"persistent HIP kernel {} had no stdin pipe",
executable.display()
))
})?;
let stdout = child.stdout.take().ok_or_else(|| {
Error::backend(format!(
"persistent HIP kernel {} had no stdout pipe",
executable.display()
))
})?;
let stderr = child.stderr.take();
Ok(KernelHandle {
child,
stdin,
stdout,
_stderr_dropper: stderr,
})
}
fn send_payload(stdin: &mut ChildStdin, payload: &[u8]) -> Result<()> {
let len = u32::try_from(payload.len())
.map_err(|_| Error::backend("kernel payload too large (>4 GiB)"))?;
let len_bytes = len.to_le_bytes();
stdin
.write_all(&len_bytes)
.map_err(|err| Error::backend(format!("failed to write kernel payload length: {err}")))?;
if !payload.is_empty() {
stdin
.write_all(payload)
.map_err(|err| Error::backend(format!("failed to write kernel payload body: {err}")))?;
}
stdin
.flush()
.map_err(|err| Error::backend(format!("failed to flush kernel payload: {err}")))?;
Ok(())
}
fn read_response(stdout: &mut ChildStdout) -> Result<String> {
let mut len_bytes = [0u8; 4];
stdout
.read_exact(&mut len_bytes)
.map_err(|err| Error::backend(format!("failed to read kernel response length: {err}")))?;
let len = u32::from_le_bytes(len_bytes) as usize;
let mut buf = vec![0u8; len];
if len > 0 {
stdout
.read_exact(&mut buf)
.map_err(|err| Error::backend(format!("failed to read kernel response body: {err}")))?;
}
Ok(String::from_utf8_lossy(&buf).into_owned())
}
fn read_response_bytes(stdout: &mut ChildStdout) -> Result<Vec<u8>> {
let mut len_bytes = [0u8; 4];
stdout
.read_exact(&mut len_bytes)
.map_err(|err| Error::backend(format!("failed to read kernel response length: {err}")))?;
let len = u32::from_le_bytes(len_bytes) as usize;
let mut buf = vec![0u8; len];
if len > 0 {
stdout
.read_exact(&mut buf)
.map_err(|err| Error::backend(format!("failed to read kernel response body: {err}")))?;
}
Ok(buf)
}
impl Drop for KernelServer {
fn drop(&mut self) {
if let Ok(mut pool) = self.pool.lock() {
for (_kernel_type, handle) in pool.drain() {
if let Ok(mut h) = handle.lock() {
let _ = h.child.kill();
let _ = h.child.wait();
}
}
}
}
}
pub fn run_persistent(kernel_type: &str, executable: &Path, payload: &str) -> Result<String> {
KernelServer::global().run(kernel_type, executable, payload)
}
pub fn run_persistent_binary(
kernel_type: &str,
executable: &Path,
payload: &[u8],
) -> Result<Vec<u8>> {
KernelServer::global().run_binary(kernel_type, executable, payload)
}
pub fn cached_executable_for(cache_dir: &Path, source_fingerprint: &str, suffix: &str) -> PathBuf {
cache_dir.join(format!("{source_fingerprint}-{suffix}"))
}
#[cfg(test)]
mod tests {
#[test]
fn protocol_is_little_endian() {
let len: u32 = 0x0102_0304;
assert_eq!(len.to_le_bytes(), [0x04, 0x03, 0x02, 0x01]);
}
}