1use std::collections::{BTreeMap, HashMap, HashSet, VecDeque};
8use std::future::Future;
9use std::io::Write;
10use std::path::{Path, PathBuf};
11use std::pin::Pin;
12use std::rc::Rc;
13use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
14use std::sync::Arc;
15use std::time::Duration;
16
17use tokio::io::AsyncBufReadExt;
18use tokio::sync::{oneshot, Mutex, Notify};
19
20use harn_parser::diagnostic_codes::Code;
21
22use crate::orchestration::MutationSessionRecord;
23use crate::value::{ErrorCategory, VmClosure, VmError, VmValue};
24use crate::visible_text::VisibleTextState;
25use crate::vm::Vm;
26
27const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
29
30pub type HostBridgeWriter = Arc<dyn Fn(&str) -> Result<(), String> + Send + Sync>;
31
32fn stdout_writer(stdout_lock: Arc<std::sync::Mutex<()>>) -> HostBridgeWriter {
33 Arc::new(move |line: &str| {
34 let _guard = stdout_lock.lock().unwrap_or_else(|e| e.into_inner());
35 let mut stdout = std::io::stdout().lock();
36 stdout
37 .write_all(line.as_bytes())
38 .map_err(|e| format!("Bridge write error: {e}"))?;
39 stdout
40 .write_all(b"\n")
41 .map_err(|e| format!("Bridge write error: {e}"))?;
42 stdout
43 .flush()
44 .map_err(|e| format!("Bridge flush error: {e}"))?;
45 Ok(())
46 })
47}
48
49pub struct HostBridge {
56 next_id: AtomicU64,
57 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
59 cancelled: Arc<AtomicBool>,
61 cancel_notify: Arc<Notify>,
63 writer: HostBridgeWriter,
65 session_id: std::sync::Mutex<String>,
67 script_name: std::sync::Mutex<String>,
69 queued_transcript_injections: HostBridgeInjectionState,
71 resume_requested: Arc<AtomicBool>,
73 skills_reload_requested: Arc<AtomicBool>,
78 daemon_idle: Arc<AtomicBool>,
80 prompt_stop_reason: std::sync::Mutex<Option<String>>,
86 visible_call_states: std::sync::Mutex<HashMap<String, VisibleTextState>>,
88 visible_call_streams: std::sync::Mutex<HashMap<String, bool>>,
90 in_process: Option<InProcessHost>,
92}
93
94struct InProcessHost {
95 module_path: PathBuf,
96 exported_functions: BTreeMap<String, Rc<VmClosure>>,
97 vm: Vm,
98}
99
100impl InProcessHost {
101 fn dispatch<'a>(
107 &'a self,
108 method: &'a str,
109 params: serde_json::Value,
110 ) -> Pin<Box<dyn Future<Output = Result<serde_json::Value, VmError>> + 'a>> {
111 Box::pin(async move {
112 match method {
113 "builtin_call" => {
114 let name = params
115 .get("name")
116 .and_then(|value| value.as_str())
117 .unwrap_or_default();
118 let args = params
119 .get("args")
120 .and_then(|value| value.as_array())
121 .cloned()
122 .unwrap_or_default()
123 .into_iter()
124 .map(|value| json_result_to_vm_value(&value))
125 .collect::<Vec<_>>();
126 self.invoke_export(name, &args).await
127 }
128 "host/tools/list" => self
129 .invoke_optional_export("host_tools_list", &[])
130 .await
131 .map(|value| value.unwrap_or_else(|| serde_json::json!({ "tools": [] }))),
132 "session/request_permission" => self.request_permission(params).await,
133 other => Err(VmError::Runtime(format!(
134 "playground host backend does not implement bridge method '{other}'"
135 ))),
136 }
137 })
138 }
139
140 async fn invoke_export(
141 &self,
142 name: &str,
143 args: &[VmValue],
144 ) -> Result<serde_json::Value, VmError> {
145 let Some(closure) = self.exported_functions.get(name) else {
146 return Err(VmError::Runtime(format!(
147 "Playground host is missing capability '{name}'. Define `pub fn {name}(...)` in {}",
148 self.module_path.display()
149 )));
150 };
151
152 let mut vm = self.vm.child_vm_for_host();
153 let result = vm.call_closure_pub(closure, args).await?;
154 Ok(crate::llm::vm_value_to_json(&result))
155 }
156
157 async fn invoke_optional_export(
158 &self,
159 name: &str,
160 args: &[VmValue],
161 ) -> Result<Option<serde_json::Value>, VmError> {
162 if !self.exported_functions.contains_key(name) {
163 return Ok(None);
164 }
165 self.invoke_export(name, args).await.map(Some)
166 }
167
168 async fn request_permission(
169 &self,
170 params: serde_json::Value,
171 ) -> Result<serde_json::Value, VmError> {
172 let Some(closure) = self.exported_functions.get("request_permission") else {
173 return Ok(serde_json::json!({ "granted": true }));
174 };
175
176 let tool_name = params
177 .get("toolCall")
178 .and_then(|tool_call| tool_call.get("toolName"))
179 .and_then(|value| value.as_str())
180 .unwrap_or_default();
181 let tool_args = params
182 .get("toolCall")
183 .and_then(|tool_call| tool_call.get("rawInput"))
184 .map(json_result_to_vm_value)
185 .unwrap_or(VmValue::Nil);
186 let full_payload = json_result_to_vm_value(¶ms);
187
188 let arg_count = closure.func.params.len();
189 let args = if arg_count >= 3 {
190 vec![
191 VmValue::String(Rc::from(tool_name.to_string())),
192 tool_args,
193 full_payload,
194 ]
195 } else if arg_count == 2 {
196 vec![VmValue::String(Rc::from(tool_name.to_string())), tool_args]
197 } else if arg_count == 1 {
198 vec![full_payload]
199 } else {
200 Vec::new()
201 };
202
203 let mut vm = self.vm.child_vm_for_host();
204 let result = vm.call_closure_pub(closure, &args).await?;
205 let payload = match result {
206 VmValue::Bool(granted) => serde_json::json!({ "granted": granted }),
207 VmValue::String(reason) if !reason.is_empty() => {
208 serde_json::json!({ "granted": false, "reason": reason.to_string() })
209 }
210 other => {
211 let json = crate::llm::vm_value_to_json(&other);
212 if json
213 .get("granted")
214 .and_then(|value| value.as_bool())
215 .is_some()
216 || json.get("outcome").is_some()
217 {
218 json
219 } else {
220 serde_json::json!({ "granted": other.is_truthy() })
221 }
222 }
223 };
224 Ok(payload)
225 }
226}
227
228#[derive(Clone, Copy, Debug, PartialEq, Eq)]
229pub enum QueuedUserMessageMode {
230 InterruptImmediate,
231 FinishStep,
232 WaitForCompletion,
233}
234
235#[derive(Clone, Copy, Debug, PartialEq, Eq)]
236pub enum DeliveryCheckpoint {
237 InterruptImmediate,
238 AfterCurrentOperation,
239 EndOfInteraction,
240}
241
242impl QueuedUserMessageMode {
243 fn from_str(value: &str) -> Self {
244 match value {
245 "interrupt_immediate" | "interrupt" => Self::InterruptImmediate,
246 "finish_step" | "after_current_operation" => Self::FinishStep,
247 _ => Self::WaitForCompletion,
248 }
249 }
250}
251
252#[derive(Clone, Debug, PartialEq, Eq)]
253pub struct QueuedUserMessage {
254 pub message_id: String,
255 pub content: String,
256 pub transcript_content: serde_json::Value,
257 pub mode: QueuedUserMessageMode,
258}
259
260#[derive(Clone, Debug, PartialEq, Eq)]
261pub struct QueuedReminder {
262 pub reminder: crate::llm::helpers::SystemReminder,
263 pub mode: QueuedUserMessageMode,
264}
265
266#[derive(Clone, Debug, PartialEq, Eq)]
267pub enum QueuedTranscriptInjection {
268 User(QueuedUserMessage),
269 Reminder(QueuedReminder),
270}
271
272#[derive(Debug, Default)]
273struct QueuedTranscriptInjections {
274 queue: VecDeque<QueuedTranscriptInjection>,
275 revoked_user_message_ids: HashSet<String>,
276 delivered_user_message_ids: HashSet<String>,
277}
278
279#[derive(Clone, Debug, Default)]
280pub struct HostBridgeInjectionState {
281 inner: Arc<Mutex<QueuedTranscriptInjections>>,
282}
283
284#[derive(Clone, Copy, Debug, PartialEq, Eq)]
285pub enum PendingUserMessageMutationResult {
286 Mutated,
287 AlreadyRevoked,
288 AlreadyDelivered,
289 UnknownMessageId,
290}
291
292impl QueuedTranscriptInjection {
293 fn mode(&self) -> QueuedUserMessageMode {
294 match self {
295 Self::User(message) => message.mode,
296 Self::Reminder(reminder) => reminder.mode,
297 }
298 }
299}
300
301fn new_inject_message_id() -> String {
302 format!("msg_inj_{}", uuid::Uuid::now_v7().simple())
303}
304
305impl HostBridgeInjectionState {
306 pub fn new() -> Self {
307 Self::default()
308 }
309
310 pub async fn push_pending_user_message(
311 &self,
312 content: String,
313 transcript_content: serde_json::Value,
314 mode: &str,
315 ) -> String {
316 let message_id = new_inject_message_id();
317 self.inner
318 .lock()
319 .await
320 .queue
321 .push_back(QueuedTranscriptInjection::User(QueuedUserMessage {
322 message_id: message_id.clone(),
323 content,
324 transcript_content,
325 mode: QueuedUserMessageMode::from_str(mode),
326 }));
327 message_id
328 }
329
330 pub async fn revoke_pending_user_message(
331 &self,
332 message_id: &str,
333 ) -> PendingUserMessageMutationResult {
334 let mut state = self.inner.lock().await;
335 let mut retained = VecDeque::new();
336 let mut revoked = false;
337 while let Some(injection) = state.queue.pop_front() {
338 match &injection {
339 QueuedTranscriptInjection::User(message) if message.message_id == message_id => {
340 revoked = true;
341 }
342 _ => retained.push_back(injection),
343 }
344 }
345 state.queue = retained;
346 if revoked {
347 state
348 .revoked_user_message_ids
349 .insert(message_id.to_string());
350 return PendingUserMessageMutationResult::Mutated;
351 }
352 if state.revoked_user_message_ids.contains(message_id) {
353 PendingUserMessageMutationResult::AlreadyRevoked
354 } else if state.delivered_user_message_ids.contains(message_id) {
355 PendingUserMessageMutationResult::AlreadyDelivered
356 } else {
357 PendingUserMessageMutationResult::UnknownMessageId
358 }
359 }
360
361 pub async fn replace_pending_user_message(
362 &self,
363 message_id: &str,
364 content: String,
365 transcript_content: serde_json::Value,
366 ) -> PendingUserMessageMutationResult {
367 let mut state = self.inner.lock().await;
368 for injection in &mut state.queue {
369 if let QueuedTranscriptInjection::User(message) = injection {
370 if message.message_id == message_id {
371 message.content = content;
372 message.transcript_content = transcript_content;
373 return PendingUserMessageMutationResult::Mutated;
374 }
375 }
376 }
377 if state.revoked_user_message_ids.contains(message_id) {
378 PendingUserMessageMutationResult::AlreadyRevoked
379 } else if state.delivered_user_message_ids.contains(message_id) {
380 PendingUserMessageMutationResult::AlreadyDelivered
381 } else {
382 PendingUserMessageMutationResult::UnknownMessageId
383 }
384 }
385
386 async fn push_session_reminder(&self, reminder: QueuedReminder) {
387 self.inner
388 .lock()
389 .await
390 .queue
391 .push_back(QueuedTranscriptInjection::Reminder(reminder));
392 }
393}
394
395fn reminder_unknown_option_error(message: impl AsRef<str>) -> String {
396 format!(
397 "{}: {}",
398 Code::ReminderUnknownOption.as_str(),
399 message.as_ref()
400 )
401}
402
403fn session_remind_shape_error(message: impl AsRef<str>) -> String {
404 format!(
405 "{}: {}",
406 Code::ReminderInvalidShape.as_str(),
407 message.as_ref()
408 )
409}
410
411fn reminder_unknown_propagate_error(message: impl AsRef<str>) -> String {
412 format!(
413 "{}: {}",
414 Code::ReminderUnknownPropagate.as_str(),
415 message.as_ref()
416 )
417}
418
419fn string_field(
420 map: &serde_json::Map<String, serde_json::Value>,
421 key: &str,
422 required: bool,
423) -> Result<Option<String>, String> {
424 match map.get(key) {
425 None | Some(serde_json::Value::Null) if required => Err(session_remind_shape_error(
426 format!("`{key}` must be a non-empty string"),
427 )),
428 None | Some(serde_json::Value::Null) => Ok(None),
429 Some(serde_json::Value::String(value)) if required && value.trim().is_empty() => Err(
430 session_remind_shape_error(format!("`{key}` must be a non-empty string")),
431 ),
432 Some(serde_json::Value::String(value)) => {
433 let trimmed = value.trim();
434 if trimmed.is_empty() {
435 Ok(None)
436 } else {
437 Ok(Some(trimmed.to_string()))
438 }
439 }
440 Some(other) => Err(session_remind_shape_error(format!(
441 "`{key}` must be a string, got {other}"
442 ))),
443 }
444}
445
446fn bool_field(
447 map: &serde_json::Map<String, serde_json::Value>,
448 key: &str,
449) -> Result<Option<bool>, String> {
450 match map.get(key) {
451 None | Some(serde_json::Value::Null) => Ok(None),
452 Some(serde_json::Value::Bool(value)) => Ok(Some(*value)),
453 Some(other) => Err(session_remind_shape_error(format!(
454 "`{key}` must be a bool, got {other}"
455 ))),
456 }
457}
458
459fn int_field(
460 map: &serde_json::Map<String, serde_json::Value>,
461 key: &str,
462) -> Result<Option<i64>, String> {
463 match map.get(key) {
464 None | Some(serde_json::Value::Null) => Ok(None),
465 Some(serde_json::Value::Number(value)) => {
466 let Some(value) = value.as_i64() else {
467 return Err(session_remind_shape_error(format!(
468 "`{key}` must be an integer"
469 )));
470 };
471 Ok(Some(value))
472 }
473 Some(other) => Err(session_remind_shape_error(format!(
474 "`{key}` must be an int, got {other}"
475 ))),
476 }
477}
478
479fn tags_field(map: &serde_json::Map<String, serde_json::Value>) -> Result<Vec<String>, String> {
480 let Some(value) = map.get("tags") else {
481 return Ok(Vec::new());
482 };
483 if value.is_null() {
484 return Ok(Vec::new());
485 }
486 let Some(values) = value.as_array() else {
487 return Err(session_remind_shape_error("`tags` must be a list"));
488 };
489 let mut tags = Vec::new();
490 for value in values {
491 let Some(tag) = value.as_str() else {
492 return Err(session_remind_shape_error(format!(
493 "`tags` entries must be strings, got {value}"
494 )));
495 };
496 let tag = tag.trim();
497 if tag.is_empty() {
498 return Err(session_remind_shape_error(
499 "`tags` entries must be non-empty strings",
500 ));
501 }
502 if !tags.iter().any(|existing| existing == tag) {
503 tags.push(tag.to_string());
504 }
505 }
506 Ok(tags)
507}
508
509fn session_remind_payload_from_value(
510 value: &serde_json::Value,
511) -> Result<crate::llm::helpers::SystemReminder, String> {
512 let Some(map) = value.as_object() else {
513 return Err(session_remind_shape_error(
514 "session/remind payload must be a reminder object",
515 ));
516 };
517 const ALLOWED: &[&str] = &[
518 "_meta",
519 "body",
520 "dedupe_key",
521 "fired_at_turn",
522 "id",
523 "preserve_on_compact",
524 "propagate",
525 "role_hint",
526 "source",
527 "tags",
528 "ttl_turns",
529 ];
530 let unknown = map
531 .keys()
532 .filter(|key| !ALLOWED.contains(&key.as_str()))
533 .map(String::as_str)
534 .collect::<Vec<_>>();
535 if !unknown.is_empty() {
536 if unknown.contains(&"content") {
537 return Err(session_remind_shape_error(
538 "session/remind expects reminder `body`, not user-message `content`",
539 ));
540 }
541 return Err(reminder_unknown_option_error(format!(
542 "unknown reminder option(s): {}",
543 unknown.join(", ")
544 )));
545 }
546 if let Some(meta) = map.get("_meta") {
547 if !meta.is_null() && !meta.is_object() {
548 return Err(session_remind_shape_error("`_meta` must be an object"));
549 }
550 }
551 let ttl_turns = int_field(map, "ttl_turns")?;
552 if let Some(value) = ttl_turns {
553 if value <= 0 {
554 return Err(session_remind_shape_error("`ttl_turns` must be > 0"));
555 }
556 }
557 let fired_at_turn = int_field(map, "fired_at_turn")?.unwrap_or(0);
558 if fired_at_turn < 0 {
559 return Err(session_remind_shape_error(
560 "`fired_at_turn` must be >= 0 when provided",
561 ));
562 }
563 match string_field(map, "source", false)?.as_deref() {
564 None | Some("bridge") => {}
565 Some(_) => {
566 return Err(session_remind_shape_error(
567 "`source` for session/remind must be bridge when provided",
568 ))
569 }
570 }
571 let propagate = match string_field(map, "propagate", false)?.as_deref() {
572 None => crate::llm::helpers::ReminderPropagate::Session,
573 Some("all") => crate::llm::helpers::ReminderPropagate::All,
574 Some("session") => crate::llm::helpers::ReminderPropagate::Session,
575 Some("none") => crate::llm::helpers::ReminderPropagate::None,
576 Some(_) => {
577 return Err(reminder_unknown_propagate_error(
578 "`propagate` must be one of all, session, or none",
579 ))
580 }
581 };
582 let role_hint = match string_field(map, "role_hint", false)?.as_deref() {
583 None => crate::llm::helpers::ReminderRoleHint::System,
584 Some("system") => crate::llm::helpers::ReminderRoleHint::System,
585 Some("developer") => crate::llm::helpers::ReminderRoleHint::Developer,
586 Some("user_block") => crate::llm::helpers::ReminderRoleHint::UserBlock,
587 Some("ephemeral_cache") => crate::llm::helpers::ReminderRoleHint::EphemeralCache,
588 Some(_) => {
589 return Err(session_remind_shape_error(
590 "`role_hint` must be one of system, developer, user_block, or ephemeral_cache",
591 ))
592 }
593 };
594 Ok(crate::llm::helpers::SystemReminder {
595 id: string_field(map, "id", false)?.unwrap_or_else(|| uuid::Uuid::now_v7().to_string()),
596 tags: tags_field(map)?,
597 dedupe_key: string_field(map, "dedupe_key", false)?,
598 ttl_turns,
599 preserve_on_compact: bool_field(map, "preserve_on_compact")?.unwrap_or(false),
600 propagate,
601 role_hint,
602 source: crate::llm::helpers::ReminderSource::Bridge,
603 body: string_field(map, "body", true)?.unwrap_or_default(),
604 fired_at_turn,
605 originating_agent_id: None,
606 })
607}
608
609fn queued_session_remind_from_params(params: &serde_json::Value) -> Result<QueuedReminder, String> {
610 let mode = QueuedUserMessageMode::from_str(
611 params
612 .get("mode")
613 .and_then(|value| value.as_str())
614 .unwrap_or("wait_for_completion"),
615 );
616 let reminder_value = if let Some(reminder) = params.get("reminder") {
617 reminder.clone()
618 } else {
619 let Some(params) = params.as_object() else {
620 return Err(session_remind_shape_error(
621 "session/remind params must be an object",
622 ));
623 };
624 let mut reminder = params.clone();
625 reminder.remove("mode");
626 reminder.remove("sessionId");
627 reminder.remove("session_id");
628 serde_json::Value::Object(reminder)
629 };
630 Ok(QueuedReminder {
631 reminder: session_remind_payload_from_value(&reminder_value)?,
632 mode,
633 })
634}
635
636#[allow(clippy::new_without_default)]
638impl HostBridge {
639 pub fn new() -> Self {
644 let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>> =
645 Arc::new(Mutex::new(HashMap::new()));
646 let cancelled = Arc::new(AtomicBool::new(false));
647 let cancel_notify = Arc::new(Notify::new());
648 let queued_transcript_injections = HostBridgeInjectionState::default();
649 let resume_requested = Arc::new(AtomicBool::new(false));
650 let skills_reload_requested = Arc::new(AtomicBool::new(false));
651 let daemon_idle = Arc::new(AtomicBool::new(false));
652
653 let pending_clone = pending.clone();
655 let cancelled_clone = cancelled.clone();
656 let cancel_notify_clone = cancel_notify.clone();
657 let queued_clone = queued_transcript_injections.clone();
658 let resume_clone = resume_requested.clone();
659 let skills_reload_clone = skills_reload_requested.clone();
660 tokio::task::spawn_local(async move {
661 let stdin = tokio::io::stdin();
662 let reader = tokio::io::BufReader::new(stdin);
663 let mut lines = reader.lines();
664
665 while let Ok(Some(line)) = lines.next_line().await {
666 let line = line.trim().to_string();
667 if line.is_empty() {
668 continue;
669 }
670
671 let msg: serde_json::Value = match serde_json::from_str(&line) {
672 Ok(v) => v,
673 Err(_) => continue,
674 };
675
676 if msg.get("id").is_none() {
678 if let Some(method) = msg["method"].as_str() {
679 if method == "cancel" {
680 cancelled_clone.store(true, Ordering::SeqCst);
681 cancel_notify_clone.notify_waiters();
682 } else if method == "agent/resume" {
683 resume_clone.store(true, Ordering::SeqCst);
684 } else if method == "skills/update" {
685 skills_reload_clone.store(true, Ordering::SeqCst);
686 } else if method == "session/remind" {
687 let params = &msg["params"];
688 if let Ok(reminder) = queued_session_remind_from_params(params) {
689 queued_clone.push_session_reminder(reminder).await;
690 }
691 }
692 }
693 continue;
694 }
695
696 if let Some(id) = msg["id"].as_u64() {
697 let mut pending = pending_clone.lock().await;
698 if let Some(sender) = pending.remove(&id) {
699 let _ = sender.send(msg);
700 }
701 }
702 }
703
704 let mut pending = pending_clone.lock().await;
706 pending.clear();
707 });
708
709 Self {
710 next_id: AtomicU64::new(1),
711 pending,
712 cancelled,
713 cancel_notify,
714 writer: stdout_writer(Arc::new(std::sync::Mutex::new(()))),
715 session_id: std::sync::Mutex::new(String::new()),
716 script_name: std::sync::Mutex::new(String::new()),
717 queued_transcript_injections,
718 resume_requested,
719 skills_reload_requested,
720 daemon_idle,
721 prompt_stop_reason: std::sync::Mutex::new(None),
722 visible_call_states: std::sync::Mutex::new(HashMap::new()),
723 visible_call_streams: std::sync::Mutex::new(HashMap::new()),
724 in_process: None,
725 }
726 }
727
728 pub fn from_parts(
734 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
735 cancelled: Arc<AtomicBool>,
736 stdout_lock: Arc<std::sync::Mutex<()>>,
737 start_id: u64,
738 ) -> Self {
739 Self::from_parts_with_writer(pending, cancelled, stdout_writer(stdout_lock), start_id)
740 }
741
742 pub fn from_parts_with_writer(
743 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
744 cancelled: Arc<AtomicBool>,
745 writer: HostBridgeWriter,
746 start_id: u64,
747 ) -> Self {
748 Self::from_parts_with_writer_and_cancel_notify(
749 pending,
750 cancelled,
751 Arc::new(Notify::new()),
752 writer,
753 start_id,
754 )
755 }
756
757 pub fn from_parts_with_writer_and_cancel_notify(
758 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
759 cancelled: Arc<AtomicBool>,
760 cancel_notify: Arc<Notify>,
761 writer: HostBridgeWriter,
762 start_id: u64,
763 ) -> Self {
764 Self::from_parts_with_writer_cancel_notify_and_injection_state(
765 pending,
766 cancelled,
767 cancel_notify,
768 writer,
769 start_id,
770 None,
771 )
772 }
773
774 pub fn from_parts_with_writer_cancel_notify_and_injection_state(
775 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
776 cancelled: Arc<AtomicBool>,
777 cancel_notify: Arc<Notify>,
778 writer: HostBridgeWriter,
779 start_id: u64,
780 injection_state: Option<HostBridgeInjectionState>,
781 ) -> Self {
782 Self {
783 next_id: AtomicU64::new(start_id),
784 pending,
785 cancelled,
786 cancel_notify,
787 writer,
788 session_id: std::sync::Mutex::new(String::new()),
789 script_name: std::sync::Mutex::new(String::new()),
790 queued_transcript_injections: injection_state.unwrap_or_default(),
791 resume_requested: Arc::new(AtomicBool::new(false)),
792 skills_reload_requested: Arc::new(AtomicBool::new(false)),
793 daemon_idle: Arc::new(AtomicBool::new(false)),
794 prompt_stop_reason: std::sync::Mutex::new(None),
795 visible_call_states: std::sync::Mutex::new(HashMap::new()),
796 visible_call_streams: std::sync::Mutex::new(HashMap::new()),
797 in_process: None,
798 }
799 }
800
801 pub async fn from_harn_module(mut vm: Vm, module_path: &Path) -> Result<Self, VmError> {
804 let exported_functions = vm.load_module_exports(module_path).await?;
805 Ok(Self {
806 next_id: AtomicU64::new(1),
807 pending: Arc::new(Mutex::new(HashMap::new())),
808 cancelled: Arc::new(AtomicBool::new(false)),
809 cancel_notify: Arc::new(Notify::new()),
810 writer: stdout_writer(Arc::new(std::sync::Mutex::new(()))),
811 session_id: std::sync::Mutex::new(String::new()),
812 script_name: std::sync::Mutex::new(String::new()),
813 queued_transcript_injections: HostBridgeInjectionState::default(),
814 resume_requested: Arc::new(AtomicBool::new(false)),
815 skills_reload_requested: Arc::new(AtomicBool::new(false)),
816 daemon_idle: Arc::new(AtomicBool::new(false)),
817 prompt_stop_reason: std::sync::Mutex::new(None),
818 visible_call_states: std::sync::Mutex::new(HashMap::new()),
819 visible_call_streams: std::sync::Mutex::new(HashMap::new()),
820 in_process: Some(InProcessHost {
821 module_path: module_path.to_path_buf(),
822 exported_functions,
823 vm,
824 }),
825 })
826 }
827
828 pub fn set_session_id(&self, id: &str) {
830 *self.session_id.lock().unwrap_or_else(|e| e.into_inner()) = id.to_string();
831 }
832
833 pub fn set_script_name(&self, name: &str) {
835 *self.script_name.lock().unwrap_or_else(|e| e.into_inner()) = name.to_string();
836 }
837
838 fn get_script_name(&self) -> String {
840 self.script_name
841 .lock()
842 .unwrap_or_else(|e| e.into_inner())
843 .clone()
844 }
845
846 pub fn get_session_id(&self) -> String {
848 self.session_id
849 .lock()
850 .unwrap_or_else(|e| e.into_inner())
851 .clone()
852 }
853
854 fn write_line(&self, line: &str) -> Result<(), VmError> {
856 (self.writer)(line).map_err(VmError::Runtime)
857 }
858
859 pub async fn call(
862 &self,
863 method: &str,
864 params: serde_json::Value,
865 ) -> Result<serde_json::Value, VmError> {
866 if let Some(in_process) = &self.in_process {
867 return in_process.dispatch(method, params).await;
868 }
869
870 if self.is_cancelled() {
871 return Err(VmError::Runtime("Bridge: operation cancelled".into()));
872 }
873
874 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
875 let cancel_wait = self.cancel_notify.notified();
876 tokio::pin!(cancel_wait);
877
878 let request = crate::jsonrpc::request(id, method, params);
879
880 let (tx, rx) = oneshot::channel();
881 {
882 let mut pending = self.pending.lock().await;
883 pending.insert(id, tx);
884 }
885
886 let line = serde_json::to_string(&request)
887 .map_err(|e| VmError::Runtime(format!("Bridge serialization error: {e}")))?;
888 if let Err(e) = self.write_line(&line) {
889 let mut pending = self.pending.lock().await;
890 pending.remove(&id);
891 return Err(e);
892 }
893
894 if self.is_cancelled() {
895 let mut pending = self.pending.lock().await;
896 pending.remove(&id);
897 return Err(VmError::Runtime("Bridge: operation cancelled".into()));
898 }
899
900 let response = tokio::select! {
901 result = rx => match result {
902 Ok(msg) => msg,
903 Err(_) => {
904 return Err(VmError::Runtime(
906 "Bridge: host closed connection before responding".into(),
907 ));
908 }
909 },
910 _ = &mut cancel_wait => {
911 let mut pending = self.pending.lock().await;
912 pending.remove(&id);
913 return Err(VmError::Runtime("Bridge: operation cancelled".into()));
914 }
915 _ = tokio::time::sleep(DEFAULT_TIMEOUT) => {
916 let mut pending = self.pending.lock().await;
917 pending.remove(&id);
918 return Err(VmError::Runtime(format!(
919 "Bridge: host did not respond to '{method}' within {}s",
920 DEFAULT_TIMEOUT.as_secs()
921 )));
922 }
923 };
924
925 if let Some(error) = response.get("error") {
926 let message = error["message"].as_str().unwrap_or("Unknown host error");
927 let code = error["code"].as_i64().unwrap_or(-1);
928 if code == -32001 {
930 return Err(VmError::CategorizedError {
931 message: message.to_string(),
932 category: ErrorCategory::ToolRejected,
933 });
934 }
935 return Err(VmError::Runtime(format!("Host error ({code}): {message}")));
936 }
937
938 Ok(response["result"].clone())
939 }
940
941 pub fn notify(&self, method: &str, params: serde_json::Value) {
944 let notification = crate::jsonrpc::notification(method, params);
945 if self.in_process.is_some() {
946 return;
947 }
948 if let Ok(line) = serde_json::to_string(¬ification) {
949 let _ = self.write_line(&line);
950 }
951 }
952
953 pub fn is_cancelled(&self) -> bool {
955 self.cancelled.load(Ordering::SeqCst)
956 }
957
958 pub fn take_resume_signal(&self) -> bool {
959 self.resume_requested.swap(false, Ordering::SeqCst)
960 }
961
962 pub fn signal_resume(&self) {
963 self.resume_requested.store(true, Ordering::SeqCst);
964 }
965
966 pub fn set_daemon_idle(&self, idle: bool) {
967 self.daemon_idle.store(idle, Ordering::SeqCst);
968 }
969
970 pub fn is_daemon_idle(&self) -> bool {
971 self.daemon_idle.load(Ordering::SeqCst)
972 }
973
974 pub fn set_prompt_stop_reason(&self, reason: &str) {
979 *self
980 .prompt_stop_reason
981 .lock()
982 .unwrap_or_else(|e| e.into_inner()) = Some(reason.to_string());
983 }
984
985 pub fn take_prompt_stop_reason(&self) -> Option<String> {
990 self.prompt_stop_reason
991 .lock()
992 .unwrap_or_else(|e| e.into_inner())
993 .take()
994 }
995
996 pub fn take_skills_reload_signal(&self) -> bool {
1001 self.skills_reload_requested.swap(false, Ordering::SeqCst)
1002 }
1003
1004 pub fn signal_skills_reload(&self) {
1008 self.skills_reload_requested.store(true, Ordering::SeqCst);
1009 }
1010
1011 pub async fn list_host_skills(&self) -> Result<Vec<serde_json::Value>, VmError> {
1017 let result = self.call("skills/list", serde_json::json!({})).await?;
1018 match result {
1019 serde_json::Value::Array(items) => Ok(items),
1020 serde_json::Value::Object(map) => match map.get("skills") {
1021 Some(serde_json::Value::Array(items)) => Ok(items.clone()),
1022 _ => Err(VmError::Runtime(
1023 "skills/list: host response must be an array or { skills: [...] }".into(),
1024 )),
1025 },
1026 _ => Err(VmError::Runtime(
1027 "skills/list: unexpected response shape".into(),
1028 )),
1029 }
1030 }
1031
1032 pub async fn list_host_tools(&self) -> Result<Vec<serde_json::Value>, VmError> {
1038 let result = self.call("host/tools/list", serde_json::json!({})).await?;
1039 parse_host_tools_list_response(result)
1040 }
1041
1042 pub async fn fetch_host_skill(&self, id: &str) -> Result<serde_json::Value, VmError> {
1046 self.call("skills/fetch", serde_json::json!({ "id": id }))
1047 .await
1048 }
1049
1050 pub fn injection_state(&self) -> HostBridgeInjectionState {
1051 self.queued_transcript_injections.clone()
1052 }
1053
1054 pub async fn push_pending_user_message(
1055 &self,
1056 content: String,
1057 transcript_content: serde_json::Value,
1058 mode: &str,
1059 ) -> String {
1060 self.queued_transcript_injections
1061 .push_pending_user_message(content, transcript_content, mode)
1062 .await
1063 }
1064
1065 pub async fn push_queued_user_message(&self, content: String, mode: &str) -> String {
1066 self.push_pending_user_message(content.clone(), serde_json::Value::String(content), mode)
1067 .await
1068 }
1069
1070 pub async fn revoke_pending_user_message(
1071 &self,
1072 message_id: &str,
1073 ) -> PendingUserMessageMutationResult {
1074 self.queued_transcript_injections
1075 .revoke_pending_user_message(message_id)
1076 .await
1077 }
1078
1079 pub async fn replace_pending_user_message(
1080 &self,
1081 message_id: &str,
1082 content: String,
1083 transcript_content: serde_json::Value,
1084 ) -> PendingUserMessageMutationResult {
1085 self.queued_transcript_injections
1086 .replace_pending_user_message(message_id, content, transcript_content)
1087 .await
1088 }
1089
1090 pub async fn push_queued_session_remind_from_params(
1091 &self,
1092 params: &serde_json::Value,
1093 ) -> Result<String, String> {
1094 let reminder = queued_session_remind_from_params(params)?;
1095 let reminder_id = reminder.reminder.id.clone();
1096 self.queued_transcript_injections
1097 .push_session_reminder(reminder)
1098 .await;
1099 Ok(reminder_id)
1100 }
1101
1102 pub async fn take_queued_user_messages(
1103 &self,
1104 include_interrupt_immediate: bool,
1105 include_finish_step: bool,
1106 include_wait_for_completion: bool,
1107 ) -> Vec<QueuedUserMessage> {
1108 let mut state = self.queued_transcript_injections.inner.lock().await;
1109 let mut selected = Vec::new();
1110 let mut retained = VecDeque::new();
1111 while let Some(injection) = state.queue.pop_front() {
1112 let should_take = match injection.mode() {
1113 QueuedUserMessageMode::InterruptImmediate => include_interrupt_immediate,
1114 QueuedUserMessageMode::FinishStep => include_finish_step,
1115 QueuedUserMessageMode::WaitForCompletion => include_wait_for_completion,
1116 };
1117 match (should_take, injection) {
1118 (true, QueuedTranscriptInjection::User(message)) => {
1119 state
1120 .delivered_user_message_ids
1121 .insert(message.message_id.clone());
1122 selected.push(message);
1123 }
1124 (_, injection) => retained.push_back(injection),
1125 }
1126 }
1127 state.queue = retained;
1128 selected
1129 }
1130
1131 pub async fn take_queued_transcript_injections(
1132 &self,
1133 include_interrupt_immediate: bool,
1134 include_finish_step: bool,
1135 include_wait_for_completion: bool,
1136 ) -> Vec<QueuedTranscriptInjection> {
1137 let mut state = self.queued_transcript_injections.inner.lock().await;
1138 let mut selected = Vec::new();
1139 let mut retained = VecDeque::new();
1140 while let Some(injection) = state.queue.pop_front() {
1141 let should_take = match injection.mode() {
1142 QueuedUserMessageMode::InterruptImmediate => include_interrupt_immediate,
1143 QueuedUserMessageMode::FinishStep => include_finish_step,
1144 QueuedUserMessageMode::WaitForCompletion => include_wait_for_completion,
1145 };
1146 if should_take {
1147 if let QueuedTranscriptInjection::User(message) = &injection {
1148 state
1149 .delivered_user_message_ids
1150 .insert(message.message_id.clone());
1151 }
1152 selected.push(injection);
1153 } else {
1154 retained.push_back(injection);
1155 }
1156 }
1157 state.queue = retained;
1158 selected
1159 }
1160
1161 pub async fn take_queued_user_messages_for(
1162 &self,
1163 checkpoint: DeliveryCheckpoint,
1164 ) -> Vec<QueuedUserMessage> {
1165 match checkpoint {
1166 DeliveryCheckpoint::InterruptImmediate => {
1167 self.take_queued_user_messages(true, false, false).await
1168 }
1169 DeliveryCheckpoint::AfterCurrentOperation => {
1170 self.take_queued_user_messages(false, true, false).await
1171 }
1172 DeliveryCheckpoint::EndOfInteraction => {
1173 self.take_queued_user_messages(false, false, true).await
1174 }
1175 }
1176 }
1177
1178 pub async fn take_queued_transcript_injections_for(
1179 &self,
1180 checkpoint: DeliveryCheckpoint,
1181 ) -> Vec<QueuedTranscriptInjection> {
1182 match checkpoint {
1183 DeliveryCheckpoint::InterruptImmediate => {
1184 self.take_queued_transcript_injections(true, false, false)
1185 .await
1186 }
1187 DeliveryCheckpoint::AfterCurrentOperation => {
1188 self.take_queued_transcript_injections(false, true, false)
1189 .await
1190 }
1191 DeliveryCheckpoint::EndOfInteraction => {
1192 self.take_queued_transcript_injections(false, false, true)
1193 .await
1194 }
1195 }
1196 }
1197
1198 pub fn send_output(&self, text: &str) {
1200 self.notify("output", serde_json::json!({"text": text}));
1201 }
1202
1203 pub fn send_progress(
1205 &self,
1206 phase: &str,
1207 message: &str,
1208 progress: Option<i64>,
1209 total: Option<i64>,
1210 data: Option<serde_json::Value>,
1211 ) {
1212 let mut payload = serde_json::json!({"phase": phase, "message": message});
1213 if let Some(p) = progress {
1214 payload["progress"] = serde_json::json!(p);
1215 }
1216 if let Some(t) = total {
1217 payload["total"] = serde_json::json!(t);
1218 }
1219 if let Some(d) = data {
1220 payload["data"] = d;
1221 }
1222 self.notify("progress", payload);
1223 }
1224
1225 pub fn send_log(&self, level: &str, message: &str, fields: Option<serde_json::Value>) {
1227 let mut payload = serde_json::json!({"level": level, "message": message});
1228 if let Some(f) = fields {
1229 payload["fields"] = f;
1230 }
1231 self.notify("log", payload);
1232 }
1233
1234 pub fn send_call_start(
1237 &self,
1238 call_id: &str,
1239 call_type: &str,
1240 name: &str,
1241 metadata: serde_json::Value,
1242 ) {
1243 let session_id = self.get_session_id();
1244 let script = self.get_script_name();
1245 let stream_publicly = metadata
1246 .get("stream_publicly")
1247 .and_then(|value| value.as_bool())
1248 .unwrap_or(true);
1249 self.visible_call_streams
1250 .lock()
1251 .unwrap_or_else(|e| e.into_inner())
1252 .insert(call_id.to_string(), stream_publicly);
1253 self.notify(
1254 "session/update",
1255 serde_json::json!({
1256 "sessionId": session_id,
1257 "update": {
1258 "sessionUpdate": "call_start",
1259 "content": {
1260 "toolCallId": call_id,
1261 "call_type": call_type,
1262 "name": name,
1263 "script": script,
1264 "metadata": metadata,
1265 },
1266 },
1267 }),
1268 );
1269 }
1270
1271 pub fn send_call_progress(
1274 &self,
1275 call_id: &str,
1276 delta: &str,
1277 accumulated_tokens: u64,
1278 user_visible: bool,
1279 ) {
1280 let session_id = self.get_session_id();
1281 let (visible_text, visible_delta) = {
1282 let stream_publicly = self
1283 .visible_call_streams
1284 .lock()
1285 .unwrap_or_else(|e| e.into_inner())
1286 .get(call_id)
1287 .copied()
1288 .unwrap_or(true);
1289 let mut states = self
1290 .visible_call_states
1291 .lock()
1292 .unwrap_or_else(|e| e.into_inner());
1293 let state = states.entry(call_id.to_string()).or_default();
1294 state.push(delta, stream_publicly)
1295 };
1296 self.notify(
1297 "session/update",
1298 serde_json::json!({
1299 "sessionId": session_id,
1300 "update": {
1301 "sessionUpdate": "call_progress",
1302 "content": {
1303 "toolCallId": call_id,
1304 "delta": delta,
1305 "accumulated_tokens": accumulated_tokens,
1306 "visible_text": visible_text,
1307 "visible_delta": visible_delta,
1308 "user_visible": user_visible,
1309 },
1310 },
1311 }),
1312 );
1313 }
1314
1315 pub fn send_call_end(
1317 &self,
1318 call_id: &str,
1319 call_type: &str,
1320 name: &str,
1321 duration_ms: u64,
1322 status: &str,
1323 metadata: serde_json::Value,
1324 ) {
1325 let session_id = self.get_session_id();
1326 let script = self.get_script_name();
1327 self.visible_call_states
1328 .lock()
1329 .unwrap_or_else(|e| e.into_inner())
1330 .remove(call_id);
1331 self.visible_call_streams
1332 .lock()
1333 .unwrap_or_else(|e| e.into_inner())
1334 .remove(call_id);
1335 self.notify(
1336 "session/update",
1337 serde_json::json!({
1338 "sessionId": session_id,
1339 "update": {
1340 "sessionUpdate": "call_end",
1341 "content": {
1342 "toolCallId": call_id,
1343 "call_type": call_type,
1344 "name": name,
1345 "script": script,
1346 "duration_ms": duration_ms,
1347 "status": status,
1348 "metadata": metadata,
1349 },
1350 },
1351 }),
1352 );
1353 }
1354
1355 pub fn send_worker_update(
1357 &self,
1358 worker_id: &str,
1359 worker_name: &str,
1360 status: &str,
1361 metadata: serde_json::Value,
1362 audit: Option<&MutationSessionRecord>,
1363 ) {
1364 let session_id = self.get_session_id();
1365 let script = self.get_script_name();
1366 let started_at = metadata.get("started_at").cloned().unwrap_or_default();
1367 let finished_at = metadata.get("finished_at").cloned().unwrap_or_default();
1368 let snapshot_path = metadata.get("snapshot_path").cloned().unwrap_or_default();
1369 let run_id = metadata.get("child_run_id").cloned().unwrap_or_default();
1370 let run_path = metadata.get("child_run_path").cloned().unwrap_or_default();
1371 let lifecycle = serde_json::json!({
1372 "event": status,
1373 "worker_id": worker_id,
1374 "worker_name": worker_name,
1375 "started_at": started_at,
1376 "finished_at": finished_at,
1377 });
1378 self.notify(
1379 "session/update",
1380 serde_json::json!({
1381 "sessionId": session_id,
1382 "update": {
1383 "sessionUpdate": "worker_update",
1384 "content": {
1385 "worker_id": worker_id,
1386 "worker_name": worker_name,
1387 "status": status,
1388 "script": script,
1389 "started_at": started_at,
1390 "finished_at": finished_at,
1391 "snapshot_path": snapshot_path,
1392 "run_id": run_id,
1393 "run_path": run_path,
1394 "lifecycle": lifecycle,
1395 "audit": audit,
1396 "metadata": metadata,
1397 },
1398 },
1399 }),
1400 );
1401 }
1402}
1403
1404pub fn json_result_to_vm_value(val: &serde_json::Value) -> VmValue {
1406 crate::stdlib::json_to_vm_value(val)
1407}
1408
1409fn parse_host_tools_list_response(
1410 result: serde_json::Value,
1411) -> Result<Vec<serde_json::Value>, VmError> {
1412 let tools = match result {
1413 serde_json::Value::Array(items) => items,
1414 serde_json::Value::Object(map) => match map.get("tools").cloned().or_else(|| {
1415 map.get("result")
1416 .and_then(|value| value.get("tools"))
1417 .cloned()
1418 }) {
1419 Some(serde_json::Value::Array(items)) => items,
1420 _ => {
1421 return Err(VmError::Runtime(
1422 "host/tools/list: host response must be an array or { tools: [...] }".into(),
1423 ));
1424 }
1425 },
1426 _ => {
1427 return Err(VmError::Runtime(
1428 "host/tools/list: unexpected response shape".into(),
1429 ));
1430 }
1431 };
1432
1433 let mut normalized = Vec::with_capacity(tools.len());
1434 for tool in tools {
1435 let serde_json::Value::Object(map) = tool else {
1436 return Err(VmError::Runtime(
1437 "host/tools/list: every tool must be an object".into(),
1438 ));
1439 };
1440 let Some(name) = map.get("name").and_then(|value| value.as_str()) else {
1441 return Err(VmError::Runtime(
1442 "host/tools/list: every tool must include a string `name`".into(),
1443 ));
1444 };
1445 let description = map
1446 .get("description")
1447 .and_then(|value| value.as_str())
1448 .or_else(|| {
1449 map.get("short_description")
1450 .and_then(|value| value.as_str())
1451 })
1452 .unwrap_or_default();
1453 let schema = map
1454 .get("schema")
1455 .cloned()
1456 .or_else(|| map.get("parameters").cloned())
1457 .or_else(|| map.get("input_schema").cloned())
1458 .unwrap_or(serde_json::Value::Null);
1459 let deprecated = map
1460 .get("deprecated")
1461 .and_then(|value| value.as_bool())
1462 .unwrap_or(false);
1463 normalized.push(serde_json::json!({
1464 "name": name,
1465 "description": description,
1466 "schema": schema,
1467 "deprecated": deprecated,
1468 }));
1469 }
1470 Ok(normalized)
1471}
1472
1473#[cfg(test)]
1474mod tests {
1475 use super::*;
1476
1477 fn test_bridge() -> HostBridge {
1478 HostBridge::from_parts(
1479 Arc::new(Mutex::new(HashMap::new())),
1480 Arc::new(AtomicBool::new(false)),
1481 Arc::new(std::sync::Mutex::new(())),
1482 1,
1483 )
1484 }
1485
1486 fn test_bridge_sharing_injection_state(owner: &HostBridge) -> HostBridge {
1487 HostBridge::from_parts_with_writer_cancel_notify_and_injection_state(
1488 Arc::new(Mutex::new(HashMap::new())),
1489 Arc::new(AtomicBool::new(false)),
1490 Arc::new(Notify::new()),
1491 Arc::new(|_| Ok(())),
1492 100,
1493 Some(owner.injection_state()),
1494 )
1495 }
1496
1497 #[test]
1498 fn test_json_rpc_request_format() {
1499 let request = crate::jsonrpc::request(
1500 1,
1501 "llm_call",
1502 serde_json::json!({
1503 "prompt": "Hello",
1504 "system": "Be helpful",
1505 }),
1506 );
1507 let s = serde_json::to_string(&request).unwrap();
1508 assert!(s.contains("\"jsonrpc\":\"2.0\""));
1509 assert!(s.contains("\"id\":1"));
1510 assert!(s.contains("\"method\":\"llm_call\""));
1511 }
1512
1513 #[test]
1514 fn test_json_rpc_notification_format() {
1515 let notification =
1516 crate::jsonrpc::notification("output", serde_json::json!({"text": "[harn] hello\n"}));
1517 let s = serde_json::to_string(¬ification).unwrap();
1518 assert!(s.contains("\"method\":\"output\""));
1519 assert!(!s.contains("\"id\""));
1520 }
1521
1522 #[test]
1523 fn test_json_rpc_error_response_parsing() {
1524 let response = crate::jsonrpc::error_response(1, -32600, "Invalid request");
1525 assert!(response.get("error").is_some());
1526 assert_eq!(
1527 response["error"]["message"].as_str().unwrap(),
1528 "Invalid request"
1529 );
1530 }
1531
1532 #[test]
1533 fn test_json_rpc_success_response_parsing() {
1534 let response = crate::jsonrpc::response(
1535 1,
1536 serde_json::json!({
1537 "text": "Hello world",
1538 "input_tokens": 10,
1539 "output_tokens": 5,
1540 }),
1541 );
1542 assert!(response.get("result").is_some());
1543 assert_eq!(response["result"]["text"].as_str().unwrap(), "Hello world");
1544 }
1545
1546 #[test]
1547 fn test_cancelled_flag() {
1548 let cancelled = Arc::new(AtomicBool::new(false));
1549 assert!(!cancelled.load(Ordering::SeqCst));
1550 cancelled.store(true, Ordering::SeqCst);
1551 assert!(cancelled.load(Ordering::SeqCst));
1552 }
1553
1554 #[test]
1555 fn pending_host_calls_return_when_cancellation_arrives() {
1556 let runtime = tokio::runtime::Builder::new_current_thread()
1557 .enable_all()
1558 .build()
1559 .unwrap();
1560 runtime.block_on(async {
1561 let pending = Arc::new(Mutex::new(HashMap::new()));
1562 let cancelled = Arc::new(AtomicBool::new(false));
1563 let bridge = HostBridge::from_parts_with_writer(
1564 pending.clone(),
1565 cancelled.clone(),
1566 Arc::new(|_| Ok(())),
1567 1,
1568 );
1569
1570 let call = bridge.call("host/work", serde_json::json!({}));
1571 tokio::pin!(call);
1572
1573 loop {
1574 tokio::select! {
1575 result = &mut call => panic!("call completed before cancellation: {result:?}"),
1576 _ = tokio::task::yield_now() => {}
1577 }
1578 if !pending.lock().await.is_empty() {
1579 break;
1580 }
1581 }
1582
1583 cancelled.store(true, Ordering::SeqCst);
1584 bridge.cancel_notify.notify_waiters();
1585
1586 let result = tokio::time::timeout(Duration::from_secs(1), call)
1587 .await
1588 .expect("pending call should observe cancellation promptly");
1589 assert!(
1590 matches!(result, Err(VmError::Runtime(message)) if message.contains("cancelled"))
1591 );
1592 assert!(pending.lock().await.is_empty());
1593 });
1594 }
1595
1596 #[test]
1597 fn queued_messages_are_filtered_by_delivery_mode() {
1598 let runtime = tokio::runtime::Builder::new_current_thread()
1599 .enable_all()
1600 .build()
1601 .unwrap();
1602 runtime.block_on(async {
1603 let bridge = test_bridge();
1604 bridge
1605 .push_queued_user_message("first".to_string(), "finish_step")
1606 .await;
1607 bridge
1608 .push_queued_user_message("second".to_string(), "wait_for_completion")
1609 .await;
1610
1611 let finish_step = bridge.take_queued_user_messages(false, true, false).await;
1612 assert_eq!(finish_step.len(), 1);
1613 assert_eq!(finish_step[0].content, "first");
1614
1615 let turn_end = bridge.take_queued_user_messages(false, false, true).await;
1616 assert_eq!(turn_end.len(), 1);
1617 assert_eq!(turn_end[0].content, "second");
1618 });
1619 }
1620
1621 #[test]
1622 fn pending_user_messages_support_revoke_replace_and_delivery_states() {
1623 let runtime = tokio::runtime::Builder::new_current_thread()
1624 .enable_all()
1625 .build()
1626 .unwrap();
1627 runtime.block_on(async {
1628 let bridge = test_bridge();
1629 let first_id = bridge
1630 .push_pending_user_message(
1631 "first".to_string(),
1632 serde_json::json!("first"),
1633 "wait_for_completion",
1634 )
1635 .await;
1636 let second_id = bridge
1637 .push_pending_user_message(
1638 "second".to_string(),
1639 serde_json::json!("second"),
1640 "wait_for_completion",
1641 )
1642 .await;
1643
1644 assert_eq!(
1645 bridge
1646 .replace_pending_user_message(
1647 &second_id,
1648 "second edited".to_string(),
1649 serde_json::json!("second edited"),
1650 )
1651 .await,
1652 PendingUserMessageMutationResult::Mutated
1653 );
1654 assert_eq!(
1655 bridge.revoke_pending_user_message(&first_id).await,
1656 PendingUserMessageMutationResult::Mutated
1657 );
1658 assert_eq!(
1659 bridge.revoke_pending_user_message(&first_id).await,
1660 PendingUserMessageMutationResult::AlreadyRevoked
1661 );
1662
1663 let delivered = bridge
1664 .take_queued_user_messages_for(DeliveryCheckpoint::EndOfInteraction)
1665 .await;
1666 assert_eq!(delivered.len(), 1);
1667 assert_eq!(delivered[0].message_id, second_id);
1668 assert_eq!(delivered[0].content, "second edited");
1669
1670 assert_eq!(
1671 bridge.revoke_pending_user_message(&second_id).await,
1672 PendingUserMessageMutationResult::AlreadyDelivered
1673 );
1674 assert_eq!(
1675 bridge
1676 .replace_pending_user_message(
1677 &second_id,
1678 "too late".to_string(),
1679 serde_json::json!("too late"),
1680 )
1681 .await,
1682 PendingUserMessageMutationResult::AlreadyDelivered
1683 );
1684 assert_eq!(
1685 bridge.revoke_pending_user_message("missing").await,
1686 PendingUserMessageMutationResult::UnknownMessageId
1687 );
1688 });
1689 }
1690
1691 #[test]
1692 fn pending_user_message_replace_preserves_fifo_position_and_mode() {
1693 let runtime = tokio::runtime::Builder::new_current_thread()
1694 .enable_all()
1695 .build()
1696 .unwrap();
1697 runtime.block_on(async {
1698 let bridge = test_bridge();
1699 let first_id = bridge
1700 .push_pending_user_message(
1701 "first".to_string(),
1702 serde_json::json!("first"),
1703 "finish_step",
1704 )
1705 .await;
1706 let second_id = bridge
1707 .push_pending_user_message(
1708 "second".to_string(),
1709 serde_json::json!("second"),
1710 "finish_step",
1711 )
1712 .await;
1713 assert_eq!(
1714 bridge
1715 .replace_pending_user_message(
1716 &first_id,
1717 "first edited".to_string(),
1718 serde_json::json!("first edited"),
1719 )
1720 .await,
1721 PendingUserMessageMutationResult::Mutated
1722 );
1723
1724 let delivered = bridge
1725 .take_queued_user_messages_for(DeliveryCheckpoint::AfterCurrentOperation)
1726 .await;
1727 assert_eq!(
1728 delivered
1729 .iter()
1730 .map(|message| (&message.message_id, message.content.as_str(), message.mode))
1731 .collect::<Vec<_>>(),
1732 vec![
1733 (&first_id, "first edited", QueuedUserMessageMode::FinishStep,),
1734 (&second_id, "second", QueuedUserMessageMode::FinishStep),
1735 ]
1736 );
1737 });
1738 }
1739
1740 #[test]
1741 fn pending_user_message_state_survives_bridge_replacement() {
1742 let runtime = tokio::runtime::Builder::new_current_thread()
1743 .enable_all()
1744 .build()
1745 .unwrap();
1746 runtime.block_on(async {
1747 let bridge = test_bridge();
1748 let revoked_id = bridge
1749 .push_pending_user_message(
1750 "revoke me".to_string(),
1751 serde_json::json!("revoke me"),
1752 "wait_for_completion",
1753 )
1754 .await;
1755 let delivered_id = bridge
1756 .push_pending_user_message(
1757 "deliver me".to_string(),
1758 serde_json::json!("deliver me"),
1759 "wait_for_completion",
1760 )
1761 .await;
1762 assert_eq!(
1763 bridge.revoke_pending_user_message(&revoked_id).await,
1764 PendingUserMessageMutationResult::Mutated
1765 );
1766 bridge.cancelled.store(true, Ordering::SeqCst);
1767
1768 let replacement_bridge = test_bridge_sharing_injection_state(&bridge);
1769 assert_eq!(
1770 replacement_bridge
1771 .revoke_pending_user_message(&revoked_id)
1772 .await,
1773 PendingUserMessageMutationResult::AlreadyRevoked
1774 );
1775 let delivered = replacement_bridge
1776 .take_queued_user_messages_for(DeliveryCheckpoint::EndOfInteraction)
1777 .await;
1778 assert_eq!(delivered.len(), 1);
1779 assert_eq!(delivered[0].message_id, delivered_id);
1780 assert_eq!(delivered[0].content, "deliver me");
1781 assert_eq!(
1782 bridge.revoke_pending_user_message(&delivered_id).await,
1783 PendingUserMessageMutationResult::AlreadyDelivered
1784 );
1785 });
1786 }
1787
1788 #[test]
1789 fn queued_transcript_injections_preserve_user_reminder_separation() {
1790 let runtime = tokio::runtime::Builder::new_current_thread()
1791 .enable_all()
1792 .build()
1793 .unwrap();
1794 runtime.block_on(async {
1795 let bridge = test_bridge();
1796 bridge
1797 .push_queued_user_message("human follow-up".to_string(), "finish_step")
1798 .await;
1799 let reminder_id = bridge
1800 .push_queued_session_remind_from_params(&serde_json::json!({
1801 "body": "Host-provided ambient context.",
1802 "tags": ["host"],
1803 "dedupe_key": "host-context",
1804 "ttl_turns": 2,
1805 "mode": "wait_for_completion",
1806 "_meta": {"harn": {"source": "test"}},
1807 }))
1808 .await
1809 .expect("valid reminder");
1810
1811 let finish_step = bridge.take_queued_user_messages(false, true, false).await;
1812 assert_eq!(finish_step.len(), 1);
1813 assert_eq!(finish_step[0].content, "human follow-up");
1814
1815 let no_user_messages = bridge.take_queued_user_messages(false, false, true).await;
1816 assert!(no_user_messages.is_empty());
1817
1818 let injections = bridge
1819 .take_queued_transcript_injections_for(DeliveryCheckpoint::EndOfInteraction)
1820 .await;
1821 assert_eq!(injections.len(), 1);
1822 let QueuedTranscriptInjection::Reminder(reminder) = &injections[0] else {
1823 panic!("expected queued reminder");
1824 };
1825 assert_eq!(reminder.reminder.id, reminder_id);
1826 assert_eq!(reminder.reminder.body, "Host-provided ambient context.");
1827 assert_eq!(reminder.reminder.tags, vec!["host".to_string()]);
1828 assert_eq!(
1829 reminder.reminder.dedupe_key.as_deref(),
1830 Some("host-context")
1831 );
1832 assert_eq!(reminder.reminder.ttl_turns, Some(2));
1833 assert_eq!(
1834 reminder.reminder.source,
1835 crate::llm::helpers::ReminderSource::Bridge
1836 );
1837 });
1838 }
1839
1840 #[test]
1841 fn bridge_remind_modes_honor_delivery_checkpoints() {
1842 let runtime = tokio::runtime::Builder::new_current_thread()
1843 .enable_all()
1844 .build()
1845 .unwrap();
1846 runtime.block_on(async {
1847 let cases = [
1848 (
1849 "interrupt_immediate",
1850 DeliveryCheckpoint::InterruptImmediate,
1851 DeliveryCheckpoint::AfterCurrentOperation,
1852 ),
1853 (
1854 "finish_step",
1855 DeliveryCheckpoint::AfterCurrentOperation,
1856 DeliveryCheckpoint::EndOfInteraction,
1857 ),
1858 (
1859 "wait_for_completion",
1860 DeliveryCheckpoint::EndOfInteraction,
1861 DeliveryCheckpoint::InterruptImmediate,
1862 ),
1863 ];
1864
1865 for (mode, expected_checkpoint, wrong_checkpoint) in cases {
1866 let bridge = test_bridge();
1867 bridge
1868 .push_queued_session_remind_from_params(&serde_json::json!({
1869 "body": format!("Reminder for {mode}"),
1870 "mode": mode,
1871 }))
1872 .await
1873 .expect("valid session/remind payload");
1874
1875 let premature = bridge
1876 .take_queued_transcript_injections_for(wrong_checkpoint)
1877 .await;
1878 assert!(
1879 premature.is_empty(),
1880 "{mode} reminder must not be delivered at {wrong_checkpoint:?}"
1881 );
1882
1883 let delivered = bridge
1884 .take_queued_transcript_injections_for(expected_checkpoint)
1885 .await;
1886 assert_eq!(delivered.len(), 1, "{mode} reminder was not delivered");
1887 let QueuedTranscriptInjection::Reminder(reminder) = &delivered[0] else {
1888 panic!("expected reminder for {mode}");
1889 };
1890 assert_eq!(reminder.reminder.body, format!("Reminder for {mode}"));
1891 }
1892 });
1893 }
1894
1895 #[test]
1896 fn session_remind_validation_rejects_user_message_shape() {
1897 let err = queued_session_remind_from_params(&serde_json::json!({
1898 "content": "this is still a user message",
1899 "mode": "interrupt_immediate",
1900 }))
1901 .expect_err("session/remind must require a reminder body");
1902 assert!(err.contains(Code::ReminderInvalidShape.as_str()));
1903 assert!(err.contains("body"));
1904 }
1905
1906 #[test]
1907 fn session_remind_validation_rejects_unknown_options_separately() {
1908 let err = queued_session_remind_from_params(&serde_json::json!({
1909 "body": "valid body",
1910 "unknown_host_field": true,
1911 }))
1912 .expect_err("session/remind must reject unknown top-level fields");
1913 assert!(err.contains(Code::ReminderUnknownOption.as_str()));
1914 assert!(err.contains("unknown_host_field"));
1915 }
1916
1917 #[test]
1918 fn session_remind_validation_rejects_unknown_propagate_with_specific_code() {
1919 let err = queued_session_remind_from_params(&serde_json::json!({
1920 "body": "valid body",
1921 "propagate": "workspace",
1922 }))
1923 .expect_err("session/remind must reject unknown propagate values");
1924 assert!(err.contains(Code::ReminderUnknownPropagate.as_str()));
1925 assert!(err.contains("propagate"));
1926 }
1927
1928 #[test]
1929 fn test_json_result_to_vm_value_string() {
1930 let val = serde_json::json!("hello");
1931 let vm_val = json_result_to_vm_value(&val);
1932 assert_eq!(vm_val.display(), "hello");
1933 }
1934
1935 #[test]
1936 fn test_json_result_to_vm_value_dict() {
1937 let val = serde_json::json!({"name": "test", "count": 42});
1938 let vm_val = json_result_to_vm_value(&val);
1939 let VmValue::Dict(d) = &vm_val else {
1940 unreachable!("Expected Dict, got {:?}", vm_val);
1941 };
1942 assert_eq!(d.get("name").unwrap().display(), "test");
1943 assert_eq!(d.get("count").unwrap().display(), "42");
1944 }
1945
1946 #[test]
1947 fn test_json_result_to_vm_value_null() {
1948 let val = serde_json::json!(null);
1949 let vm_val = json_result_to_vm_value(&val);
1950 assert!(matches!(vm_val, VmValue::Nil));
1951 }
1952
1953 #[test]
1954 fn test_json_result_to_vm_value_nested() {
1955 let val = serde_json::json!({
1956 "text": "response",
1957 "tool_calls": [
1958 {"id": "tc_1", "name": "read_file", "arguments": {"path": "foo.rs"}}
1959 ],
1960 "input_tokens": 100,
1961 "output_tokens": 50,
1962 });
1963 let vm_val = json_result_to_vm_value(&val);
1964 let VmValue::Dict(d) = &vm_val else {
1965 unreachable!("Expected Dict, got {:?}", vm_val);
1966 };
1967 assert_eq!(d.get("text").unwrap().display(), "response");
1968 let VmValue::List(list) = d.get("tool_calls").unwrap() else {
1969 unreachable!("Expected List for tool_calls");
1970 };
1971 assert_eq!(list.len(), 1);
1972 }
1973
1974 #[test]
1975 fn parse_host_tools_list_accepts_object_wrapper() {
1976 let tools = parse_host_tools_list_response(serde_json::json!({
1977 "tools": [
1978 {
1979 "name": "Read",
1980 "description": "Read a file",
1981 "schema": {"type": "object"},
1982 }
1983 ]
1984 }))
1985 .expect("tool list");
1986
1987 assert_eq!(tools.len(), 1);
1988 assert_eq!(tools[0]["name"], "Read");
1989 assert_eq!(tools[0]["deprecated"], false);
1990 }
1991
1992 #[test]
1993 fn parse_host_tools_list_accepts_compat_fields() {
1994 let tools = parse_host_tools_list_response(serde_json::json!({
1995 "result": {
1996 "tools": [
1997 {
1998 "name": "Edit",
1999 "short_description": "Apply an edit",
2000 "input_schema": {"type": "object"},
2001 "deprecated": true,
2002 }
2003 ]
2004 }
2005 }))
2006 .expect("tool list");
2007
2008 assert_eq!(tools[0]["description"], "Apply an edit");
2009 assert_eq!(tools[0]["schema"]["type"], "object");
2010 assert_eq!(tools[0]["deprecated"], true);
2011 }
2012
2013 #[test]
2014 fn parse_host_tools_list_requires_tool_names() {
2015 let err = parse_host_tools_list_response(serde_json::json!({
2016 "tools": [
2017 {"description": "missing name"}
2018 ]
2019 }))
2020 .expect_err("expected error");
2021 assert!(err
2022 .to_string()
2023 .contains("host/tools/list: every tool must include a string `name`"));
2024 }
2025
2026 #[test]
2027 fn test_timeout_duration() {
2028 assert_eq!(DEFAULT_TIMEOUT.as_secs(), 300);
2029 }
2030}