1use std::collections::BTreeMap;
36use std::io::{Read, Write};
37use std::path::PathBuf;
38use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
39use std::sync::{Arc, LazyLock, Mutex, OnceLock};
40use std::time::Duration;
41
42use harn_vm::VmValue;
43
44use crate::error::HostlibError;
45use crate::process::{self as process_handle, ProcessHandle, ProcessKiller, SpawnSpec};
46use crate::tools::proc::{self, CaptureConfig, CommandStatus, EnvMode};
47
48static HANDLE_COUNTER: AtomicU64 = AtomicU64::new(1);
50
51struct CancelState {
53 cancelled: AtomicBool,
56}
57
58#[derive(Default)]
59struct OutputState {
60 stdout: Vec<u8>,
61 stderr: Vec<u8>,
62}
63
64struct HandleEntry {
66 handle: Option<Box<dyn ProcessHandle>>,
68 killer: Arc<dyn ProcessKiller>,
70 session_id: String,
71 cancel_state: Arc<CancelState>,
73 completion_tx: Option<std::sync::mpsc::SyncSender<()>>,
77}
78
79#[derive(Default)]
80struct HandleStore {
81 entries: BTreeMap<String, HandleEntry>,
82}
83
84static HANDLE_STORE: LazyLock<Mutex<HandleStore>> =
85 LazyLock::new(|| Mutex::new(HandleStore::default()));
86
87pub struct LongRunningHandleInfo {
90 pub command_id: String,
92 pub handle_id: String,
94 pub started_at: String,
96 pub pid: u32,
98 pub process_group_id: Option<u32>,
100 pub command_display: String,
102}
103
104pub(crate) struct LongRunningSpawnOptions {
105 pub(crate) env_mode: EnvMode,
106 pub(crate) capture: CaptureConfig,
107 pub(crate) session_id: String,
108 pub(crate) progress_interval: Option<Duration>,
109 pub(crate) progress_max_inline_bytes: usize,
110}
111
112struct WaiterContext {
113 command_id: String,
114 handle_id: String,
115 session_id: String,
116 started_at: String,
117 process_group_id: Option<u32>,
118 command_display: String,
119 progress_interval: Option<Duration>,
120 progress_max_inline_bytes: usize,
121}
122
123struct ProgressThreadContext {
124 command_id: String,
125 handle_id: String,
126 session_id: String,
127 started_at: String,
128 command_display: String,
129 process_group_id: Option<u32>,
130 output_path: PathBuf,
131 stdout_path: PathBuf,
132 stderr_path: PathBuf,
133 output_state: Arc<Mutex<OutputState>>,
134 cancel_state: Arc<CancelState>,
135 done: Arc<AtomicBool>,
136 started: std::time::Instant,
137 interval: Duration,
138 max_inline_bytes: usize,
139}
140
141impl LongRunningHandleInfo {
142 pub fn into_handle_response(self) -> VmValue {
144 proc::running_response(
145 self.command_id,
146 self.handle_id,
147 self.pid,
148 self.process_group_id,
149 self.started_at,
150 self.command_display,
151 )
152 }
153}
154
155pub fn spawn_long_running(
160 builtin: &'static str,
161 program: String,
162 args: Vec<String>,
163 cwd: Option<PathBuf>,
164 env: BTreeMap<String, String>,
165 session_id: String,
166) -> Result<LongRunningHandleInfo, HostlibError> {
167 spawn_long_running_with_options(
168 builtin,
169 program,
170 args,
171 cwd,
172 env,
173 LongRunningSpawnOptions {
174 env_mode: EnvMode::InheritClean,
175 capture: CaptureConfig::default(),
176 session_id,
177 progress_interval: None,
178 progress_max_inline_bytes: CaptureConfig::default().max_inline_bytes,
179 },
180 )
181}
182
183pub(crate) fn spawn_long_running_with_options(
184 builtin: &'static str,
185 program: String,
186 args: Vec<String>,
187 cwd: Option<PathBuf>,
188 env: BTreeMap<String, String>,
189 options: LongRunningSpawnOptions,
190) -> Result<LongRunningHandleInfo, HostlibError> {
191 let spec = SpawnSpec {
192 builtin,
193 program: program.clone(),
194 args: args.clone(),
195 cwd,
196 env,
197 env_mode: options.env_mode,
198 use_stdin: false,
199 configure_process_group: true,
200 };
201 let handle = process_handle::spawn_process(spec)
202 .map_err(|e| proc::process_error_to_hostlib(builtin, e))?;
203
204 let pid = handle.pid().unwrap_or(0);
205 let process_group_id = handle.process_group_id();
206 let killer = handle.killer();
207 let id = HANDLE_COUNTER.fetch_add(1, Ordering::SeqCst);
208 let handle_id = format!("hto-{:x}-{id}", std::process::id());
209 let command_id = proc::next_command_id();
210 let started_at = proc::now_rfc3339();
211
212 let mut all_argv = vec![program.clone()];
213 all_argv.extend(args.iter().cloned());
214 let command_display = all_argv.join(" ");
215
216 let cancel_state = Arc::new(CancelState {
217 cancelled: AtomicBool::new(false),
218 });
219
220 {
221 let mut store = HANDLE_STORE
222 .lock()
223 .expect("long-running handle store poisoned");
224 store.entries.insert(
225 handle_id.clone(),
226 HandleEntry {
227 handle: Some(handle),
228 killer,
229 session_id: options.session_id.clone(),
230 cancel_state: cancel_state.clone(),
231 completion_tx: None,
232 },
233 );
234 }
235
236 let waiter_context = WaiterContext {
237 command_id: command_id.clone(),
238 handle_id: handle_id.clone(),
239 session_id: options.session_id,
240 started_at: started_at.clone(),
241 process_group_id,
242 command_display: command_display.clone(),
243 progress_interval: options.progress_interval,
244 progress_max_inline_bytes: options.progress_max_inline_bytes,
245 };
246 let waiter_thread_name = waiter_context.handle_id.clone();
247 let capture = options.capture;
248 std::thread::Builder::new()
249 .name(format!("hto-waiter-{waiter_thread_name}"))
250 .spawn(move || {
251 waiter_thread(waiter_context, cancel_state, capture);
252 })
253 .map_err(|e| HostlibError::Backend {
254 builtin,
255 message: format!("failed to spawn waiter thread: {e}"),
256 })?;
257
258 Ok(LongRunningHandleInfo {
259 command_id,
260 handle_id,
261 started_at,
262 pid,
263 process_group_id,
264 command_display,
265 })
266}
267
268fn waiter_thread(context: WaiterContext, cancel_state: Arc<CancelState>, capture: CaptureConfig) {
270 let waiter_start = std::time::Instant::now();
271
272 let mut handle = {
275 let mut store = HANDLE_STORE
276 .lock()
277 .expect("long-running handle store poisoned");
278 match store.entries.get_mut(&context.handle_id) {
279 Some(entry) => match entry.handle.take() {
280 Some(h) => h,
281 None => return, },
283 None => return, }
285 };
286
287 let output_state = Arc::new(Mutex::new(OutputState::default()));
288 let done = Arc::new(AtomicBool::new(false));
289 let planned = proc::planned_artifact_paths(&context.command_id);
290 if let Some(parent) = planned.output_path.parent() {
291 let _ = std::fs::create_dir_all(parent);
292 }
293 let _ = std::fs::File::create(&planned.stdout_path);
294 let _ = std::fs::File::create(&planned.stderr_path);
295 let combined_file = std::fs::File::create(&planned.output_path)
296 .ok()
297 .map(|file| Arc::new(Mutex::new(file)));
298
299 let stdout_thread = handle.take_stdout().map(|out| {
300 spawn_output_drain(
301 out,
302 output_state.clone(),
303 planned.stdout_path.clone(),
304 combined_file.clone(),
305 true,
306 )
307 });
308 let stderr_thread = handle.take_stderr().map(|err| {
309 spawn_output_drain(
310 err,
311 output_state.clone(),
312 planned.stderr_path.clone(),
313 combined_file.clone(),
314 false,
315 )
316 });
317
318 let progress_thread = context
319 .progress_interval
320 .filter(|interval| !interval.is_zero())
321 .map(|interval| {
322 spawn_progress_thread(ProgressThreadContext {
323 command_id: context.command_id.clone(),
324 handle_id: context.handle_id.clone(),
325 session_id: context.session_id.clone(),
326 started_at: context.started_at.clone(),
327 command_display: context.command_display.clone(),
328 process_group_id: context.process_group_id,
329 output_path: planned.output_path.clone(),
330 stdout_path: planned.stdout_path.clone(),
331 stderr_path: planned.stderr_path.clone(),
332 output_state: output_state.clone(),
333 cancel_state: cancel_state.clone(),
334 done: done.clone(),
335 started: waiter_start,
336 interval,
337 max_inline_bytes: context.progress_max_inline_bytes,
338 })
339 });
340
341 let status = handle.wait().ok();
342
343 if let Some(thread) = stdout_thread {
344 let _ = thread.join();
345 }
346 if let Some(thread) = stderr_thread {
347 let _ = thread.join();
348 }
349 done.store(true, Ordering::Release);
350 drop(progress_thread);
351 let (stdout, stderr) = {
352 let state = output_state
353 .lock()
354 .unwrap_or_else(|poison| poison.into_inner());
355 (state.stdout.clone(), state.stderr.clone())
356 };
357
358 let completion_tx = {
361 let mut store = HANDLE_STORE
362 .lock()
363 .expect("long-running handle store poisoned");
364 store
365 .entries
366 .remove(&context.handle_id)
367 .and_then(|mut e| e.completion_tx.take())
368 };
369
370 let signal_done = move || {
371 if let Some(tx) = completion_tx {
372 let _ = tx.try_send(());
373 }
374 };
375
376 if cancel_state.cancelled.load(Ordering::Acquire) {
379 signal_done();
380 return;
381 }
382
383 let (exit_code, signal_name) = match status {
384 Some(s) => decode_exit_status(s),
385 None => (-1, Some("SIGKILL".to_string())),
387 };
388 let duration = waiter_start.elapsed();
389 let duration_ms = duration.as_millis() as i64;
390 let artifacts = match proc::persist_artifacts(
391 &context.command_id,
392 &stdout,
393 &stderr,
394 Some(&context.handle_id),
395 ) {
396 Ok(artifacts) => artifacts,
397 Err(_) => return,
398 };
399 let (inline_stdout, inline_stderr) = proc::inline_output(&stdout, &stderr, capture);
400
401 let mut payload = serde_json::Map::new();
402 payload.insert(
403 "command_id".into(),
404 serde_json::Value::String(context.command_id.clone()),
405 );
406 payload.insert(
407 "status".into(),
408 serde_json::Value::String(CommandStatus::Completed.as_str().to_string()),
409 );
410 payload.insert(
411 "handle_id".into(),
412 serde_json::Value::String(context.handle_id),
413 );
414 payload.insert(
415 "command_or_op_descriptor".into(),
416 serde_json::Value::String(context.command_display),
417 );
418 payload.insert(
419 "started_at".into(),
420 serde_json::Value::String(context.started_at),
421 );
422 payload.insert(
423 "ended_at".into(),
424 serde_json::Value::String(proc::now_rfc3339()),
425 );
426 payload.insert(
427 "duration_ms".into(),
428 serde_json::Value::Number(duration_ms.into()),
429 );
430 payload.insert(
431 "exit_code".into(),
432 serde_json::Value::Number(exit_code.into()),
433 );
434 payload.insert("stdout".into(), serde_json::Value::String(inline_stdout));
435 payload.insert("stderr".into(), serde_json::Value::String(inline_stderr));
436 payload.insert(
437 "output_path".into(),
438 serde_json::Value::String(artifacts.output_path.display().to_string()),
439 );
440 payload.insert(
441 "stdout_path".into(),
442 serde_json::Value::String(artifacts.stdout_path.display().to_string()),
443 );
444 payload.insert(
445 "stderr_path".into(),
446 serde_json::Value::String(artifacts.stderr_path.display().to_string()),
447 );
448 payload.insert(
449 "line_count".into(),
450 serde_json::Value::Number(artifacts.line_count.into()),
451 );
452 payload.insert(
453 "byte_count".into(),
454 serde_json::Value::Number(artifacts.byte_count.into()),
455 );
456 payload.insert(
457 "output_sha256".into(),
458 serde_json::Value::String(artifacts.output_sha256),
459 );
460 if let Some(pgid) = context.process_group_id {
461 payload.insert(
462 "process_group_id".into(),
463 serde_json::Value::Number((pgid as u64).into()),
464 );
465 }
466 if let Some(sig) = signal_name {
467 payload.insert("signal".into(), serde_json::Value::String(sig));
468 } else {
469 payload.insert("signal".into(), serde_json::Value::Null);
470 }
471
472 let content = serde_json::to_string(&payload).unwrap_or_default();
473 harn_vm::push_pending_feedback_global(&context.session_id, "tool_result", &content);
474 signal_done();
475}
476
477fn spawn_output_drain(
478 mut reader: Box<dyn Read + Send>,
479 state: Arc<Mutex<OutputState>>,
480 path: std::path::PathBuf,
481 combined_file: Option<Arc<Mutex<std::fs::File>>>,
482 stdout: bool,
483) -> std::thread::JoinHandle<()> {
484 std::thread::spawn(move || {
485 let mut file = std::fs::File::create(path).ok();
486 let mut buf = [0_u8; 8192];
487 loop {
488 let read = match reader.read(&mut buf) {
489 Ok(0) => break,
490 Ok(read) => read,
491 Err(_) => break,
492 };
493 let chunk = &buf[..read];
494 if let Some(file) = file.as_mut() {
495 let _ = file.write_all(chunk);
496 }
497 if let Some(combined) = combined_file.as_ref() {
498 if let Ok(mut combined) = combined.lock() {
499 let _ = combined.write_all(chunk);
500 }
501 }
502 if let Ok(mut state) = state.lock() {
503 if stdout {
504 state.stdout.extend_from_slice(chunk);
505 } else {
506 state.stderr.extend_from_slice(chunk);
507 }
508 }
509 }
510 })
511}
512
513fn spawn_progress_thread(context: ProgressThreadContext) -> std::thread::JoinHandle<()> {
514 std::thread::spawn(move || {
515 while !context.done.load(Ordering::Acquire)
516 && !context.cancel_state.cancelled.load(Ordering::Acquire)
517 {
518 std::thread::sleep(context.interval);
519 if context.done.load(Ordering::Acquire)
520 || context.cancel_state.cancelled.load(Ordering::Acquire)
521 {
522 break;
523 }
524 let (stdout, stderr) = {
525 let state = context
526 .output_state
527 .lock()
528 .unwrap_or_else(|poison| poison.into_inner());
529 (state.stdout.clone(), state.stderr.clone())
530 };
531 let capture = CaptureConfig {
532 max_inline_bytes: context.max_inline_bytes,
533 ..CaptureConfig::default()
534 };
535 let (inline_stdout, inline_stderr) = proc::inline_output(&stdout, &stderr, capture);
536 let byte_count = stdout.len().saturating_add(stderr.len());
537 let payload = serde_json::json!({
538 "command_id": &context.command_id,
539 "handle_id": &context.handle_id,
540 "status": CommandStatus::Running.as_str(),
541 "command_or_op_descriptor": &context.command_display,
542 "started_at": &context.started_at,
543 "ended_at": null,
544 "duration_ms": context.started.elapsed().as_millis() as i64,
545 "exit_code": null,
546 "signal": null,
547 "stdout": inline_stdout,
548 "stderr": inline_stderr,
549 "output_path": context.output_path.display().to_string(),
550 "stdout_path": context.stdout_path.display().to_string(),
551 "stderr_path": context.stderr_path.display().to_string(),
552 "byte_count": byte_count as i64,
553 "line_count": stdout.iter().chain(stderr.iter()).filter(|byte| **byte == b'\n').count() as i64,
554 "process_group_id": context.process_group_id,
555 });
556 harn_vm::push_pending_feedback_global(
557 &context.session_id,
558 "tool_progress",
559 &payload.to_string(),
560 );
561 }
562 })
563}
564
565pub fn cancel_handle(handle_id: &str) -> bool {
568 let (handle_owned, killer, cancel_state, completion_tx) = {
569 let mut store = HANDLE_STORE
570 .lock()
571 .expect("long-running handle store poisoned");
572 match store.entries.remove(handle_id) {
573 None => return false,
574 Some(mut entry) => (
575 entry.handle.take(),
576 entry.killer.clone(),
577 entry.cancel_state.clone(),
578 entry.completion_tx.take(),
579 ),
580 }
581 };
582 do_kill(handle_owned, killer, cancel_state);
583 if let Some(tx) = completion_tx {
587 let _ = tx.try_send(());
588 }
589 true
590}
591
592type SessionKillEntry = (
596 Option<Box<dyn ProcessHandle>>,
597 Arc<dyn ProcessKiller>,
598 Arc<CancelState>,
599);
600
601pub fn cancel_session_handles(session_id: &str) {
604 let to_kill: Vec<SessionKillEntry> = {
605 let mut store = HANDLE_STORE
606 .lock()
607 .expect("long-running handle store poisoned");
608 let matching: Vec<String> = store
609 .entries
610 .iter()
611 .filter(|(_, e)| e.session_id == session_id)
612 .map(|(id, _)| id.clone())
613 .collect();
614 matching
615 .into_iter()
616 .filter_map(|id| {
617 store.entries.remove(&id).map(|mut e| {
618 let handle = e.handle.take();
619 (handle, e.killer.clone(), e.cancel_state.clone())
620 })
621 })
622 .collect()
623 };
624 for (handle, killer, cancel_state) in to_kill {
625 do_kill(handle, killer, cancel_state);
626 }
627}
628
629fn do_kill(
632 handle: Option<Box<dyn ProcessHandle>>,
633 killer: Arc<dyn ProcessKiller>,
634 cancel_state: Arc<CancelState>,
635) {
636 cancel_state.cancelled.store(true, Ordering::Release);
638 killer.kill();
642 drop(handle);
643}
644
645pub(crate) fn register_cleanup_hook() {
649 static REGISTERED: OnceLock<()> = OnceLock::new();
650 REGISTERED.get_or_init(|| {
651 let hook: Arc<dyn Fn(&str) + Send + Sync> = Arc::new(|session_id: &str| {
652 cancel_session_handles(session_id);
653 });
654 harn_vm::register_session_end_hook(hook);
655 });
656}
657
658fn decode_exit_status(status: process_handle::ExitStatus) -> (i32, Option<String>) {
659 if let Some(code) = status.code {
660 return (code, None);
661 }
662 if let Some(sig) = status.signal {
663 return (-1, Some(format!("SIG{sig}")));
664 }
665 (-1, None)
666}
667
668pub fn register_completion_notifier(handle_id: &str) -> Option<std::sync::mpsc::Receiver<()>> {
674 let (tx, rx) = std::sync::mpsc::sync_channel::<()>(1);
675 let mut store = HANDLE_STORE
676 .lock()
677 .expect("long-running handle store poisoned");
678 let entry = store.entries.get_mut(handle_id)?;
679 entry.completion_tx = Some(tx);
680 Some(rx)
681}