1use std::collections::BTreeMap;
37use std::io::{Read, Write};
38use std::path::PathBuf;
39use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
40use std::sync::{Arc, LazyLock, Mutex, OnceLock};
41use std::time::Duration;
42
43use harn_vm::VmValue;
44
45use crate::error::HostlibError;
46use crate::process::{self as process_handle, ProcessHandle, ProcessKiller, SpawnSpec};
47use crate::tools::proc::{self, CaptureConfig, CommandStatus, EnvMode};
48
49static HANDLE_COUNTER: AtomicU64 = AtomicU64::new(1);
51
52struct CancelState {
54 cancelled: AtomicBool,
57}
58
59#[derive(Default)]
60struct OutputState {
61 stdout: Vec<u8>,
62 stderr: Vec<u8>,
63}
64
65struct HandleEntry {
67 handle: Option<Box<dyn ProcessHandle>>,
69 killer: Arc<dyn ProcessKiller>,
71 session_id: String,
72 cancel_state: Arc<CancelState>,
74 completion_tx: Option<std::sync::mpsc::SyncSender<()>>,
78}
79
80#[derive(Default)]
81struct HandleStore {
82 entries: BTreeMap<String, HandleEntry>,
83}
84
85static HANDLE_STORE: LazyLock<Mutex<HandleStore>> =
86 LazyLock::new(|| Mutex::new(HandleStore::default()));
87
88pub struct LongRunningHandleInfo {
91 pub command_id: String,
93 pub handle_id: String,
95 pub started_at: String,
97 pub pid: u32,
99 pub process_group_id: Option<u32>,
101 pub command_display: String,
103}
104
105pub(crate) struct LongRunningSpawnOptions {
106 pub(crate) env_mode: EnvMode,
107 pub(crate) capture: CaptureConfig,
108 pub(crate) session_id: String,
109 pub(crate) progress_interval: Option<Duration>,
110 pub(crate) progress_max_inline_bytes: usize,
111}
112
113struct WaiterContext {
114 command_id: String,
115 handle_id: String,
116 session_id: String,
117 started_at: String,
118 process_group_id: Option<u32>,
119 command_display: String,
120 progress_interval: Option<Duration>,
121 progress_max_inline_bytes: usize,
122}
123
124struct ProgressThreadContext {
125 command_id: String,
126 handle_id: String,
127 session_id: String,
128 started_at: String,
129 command_display: String,
130 process_group_id: Option<u32>,
131 output_path: PathBuf,
132 stdout_path: PathBuf,
133 stderr_path: PathBuf,
134 output_state: Arc<Mutex<OutputState>>,
135 cancel_state: Arc<CancelState>,
136 done: Arc<AtomicBool>,
137 started: std::time::Instant,
138 interval: Duration,
139 max_inline_bytes: usize,
140}
141
142impl LongRunningHandleInfo {
143 pub fn into_handle_response(self) -> VmValue {
145 proc::running_response(
146 self.command_id,
147 self.handle_id,
148 self.pid,
149 self.process_group_id,
150 self.started_at,
151 self.command_display,
152 )
153 }
154}
155
156pub fn spawn_long_running(
162 builtin: &'static str,
163 program: String,
164 args: Vec<String>,
165 cwd: Option<PathBuf>,
166 env: BTreeMap<String, String>,
167 session_id: String,
168) -> Result<LongRunningHandleInfo, HostlibError> {
169 spawn_long_running_with_options(
170 builtin,
171 program,
172 args,
173 cwd,
174 env,
175 LongRunningSpawnOptions {
176 env_mode: EnvMode::InheritClean,
177 capture: CaptureConfig::default(),
178 session_id,
179 progress_interval: None,
180 progress_max_inline_bytes: CaptureConfig::default().max_inline_bytes,
181 },
182 )
183}
184
185pub(crate) fn spawn_long_running_with_options(
186 builtin: &'static str,
187 program: String,
188 args: Vec<String>,
189 cwd: Option<PathBuf>,
190 env: BTreeMap<String, String>,
191 options: LongRunningSpawnOptions,
192) -> Result<LongRunningHandleInfo, HostlibError> {
193 let spec = SpawnSpec {
194 builtin,
195 program: program.clone(),
196 args: args.clone(),
197 cwd,
198 env,
199 env_mode: options.env_mode,
200 use_stdin: false,
201 configure_process_group: true,
202 };
203 let handle = process_handle::spawn_process(spec)
204 .map_err(|e| proc::process_error_to_hostlib(builtin, e))?;
205
206 let pid = handle.pid().unwrap_or(0);
207 let process_group_id = handle.process_group_id();
208 let killer = handle.killer();
209 let id = HANDLE_COUNTER.fetch_add(1, Ordering::SeqCst);
210 let handle_id = format!("hto-{:x}-{id}", std::process::id());
211 let command_id = proc::next_command_id();
212 let started_at = proc::now_rfc3339();
213
214 let mut all_argv = vec![program.clone()];
215 all_argv.extend(args.iter().cloned());
216 let command_display = all_argv.join(" ");
217
218 let cancel_state = Arc::new(CancelState {
219 cancelled: AtomicBool::new(false),
220 });
221
222 {
223 let mut store = HANDLE_STORE
224 .lock()
225 .expect("long-running handle store poisoned");
226 store.entries.insert(
227 handle_id.clone(),
228 HandleEntry {
229 handle: Some(handle),
230 killer,
231 session_id: options.session_id.clone(),
232 cancel_state: cancel_state.clone(),
233 completion_tx: None,
234 },
235 );
236 }
237
238 let waiter_context = WaiterContext {
239 command_id: command_id.clone(),
240 handle_id: handle_id.clone(),
241 session_id: options.session_id,
242 started_at: started_at.clone(),
243 process_group_id,
244 command_display: command_display.clone(),
245 progress_interval: options.progress_interval,
246 progress_max_inline_bytes: options.progress_max_inline_bytes,
247 };
248 let waiter_thread_name = waiter_context.handle_id.clone();
249 let capture = options.capture;
250 std::thread::Builder::new()
251 .name(format!("hto-waiter-{waiter_thread_name}"))
252 .spawn(move || {
253 waiter_thread(waiter_context, cancel_state, capture);
254 })
255 .map_err(|e| HostlibError::Backend {
256 builtin,
257 message: format!("failed to spawn waiter thread: {e}"),
258 })?;
259
260 Ok(LongRunningHandleInfo {
261 command_id,
262 handle_id,
263 started_at,
264 pid,
265 process_group_id,
266 command_display,
267 })
268}
269
270fn waiter_thread(context: WaiterContext, cancel_state: Arc<CancelState>, capture: CaptureConfig) {
272 let waiter_start = std::time::Instant::now();
273
274 let mut handle = {
277 let mut store = HANDLE_STORE
278 .lock()
279 .expect("long-running handle store poisoned");
280 match store.entries.get_mut(&context.handle_id) {
281 Some(entry) => match entry.handle.take() {
282 Some(h) => h,
283 None => return, },
285 None => return, }
287 };
288
289 let output_state = Arc::new(Mutex::new(OutputState::default()));
290 let done = Arc::new(AtomicBool::new(false));
291 let planned = proc::planned_artifact_paths(&context.command_id);
292 if let Some(parent) = planned.output_path.parent() {
293 let _ = std::fs::create_dir_all(parent);
294 }
295 let _ = std::fs::File::create(&planned.stdout_path);
296 let _ = std::fs::File::create(&planned.stderr_path);
297 let combined_file = std::fs::File::create(&planned.output_path)
298 .ok()
299 .map(|file| Arc::new(Mutex::new(file)));
300
301 let stdout_thread = handle.take_stdout().map(|out| {
302 spawn_output_drain(
303 out,
304 output_state.clone(),
305 planned.stdout_path.clone(),
306 combined_file.clone(),
307 true,
308 )
309 });
310 let stderr_thread = handle.take_stderr().map(|err| {
311 spawn_output_drain(
312 err,
313 output_state.clone(),
314 planned.stderr_path.clone(),
315 combined_file.clone(),
316 false,
317 )
318 });
319
320 let progress_thread = context
321 .progress_interval
322 .filter(|interval| !interval.is_zero())
323 .map(|interval| {
324 spawn_progress_thread(ProgressThreadContext {
325 command_id: context.command_id.clone(),
326 handle_id: context.handle_id.clone(),
327 session_id: context.session_id.clone(),
328 started_at: context.started_at.clone(),
329 command_display: context.command_display.clone(),
330 process_group_id: context.process_group_id,
331 output_path: planned.output_path.clone(),
332 stdout_path: planned.stdout_path.clone(),
333 stderr_path: planned.stderr_path.clone(),
334 output_state: output_state.clone(),
335 cancel_state: cancel_state.clone(),
336 done: done.clone(),
337 started: waiter_start,
338 interval,
339 max_inline_bytes: context.progress_max_inline_bytes,
340 })
341 });
342
343 let status = handle.wait().ok();
344
345 if let Some(thread) = stdout_thread {
346 let _ = thread.join();
347 }
348 if let Some(thread) = stderr_thread {
349 let _ = thread.join();
350 }
351 done.store(true, Ordering::Release);
352 drop(progress_thread);
353 let (stdout, stderr) = {
354 let state = output_state
355 .lock()
356 .unwrap_or_else(|poison| poison.into_inner());
357 (state.stdout.clone(), state.stderr.clone())
358 };
359
360 let completion_tx = {
363 let mut store = HANDLE_STORE
364 .lock()
365 .expect("long-running handle store poisoned");
366 store
367 .entries
368 .remove(&context.handle_id)
369 .and_then(|mut e| e.completion_tx.take())
370 };
371
372 let signal_done = move || {
373 if let Some(tx) = completion_tx {
374 let _ = tx.try_send(());
375 }
376 };
377
378 if cancel_state.cancelled.load(Ordering::Acquire) {
381 signal_done();
382 return;
383 }
384
385 let (exit_code, signal_name) = match status {
386 Some(s) => decode_exit_status(s),
387 None => (-1, Some("SIGKILL".to_string())),
389 };
390 let duration = waiter_start.elapsed();
391 let duration_ms = duration.as_millis() as i64;
392 let artifacts = match proc::persist_artifacts(
393 &context.command_id,
394 &stdout,
395 &stderr,
396 Some(&context.handle_id),
397 ) {
398 Ok(artifacts) => artifacts,
399 Err(_) => return,
400 };
401 let (inline_stdout, inline_stderr) = proc::inline_output(&stdout, &stderr, capture);
402
403 let mut payload = serde_json::Map::new();
404 payload.insert(
405 "command_id".into(),
406 serde_json::Value::String(context.command_id.clone()),
407 );
408 payload.insert(
409 "status".into(),
410 serde_json::Value::String(CommandStatus::Completed.as_str().to_string()),
411 );
412 payload.insert(
413 "handle_id".into(),
414 serde_json::Value::String(context.handle_id),
415 );
416 payload.insert(
417 "command_or_op_descriptor".into(),
418 serde_json::Value::String(context.command_display),
419 );
420 payload.insert(
421 "started_at".into(),
422 serde_json::Value::String(context.started_at),
423 );
424 payload.insert(
425 "ended_at".into(),
426 serde_json::Value::String(proc::now_rfc3339()),
427 );
428 payload.insert(
429 "duration_ms".into(),
430 serde_json::Value::Number(duration_ms.into()),
431 );
432 payload.insert(
433 "exit_code".into(),
434 serde_json::Value::Number(exit_code.into()),
435 );
436 payload.insert("stdout".into(), serde_json::Value::String(inline_stdout));
437 payload.insert("stderr".into(), serde_json::Value::String(inline_stderr));
438 payload.insert(
439 "output_path".into(),
440 serde_json::Value::String(artifacts.output_path.display().to_string()),
441 );
442 payload.insert(
443 "stdout_path".into(),
444 serde_json::Value::String(artifacts.stdout_path.display().to_string()),
445 );
446 payload.insert(
447 "stderr_path".into(),
448 serde_json::Value::String(artifacts.stderr_path.display().to_string()),
449 );
450 payload.insert(
451 "line_count".into(),
452 serde_json::Value::Number(artifacts.line_count.into()),
453 );
454 payload.insert(
455 "byte_count".into(),
456 serde_json::Value::Number(artifacts.byte_count.into()),
457 );
458 payload.insert(
459 "output_sha256".into(),
460 serde_json::Value::String(artifacts.output_sha256),
461 );
462 if let Some(pgid) = context.process_group_id {
463 payload.insert(
464 "process_group_id".into(),
465 serde_json::Value::Number((pgid as u64).into()),
466 );
467 }
468 if let Some(sig) = signal_name {
469 payload.insert("signal".into(), serde_json::Value::String(sig));
470 } else {
471 payload.insert("signal".into(), serde_json::Value::Null);
472 }
473
474 let content = serde_json::to_string(&payload).unwrap_or_default();
475 harn_vm::orchestration::agent_inbox::push(
476 &context.session_id,
477 "tool_result",
478 &content,
479 "hostlib.long_running.exit",
480 );
481 signal_done();
482}
483
484fn spawn_output_drain(
485 mut reader: Box<dyn Read + Send>,
486 state: Arc<Mutex<OutputState>>,
487 path: std::path::PathBuf,
488 combined_file: Option<Arc<Mutex<std::fs::File>>>,
489 stdout: bool,
490) -> std::thread::JoinHandle<()> {
491 std::thread::spawn(move || {
492 let mut file = std::fs::File::create(path).ok();
493 let mut buf = [0_u8; 8192];
494 loop {
495 let read = match reader.read(&mut buf) {
496 Ok(0) => break,
497 Ok(read) => read,
498 Err(_) => break,
499 };
500 let chunk = &buf[..read];
501 if let Some(file) = file.as_mut() {
502 let _ = file.write_all(chunk);
503 }
504 if let Some(combined) = combined_file.as_ref() {
505 if let Ok(mut combined) = combined.lock() {
506 let _ = combined.write_all(chunk);
507 }
508 }
509 if let Ok(mut state) = state.lock() {
510 if stdout {
511 state.stdout.extend_from_slice(chunk);
512 } else {
513 state.stderr.extend_from_slice(chunk);
514 }
515 }
516 }
517 })
518}
519
520fn spawn_progress_thread(context: ProgressThreadContext) -> std::thread::JoinHandle<()> {
521 std::thread::spawn(move || {
522 while !context.done.load(Ordering::Acquire)
523 && !context.cancel_state.cancelled.load(Ordering::Acquire)
524 {
525 std::thread::sleep(context.interval);
526 if context.done.load(Ordering::Acquire)
527 || context.cancel_state.cancelled.load(Ordering::Acquire)
528 {
529 break;
530 }
531 let (stdout, stderr) = {
532 let state = context
533 .output_state
534 .lock()
535 .unwrap_or_else(|poison| poison.into_inner());
536 (state.stdout.clone(), state.stderr.clone())
537 };
538 let capture = CaptureConfig {
539 max_inline_bytes: context.max_inline_bytes,
540 ..CaptureConfig::default()
541 };
542 let (inline_stdout, inline_stderr) = proc::inline_output(&stdout, &stderr, capture);
543 let byte_count = stdout.len().saturating_add(stderr.len());
544 let payload = serde_json::json!({
545 "command_id": &context.command_id,
546 "handle_id": &context.handle_id,
547 "status": CommandStatus::Running.as_str(),
548 "command_or_op_descriptor": &context.command_display,
549 "started_at": &context.started_at,
550 "ended_at": null,
551 "duration_ms": context.started.elapsed().as_millis() as i64,
552 "exit_code": null,
553 "signal": null,
554 "stdout": inline_stdout,
555 "stderr": inline_stderr,
556 "output_path": context.output_path.display().to_string(),
557 "stdout_path": context.stdout_path.display().to_string(),
558 "stderr_path": context.stderr_path.display().to_string(),
559 "byte_count": byte_count as i64,
560 "line_count": stdout.iter().chain(stderr.iter()).filter(|byte| **byte == b'\n').count() as i64,
561 "process_group_id": context.process_group_id,
562 });
563 harn_vm::orchestration::agent_inbox::push(
564 &context.session_id,
565 "tool_progress",
566 &payload.to_string(),
567 "hostlib.long_running.progress",
568 );
569 }
570 })
571}
572
573pub fn cancel_handle(handle_id: &str) -> bool {
576 let (handle_owned, killer, cancel_state, completion_tx) = {
577 let mut store = HANDLE_STORE
578 .lock()
579 .expect("long-running handle store poisoned");
580 match store.entries.remove(handle_id) {
581 None => return false,
582 Some(mut entry) => (
583 entry.handle.take(),
584 entry.killer.clone(),
585 entry.cancel_state.clone(),
586 entry.completion_tx.take(),
587 ),
588 }
589 };
590 do_kill(handle_owned, killer, cancel_state);
591 if let Some(tx) = completion_tx {
595 let _ = tx.try_send(());
596 }
597 true
598}
599
600type SessionKillEntry = (
604 Option<Box<dyn ProcessHandle>>,
605 Arc<dyn ProcessKiller>,
606 Arc<CancelState>,
607);
608
609pub fn cancel_session_handles(session_id: &str) {
612 let to_kill: Vec<SessionKillEntry> = {
613 let mut store = HANDLE_STORE
614 .lock()
615 .expect("long-running handle store poisoned");
616 let matching: Vec<String> = store
617 .entries
618 .iter()
619 .filter(|(_, e)| e.session_id == session_id)
620 .map(|(id, _)| id.clone())
621 .collect();
622 matching
623 .into_iter()
624 .filter_map(|id| {
625 store.entries.remove(&id).map(|mut e| {
626 let handle = e.handle.take();
627 (handle, e.killer.clone(), e.cancel_state.clone())
628 })
629 })
630 .collect()
631 };
632 for (handle, killer, cancel_state) in to_kill {
633 do_kill(handle, killer, cancel_state);
634 }
635}
636
637fn do_kill(
640 handle: Option<Box<dyn ProcessHandle>>,
641 killer: Arc<dyn ProcessKiller>,
642 cancel_state: Arc<CancelState>,
643) {
644 cancel_state.cancelled.store(true, Ordering::Release);
646 killer.kill();
650 drop(handle);
651}
652
653pub(crate) fn register_cleanup_hook() {
657 static REGISTERED: OnceLock<()> = OnceLock::new();
658 REGISTERED.get_or_init(|| {
659 let hook: Arc<dyn Fn(&str) + Send + Sync> = Arc::new(|session_id: &str| {
660 cancel_session_handles(session_id);
661 });
662 harn_vm::register_session_end_hook(hook);
663 });
664}
665
666fn decode_exit_status(status: process_handle::ExitStatus) -> (i32, Option<String>) {
667 if let Some(code) = status.code {
668 return (code, None);
669 }
670 if let Some(sig) = status.signal {
671 return (-1, Some(format!("SIG{sig}")));
672 }
673 (-1, None)
674}
675
676pub fn register_completion_notifier(handle_id: &str) -> Option<std::sync::mpsc::Receiver<()>> {
682 let (tx, rx) = std::sync::mpsc::sync_channel::<()>(1);
683 let mut store = HANDLE_STORE
684 .lock()
685 .expect("long-running handle store poisoned");
686 let entry = store.entries.get_mut(handle_id)?;
687 entry.completion_tx = Some(tx);
688 Some(rx)
689}