Skip to main content

harn_hostlib/tools/
long_running.rs

1//! Long-running tool handle machinery.
2//!
3//! When a caller passes `long_running: true` to `run_command`, `run_test`, or
4//! `run_build_command`, the builtin spawns the child process without waiting,
5//! registers it here, and returns a handle dict immediately:
6//!
7//! ```json
8//! {
9//!   "handle_id": "hto-<pid-hex>-<n>",
10//!   "started_at": "...",
11//!   "command_or_op_descriptor": "..."
12//! }
13//! ```
14//!
15//! A background thread waits for the child and, when it exits, calls
16//! `harn_vm::push_pending_feedback_global(session_id, "tool_result", json)`
17//! so the agent-loop's next turn-preflight picks it up.
18//!
19//! ### Cancellation
20//!
21//! `cancel_handle(handle_id)` kills the spawned process (SIGKILL) within
22//! 2 seconds. The session-end hook registered on startup kills every
23//! in-flight handle associated with the ending session.
24//!
25//! #### PID-based signaling
26//!
27//! The waiter thread takes ownership of the `Child` object to drain
28//! stdout/stderr and call `wait()`. To keep cancellation possible even
29//! after the waiter has taken the `Child`, we store the raw OS process ID
30//! in the entry and kill by PID when needed. On Unix we call `kill(2)`
31//! directly via an `extern "C"` declaration (no `libc` crate required).
32//! A shared `cancelled` flag suppresses the feedback push when the waiter
33//! sees an exit caused by cancellation.
34
35use std::collections::BTreeMap;
36use std::io::{Read, Write};
37use std::path::PathBuf;
38use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
39use std::sync::{Arc, LazyLock, Mutex, OnceLock};
40use std::time::Duration;
41
42use harn_vm::VmValue;
43
44use crate::error::HostlibError;
45use crate::process::{self as process_handle, ProcessHandle, ProcessKiller, SpawnSpec};
46use crate::tools::proc::{self, CaptureConfig, CommandStatus, EnvMode};
47
48/// Atomic counter for generating unique handle IDs within this process.
49static HANDLE_COUNTER: AtomicU64 = AtomicU64::new(1);
50
51/// Shared cancellation state between the store entry and its waiter thread.
52struct CancelState {
53    /// Set to `true` when `cancel_handle` / `cancel_session_handles` runs.
54    /// The waiter checks this before pushing feedback.
55    cancelled: AtomicBool,
56}
57
58#[derive(Default)]
59struct OutputState {
60    stdout: Vec<u8>,
61    stderr: Vec<u8>,
62}
63
64/// Shared state for a single in-flight child process.
65struct HandleEntry {
66    /// The process handle. `None` after the waiter thread takes ownership.
67    handle: Option<Box<dyn ProcessHandle>>,
68    /// Killer that works even after the waiter took `handle`.
69    killer: Arc<dyn ProcessKiller>,
70    session_id: String,
71    /// Shared with the waiter thread.
72    cancel_state: Arc<CancelState>,
73    /// Sender used by the waiter thread to signal that the post-exit
74    /// feedback push is complete. `None` if the test-side hasn't asked
75    /// to be notified.
76    completion_tx: Option<std::sync::mpsc::SyncSender<()>>,
77}
78
79#[derive(Default)]
80struct HandleStore {
81    entries: BTreeMap<String, HandleEntry>,
82}
83
84static HANDLE_STORE: LazyLock<Mutex<HandleStore>> =
85    LazyLock::new(|| Mutex::new(HandleStore::default()));
86
87/// Metadata returned to the caller immediately when a long-running spawn
88/// succeeds. Serialised as a response dict by the calling builtin.
89pub struct LongRunningHandleInfo {
90    /// Command identifier shared with foreground command responses.
91    pub command_id: String,
92    /// Opaque handle identifier, e.g. `"hto-<pid-hex>-<n>"`.
93    pub handle_id: String,
94    /// RFC 3339 timestamp of the spawn.
95    pub started_at: String,
96    /// Raw child process id reported by the platform.
97    pub pid: u32,
98    /// Child process group id when the platform exposes it.
99    pub process_group_id: Option<u32>,
100    /// Human-readable display form of the argv (space-joined).
101    pub command_display: String,
102}
103
104pub(crate) struct LongRunningSpawnOptions {
105    pub(crate) env_mode: EnvMode,
106    pub(crate) capture: CaptureConfig,
107    pub(crate) session_id: String,
108    pub(crate) progress_interval: Option<Duration>,
109    pub(crate) progress_max_inline_bytes: usize,
110}
111
112struct WaiterContext {
113    command_id: String,
114    handle_id: String,
115    session_id: String,
116    started_at: String,
117    process_group_id: Option<u32>,
118    command_display: String,
119    progress_interval: Option<Duration>,
120    progress_max_inline_bytes: usize,
121}
122
123struct ProgressThreadContext {
124    command_id: String,
125    handle_id: String,
126    session_id: String,
127    started_at: String,
128    command_display: String,
129    process_group_id: Option<u32>,
130    output_path: PathBuf,
131    stdout_path: PathBuf,
132    stderr_path: PathBuf,
133    output_state: Arc<Mutex<OutputState>>,
134    cancel_state: Arc<CancelState>,
135    done: Arc<AtomicBool>,
136    started: std::time::Instant,
137    interval: Duration,
138    max_inline_bytes: usize,
139}
140
141impl LongRunningHandleInfo {
142    /// Convert into the standard handle response dict returned to the agent.
143    pub fn into_handle_response(self) -> VmValue {
144        proc::running_response(
145            self.command_id,
146            self.handle_id,
147            self.pid,
148            self.process_group_id,
149            self.started_at,
150            self.command_display,
151        )
152    }
153}
154
155/// Spawn the argv as a long-running child process and return a handle.
156///
157/// The background waiter calls `push_pending_feedback_global` when the
158/// process exits so the next agent-loop turn sees the result.
159pub fn spawn_long_running(
160    builtin: &'static str,
161    program: String,
162    args: Vec<String>,
163    cwd: Option<PathBuf>,
164    env: BTreeMap<String, String>,
165    session_id: String,
166) -> Result<LongRunningHandleInfo, HostlibError> {
167    spawn_long_running_with_options(
168        builtin,
169        program,
170        args,
171        cwd,
172        env,
173        LongRunningSpawnOptions {
174            env_mode: EnvMode::InheritClean,
175            capture: CaptureConfig::default(),
176            session_id,
177            progress_interval: None,
178            progress_max_inline_bytes: CaptureConfig::default().max_inline_bytes,
179        },
180    )
181}
182
183pub(crate) fn spawn_long_running_with_options(
184    builtin: &'static str,
185    program: String,
186    args: Vec<String>,
187    cwd: Option<PathBuf>,
188    env: BTreeMap<String, String>,
189    options: LongRunningSpawnOptions,
190) -> Result<LongRunningHandleInfo, HostlibError> {
191    let spec = SpawnSpec {
192        builtin,
193        program: program.clone(),
194        args: args.clone(),
195        cwd,
196        env,
197        env_mode: options.env_mode,
198        use_stdin: false,
199        configure_process_group: true,
200    };
201    let handle = process_handle::spawn_process(spec)
202        .map_err(|e| proc::process_error_to_hostlib(builtin, e))?;
203
204    let pid = handle.pid().unwrap_or(0);
205    let process_group_id = handle.process_group_id();
206    let killer = handle.killer();
207    let id = HANDLE_COUNTER.fetch_add(1, Ordering::SeqCst);
208    let handle_id = format!("hto-{:x}-{id}", std::process::id());
209    let command_id = proc::next_command_id();
210    let started_at = proc::now_rfc3339();
211
212    let mut all_argv = vec![program.clone()];
213    all_argv.extend(args.iter().cloned());
214    let command_display = all_argv.join(" ");
215
216    let cancel_state = Arc::new(CancelState {
217        cancelled: AtomicBool::new(false),
218    });
219
220    {
221        let mut store = HANDLE_STORE
222            .lock()
223            .expect("long-running handle store poisoned");
224        store.entries.insert(
225            handle_id.clone(),
226            HandleEntry {
227                handle: Some(handle),
228                killer,
229                session_id: options.session_id.clone(),
230                cancel_state: cancel_state.clone(),
231                completion_tx: None,
232            },
233        );
234    }
235
236    let waiter_context = WaiterContext {
237        command_id: command_id.clone(),
238        handle_id: handle_id.clone(),
239        session_id: options.session_id,
240        started_at: started_at.clone(),
241        process_group_id,
242        command_display: command_display.clone(),
243        progress_interval: options.progress_interval,
244        progress_max_inline_bytes: options.progress_max_inline_bytes,
245    };
246    let waiter_thread_name = waiter_context.handle_id.clone();
247    let capture = options.capture;
248    std::thread::Builder::new()
249        .name(format!("hto-waiter-{waiter_thread_name}"))
250        .spawn(move || {
251            waiter_thread(waiter_context, cancel_state, capture);
252        })
253        .map_err(|e| HostlibError::Backend {
254            builtin,
255            message: format!("failed to spawn waiter thread: {e}"),
256        })?;
257
258    Ok(LongRunningHandleInfo {
259        command_id,
260        handle_id,
261        started_at,
262        pid,
263        process_group_id,
264        command_display,
265    })
266}
267
268/// Background thread that waits for a child process and fires feedback.
269fn waiter_thread(context: WaiterContext, cancel_state: Arc<CancelState>, capture: CaptureConfig) {
270    let waiter_start = std::time::Instant::now();
271
272    // Take the handle out of the store. If the entry is already gone (i.e.
273    // cancel_handle ran and removed it before us), exit without action.
274    let mut handle = {
275        let mut store = HANDLE_STORE
276            .lock()
277            .expect("long-running handle store poisoned");
278        match store.entries.get_mut(&context.handle_id) {
279            Some(entry) => match entry.handle.take() {
280                Some(h) => h,
281                None => return, // already cancelled before we ran
282            },
283            None => return, // entry removed (cancelled before store insert — shouldn't happen)
284        }
285    };
286
287    let output_state = Arc::new(Mutex::new(OutputState::default()));
288    let done = Arc::new(AtomicBool::new(false));
289    let planned = proc::planned_artifact_paths(&context.command_id);
290    if let Some(parent) = planned.output_path.parent() {
291        let _ = std::fs::create_dir_all(parent);
292    }
293    let _ = std::fs::File::create(&planned.stdout_path);
294    let _ = std::fs::File::create(&planned.stderr_path);
295    let combined_file = std::fs::File::create(&planned.output_path)
296        .ok()
297        .map(|file| Arc::new(Mutex::new(file)));
298
299    let stdout_thread = handle.take_stdout().map(|out| {
300        spawn_output_drain(
301            out,
302            output_state.clone(),
303            planned.stdout_path.clone(),
304            combined_file.clone(),
305            true,
306        )
307    });
308    let stderr_thread = handle.take_stderr().map(|err| {
309        spawn_output_drain(
310            err,
311            output_state.clone(),
312            planned.stderr_path.clone(),
313            combined_file.clone(),
314            false,
315        )
316    });
317
318    let progress_thread = context
319        .progress_interval
320        .filter(|interval| !interval.is_zero())
321        .map(|interval| {
322            spawn_progress_thread(ProgressThreadContext {
323                command_id: context.command_id.clone(),
324                handle_id: context.handle_id.clone(),
325                session_id: context.session_id.clone(),
326                started_at: context.started_at.clone(),
327                command_display: context.command_display.clone(),
328                process_group_id: context.process_group_id,
329                output_path: planned.output_path.clone(),
330                stdout_path: planned.stdout_path.clone(),
331                stderr_path: planned.stderr_path.clone(),
332                output_state: output_state.clone(),
333                cancel_state: cancel_state.clone(),
334                done: done.clone(),
335                started: waiter_start,
336                interval,
337                max_inline_bytes: context.progress_max_inline_bytes,
338            })
339        });
340
341    let status = handle.wait().ok();
342
343    if let Some(thread) = stdout_thread {
344        let _ = thread.join();
345    }
346    if let Some(thread) = stderr_thread {
347        let _ = thread.join();
348    }
349    done.store(true, Ordering::Release);
350    drop(progress_thread);
351    let (stdout, stderr) = {
352        let state = output_state
353            .lock()
354            .unwrap_or_else(|poison| poison.into_inner());
355        (state.stdout.clone(), state.stderr.clone())
356    };
357
358    // Remove our entry from the store, taking the completion notifier on
359    // the way out so we can signal it after the feedback push completes.
360    let completion_tx = {
361        let mut store = HANDLE_STORE
362            .lock()
363            .expect("long-running handle store poisoned");
364        store
365            .entries
366            .remove(&context.handle_id)
367            .and_then(|mut e| e.completion_tx.take())
368    };
369
370    let signal_done = move || {
371        if let Some(tx) = completion_tx {
372            let _ = tx.try_send(());
373        }
374    };
375
376    // If cancellation was requested, don't push feedback — the caller
377    // that cancelled doesn't want to receive a spurious tool_result.
378    if cancel_state.cancelled.load(Ordering::Acquire) {
379        signal_done();
380        return;
381    }
382
383    let (exit_code, signal_name) = match status {
384        Some(s) => decode_exit_status(s),
385        // wait() itself failed — treat as killed (extremely unusual).
386        None => (-1, Some("SIGKILL".to_string())),
387    };
388    let duration = waiter_start.elapsed();
389    let duration_ms = duration.as_millis() as i64;
390    let artifacts = match proc::persist_artifacts(
391        &context.command_id,
392        &stdout,
393        &stderr,
394        Some(&context.handle_id),
395    ) {
396        Ok(artifacts) => artifacts,
397        Err(_) => return,
398    };
399    let (inline_stdout, inline_stderr) = proc::inline_output(&stdout, &stderr, capture);
400
401    let mut payload = serde_json::Map::new();
402    payload.insert(
403        "command_id".into(),
404        serde_json::Value::String(context.command_id.clone()),
405    );
406    payload.insert(
407        "status".into(),
408        serde_json::Value::String(CommandStatus::Completed.as_str().to_string()),
409    );
410    payload.insert(
411        "handle_id".into(),
412        serde_json::Value::String(context.handle_id),
413    );
414    payload.insert(
415        "command_or_op_descriptor".into(),
416        serde_json::Value::String(context.command_display),
417    );
418    payload.insert(
419        "started_at".into(),
420        serde_json::Value::String(context.started_at),
421    );
422    payload.insert(
423        "ended_at".into(),
424        serde_json::Value::String(proc::now_rfc3339()),
425    );
426    payload.insert(
427        "duration_ms".into(),
428        serde_json::Value::Number(duration_ms.into()),
429    );
430    payload.insert(
431        "exit_code".into(),
432        serde_json::Value::Number(exit_code.into()),
433    );
434    payload.insert("stdout".into(), serde_json::Value::String(inline_stdout));
435    payload.insert("stderr".into(), serde_json::Value::String(inline_stderr));
436    payload.insert(
437        "output_path".into(),
438        serde_json::Value::String(artifacts.output_path.display().to_string()),
439    );
440    payload.insert(
441        "stdout_path".into(),
442        serde_json::Value::String(artifacts.stdout_path.display().to_string()),
443    );
444    payload.insert(
445        "stderr_path".into(),
446        serde_json::Value::String(artifacts.stderr_path.display().to_string()),
447    );
448    payload.insert(
449        "line_count".into(),
450        serde_json::Value::Number(artifacts.line_count.into()),
451    );
452    payload.insert(
453        "byte_count".into(),
454        serde_json::Value::Number(artifacts.byte_count.into()),
455    );
456    payload.insert(
457        "output_sha256".into(),
458        serde_json::Value::String(artifacts.output_sha256),
459    );
460    if let Some(pgid) = context.process_group_id {
461        payload.insert(
462            "process_group_id".into(),
463            serde_json::Value::Number((pgid as u64).into()),
464        );
465    }
466    if let Some(sig) = signal_name {
467        payload.insert("signal".into(), serde_json::Value::String(sig));
468    } else {
469        payload.insert("signal".into(), serde_json::Value::Null);
470    }
471
472    let content = serde_json::to_string(&payload).unwrap_or_default();
473    harn_vm::push_pending_feedback_global(&context.session_id, "tool_result", &content);
474    signal_done();
475}
476
477fn spawn_output_drain(
478    mut reader: Box<dyn Read + Send>,
479    state: Arc<Mutex<OutputState>>,
480    path: std::path::PathBuf,
481    combined_file: Option<Arc<Mutex<std::fs::File>>>,
482    stdout: bool,
483) -> std::thread::JoinHandle<()> {
484    std::thread::spawn(move || {
485        let mut file = std::fs::File::create(path).ok();
486        let mut buf = [0_u8; 8192];
487        loop {
488            let read = match reader.read(&mut buf) {
489                Ok(0) => break,
490                Ok(read) => read,
491                Err(_) => break,
492            };
493            let chunk = &buf[..read];
494            if let Some(file) = file.as_mut() {
495                let _ = file.write_all(chunk);
496            }
497            if let Some(combined) = combined_file.as_ref() {
498                if let Ok(mut combined) = combined.lock() {
499                    let _ = combined.write_all(chunk);
500                }
501            }
502            if let Ok(mut state) = state.lock() {
503                if stdout {
504                    state.stdout.extend_from_slice(chunk);
505                } else {
506                    state.stderr.extend_from_slice(chunk);
507                }
508            }
509        }
510    })
511}
512
513fn spawn_progress_thread(context: ProgressThreadContext) -> std::thread::JoinHandle<()> {
514    std::thread::spawn(move || {
515        while !context.done.load(Ordering::Acquire)
516            && !context.cancel_state.cancelled.load(Ordering::Acquire)
517        {
518            std::thread::sleep(context.interval);
519            if context.done.load(Ordering::Acquire)
520                || context.cancel_state.cancelled.load(Ordering::Acquire)
521            {
522                break;
523            }
524            let (stdout, stderr) = {
525                let state = context
526                    .output_state
527                    .lock()
528                    .unwrap_or_else(|poison| poison.into_inner());
529                (state.stdout.clone(), state.stderr.clone())
530            };
531            let capture = CaptureConfig {
532                max_inline_bytes: context.max_inline_bytes,
533                ..CaptureConfig::default()
534            };
535            let (inline_stdout, inline_stderr) = proc::inline_output(&stdout, &stderr, capture);
536            let byte_count = stdout.len().saturating_add(stderr.len());
537            let payload = serde_json::json!({
538                "command_id": &context.command_id,
539                "handle_id": &context.handle_id,
540                "status": CommandStatus::Running.as_str(),
541                "command_or_op_descriptor": &context.command_display,
542                "started_at": &context.started_at,
543                "ended_at": null,
544                "duration_ms": context.started.elapsed().as_millis() as i64,
545                "exit_code": null,
546                "signal": null,
547                "stdout": inline_stdout,
548                "stderr": inline_stderr,
549                "output_path": context.output_path.display().to_string(),
550                "stdout_path": context.stdout_path.display().to_string(),
551                "stderr_path": context.stderr_path.display().to_string(),
552                "byte_count": byte_count as i64,
553                "line_count": stdout.iter().chain(stderr.iter()).filter(|byte| **byte == b'\n').count() as i64,
554                "process_group_id": context.process_group_id,
555            });
556            harn_vm::push_pending_feedback_global(
557                &context.session_id,
558                "tool_progress",
559                &payload.to_string(),
560            );
561        }
562    })
563}
564
565/// Cancel a specific in-flight long-running handle. Kills the process and
566/// removes the entry. Returns `true` if the handle was found and cancelled.
567pub fn cancel_handle(handle_id: &str) -> bool {
568    let (handle_owned, killer, cancel_state, completion_tx) = {
569        let mut store = HANDLE_STORE
570            .lock()
571            .expect("long-running handle store poisoned");
572        match store.entries.remove(handle_id) {
573            None => return false,
574            Some(mut entry) => (
575                entry.handle.take(),
576                entry.killer.clone(),
577                entry.cancel_state.clone(),
578                entry.completion_tx.take(),
579            ),
580        }
581    };
582    do_kill(handle_owned, killer, cancel_state);
583    // If a test registered a completion notifier, signal it now — the
584    // waiter won't be able to (entry already removed) and we know the
585    // waiter will skip feedback because cancellation is set.
586    if let Some(tx) = completion_tx {
587        let _ = tx.try_send(());
588    }
589    true
590}
591
592/// Tuple shape used by `cancel_session_handles` to drain entries while
593/// holding the store lock for as little as possible. Boxed-trait fields
594/// make it noisy to inline as an unnamed type.
595type SessionKillEntry = (
596    Option<Box<dyn ProcessHandle>>,
597    Arc<dyn ProcessKiller>,
598    Arc<CancelState>,
599);
600
601/// Cancel all in-flight handles for a given session. Called by the
602/// session-end hook to avoid orphaned processes.
603pub fn cancel_session_handles(session_id: &str) {
604    let to_kill: Vec<SessionKillEntry> = {
605        let mut store = HANDLE_STORE
606            .lock()
607            .expect("long-running handle store poisoned");
608        let matching: Vec<String> = store
609            .entries
610            .iter()
611            .filter(|(_, e)| e.session_id == session_id)
612            .map(|(id, _)| id.clone())
613            .collect();
614        matching
615            .into_iter()
616            .filter_map(|id| {
617                store.entries.remove(&id).map(|mut e| {
618                    let handle = e.handle.take();
619                    (handle, e.killer.clone(), e.cancel_state.clone())
620                })
621            })
622            .collect()
623    };
624    for (handle, killer, cancel_state) in to_kill {
625        do_kill(handle, killer, cancel_state);
626    }
627}
628
629/// Set the cancellation flag and kill the process. Used by both `cancel_handle`
630/// and `cancel_session_handles`.
631fn do_kill(
632    handle: Option<Box<dyn ProcessHandle>>,
633    killer: Arc<dyn ProcessKiller>,
634    cancel_state: Arc<CancelState>,
635) {
636    // Signal cancellation so the waiter (if still running) skips feedback.
637    cancel_state.cancelled.store(true, Ordering::Release);
638    // Kill via the handle's killer (works whether or not we still own
639    // the handle). If we still hold the handle, drop it after kill so the
640    // OS reaps the child.
641    killer.kill();
642    drop(handle);
643}
644
645/// Register the session-cleanup hook with harn-vm. Uses a `OnceLock` so the
646/// hook is registered exactly once even if `register_builtins` is called
647/// multiple times (e.g. in tests).
648pub(crate) fn register_cleanup_hook() {
649    static REGISTERED: OnceLock<()> = OnceLock::new();
650    REGISTERED.get_or_init(|| {
651        let hook: Arc<dyn Fn(&str) + Send + Sync> = Arc::new(|session_id: &str| {
652            cancel_session_handles(session_id);
653        });
654        harn_vm::register_session_end_hook(hook);
655    });
656}
657
658fn decode_exit_status(status: process_handle::ExitStatus) -> (i32, Option<String>) {
659    if let Some(code) = status.code {
660        return (code, None);
661    }
662    if let Some(sig) = status.signal {
663        return (-1, Some(format!("SIG{sig}")));
664    }
665    (-1, None)
666}
667
668/// Register a completion notifier for `handle_id`. The waiter thread sends
669/// `()` on the returned receiver after it pushes the feedback item to the
670/// global queue. Returns `None` if the handle is no longer in the store
671/// (e.g. already cancelled or completed). Used by tests to await waiter
672/// completion deterministically — no polling, no `thread::sleep`.
673pub fn register_completion_notifier(handle_id: &str) -> Option<std::sync::mpsc::Receiver<()>> {
674    let (tx, rx) = std::sync::mpsc::sync_channel::<()>(1);
675    let mut store = HANDLE_STORE
676        .lock()
677        .expect("long-running handle store poisoned");
678    let entry = store.entries.get_mut(handle_id)?;
679    entry.completion_tx = Some(tx);
680    Some(rx)
681}