use std::collections::BTreeMap;
use std::path::PathBuf;
use std::process::{Child, Stdio};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, LazyLock, Mutex, OnceLock};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use harn_vm::VmValue;
use harn_vm::process_sandbox;
use crate::error::HostlibError;
static HANDLE_COUNTER: AtomicU64 = AtomicU64::new(1);
struct CancelState {
cancelled: AtomicBool,
}
struct HandleEntry {
child: Option<Child>,
pid: u32,
session_id: String,
cancel_state: Arc<CancelState>,
}
#[derive(Default)]
struct HandleStore {
entries: BTreeMap<String, HandleEntry>,
}
static HANDLE_STORE: LazyLock<Mutex<HandleStore>> =
LazyLock::new(|| Mutex::new(HandleStore::default()));
pub struct LongRunningHandleInfo {
pub handle_id: String,
pub started_at_ms: u64,
pub command_display: String,
}
impl LongRunningHandleInfo {
pub fn into_handle_response(self) -> VmValue {
super::response::ResponseBuilder::new()
.str("handle_id", self.handle_id)
.int("started_at", self.started_at_ms as i64)
.str("command", self.command_display)
.build()
}
}
pub fn spawn_long_running(
builtin: &'static str,
program: String,
args: Vec<String>,
cwd: Option<PathBuf>,
env: BTreeMap<String, String>,
session_id: String,
) -> Result<LongRunningHandleInfo, HostlibError> {
if program.is_empty() {
return Err(HostlibError::InvalidParameter {
builtin,
param: "argv",
message: "first element of argv must be a non-empty program name".to_string(),
});
}
let mut command =
process_sandbox::std_command_for(&program, &args).map_err(|e| HostlibError::Backend {
builtin,
message: format!("sandbox setup failed: {e:?}"),
})?;
if let Some(cwd_path) = cwd.as_ref() {
process_sandbox::enforce_process_cwd(cwd_path).map_err(|e| HostlibError::Backend {
builtin,
message: format!("sandbox cwd rejected: {e:?}"),
})?;
command.current_dir(cwd_path);
}
if !env.is_empty() {
command.env_clear();
for (key, value) in &env {
command.env(key, value);
}
}
command.stdout(Stdio::piped());
command.stderr(Stdio::piped());
command.stdin(Stdio::null());
let child = command.spawn().map_err(|e| {
if let Some(violation) = process_sandbox::process_spawn_error(&e) {
return HostlibError::Backend {
builtin,
message: format!("sandbox rejected spawn: {violation:?}"),
};
}
HostlibError::Backend {
builtin,
message: format!("spawn failed: {e}"),
}
})?;
let pid = child.id();
let id = HANDLE_COUNTER.fetch_add(1, Ordering::SeqCst);
let handle_id = format!("hto-{:x}-{id}", std::process::id());
let started_at_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
let mut all_argv = vec![program.clone()];
all_argv.extend(args.iter().cloned());
let command_display = all_argv.join(" ");
let cancel_state = Arc::new(CancelState {
cancelled: AtomicBool::new(false),
});
{
let mut store = HANDLE_STORE
.lock()
.expect("long-running handle store poisoned");
store.entries.insert(
handle_id.clone(),
HandleEntry {
child: Some(child),
pid,
session_id: session_id.clone(),
cancel_state: cancel_state.clone(),
},
);
}
let waiter_handle_id = handle_id.clone();
let waiter_session_id = session_id;
std::thread::Builder::new()
.name(format!("hto-waiter-{waiter_handle_id}"))
.spawn(move || {
waiter_thread(waiter_handle_id, waiter_session_id, cancel_state);
})
.map_err(|e| HostlibError::Backend {
builtin,
message: format!("failed to spawn waiter thread: {e}"),
})?;
Ok(LongRunningHandleInfo {
handle_id,
started_at_ms,
command_display,
})
}
fn waiter_thread(handle_id: String, session_id: String, cancel_state: Arc<CancelState>) {
let waiter_start = std::time::Instant::now();
let mut child = {
let mut store = HANDLE_STORE
.lock()
.expect("long-running handle store poisoned");
match store.entries.get_mut(&handle_id) {
Some(entry) => match entry.child.take() {
Some(c) => c,
None => return, },
None => return, }
};
use std::io::Read;
let mut stdout_bytes = Vec::new();
let mut stderr_bytes = Vec::new();
let (out_tx, out_rx) = std::sync::mpsc::channel::<Vec<u8>>();
let (err_tx, err_rx) = std::sync::mpsc::channel::<Vec<u8>>();
if let Some(mut out) = child.stdout.take() {
std::thread::spawn(move || {
let _ = out.read_to_end(&mut stdout_bytes);
let _ = out_tx.send(stdout_bytes);
});
}
if let Some(mut err) = child.stderr.take() {
std::thread::spawn(move || {
let _ = err.read_to_end(&mut stderr_bytes);
let _ = err_tx.send(stderr_bytes);
});
}
let status = child.wait().ok();
let stdout = out_rx
.recv_timeout(Duration::from_secs(5))
.unwrap_or_default();
let stderr = err_rx
.recv_timeout(Duration::from_secs(5))
.unwrap_or_default();
{
let mut store = HANDLE_STORE
.lock()
.expect("long-running handle store poisoned");
store.entries.remove(&handle_id);
}
if cancel_state.cancelled.load(Ordering::Acquire) {
return;
}
let (exit_code, signal_name) = match status {
Some(s) => decode_exit_status(s),
None => (-1, Some("SIGKILL".to_string())),
};
let duration_ms = waiter_start.elapsed().as_millis() as i64;
let mut payload = serde_json::Map::new();
payload.insert("handle_id".into(), serde_json::Value::String(handle_id));
payload.insert(
"exit_code".into(),
serde_json::Value::Number(exit_code.into()),
);
payload.insert(
"stdout".into(),
serde_json::Value::String(String::from_utf8_lossy(&stdout).into_owned()),
);
payload.insert(
"stderr".into(),
serde_json::Value::String(String::from_utf8_lossy(&stderr).into_owned()),
);
payload.insert(
"duration_ms".into(),
serde_json::Value::Number(duration_ms.into()),
);
if let Some(sig) = signal_name {
payload.insert("signal".into(), serde_json::Value::String(sig));
} else {
payload.insert("signal".into(), serde_json::Value::Null);
}
let content = serde_json::to_string(&payload).unwrap_or_default();
harn_vm::push_pending_feedback_global(&session_id, "tool_result", &content);
}
pub fn cancel_handle(handle_id: &str) -> bool {
let (pid, child, cancel_state) = {
let mut store = HANDLE_STORE
.lock()
.expect("long-running handle store poisoned");
match store.entries.remove(handle_id) {
None => return false,
Some(mut entry) => (entry.pid, entry.child.take(), entry.cancel_state.clone()),
}
};
do_kill(pid, child, cancel_state);
true
}
pub fn cancel_session_handles(session_id: &str) {
let to_kill: Vec<(u32, Option<Child>, Arc<CancelState>)> = {
let mut store = HANDLE_STORE
.lock()
.expect("long-running handle store poisoned");
let matching: Vec<String> = store
.entries
.iter()
.filter(|(_, e)| e.session_id == session_id)
.map(|(id, _)| id.clone())
.collect();
matching
.into_iter()
.filter_map(|id| {
store.entries.remove(&id).map(|mut e| {
let child = e.child.take();
(e.pid, child, e.cancel_state.clone())
})
})
.collect()
};
for (pid, child, cancel_state) in to_kill {
do_kill(pid, child, cancel_state);
}
}
fn do_kill(pid: u32, child: Option<Child>, cancel_state: Arc<CancelState>) {
cancel_state.cancelled.store(true, Ordering::Release);
if let Some(mut c) = child {
kill_child(&mut c);
} else {
kill_pid(pid);
}
}
pub(crate) fn register_cleanup_hook() {
static REGISTERED: OnceLock<()> = OnceLock::new();
REGISTERED.get_or_init(|| {
let hook: Arc<dyn Fn(&str) + Send + Sync> = Arc::new(|session_id: &str| {
cancel_session_handles(session_id);
});
harn_vm::register_session_end_hook(hook);
});
}
fn kill_child(child: &mut Child) {
let _ = child.kill();
let _ = child.wait();
}
fn kill_pid(pid: u32) {
#[cfg(unix)]
{
extern "C" {
fn kill(pid: i32, sig: i32) -> i32;
}
unsafe {
kill(pid as i32, 9); }
}
#[cfg(not(unix))]
{
let _ = pid; }
}
fn decode_exit_status(status: std::process::ExitStatus) -> (i32, Option<String>) {
#[cfg(unix)]
{
use std::os::unix::process::ExitStatusExt;
if let Some(code) = status.code() {
return (code, None);
}
if let Some(sig) = status.signal() {
return (-1, Some(format!("SIG{sig}")));
}
(-1, None)
}
#[cfg(not(unix))]
(status.code().unwrap_or(-1), None)
}