1use std::collections::BTreeMap;
39use std::io::{Read, Write};
40use std::path::PathBuf;
41use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
42use std::sync::{Arc, LazyLock, Mutex, OnceLock};
43use std::time::Duration;
44
45use harn_vm::VmValue;
46
47use crate::error::HostlibError;
48use crate::process::{self as process_handle, ProcessHandle, ProcessKiller, SpawnSpec};
49use crate::tools::args::to_agent_path;
50use crate::tools::proc::{self, CaptureConfig, CommandStatus, EnvMode};
51
52static HANDLE_COUNTER: AtomicU64 = AtomicU64::new(1);
54
55struct CancelState {
57 cancelled: AtomicBool,
60 timed_out: AtomicBool,
64}
65
66#[derive(Default)]
67struct OutputState {
68 stdout: Vec<u8>,
69 stderr: Vec<u8>,
70}
71
72struct HandleEntry {
74 handle: Option<Box<dyn ProcessHandle>>,
76 killer: Arc<dyn ProcessKiller>,
78 session_id: String,
79 cancel_state: Arc<CancelState>,
81 completion_tx: Option<std::sync::mpsc::SyncSender<()>>,
85 result_tx: Option<std::sync::mpsc::SyncSender<VmValue>>,
88}
89
90#[derive(Default)]
91struct HandleStore {
92 entries: BTreeMap<String, HandleEntry>,
93}
94
95static HANDLE_STORE: LazyLock<Mutex<HandleStore>> =
96 LazyLock::new(|| Mutex::new(HandleStore::default()));
97
98pub struct LongRunningHandleInfo {
101 pub command_id: String,
103 pub handle_id: String,
105 pub started_at: String,
107 pub pid: u32,
109 pub process_group_id: Option<u32>,
111 pub command_display: String,
113}
114
115pub(crate) struct LongRunningSpawnOptions {
116 pub(crate) env_mode: EnvMode,
117 pub(crate) capture: CaptureConfig,
118 pub(crate) session_id: String,
119 pub(crate) progress_interval: Option<Duration>,
120 pub(crate) progress_max_inline_bytes: usize,
121}
122
123struct WaiterContext {
124 command_id: String,
125 handle_id: String,
126 session_id: String,
127 started_at: String,
128 process_group_id: Option<u32>,
129 command_display: String,
130 progress_interval: Option<Duration>,
131 progress_max_inline_bytes: usize,
132}
133
134struct ProgressThreadContext {
135 command_id: String,
136 handle_id: String,
137 session_id: String,
138 started_at: String,
139 command_display: String,
140 process_group_id: Option<u32>,
141 output_path: PathBuf,
142 stdout_path: PathBuf,
143 stderr_path: PathBuf,
144 output_state: Arc<Mutex<OutputState>>,
145 cancel_state: Arc<CancelState>,
146 done: Arc<AtomicBool>,
147 started: std::time::Instant,
148 interval: Duration,
149 max_inline_bytes: usize,
150}
151
152impl LongRunningHandleInfo {
153 pub fn into_handle_response(self) -> VmValue {
155 proc::running_response(
156 self.command_id,
157 self.handle_id,
158 self.pid,
159 self.process_group_id,
160 self.started_at,
161 self.command_display,
162 )
163 }
164}
165
166pub fn spawn_long_running(
172 builtin: &'static str,
173 program: String,
174 args: Vec<String>,
175 cwd: Option<PathBuf>,
176 env: BTreeMap<String, String>,
177 session_id: String,
178) -> Result<LongRunningHandleInfo, HostlibError> {
179 spawn_long_running_with_options(
180 builtin,
181 program,
182 args,
183 cwd,
184 env,
185 LongRunningSpawnOptions {
186 env_mode: EnvMode::InheritClean,
187 capture: CaptureConfig::default(),
188 session_id,
189 progress_interval: None,
190 progress_max_inline_bytes: CaptureConfig::default().max_inline_bytes,
191 },
192 )
193}
194
195pub(crate) fn spawn_long_running_with_options(
196 builtin: &'static str,
197 program: String,
198 args: Vec<String>,
199 cwd: Option<PathBuf>,
200 env: BTreeMap<String, String>,
201 options: LongRunningSpawnOptions,
202) -> Result<LongRunningHandleInfo, HostlibError> {
203 let mut env = env;
204 proc::apply_toolchain_path(cwd.as_deref(), &mut env, options.env_mode);
205 let spec = SpawnSpec {
206 builtin,
207 program: program.clone(),
208 args: args.clone(),
209 cwd,
210 env,
211 env_mode: options.env_mode,
212 use_stdin: false,
213 configure_process_group: true,
214 };
215 let handle = process_handle::spawn_process(spec)
216 .map_err(|e| proc::process_error_to_hostlib(builtin, e))?;
217
218 let pid = handle.pid().unwrap_or(0);
219 let process_group_id = handle.process_group_id();
220 let killer = handle.killer();
221 let id = HANDLE_COUNTER.fetch_add(1, Ordering::SeqCst);
222 let handle_id = format!("hto-{:x}-{id}", std::process::id());
223 let command_id = proc::next_command_id();
224 let started_at = proc::now_rfc3339();
225 let _artifacts = proc::register_live_artifacts(&command_id, Some(&handle_id))?;
226
227 let mut all_argv = vec![program];
228 all_argv.extend(args.iter().cloned());
229 let command_display = all_argv.join(" ");
230
231 let cancel_state = Arc::new(CancelState {
232 cancelled: AtomicBool::new(false),
233 timed_out: AtomicBool::new(false),
234 });
235
236 {
237 let mut store = HANDLE_STORE
238 .lock()
239 .expect("long-running handle store poisoned");
240 store.entries.insert(
241 handle_id.clone(),
242 HandleEntry {
243 handle: Some(handle),
244 killer,
245 session_id: options.session_id.clone(),
246 cancel_state: cancel_state.clone(),
247 completion_tx: None,
248 result_tx: None,
249 },
250 );
251 }
252
253 let waiter_context = WaiterContext {
254 command_id: command_id.clone(),
255 handle_id: handle_id.clone(),
256 session_id: options.session_id,
257 started_at: started_at.clone(),
258 process_group_id,
259 command_display: command_display.clone(),
260 progress_interval: options.progress_interval,
261 progress_max_inline_bytes: options.progress_max_inline_bytes,
262 };
263 let waiter_thread_name = waiter_context.handle_id.clone();
264 let capture = options.capture;
265 std::thread::Builder::new()
266 .name(format!("hto-waiter-{waiter_thread_name}"))
267 .spawn(move || {
268 waiter_thread(waiter_context, cancel_state, capture);
269 })
270 .map_err(|e| HostlibError::Backend {
271 builtin,
272 message: format!("failed to spawn waiter thread: {e}"),
273 })?;
274
275 Ok(LongRunningHandleInfo {
276 command_id,
277 handle_id,
278 started_at,
279 pid,
280 process_group_id,
281 command_display,
282 })
283}
284
285fn waiter_thread(context: WaiterContext, cancel_state: Arc<CancelState>, capture: CaptureConfig) {
287 let waiter_start = std::time::Instant::now();
288
289 let mut handle = {
292 let mut store = HANDLE_STORE
293 .lock()
294 .expect("long-running handle store poisoned");
295 match store.entries.get_mut(&context.handle_id) {
296 Some(entry) => match entry.handle.take() {
297 Some(h) => h,
298 None => return, },
300 None => return, }
302 };
303
304 let output_state = Arc::new(Mutex::new(OutputState::default()));
305 let done = Arc::new(AtomicBool::new(false));
306 let planned = proc::planned_artifact_paths(&context.command_id);
307 if let Some(parent) = planned.output_path.parent() {
308 let _ = std::fs::create_dir_all(parent);
309 }
310 let _ = std::fs::File::create(&planned.stdout_path);
311 let _ = std::fs::File::create(&planned.stderr_path);
312 let combined_file = std::fs::File::create(&planned.output_path)
313 .ok()
314 .map(|file| Arc::new(Mutex::new(file)));
315
316 let stdout_thread = handle.take_stdout().map(|out| {
317 spawn_output_drain(
318 out,
319 output_state.clone(),
320 planned.stdout_path.clone(),
321 combined_file.clone(),
322 true,
323 )
324 });
325 let stderr_thread = handle.take_stderr().map(|err| {
326 spawn_output_drain(
327 err,
328 output_state.clone(),
329 planned.stderr_path.clone(),
330 combined_file.clone(),
331 false,
332 )
333 });
334
335 let progress_thread = context
336 .progress_interval
337 .filter(|interval| !interval.is_zero())
338 .map(|interval| {
339 spawn_progress_thread(ProgressThreadContext {
340 command_id: context.command_id.clone(),
341 handle_id: context.handle_id.clone(),
342 session_id: context.session_id.clone(),
343 started_at: context.started_at.clone(),
344 command_display: context.command_display.clone(),
345 process_group_id: context.process_group_id,
346 output_path: planned.output_path.clone(),
347 stdout_path: planned.stdout_path.clone(),
348 stderr_path: planned.stderr_path.clone(),
349 output_state: output_state.clone(),
350 cancel_state: cancel_state.clone(),
351 done: done.clone(),
352 started: waiter_start,
353 interval,
354 max_inline_bytes: context.progress_max_inline_bytes,
355 })
356 });
357
358 let status = handle.wait().ok();
359
360 if let Some(thread) = stdout_thread {
361 let _ = thread.join();
362 }
363 if let Some(thread) = stderr_thread {
364 let _ = thread.join();
365 }
366 done.store(true, Ordering::Release);
367 drop(progress_thread);
368 let (stdout, stderr) = {
369 let state = output_state
370 .lock()
371 .unwrap_or_else(|poison| poison.into_inner());
372 (state.stdout.clone(), state.stderr.clone())
373 };
374
375 let (completion_tx, result_tx) = {
378 let mut store = HANDLE_STORE
379 .lock()
380 .expect("long-running handle store poisoned");
381 let entry = store
382 .entries
383 .remove(&context.handle_id)
384 .map(|mut e| (e.completion_tx.take(), e.result_tx.take()));
385 entry.unwrap_or((None, None))
386 };
387
388 let signal_done = move || {
389 if let Some(tx) = completion_tx {
390 let _ = tx.try_send(());
391 }
392 };
393
394 let cancelled = cancel_state.cancelled.load(Ordering::Acquire);
395 let timed_out = cancelled && cancel_state.timed_out.load(Ordering::Acquire);
396
397 let (exit_code, signal_name) = match status {
398 Some(s) => decode_exit_status(s),
399 None => (-1, Some("SIGKILL".to_string())),
401 };
402 let command_status = if timed_out {
403 CommandStatus::TimedOut
404 } else if cancelled {
405 CommandStatus::Killed
406 } else {
407 CommandStatus::Completed
408 };
409 let duration = waiter_start.elapsed();
410 let duration_ms = duration.as_millis() as i64;
411 let artifacts = match proc::persist_artifacts(
412 &context.command_id,
413 &stdout,
414 &stderr,
415 Some(&context.handle_id),
416 ) {
417 Ok(artifacts) => artifacts,
418 Err(_) => return,
419 };
420 let (inline_stdout, inline_stderr) = proc::inline_output(&stdout, &stderr, capture);
421
422 let mut payload = serde_json::Map::new();
423 payload.insert(
424 "command_id".into(),
425 serde_json::Value::String(context.command_id.clone()),
426 );
427 payload.insert(
428 "status".into(),
429 serde_json::Value::String(command_status.as_str().to_string()),
430 );
431 payload.insert(
432 "handle_id".into(),
433 serde_json::Value::String(context.handle_id),
434 );
435 payload.insert(
436 "command_or_op_descriptor".into(),
437 serde_json::Value::String(context.command_display),
438 );
439 payload.insert(
440 "started_at".into(),
441 serde_json::Value::String(context.started_at),
442 );
443 payload.insert(
444 "ended_at".into(),
445 serde_json::Value::String(proc::now_rfc3339()),
446 );
447 payload.insert(
448 "duration_ms".into(),
449 serde_json::Value::Number(duration_ms.into()),
450 );
451 payload.insert(
452 "exit_code".into(),
453 serde_json::Value::Number(exit_code.into()),
454 );
455 payload.insert("timed_out".into(), serde_json::Value::Bool(timed_out));
456 payload.insert("stdout".into(), serde_json::Value::String(inline_stdout));
457 payload.insert("stderr".into(), serde_json::Value::String(inline_stderr));
458 payload.insert(
459 "output_path".into(),
460 serde_json::Value::String(to_agent_path(&artifacts.output_path)),
461 );
462 payload.insert(
463 "stdout_path".into(),
464 serde_json::Value::String(to_agent_path(&artifacts.stdout_path)),
465 );
466 payload.insert(
467 "stderr_path".into(),
468 serde_json::Value::String(to_agent_path(&artifacts.stderr_path)),
469 );
470 payload.insert(
471 "line_count".into(),
472 serde_json::Value::Number(artifacts.line_count.into()),
473 );
474 payload.insert(
475 "byte_count".into(),
476 serde_json::Value::Number(artifacts.byte_count.into()),
477 );
478 payload.insert(
479 "output_sha256".into(),
480 serde_json::Value::String(artifacts.output_sha256),
481 );
482 if let Some(pgid) = context.process_group_id {
483 payload.insert(
484 "process_group_id".into(),
485 serde_json::Value::Number((pgid as u64).into()),
486 );
487 }
488 if let Some(sig) = signal_name {
489 payload.insert("signal".into(), serde_json::Value::String(sig));
490 } else {
491 payload.insert("signal".into(), serde_json::Value::Null);
492 }
493
494 if let Some(tx) = result_tx {
495 let value = serde_json::Value::Object(payload.clone());
496 let _ = tx.try_send(harn_vm::json_to_vm_value(&value));
497 }
498 if !cancelled {
499 let content = serde_json::to_string(&payload).unwrap_or_default();
500 harn_vm::orchestration::agent_inbox::push(
501 &context.session_id,
502 "tool_result",
503 &content,
504 "hostlib.long_running.exit",
505 );
506 }
507 signal_done();
508}
509
510fn spawn_output_drain(
511 mut reader: Box<dyn Read + Send>,
512 state: Arc<Mutex<OutputState>>,
513 path: std::path::PathBuf,
514 combined_file: Option<Arc<Mutex<std::fs::File>>>,
515 stdout: bool,
516) -> std::thread::JoinHandle<()> {
517 std::thread::spawn(move || {
518 let mut file = std::fs::File::create(path).ok();
519 let mut buf = [0_u8; 8192];
520 loop {
521 let read = match reader.read(&mut buf) {
522 Ok(0) => break,
523 Ok(read) => read,
524 Err(_) => break,
525 };
526 let chunk = &buf[..read];
527 if let Some(file) = file.as_mut() {
528 let _ = file.write_all(chunk);
529 }
530 if let Some(combined) = combined_file.as_ref() {
531 if let Ok(mut combined) = combined.lock() {
532 let _ = combined.write_all(chunk);
533 }
534 }
535 if let Ok(mut state) = state.lock() {
536 if stdout {
537 state.stdout.extend_from_slice(chunk);
538 } else {
539 state.stderr.extend_from_slice(chunk);
540 }
541 }
542 }
543 })
544}
545
546fn spawn_progress_thread(context: ProgressThreadContext) -> std::thread::JoinHandle<()> {
547 std::thread::spawn(move || {
548 while !context.done.load(Ordering::Acquire)
549 && !context.cancel_state.cancelled.load(Ordering::Acquire)
550 {
551 std::thread::sleep(context.interval);
552 if context.done.load(Ordering::Acquire)
553 || context.cancel_state.cancelled.load(Ordering::Acquire)
554 {
555 break;
556 }
557 let (stdout, stderr) = {
558 let state = context
559 .output_state
560 .lock()
561 .unwrap_or_else(|poison| poison.into_inner());
562 (state.stdout.clone(), state.stderr.clone())
563 };
564 let capture = CaptureConfig {
565 max_inline_bytes: context.max_inline_bytes,
566 ..CaptureConfig::default()
567 };
568 let (inline_stdout, inline_stderr) = proc::inline_output(&stdout, &stderr, capture);
569 let byte_count = stdout.len().saturating_add(stderr.len());
570 let payload = serde_json::json!({
571 "command_id": &context.command_id,
572 "handle_id": &context.handle_id,
573 "status": CommandStatus::Running.as_str(),
574 "command_or_op_descriptor": &context.command_display,
575 "started_at": &context.started_at,
576 "ended_at": null,
577 "duration_ms": context.started.elapsed().as_millis() as i64,
578 "exit_code": null,
579 "signal": null,
580 "stdout": inline_stdout,
581 "stderr": inline_stderr,
582 "output_path": to_agent_path(&context.output_path),
583 "stdout_path": to_agent_path(&context.stdout_path),
584 "stderr_path": to_agent_path(&context.stderr_path),
585 "byte_count": byte_count as i64,
586 "line_count": stdout.iter().chain(stderr.iter()).filter(|byte| **byte == b'\n').count() as i64,
587 "process_group_id": context.process_group_id,
588 });
589 harn_vm::orchestration::agent_inbox::push(
590 &context.session_id,
591 "tool_progress",
592 &payload.to_string(),
593 "hostlib.long_running.progress",
594 );
595 }
596 })
597}
598
599pub(crate) struct CancelOptions {
600 pub(crate) timed_out: bool,
601 pub(crate) wait_result: Option<Duration>,
602}
603
604pub(crate) struct CancelOutcome {
605 pub(crate) cancelled: bool,
606 pub(crate) result: Option<VmValue>,
607}
608
609pub fn cancel_handle(handle_id: &str) -> bool {
613 cancel_handle_with_options(
614 handle_id,
615 CancelOptions {
616 timed_out: false,
617 wait_result: None,
618 },
619 )
620 .cancelled
621}
622
623pub(crate) fn cancel_handle_with_options(handle_id: &str, options: CancelOptions) -> CancelOutcome {
624 let (killer, cancel_state, result_rx) = {
625 let mut store = HANDLE_STORE
626 .lock()
627 .expect("long-running handle store poisoned");
628 let Some(entry) = store.entries.get_mut(handle_id) else {
629 return CancelOutcome {
630 cancelled: false,
631 result: None,
632 };
633 };
634 if entry.cancel_state.cancelled.swap(true, Ordering::AcqRel) {
635 return CancelOutcome {
636 cancelled: false,
637 result: None,
638 };
639 }
640 entry
641 .cancel_state
642 .timed_out
643 .store(options.timed_out, Ordering::Release);
644 let result_rx = options.wait_result.map(|_| {
645 let (tx, rx) = std::sync::mpsc::sync_channel::<VmValue>(1);
646 entry.result_tx = Some(tx);
647 rx
648 });
649 (entry.killer.clone(), entry.cancel_state.clone(), result_rx)
650 };
651 do_kill(killer, cancel_state);
652 let result = match (options.wait_result, result_rx) {
653 (Some(timeout), Some(rx)) => rx.recv_timeout(timeout).ok(),
654 _ => None,
655 };
656 CancelOutcome {
657 cancelled: true,
658 result,
659 }
660}
661
662type SessionKillEntry = (Arc<dyn ProcessKiller>, Arc<CancelState>);
666
667pub fn cancel_session_handles(session_id: &str) {
670 let to_kill: Vec<SessionKillEntry> = {
671 let store = HANDLE_STORE
672 .lock()
673 .expect("long-running handle store poisoned");
674 let matching: Vec<String> = store
675 .entries
676 .iter()
677 .filter(|(_, e)| e.session_id == session_id)
678 .map(|(id, _)| id.clone())
679 .collect();
680 matching
681 .into_iter()
682 .filter_map(|id| {
683 let entry = store.entries.get(&id)?;
684 if entry.cancel_state.cancelled.swap(true, Ordering::AcqRel) {
685 return None;
686 }
687 entry.cancel_state.timed_out.store(false, Ordering::Release);
688 Some((entry.killer.clone(), entry.cancel_state.clone()))
689 })
690 .collect()
691 };
692 for (killer, cancel_state) in to_kill {
693 do_kill(killer, cancel_state);
694 }
695}
696
697fn do_kill(killer: Arc<dyn ProcessKiller>, cancel_state: Arc<CancelState>) {
700 killer.kill();
703 cancel_state.cancelled.store(true, Ordering::Release);
704}
705
706pub(crate) fn register_cleanup_hook() {
710 static REGISTERED: OnceLock<()> = OnceLock::new();
711 REGISTERED.get_or_init(|| {
712 let hook: Arc<dyn Fn(&str) + Send + Sync> = Arc::new(|session_id: &str| {
713 cancel_session_handles(session_id);
714 });
715 harn_vm::register_session_end_hook(hook);
716 });
717}
718
719fn decode_exit_status(status: process_handle::ExitStatus) -> (i32, Option<String>) {
720 if let Some(code) = status.code {
721 return (code, None);
722 }
723 if let Some(sig) = status.signal {
724 return (-1, Some(format!("SIG{sig}")));
725 }
726 (-1, None)
727}
728
729pub fn register_completion_notifier(handle_id: &str) -> Option<std::sync::mpsc::Receiver<()>> {
735 let (tx, rx) = std::sync::mpsc::sync_channel::<()>(1);
736 let mut store = HANDLE_STORE
737 .lock()
738 .expect("long-running handle store poisoned");
739 let entry = store.entries.get_mut(handle_id)?;
740 entry.completion_tx = Some(tx);
741 Some(rx)
742}