1use std::collections::BTreeMap;
39use std::io::{Read, Write};
40use std::path::PathBuf;
41use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
42use std::sync::{Arc, LazyLock, Mutex, OnceLock};
43use std::time::Duration;
44
45use harn_vm::VmValue;
46
47use crate::error::HostlibError;
48use crate::process::{self as process_handle, ProcessHandle, ProcessKiller, SpawnSpec};
49use crate::tools::proc::{self, CaptureConfig, CommandStatus, EnvMode};
50
51static HANDLE_COUNTER: AtomicU64 = AtomicU64::new(1);
53
54struct CancelState {
56 cancelled: AtomicBool,
59 timed_out: AtomicBool,
63}
64
65#[derive(Default)]
66struct OutputState {
67 stdout: Vec<u8>,
68 stderr: Vec<u8>,
69}
70
71struct HandleEntry {
73 handle: Option<Box<dyn ProcessHandle>>,
75 killer: Arc<dyn ProcessKiller>,
77 session_id: String,
78 cancel_state: Arc<CancelState>,
80 completion_tx: Option<std::sync::mpsc::SyncSender<()>>,
84 result_tx: Option<std::sync::mpsc::SyncSender<VmValue>>,
87}
88
89#[derive(Default)]
90struct HandleStore {
91 entries: BTreeMap<String, HandleEntry>,
92}
93
94static HANDLE_STORE: LazyLock<Mutex<HandleStore>> =
95 LazyLock::new(|| Mutex::new(HandleStore::default()));
96
97pub struct LongRunningHandleInfo {
100 pub command_id: String,
102 pub handle_id: String,
104 pub started_at: String,
106 pub pid: u32,
108 pub process_group_id: Option<u32>,
110 pub command_display: String,
112}
113
114pub(crate) struct LongRunningSpawnOptions {
115 pub(crate) env_mode: EnvMode,
116 pub(crate) capture: CaptureConfig,
117 pub(crate) session_id: String,
118 pub(crate) progress_interval: Option<Duration>,
119 pub(crate) progress_max_inline_bytes: usize,
120}
121
122struct WaiterContext {
123 command_id: String,
124 handle_id: String,
125 session_id: String,
126 started_at: String,
127 process_group_id: Option<u32>,
128 command_display: String,
129 progress_interval: Option<Duration>,
130 progress_max_inline_bytes: usize,
131}
132
133struct ProgressThreadContext {
134 command_id: String,
135 handle_id: String,
136 session_id: String,
137 started_at: String,
138 command_display: String,
139 process_group_id: Option<u32>,
140 output_path: PathBuf,
141 stdout_path: PathBuf,
142 stderr_path: PathBuf,
143 output_state: Arc<Mutex<OutputState>>,
144 cancel_state: Arc<CancelState>,
145 done: Arc<AtomicBool>,
146 started: std::time::Instant,
147 interval: Duration,
148 max_inline_bytes: usize,
149}
150
151impl LongRunningHandleInfo {
152 pub fn into_handle_response(self) -> VmValue {
154 proc::running_response(
155 self.command_id,
156 self.handle_id,
157 self.pid,
158 self.process_group_id,
159 self.started_at,
160 self.command_display,
161 )
162 }
163}
164
165pub fn spawn_long_running(
171 builtin: &'static str,
172 program: String,
173 args: Vec<String>,
174 cwd: Option<PathBuf>,
175 env: BTreeMap<String, String>,
176 session_id: String,
177) -> Result<LongRunningHandleInfo, HostlibError> {
178 spawn_long_running_with_options(
179 builtin,
180 program,
181 args,
182 cwd,
183 env,
184 LongRunningSpawnOptions {
185 env_mode: EnvMode::InheritClean,
186 capture: CaptureConfig::default(),
187 session_id,
188 progress_interval: None,
189 progress_max_inline_bytes: CaptureConfig::default().max_inline_bytes,
190 },
191 )
192}
193
194pub(crate) fn spawn_long_running_with_options(
195 builtin: &'static str,
196 program: String,
197 args: Vec<String>,
198 cwd: Option<PathBuf>,
199 env: BTreeMap<String, String>,
200 options: LongRunningSpawnOptions,
201) -> Result<LongRunningHandleInfo, HostlibError> {
202 let mut env = env;
203 proc::apply_toolchain_path(cwd.as_deref(), &mut env, options.env_mode);
204 let spec = SpawnSpec {
205 builtin,
206 program: program.clone(),
207 args: args.clone(),
208 cwd,
209 env,
210 env_mode: options.env_mode,
211 use_stdin: false,
212 configure_process_group: true,
213 };
214 let handle = process_handle::spawn_process(spec)
215 .map_err(|e| proc::process_error_to_hostlib(builtin, e))?;
216
217 let pid = handle.pid().unwrap_or(0);
218 let process_group_id = handle.process_group_id();
219 let killer = handle.killer();
220 let id = HANDLE_COUNTER.fetch_add(1, Ordering::SeqCst);
221 let handle_id = format!("hto-{:x}-{id}", std::process::id());
222 let command_id = proc::next_command_id();
223 let started_at = proc::now_rfc3339();
224 let _artifacts = proc::register_live_artifacts(&command_id, Some(&handle_id))?;
225
226 let mut all_argv = vec![program];
227 all_argv.extend(args.iter().cloned());
228 let command_display = all_argv.join(" ");
229
230 let cancel_state = Arc::new(CancelState {
231 cancelled: AtomicBool::new(false),
232 timed_out: AtomicBool::new(false),
233 });
234
235 {
236 let mut store = HANDLE_STORE
237 .lock()
238 .expect("long-running handle store poisoned");
239 store.entries.insert(
240 handle_id.clone(),
241 HandleEntry {
242 handle: Some(handle),
243 killer,
244 session_id: options.session_id.clone(),
245 cancel_state: cancel_state.clone(),
246 completion_tx: None,
247 result_tx: None,
248 },
249 );
250 }
251
252 let waiter_context = WaiterContext {
253 command_id: command_id.clone(),
254 handle_id: handle_id.clone(),
255 session_id: options.session_id,
256 started_at: started_at.clone(),
257 process_group_id,
258 command_display: command_display.clone(),
259 progress_interval: options.progress_interval,
260 progress_max_inline_bytes: options.progress_max_inline_bytes,
261 };
262 let waiter_thread_name = waiter_context.handle_id.clone();
263 let capture = options.capture;
264 std::thread::Builder::new()
265 .name(format!("hto-waiter-{waiter_thread_name}"))
266 .spawn(move || {
267 waiter_thread(waiter_context, cancel_state, capture);
268 })
269 .map_err(|e| HostlibError::Backend {
270 builtin,
271 message: format!("failed to spawn waiter thread: {e}"),
272 })?;
273
274 Ok(LongRunningHandleInfo {
275 command_id,
276 handle_id,
277 started_at,
278 pid,
279 process_group_id,
280 command_display,
281 })
282}
283
284fn waiter_thread(context: WaiterContext, cancel_state: Arc<CancelState>, capture: CaptureConfig) {
286 let waiter_start = std::time::Instant::now();
287
288 let mut handle = {
291 let mut store = HANDLE_STORE
292 .lock()
293 .expect("long-running handle store poisoned");
294 match store.entries.get_mut(&context.handle_id) {
295 Some(entry) => match entry.handle.take() {
296 Some(h) => h,
297 None => return, },
299 None => return, }
301 };
302
303 let output_state = Arc::new(Mutex::new(OutputState::default()));
304 let done = Arc::new(AtomicBool::new(false));
305 let planned = proc::planned_artifact_paths(&context.command_id);
306 if let Some(parent) = planned.output_path.parent() {
307 let _ = std::fs::create_dir_all(parent);
308 }
309 let _ = std::fs::File::create(&planned.stdout_path);
310 let _ = std::fs::File::create(&planned.stderr_path);
311 let combined_file = std::fs::File::create(&planned.output_path)
312 .ok()
313 .map(|file| Arc::new(Mutex::new(file)));
314
315 let stdout_thread = handle.take_stdout().map(|out| {
316 spawn_output_drain(
317 out,
318 output_state.clone(),
319 planned.stdout_path.clone(),
320 combined_file.clone(),
321 true,
322 )
323 });
324 let stderr_thread = handle.take_stderr().map(|err| {
325 spawn_output_drain(
326 err,
327 output_state.clone(),
328 planned.stderr_path.clone(),
329 combined_file.clone(),
330 false,
331 )
332 });
333
334 let progress_thread = context
335 .progress_interval
336 .filter(|interval| !interval.is_zero())
337 .map(|interval| {
338 spawn_progress_thread(ProgressThreadContext {
339 command_id: context.command_id.clone(),
340 handle_id: context.handle_id.clone(),
341 session_id: context.session_id.clone(),
342 started_at: context.started_at.clone(),
343 command_display: context.command_display.clone(),
344 process_group_id: context.process_group_id,
345 output_path: planned.output_path.clone(),
346 stdout_path: planned.stdout_path.clone(),
347 stderr_path: planned.stderr_path.clone(),
348 output_state: output_state.clone(),
349 cancel_state: cancel_state.clone(),
350 done: done.clone(),
351 started: waiter_start,
352 interval,
353 max_inline_bytes: context.progress_max_inline_bytes,
354 })
355 });
356
357 let status = handle.wait().ok();
358
359 if let Some(thread) = stdout_thread {
360 let _ = thread.join();
361 }
362 if let Some(thread) = stderr_thread {
363 let _ = thread.join();
364 }
365 done.store(true, Ordering::Release);
366 drop(progress_thread);
367 let (stdout, stderr) = {
368 let state = output_state
369 .lock()
370 .unwrap_or_else(|poison| poison.into_inner());
371 (state.stdout.clone(), state.stderr.clone())
372 };
373
374 let (completion_tx, result_tx) = {
377 let mut store = HANDLE_STORE
378 .lock()
379 .expect("long-running handle store poisoned");
380 let entry = store
381 .entries
382 .remove(&context.handle_id)
383 .map(|mut e| (e.completion_tx.take(), e.result_tx.take()));
384 entry.unwrap_or((None, None))
385 };
386
387 let signal_done = move || {
388 if let Some(tx) = completion_tx {
389 let _ = tx.try_send(());
390 }
391 };
392
393 let cancelled = cancel_state.cancelled.load(Ordering::Acquire);
394 let timed_out = cancelled && cancel_state.timed_out.load(Ordering::Acquire);
395
396 let (exit_code, signal_name) = match status {
397 Some(s) => decode_exit_status(s),
398 None => (-1, Some("SIGKILL".to_string())),
400 };
401 let command_status = if timed_out {
402 CommandStatus::TimedOut
403 } else if cancelled {
404 CommandStatus::Killed
405 } else {
406 CommandStatus::Completed
407 };
408 let duration = waiter_start.elapsed();
409 let duration_ms = duration.as_millis() as i64;
410 let artifacts = match proc::persist_artifacts(
411 &context.command_id,
412 &stdout,
413 &stderr,
414 Some(&context.handle_id),
415 ) {
416 Ok(artifacts) => artifacts,
417 Err(_) => return,
418 };
419 let (inline_stdout, inline_stderr) = proc::inline_output(&stdout, &stderr, capture);
420
421 let mut payload = serde_json::Map::new();
422 payload.insert(
423 "command_id".into(),
424 serde_json::Value::String(context.command_id.clone()),
425 );
426 payload.insert(
427 "status".into(),
428 serde_json::Value::String(command_status.as_str().to_string()),
429 );
430 payload.insert(
431 "handle_id".into(),
432 serde_json::Value::String(context.handle_id),
433 );
434 payload.insert(
435 "command_or_op_descriptor".into(),
436 serde_json::Value::String(context.command_display),
437 );
438 payload.insert(
439 "started_at".into(),
440 serde_json::Value::String(context.started_at),
441 );
442 payload.insert(
443 "ended_at".into(),
444 serde_json::Value::String(proc::now_rfc3339()),
445 );
446 payload.insert(
447 "duration_ms".into(),
448 serde_json::Value::Number(duration_ms.into()),
449 );
450 payload.insert(
451 "exit_code".into(),
452 serde_json::Value::Number(exit_code.into()),
453 );
454 payload.insert("timed_out".into(), serde_json::Value::Bool(timed_out));
455 payload.insert("stdout".into(), serde_json::Value::String(inline_stdout));
456 payload.insert("stderr".into(), serde_json::Value::String(inline_stderr));
457 payload.insert(
458 "output_path".into(),
459 serde_json::Value::String(artifacts.output_path.display().to_string()),
460 );
461 payload.insert(
462 "stdout_path".into(),
463 serde_json::Value::String(artifacts.stdout_path.display().to_string()),
464 );
465 payload.insert(
466 "stderr_path".into(),
467 serde_json::Value::String(artifacts.stderr_path.display().to_string()),
468 );
469 payload.insert(
470 "line_count".into(),
471 serde_json::Value::Number(artifacts.line_count.into()),
472 );
473 payload.insert(
474 "byte_count".into(),
475 serde_json::Value::Number(artifacts.byte_count.into()),
476 );
477 payload.insert(
478 "output_sha256".into(),
479 serde_json::Value::String(artifacts.output_sha256),
480 );
481 if let Some(pgid) = context.process_group_id {
482 payload.insert(
483 "process_group_id".into(),
484 serde_json::Value::Number((pgid as u64).into()),
485 );
486 }
487 if let Some(sig) = signal_name {
488 payload.insert("signal".into(), serde_json::Value::String(sig));
489 } else {
490 payload.insert("signal".into(), serde_json::Value::Null);
491 }
492
493 if let Some(tx) = result_tx {
494 let value = serde_json::Value::Object(payload.clone());
495 let _ = tx.try_send(harn_vm::json_to_vm_value(&value));
496 }
497 if !cancelled {
498 let content = serde_json::to_string(&payload).unwrap_or_default();
499 harn_vm::orchestration::agent_inbox::push(
500 &context.session_id,
501 "tool_result",
502 &content,
503 "hostlib.long_running.exit",
504 );
505 }
506 signal_done();
507}
508
509fn spawn_output_drain(
510 mut reader: Box<dyn Read + Send>,
511 state: Arc<Mutex<OutputState>>,
512 path: std::path::PathBuf,
513 combined_file: Option<Arc<Mutex<std::fs::File>>>,
514 stdout: bool,
515) -> std::thread::JoinHandle<()> {
516 std::thread::spawn(move || {
517 let mut file = std::fs::File::create(path).ok();
518 let mut buf = [0_u8; 8192];
519 loop {
520 let read = match reader.read(&mut buf) {
521 Ok(0) => break,
522 Ok(read) => read,
523 Err(_) => break,
524 };
525 let chunk = &buf[..read];
526 if let Some(file) = file.as_mut() {
527 let _ = file.write_all(chunk);
528 }
529 if let Some(combined) = combined_file.as_ref() {
530 if let Ok(mut combined) = combined.lock() {
531 let _ = combined.write_all(chunk);
532 }
533 }
534 if let Ok(mut state) = state.lock() {
535 if stdout {
536 state.stdout.extend_from_slice(chunk);
537 } else {
538 state.stderr.extend_from_slice(chunk);
539 }
540 }
541 }
542 })
543}
544
545fn spawn_progress_thread(context: ProgressThreadContext) -> std::thread::JoinHandle<()> {
546 std::thread::spawn(move || {
547 while !context.done.load(Ordering::Acquire)
548 && !context.cancel_state.cancelled.load(Ordering::Acquire)
549 {
550 std::thread::sleep(context.interval);
551 if context.done.load(Ordering::Acquire)
552 || context.cancel_state.cancelled.load(Ordering::Acquire)
553 {
554 break;
555 }
556 let (stdout, stderr) = {
557 let state = context
558 .output_state
559 .lock()
560 .unwrap_or_else(|poison| poison.into_inner());
561 (state.stdout.clone(), state.stderr.clone())
562 };
563 let capture = CaptureConfig {
564 max_inline_bytes: context.max_inline_bytes,
565 ..CaptureConfig::default()
566 };
567 let (inline_stdout, inline_stderr) = proc::inline_output(&stdout, &stderr, capture);
568 let byte_count = stdout.len().saturating_add(stderr.len());
569 let payload = serde_json::json!({
570 "command_id": &context.command_id,
571 "handle_id": &context.handle_id,
572 "status": CommandStatus::Running.as_str(),
573 "command_or_op_descriptor": &context.command_display,
574 "started_at": &context.started_at,
575 "ended_at": null,
576 "duration_ms": context.started.elapsed().as_millis() as i64,
577 "exit_code": null,
578 "signal": null,
579 "stdout": inline_stdout,
580 "stderr": inline_stderr,
581 "output_path": context.output_path.display().to_string(),
582 "stdout_path": context.stdout_path.display().to_string(),
583 "stderr_path": context.stderr_path.display().to_string(),
584 "byte_count": byte_count as i64,
585 "line_count": stdout.iter().chain(stderr.iter()).filter(|byte| **byte == b'\n').count() as i64,
586 "process_group_id": context.process_group_id,
587 });
588 harn_vm::orchestration::agent_inbox::push(
589 &context.session_id,
590 "tool_progress",
591 &payload.to_string(),
592 "hostlib.long_running.progress",
593 );
594 }
595 })
596}
597
598pub(crate) struct CancelOptions {
599 pub(crate) timed_out: bool,
600 pub(crate) wait_result: Option<Duration>,
601}
602
603pub(crate) struct CancelOutcome {
604 pub(crate) cancelled: bool,
605 pub(crate) result: Option<VmValue>,
606}
607
608pub fn cancel_handle(handle_id: &str) -> bool {
612 cancel_handle_with_options(
613 handle_id,
614 CancelOptions {
615 timed_out: false,
616 wait_result: None,
617 },
618 )
619 .cancelled
620}
621
622pub(crate) fn cancel_handle_with_options(handle_id: &str, options: CancelOptions) -> CancelOutcome {
623 let (killer, cancel_state, result_rx) = {
624 let mut store = HANDLE_STORE
625 .lock()
626 .expect("long-running handle store poisoned");
627 let Some(entry) = store.entries.get_mut(handle_id) else {
628 return CancelOutcome {
629 cancelled: false,
630 result: None,
631 };
632 };
633 if entry.cancel_state.cancelled.swap(true, Ordering::AcqRel) {
634 return CancelOutcome {
635 cancelled: false,
636 result: None,
637 };
638 }
639 entry
640 .cancel_state
641 .timed_out
642 .store(options.timed_out, Ordering::Release);
643 let result_rx = options.wait_result.map(|_| {
644 let (tx, rx) = std::sync::mpsc::sync_channel::<VmValue>(1);
645 entry.result_tx = Some(tx);
646 rx
647 });
648 (entry.killer.clone(), entry.cancel_state.clone(), result_rx)
649 };
650 do_kill(killer, cancel_state);
651 let result = match (options.wait_result, result_rx) {
652 (Some(timeout), Some(rx)) => rx.recv_timeout(timeout).ok(),
653 _ => None,
654 };
655 CancelOutcome {
656 cancelled: true,
657 result,
658 }
659}
660
661type SessionKillEntry = (Arc<dyn ProcessKiller>, Arc<CancelState>);
665
666pub fn cancel_session_handles(session_id: &str) {
669 let to_kill: Vec<SessionKillEntry> = {
670 let store = HANDLE_STORE
671 .lock()
672 .expect("long-running handle store poisoned");
673 let matching: Vec<String> = store
674 .entries
675 .iter()
676 .filter(|(_, e)| e.session_id == session_id)
677 .map(|(id, _)| id.clone())
678 .collect();
679 matching
680 .into_iter()
681 .filter_map(|id| {
682 let entry = store.entries.get(&id)?;
683 if entry.cancel_state.cancelled.swap(true, Ordering::AcqRel) {
684 return None;
685 }
686 entry.cancel_state.timed_out.store(false, Ordering::Release);
687 Some((entry.killer.clone(), entry.cancel_state.clone()))
688 })
689 .collect()
690 };
691 for (killer, cancel_state) in to_kill {
692 do_kill(killer, cancel_state);
693 }
694}
695
696fn do_kill(killer: Arc<dyn ProcessKiller>, cancel_state: Arc<CancelState>) {
699 killer.kill();
702 cancel_state.cancelled.store(true, Ordering::Release);
703}
704
705pub(crate) fn register_cleanup_hook() {
709 static REGISTERED: OnceLock<()> = OnceLock::new();
710 REGISTERED.get_or_init(|| {
711 let hook: Arc<dyn Fn(&str) + Send + Sync> = Arc::new(|session_id: &str| {
712 cancel_session_handles(session_id);
713 });
714 harn_vm::register_session_end_hook(hook);
715 });
716}
717
718fn decode_exit_status(status: process_handle::ExitStatus) -> (i32, Option<String>) {
719 if let Some(code) = status.code {
720 return (code, None);
721 }
722 if let Some(sig) = status.signal {
723 return (-1, Some(format!("SIG{sig}")));
724 }
725 (-1, None)
726}
727
728pub fn register_completion_notifier(handle_id: &str) -> Option<std::sync::mpsc::Receiver<()>> {
734 let (tx, rx) = std::sync::mpsc::sync_channel::<()>(1);
735 let mut store = HANDLE_STORE
736 .lock()
737 .expect("long-running handle store poisoned");
738 let entry = store.entries.get_mut(handle_id)?;
739 entry.completion_tx = Some(tx);
740 Some(rx)
741}