1use std::collections::{BTreeSet, HashMap};
4use std::path::PathBuf;
5use std::sync::{Arc, Mutex, Weak};
6
7use anyhow::Context;
8use arc_swap::ArcSwap;
9use chrono::Utc;
10use futures::stream::{BoxStream, StreamExt};
11use halter_hooks::{Hooks, RegisteredHooks};
12use halter_protocol::{
13 AssistantMessage, AssistantPart, BlockId, CacheScope, ContentHash, Delivery,
14 HookSessionStartSource, HookWarning, Message, MessageId, ModelId, ObservedState, PendingEvent,
15 PendingToolCall, PromptSegment, PromptSegmentId, PromptSegmentKind, ProviderError,
16 ProviderRequest, ReplayMeta, ResourceSnapshot, SessionBlueprint, SessionEvent,
17 SessionEventPayload, SessionId, SessionState, StopReason, StreamEvent, SubagentEventForwarding,
18 SystemMessage, ToolCall, ToolError, ToolExecutionOutcome, ToolResult, ToolResultMessage, Turn,
19 TurnId, Usage, Volatility,
20};
21use halter_providers::ModelRegistry;
22use halter_session::{SessionStore, StoredSession};
23use halter_tools::{
24 PathLockMap, SubagentControl, SubagentParentContext, ToolEventSink, ToolPolicy, ToolRuntime,
25 ToolRuntimeEvent, ToolSessionStore,
26};
27use sha2::{Digest, Sha256};
28use tokio::sync::mpsc;
29use tokio_stream::wrappers::UnboundedReceiverStream;
30use tokio_util::sync::CancellationToken;
31use tracing::{debug, error, info, warn};
32
33#[cfg(test)]
34use crate::DefaultContextManager;
35use crate::model_selection::select_models;
36use crate::turn_registry::TurnRegistry;
37use crate::{
38 ContextManager, EventBus, ExecutedHookDispatch, HookInvocationContext, PromptAssembler,
39 run_notification, run_post_compact, run_post_tool_use, run_post_tool_use_failure,
40 run_pre_compact, run_pre_tool_use, run_session_end, run_session_start, run_stop,
41 run_user_prompt_submit,
42};
43
44pub type SessionEventStream = BoxStream<'static, anyhow::Result<SessionEvent>>;
46
47const PROVIDER_STREAM_OUTPUT_CAP_BYTES: usize = 4 * 1024 * 1024;
48const PROVIDER_STREAM_EVENT_CAP: usize = 8_192;
49const TOOL_RUNTIME_EVENT_CAP: usize = 4_096;
50const TOOL_RUNTIME_EVENT_BYTES_CAP: usize = 1024 * 1024;
51
52pub struct RuntimeServices {
54 pub resources: Arc<ResourceHandle>,
55 pub registered_hooks: Arc<RegisteredHooks>,
56 pub session_hook_store: Arc<Mutex<HashMap<SessionId, Arc<Hooks>>>>,
57 pub models: Arc<ModelRegistry>,
58 pub tools: Arc<ToolRuntime>,
59 pub path_locks: Arc<PathLockMap>,
60 pub tool_sessions: Arc<ToolSessionStore>,
61 pub sessions: Arc<dyn SessionStore>,
62 pub policy: Arc<dyn ToolPolicy>,
63 pub prompt_assembler: Arc<dyn PromptAssembler>,
64 pub context_manager: Arc<dyn ContextManager>,
65 pub event_bus: Arc<EventBus>,
66 pub parent_streams: Arc<ParentStreamRegistry>,
67 pub turn_registry: Arc<TurnRegistry>,
68 pub subagent_event_forwarding: SubagentEventForwarding,
69 pub subagent_event_forwarding_cap: u64,
70 pub shell_timeout_secs: u64,
71 pub trace_recorder: Option<Arc<crate::TraceRecorder>>,
75}
76
77#[derive(Debug, Clone)]
78pub struct ResourceHandle {
80 current: Arc<ArcSwap<ResourceState>>,
81}
82
83#[derive(Clone, Debug)]
84struct ResourceState {
85 snapshot: ResourceSnapshot,
86 hooks: Arc<Hooks>,
87 hook_warnings: Arc<Vec<HookWarning>>,
88}
89
90impl ResourceHandle {
91 #[must_use]
93 pub fn new(
94 snapshot: ResourceSnapshot,
95 hooks: Arc<Hooks>,
96 hook_warnings: Vec<HookWarning>,
97 ) -> Self {
98 Self {
99 current: Arc::new(ArcSwap::from_pointee(ResourceState {
100 snapshot,
101 hooks,
102 hook_warnings: Arc::new(hook_warnings),
103 })),
104 }
105 }
106
107 #[must_use]
109 pub fn snapshot(&self) -> Arc<ResourceSnapshot> {
110 Arc::new(self.current.load().snapshot.clone())
111 }
112
113 #[must_use]
115 pub fn hooks(&self) -> Arc<Hooks> {
116 self.current.load().hooks.clone()
117 }
118
119 #[must_use]
121 pub fn hook_warnings(&self) -> Arc<Vec<HookWarning>> {
122 self.current.load().hook_warnings.clone()
123 }
124
125 pub fn replace(
127 &self,
128 snapshot: ResourceSnapshot,
129 hooks: Arc<Hooks>,
130 hook_warnings: Vec<HookWarning>,
131 ) {
132 info!(revision = %snapshot.revision, "replaced resource snapshot");
133 self.current.store(Arc::new(ResourceState {
134 snapshot,
135 hooks,
136 hook_warnings: Arc::new(hook_warnings),
137 }));
138 }
139}
140
141#[derive(Debug, Clone)]
142pub struct SessionInit {
144 pub session_id: Option<SessionId>,
145 pub parent_session_id: Option<SessionId>,
146 pub working_dir: PathBuf,
147 pub system_prompt_seed: Vec<PromptSegment>,
148 pub max_turns: Option<u32>,
149 pub default_model: Option<ModelId>,
150 pub subagent_model: Option<ModelId>,
151 pub subagent_event_forwarding: Option<SubagentEventForwarding>,
152 pub subagent_depth: u32,
153}
154
155impl Default for SessionInit {
156 fn default() -> Self {
157 Self {
158 session_id: None,
159 parent_session_id: None,
160 working_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
161 system_prompt_seed: vec![crate::prompt::default_system_prompt_segment()],
162 max_turns: None,
163 default_model: None,
164 subagent_model: None,
165 subagent_event_forwarding: None,
166 subagent_depth: 0,
167 }
168 }
169}
170
171impl SessionInit {
172 #[must_use]
174 pub fn with_default_model(mut self, model: impl Into<ModelId>) -> Self {
175 self.default_model = Some(model.into());
176 self
177 }
178
179 #[must_use]
181 pub fn with_subagent_model(mut self, model: impl Into<ModelId>) -> Self {
182 self.subagent_model = Some(model.into());
183 self
184 }
185
186 #[must_use]
188 pub fn with_subagent_event_forwarding(mut self, mode: SubagentEventForwarding) -> Self {
189 self.subagent_event_forwarding = Some(mode);
190 self
191 }
192}
193
194struct EvictionGuard {
208 services: Arc<RuntimeServices>,
209 session_id: SessionId,
210}
211
212impl Drop for EvictionGuard {
213 fn drop(&mut self) {
214 evict_session_hooks(&self.services, &self.session_id);
215 }
216}
217
218#[derive(Clone)]
222pub struct SessionHandle {
223 services: Arc<RuntimeServices>,
224 session_id: SessionId,
225 session_hooks: Arc<Hooks>,
226 #[allow(dead_code)]
229 eviction: Arc<EvictionGuard>,
230}
231
232pub type HalterSession = SessionHandle;
235
236#[derive(Debug, Default)]
237struct ToolEventBuffer {
238 events: Vec<ToolRuntimeEvent>,
239 bytes: usize,
240 truncated: bool,
241}
242
243struct SessionToolEventSink {
244 buffer: Arc<Mutex<ToolEventBuffer>>,
245}
246
247impl ToolEventSink for SessionToolEventSink {
248 fn emit(&self, event: ToolRuntimeEvent) {
249 let mut buffer = self
250 .buffer
251 .lock()
252 .unwrap_or_else(|poisoned| poisoned.into_inner());
253 let event_bytes = tool_runtime_event_bytes(&event);
254 let over_count = buffer.events.len() >= TOOL_RUNTIME_EVENT_CAP;
255 let over_bytes = buffer.bytes.saturating_add(event_bytes) > TOOL_RUNTIME_EVENT_BYTES_CAP;
256 if over_count || over_bytes {
257 if !buffer.truncated {
258 let tool_name = tool_runtime_event_tool_name(&event).to_owned();
259 let chunk = format!(
260 "\n[tool output truncated after {} bytes]\n",
261 TOOL_RUNTIME_EVENT_BYTES_CAP
262 );
263 buffer.bytes = buffer.bytes.saturating_add(chunk.len());
264 buffer
265 .events
266 .push(ToolRuntimeEvent::ToolOutput { tool_name, chunk });
267 buffer.truncated = true;
268 }
269 return;
270 }
271 buffer.bytes = buffer.bytes.saturating_add(event_bytes);
272 buffer.events.push(event);
273 }
274}
275
276struct ToolEventDrain {
277 buffer: Arc<Mutex<ToolEventBuffer>>,
278}
279
280impl ToolEventDrain {
281 fn into_events(self) -> Vec<ToolRuntimeEvent> {
282 self.buffer
283 .lock()
284 .unwrap_or_else(|poisoned| poisoned.into_inner())
285 .events
286 .drain(..)
287 .collect()
288 }
289}
290
291#[derive(Clone)]
296struct LiveTurnStream {
297 tx: mpsc::UnboundedSender<anyhow::Result<SessionEvent>>,
298 forwarded_event_cap: Option<u64>,
299 forwarded_state: Arc<Mutex<ForwardedEventState>>,
300}
301
302#[derive(Debug, Default)]
303struct ForwardedEventState {
304 forwarded_events: u64,
305 capped: bool,
306}
307
308impl LiveTurnStream {
309 fn new(tx: mpsc::UnboundedSender<anyhow::Result<SessionEvent>>, cap: u64) -> Self {
310 Self {
311 tx,
312 forwarded_event_cap: (cap > 0).then_some(cap),
313 forwarded_state: Arc::new(Mutex::new(ForwardedEventState::default())),
314 }
315 }
316
317 fn emit_committed(&self, event: SessionEvent) {
318 let _ = self.tx.send(Ok(event));
319 }
320
321 fn emit_forwarded(&self, event: SessionEvent) {
322 let should_send_lagged = {
323 let mut state = self
324 .forwarded_state
325 .lock()
326 .unwrap_or_else(|poisoned| poisoned.into_inner());
327 if state.capped {
328 return;
329 }
330 if let Some(cap) = self.forwarded_event_cap
331 && state.forwarded_events >= cap
332 {
333 state.capped = true;
334 true
335 } else {
336 state.forwarded_events = state.forwarded_events.saturating_add(1);
337 false
338 }
339 };
340
341 if should_send_lagged {
342 let _ = self.tx.send(Ok(forwarding_lagged_event()));
343 return;
344 }
345
346 let _ = self.tx.send(Ok(event));
347 }
348
349 fn emit_error(&self, error: anyhow::Error) {
350 let _ = self.tx.send(Err(error));
351 }
352}
353
354fn forwarding_lagged_event() -> SessionEvent {
355 PendingEvent::new(
356 SessionId::from(crate::event_bus::BUS_SESSION_ID),
357 Delivery::BestEffort,
358 SessionEventPayload::Lagged { dropped_events: 1 },
359 )
360 .into_committed(0)
361}
362
363#[derive(Default)]
364pub struct ParentStreamRegistry {
366 active: Mutex<HashMap<SessionId, Vec<Weak<LiveTurnStream>>>>,
367}
368
369impl ParentStreamRegistry {
370 fn register(
371 self: &Arc<Self>,
372 session_id: SessionId,
373 stream: &Arc<LiveTurnStream>,
374 ) -> ParentStreamRegistration {
375 let weak = Arc::downgrade(stream);
376 let mut active = self
377 .active
378 .lock()
379 .unwrap_or_else(|poisoned| poisoned.into_inner());
380 active
381 .entry(session_id.clone())
382 .or_default()
383 .push(weak.clone());
384 ParentStreamRegistration {
385 registry: self.clone(),
386 session_id,
387 stream: weak,
388 }
389 }
390
391 fn forward_to_ancestors(&self, ancestors: &[SessionId], event: &SessionEvent) {
392 if ancestors.is_empty() {
393 return;
394 }
395
396 let streams = {
397 let mut active = self
398 .active
399 .lock()
400 .unwrap_or_else(|poisoned| poisoned.into_inner());
401 let mut streams = Vec::new();
402 let mut empty_keys = Vec::new();
403 for ancestor in ancestors {
404 if let Some(entries) = active.get_mut(ancestor) {
405 entries.retain(|entry| {
406 if let Some(stream) = entry.upgrade() {
407 streams.push(stream);
408 true
409 } else {
410 false
411 }
412 });
413 if entries.is_empty() {
414 empty_keys.push(ancestor.clone());
415 }
416 }
417 }
418 for key in empty_keys {
419 active.remove(&key);
420 }
421 streams
422 };
423
424 for stream in streams {
425 stream.emit_forwarded(event.clone());
426 }
427 }
428
429 fn deregister(&self, session_id: &SessionId, stream: &Weak<LiveTurnStream>) {
430 let mut active = self
431 .active
432 .lock()
433 .unwrap_or_else(|poisoned| poisoned.into_inner());
434 if let Some(entries) = active.get_mut(session_id) {
435 entries.retain(|entry| !Weak::ptr_eq(entry, stream) && entry.strong_count() > 0);
436 if entries.is_empty() {
437 active.remove(session_id);
438 }
439 }
440 }
441}
442
443struct ParentStreamRegistration {
444 registry: Arc<ParentStreamRegistry>,
445 session_id: SessionId,
446 stream: Weak<LiveTurnStream>,
447}
448
449impl Drop for ParentStreamRegistration {
450 fn drop(&mut self) {
451 self.registry.deregister(&self.session_id, &self.stream);
452 }
453}
454
455fn track_fired_hook_ids(fired_hook_ids: &mut BTreeSet<String>, dispatch: &ExecutedHookDispatch) {
456 for fired_hook_id in &dispatch.fired_hook_ids {
457 fired_hook_ids.insert(fired_hook_id.clone());
458 }
459}
460
461#[derive(Clone)]
462pub struct SessionRuntime {
464 services: Arc<RuntimeServices>,
465 subagents: Arc<dyn SubagentControl>,
466}
467
468impl SessionRuntime {
469 #[must_use]
471 pub fn new(services: Arc<RuntimeServices>) -> Self {
472 let subagents: Arc<dyn SubagentControl> = Arc::new(
473 crate::subagents::RuntimeSubagentControl::new(services.clone()),
474 );
475 Self {
476 services,
477 subagents,
478 }
479 }
480
481 #[must_use]
483 pub fn subagent_control(&self) -> Arc<dyn SubagentControl> {
484 self.subagents.clone()
485 }
486
487 pub async fn new_session(&self, init: SessionInit) -> anyhow::Result<HalterSession> {
489 debug!(
490 working_dir = %init.working_dir.display(),
491 parent_session_id = ?init.parent_session_id,
492 max_turns = ?init.max_turns,
493 default_model = ?init.default_model,
494 subagent_model = ?init.subagent_model,
495 subagent_depth = init.subagent_depth,
496 "creating session"
497 );
498 create_session_seeded(
499 self.services.clone(),
500 init,
501 SessionState::default(),
502 self.services.resources.snapshot(),
503 )
504 .await
505 }
506
507 pub async fn resume(&self, session_id: &SessionId) -> anyhow::Result<Option<HalterSession>> {
509 let existing = self.services.sessions.load_session(session_id).await?;
510 debug!(session_id = %session_id, found = existing.is_some(), "resuming session");
511 if let Some(mut stored) = existing {
512 let expected_state = stored.state.clone();
513 stored.state.pending_session_start_source = Some(HookSessionStartSource::Resume);
514 self.services
515 .sessions
516 .commit(
517 session_id,
518 None,
519 Some(expected_state),
520 Some(stored.state),
521 Vec::new(),
522 )
523 .await?;
524 return Ok(Some(HalterSession::new(
525 self.services.clone(),
526 session_id.clone(),
527 )?));
528 }
529 Ok(None)
530 }
531
532 pub async fn list_sessions(&self) -> anyhow::Result<Vec<SessionBlueprint>> {
534 let sessions = self.services.sessions.list_sessions().await?;
535 debug!(session_count = sessions.len(), "listed sessions");
536 Ok(sessions)
537 }
538
539 pub fn replace_resources(
541 &self,
542 snapshot: ResourceSnapshot,
543 hooks: Arc<Hooks>,
544 hook_warnings: Vec<HookWarning>,
545 ) {
546 self.services
547 .resources
548 .replace(snapshot, hooks, hook_warnings);
549 }
550
551 pub async fn shutdown(&self, drain: std::time::Duration) -> crate::ShutdownReport {
559 let report = self.services.turn_registry.shutdown(drain).await;
560 info!(
561 drained = report.turns_drained,
562 aborted = report.turns_aborted,
563 timed_out = report.timed_out,
564 drain_ms = %drain.as_millis(),
565 "runtime shutdown"
566 );
567 report
568 }
569}
570
571impl SessionHandle {
572 pub(crate) fn new(
573 services: Arc<RuntimeServices>,
574 session_id: SessionId,
575 ) -> anyhow::Result<Self> {
576 let session_hooks = lookup_or_create_session_hooks(&services, &session_id)?;
577 let eviction = Arc::new(EvictionGuard {
578 services: services.clone(),
579 session_id: session_id.clone(),
580 });
581 Ok(Self {
582 services,
583 session_id,
584 session_hooks,
585 eviction,
586 })
587 }
588
589 #[must_use]
591 pub fn session_id(&self) -> &SessionId {
592 &self.session_id
593 }
594
595 pub(crate) fn services(&self) -> &Arc<RuntimeServices> {
596 &self.services
597 }
598
599 pub(crate) fn session_hooks(&self) -> &Arc<Hooks> {
600 &self.session_hooks
601 }
602
603 pub async fn submit_turn(&self, turn: Turn) -> anyhow::Result<SessionEventStream> {
605 self.submit_turn_with_cancel(turn, CancellationToken::new())
606 .await
607 }
608
609 pub(crate) async fn submit_turn_with_cancel(
610 &self,
611 turn: Turn,
612 turn_cancel: CancellationToken,
613 ) -> anyhow::Result<SessionEventStream> {
614 info!(
615 session_id = %self.session_id,
616 turn_id = %turn.id,
617 user_part_count = turn.user_message.parts.len(),
618 "submitting turn"
619 );
620 let stored = self
621 .services
622 .sessions
623 .load_session(&self.session_id)
624 .await?
625 .with_context(|| {
626 format!(
627 "failed to submit turn: unknown session '{}'",
628 self.session_id.0
629 )
630 })?;
631 if self.services.turn_registry.is_shutting_down() {
635 anyhow::bail!(
636 "failed to submit turn '{}': runtime is shutting down",
637 turn.id
638 );
639 }
640
641 let (tx, rx) = mpsc::unbounded_channel();
642 let live = Arc::new(LiveTurnStream::new(
643 tx,
644 self.services.subagent_event_forwarding_cap,
645 ));
646 let parent_stream_registration = stored
647 .blueprint
648 .subagent_event_forwarding
649 .is_enabled()
650 .then(|| {
651 self.services
652 .parent_streams
653 .register(self.session_id.clone(), &live)
654 });
655 let session = self.clone();
656
657 let registry = session.services.turn_registry.clone();
658 let turn_id_for_dereg = turn.id.clone();
659 let turn_id_for_register = turn.id.clone();
660 let task_cancel = turn_cancel.clone();
661 let task_cancel_status = turn_cancel.clone();
662 let handle = tokio::spawn(async move {
663 struct DeregisterOnDrop {
666 registry: Arc<TurnRegistry>,
667 turn_id: TurnId,
668 }
669 impl Drop for DeregisterOnDrop {
670 fn drop(&mut self) {
671 self.registry.deregister(&self.turn_id);
672 }
673 }
674 let _guard = DeregisterOnDrop {
675 registry: registry.clone(),
676 turn_id: turn_id_for_dereg,
677 };
678 let _parent_stream_registration = parent_stream_registration;
679
680 let expected_state = stored.state.clone();
681 let started = session.make_event(SessionEventPayload::TurnStarted {
682 turn_id: turn.id.clone(),
683 });
684 if let Err(error) = session
685 .commit_and_publish(
686 None,
687 Some(expected_state),
688 None,
689 vec![started],
690 Some(live.as_ref()),
691 )
692 .await
693 {
694 error!(
695 session_id = %session.session_id,
696 turn_id = %turn.id,
697 error = %error,
698 "failed to commit turn start"
699 );
700 live.emit_error(error);
701 return;
702 }
703
704 match session
705 .run_turn(stored, turn.clone(), task_cancel, live.as_ref())
706 .await
707 {
708 Ok(turn_commit) => {
709 if let Err(error) = session
710 .commit_and_publish(
711 Some(turn_commit.snapshot),
712 Some(turn_commit.expected_state),
713 Some(turn_commit.state),
714 turn_commit.events,
715 Some(live.as_ref()),
716 )
717 .await
718 {
719 error!(
720 session_id = %session.session_id,
721 turn_id = %turn.id,
722 error = %error,
723 "failed to commit successful turn"
724 );
725 live.emit_error(error);
726 }
727 }
728 Err(error) => {
729 let provider_error = error.downcast_ref::<ProviderError>();
730 let retryable = provider_error
731 .map(|provider_error| provider_error.retryable)
732 .unwrap_or(false);
733 let cancelled = task_cancel_status.is_cancelled()
734 || provider_error.is_some_and(ProviderError::is_cancelled);
735 error!(
736 session_id = %session.session_id,
737 turn_id = %turn.id,
738 error = %error,
739 retryable,
740 cancelled,
741 "turn failed before commit"
742 );
743 let failure_events =
744 vec![session.make_event(SessionEventPayload::TurnFailed {
745 turn_id: turn.id.clone(),
746 error: error.to_string(),
747 cancelled,
748 retryable,
749 })];
750 if let Err(commit_error) = session
751 .commit_turn_failure(failure_events, live.as_ref())
752 .await
753 {
754 error!(
755 session_id = %session.session_id,
756 turn_id = %turn.id,
757 error = %commit_error,
758 "failed to commit failed turn"
759 );
760 live.emit_error(commit_error);
761 }
762 }
763 }
764 });
765
766 if let Err(register_error) =
767 self.services
768 .turn_registry
769 .register(turn_id_for_register, turn_cancel, handle)
770 {
771 anyhow::bail!("failed to register turn: {register_error}");
775 }
776
777 Ok(UnboundedReceiverStream::new(rx).boxed())
778 }
779
780 pub async fn replay(&self) -> anyhow::Result<Vec<SessionEvent>> {
782 self.services.sessions.replay(&self.session_id).await
783 }
784
785 pub async fn shutdown(&self, reason: &str) -> anyhow::Result<()> {
787 let stored = self
788 .services
789 .sessions
790 .load_session(&self.session_id)
791 .await?
792 .with_context(|| {
793 format!(
794 "failed to shut down session: unknown session '{}'",
795 self.session_id.0
796 )
797 })?;
798 let expected_state = stored.state.clone();
799 let mut state = stored.state;
800 let mut events = Vec::new();
801 let turn_id = TurnId::new();
802 let hook_ctx = HookInvocationContext {
803 turn_id: &turn_id,
804 model: &stored.blueprint.default_model,
805 working_dir: &stored.blueprint.working_dir,
806 };
807 let fired_hook_ids = state
808 .fired_hook_ids
809 .iter()
810 .cloned()
811 .collect::<BTreeSet<_>>();
812 let dispatch = run_session_end(self, &fired_hook_ids, hook_ctx, reason).await?;
813 self.record_hook_dispatch(&mut events, &dispatch);
814 if dispatch.merged.block_reason.is_some() || dispatch.merged.stop_reason.is_some() {
815 warn!(session_id = %self.session_id, reason, "hooks.ignored_block");
816 }
817 for message in apply_hook_side_effects(&mut state, &dispatch) {
818 self.push_event(&mut events, SessionEventPayload::MessageItem { message });
819 }
820 self.push_event(&mut events, SessionEventPayload::SessionShutdownComplete);
821 let _ = self
822 .commit_and_publish(None, Some(expected_state), Some(state), events, None)
823 .await?;
824 evict_session_hooks(&self.services, &self.session_id);
825 Ok(())
826 }
827
828 pub async fn notify(&self, notification_type: &str, message: &str) -> anyhow::Result<()> {
830 let stored = self
831 .services
832 .sessions
833 .load_session(&self.session_id)
834 .await?
835 .with_context(|| {
836 format!(
837 "failed to emit notification: unknown session '{}'",
838 self.session_id.0
839 )
840 })?;
841 let expected_state = stored.state.clone();
842 let mut state = stored.state;
843 let turn_id = TurnId::new();
844 let hook_ctx = HookInvocationContext {
845 turn_id: &turn_id,
846 model: &stored.blueprint.default_model,
847 working_dir: &stored.blueprint.working_dir,
848 };
849 let fired_hook_ids = state
850 .fired_hook_ids
851 .iter()
852 .cloned()
853 .collect::<BTreeSet<_>>();
854 let dispatch =
855 run_notification(self, &fired_hook_ids, hook_ctx, notification_type, message).await?;
856 let mut events = Vec::new();
857 self.record_hook_dispatch(&mut events, &dispatch);
858 for message in apply_hook_side_effects(&mut state, &dispatch) {
859 self.push_event(&mut events, SessionEventPayload::MessageItem { message });
860 }
861 let _ = self
862 .commit_and_publish(None, Some(expected_state), Some(state), events, None)
863 .await?;
864 Ok(())
865 }
866
867 pub async fn compact(
869 &self,
870 trigger: &str,
871 custom_instructions: Option<&str>,
872 ) -> anyhow::Result<()> {
873 let stored = self
874 .services
875 .sessions
876 .load_session(&self.session_id)
877 .await?
878 .with_context(|| {
879 format!(
880 "failed to compact session: unknown session '{}'",
881 self.session_id.0
882 )
883 })?;
884 let expected_state = stored.state.clone();
885 let mut state = stored.state;
886 let mut events = Vec::new();
887 let turn_id = TurnId::new();
888 let hook_ctx = HookInvocationContext {
889 turn_id: &turn_id,
890 model: &stored.blueprint.default_model,
891 working_dir: &stored.blueprint.working_dir,
892 };
893 let mut fired_hook_ids = state
894 .fired_hook_ids
895 .iter()
896 .cloned()
897 .collect::<BTreeSet<_>>();
898 let pre_dispatch = run_pre_compact(
899 self,
900 &fired_hook_ids,
901 hook_ctx,
902 trigger,
903 custom_instructions,
904 )
905 .await?;
906 track_fired_hook_ids(&mut fired_hook_ids, &pre_dispatch);
907 self.record_hook_dispatch(&mut events, &pre_dispatch);
908 for message in apply_hook_side_effects(&mut state, &pre_dispatch) {
909 self.push_event(&mut events, SessionEventPayload::MessageItem { message });
910 }
911 if pre_dispatch.merged.block_reason.is_some() {
912 let _ = self
913 .commit_and_publish(None, Some(expected_state), Some(state), events, None)
914 .await?;
915 return Ok(());
916 }
917
918 let observed = observe_state(stored.blueprint.working_dir.clone());
919 let compaction_model = self
920 .services
921 .models
922 .model(&stored.blueprint.default_model)?;
923 let compaction_provider = self.services.models.provider(&compaction_model.provider)?;
924 let outcome = self
925 .services
926 .context_manager
927 .compact_now(
928 &stored.blueprint,
929 &state,
930 &observed,
931 stored.snapshot.as_ref(),
932 &self.services.tools.specs(),
933 &compaction_model,
934 compaction_provider.as_ref(),
935 custom_instructions,
936 )
937 .await?;
938 let summary = match outcome.apply(&mut state) {
939 Some(result) => result.summary,
940 None => "No compaction needed.".to_owned(),
941 };
942 self.push_event(
943 &mut events,
944 SessionEventPayload::ContextCompacted {
945 summary: summary.clone(),
946 },
947 );
948
949 let post_dispatch =
950 run_post_compact(self, &fired_hook_ids, hook_ctx, trigger, &summary).await?;
951 self.record_hook_dispatch(&mut events, &post_dispatch);
952 for message in apply_hook_side_effects(&mut state, &post_dispatch) {
953 self.push_event(&mut events, SessionEventPayload::MessageItem { message });
954 }
955
956 let _ = self
957 .commit_and_publish(None, Some(expected_state), Some(state), events, None)
958 .await?;
959 Ok(())
960 }
961
962 async fn run_turn(
963 &self,
964 stored: StoredSession,
965 turn: Turn,
966 turn_cancel: CancellationToken,
967 live: &LiveTurnStream,
968 ) -> anyhow::Result<TurnCommit> {
969 let snapshot = self.services.resources.snapshot();
970 let mut expected_state = stored.state.clone();
971 let mut state = stored.state;
972 let mut events = Vec::new();
973 let mut turn_usage = Usage::default();
974 let mut provider_iterations = 0u32;
975 let mut fired_hook_ids = state
976 .fired_hook_ids
977 .iter()
978 .cloned()
979 .collect::<BTreeSet<_>>();
980 let hook_model = turn
981 .default_model
982 .clone()
983 .unwrap_or_else(|| stored.blueprint.default_model.clone());
984 let hook_ctx = HookInvocationContext {
985 turn_id: &turn.id,
986 model: &hook_model,
987 working_dir: &stored.blueprint.working_dir,
988 };
989
990 for warning in std::mem::take(&mut state.pending_warning_messages) {
991 self.push_event(
992 &mut events,
993 SessionEventPayload::Warning {
994 message: format_hook_warning(&warning),
995 },
996 );
997 }
998
999 if let Some(source) = state.pending_session_start_source.take() {
1000 let hook_dispatch = run_session_start(self, &fired_hook_ids, hook_ctx, source).await?;
1001 track_fired_hook_ids(&mut fired_hook_ids, &hook_dispatch);
1002 self.record_hook_dispatch(&mut events, &hook_dispatch);
1003 if hook_dispatch.merged.block_reason.is_some()
1004 || hook_dispatch.merged.stop_reason.is_some()
1005 {
1006 warn!(
1007 session_id = %self.session_id,
1008 turn_id = %turn.id,
1009 "hooks.ignored_block"
1010 );
1011 }
1012 for message in apply_hook_side_effects(&mut state, &hook_dispatch) {
1013 self.push_event(&mut events, SessionEventPayload::MessageItem { message });
1014 }
1015 }
1016
1017 let prompt_dispatch = run_user_prompt_submit(
1018 self,
1019 &fired_hook_ids,
1020 hook_ctx,
1021 &turn.user_message.plain_text(),
1022 )
1023 .await?;
1024 track_fired_hook_ids(&mut fired_hook_ids, &prompt_dispatch);
1025 self.record_hook_dispatch(&mut events, &prompt_dispatch);
1026 for message in apply_hook_side_effects(&mut state, &prompt_dispatch) {
1027 self.push_event(&mut events, SessionEventPayload::MessageItem { message });
1028 }
1029
1030 if let Some(reason) = prompt_dispatch
1031 .merged
1032 .stop_reason
1033 .clone()
1034 .or_else(|| prompt_dispatch.merged.block_reason.clone())
1035 {
1036 let blocked = Message::System(SystemMessage {
1037 id: MessageId::new(),
1038 created_at: Utc::now(),
1039 text: reason,
1040 });
1041 state.messages.push(blocked.clone());
1042 self.push_event(
1043 &mut events,
1044 SessionEventPayload::MessageItem { message: blocked },
1045 );
1046 events.push(self.make_event(SessionEventPayload::TurnCompleted {
1047 turn_id: turn.id,
1048 usage: turn_usage,
1049 }));
1050 return Ok(TurnCommit {
1051 expected_state,
1052 snapshot,
1053 state,
1054 events,
1055 });
1056 }
1057
1058 let user_message = Message::User(turn.user_message.clone());
1059 state.messages.push(user_message.clone());
1060 self.push_event(
1061 &mut events,
1062 SessionEventPayload::MessageItem {
1063 message: user_message,
1064 },
1065 );
1066 self.flush_turn_progress(
1067 snapshot.clone(),
1068 &mut expected_state,
1069 &state,
1070 &mut events,
1071 live,
1072 )
1073 .await?;
1074
1075 loop {
1076 ensure_provider_iteration_allowed(stored.blueprint.max_turns, provider_iterations)?;
1077 provider_iterations = provider_iterations.saturating_add(1);
1078
1079 let compaction_model = self
1080 .services
1081 .models
1082 .model(&stored.blueprint.default_model)?;
1083 let compaction_provider = self.services.models.provider(&compaction_model.provider)?;
1084 let observed = observe_state(stored.blueprint.working_dir.clone());
1085 let plan = self
1086 .services
1087 .context_manager
1088 .plan(
1089 &stored.blueprint,
1090 &state,
1091 &observed,
1092 snapshot.as_ref(),
1093 &self.services.tools.specs(),
1094 &compaction_model,
1095 compaction_provider.as_ref(),
1096 )
1097 .await?;
1098
1099 let plan_outcome = crate::CompactionOutcome {
1100 messages: plan.messages.clone(),
1101 compacted_prefix: plan.compacted_prefix.clone(),
1102 compaction: plan.compaction.clone(),
1103 session_start_latch: None,
1104 };
1105 if let Some(result) = plan_outcome.apply(&mut state) {
1106 self.push_event(
1107 &mut events,
1108 SessionEventPayload::ContextCompacted {
1109 summary: result.summary,
1110 },
1111 );
1112 }
1113
1114 let prompt = self.services.prompt_assembler.assemble(&plan).await?;
1115
1116 let selected_models = select_models(
1117 &stored.blueprint.default_model,
1118 &stored.blueprint.subagent_model,
1119 turn.default_model.as_ref(),
1120 turn.subagent_model.as_ref(),
1121 );
1122 let model = self.services.models.model(&selected_models.default_model)?;
1123 let subagent_model = self
1124 .services
1125 .models
1126 .model(&selected_models.subagent_model)?;
1127 let provider = self.services.models.provider(&model.provider)?;
1128 let request = ProviderRequest {
1129 session_id: self.session_id.clone(),
1130 turn_id: turn.id.clone(),
1131 model: model.clone(),
1132 prompt,
1133 compacted_prefix: plan.compacted_prefix.clone(),
1134 messages: plan.messages.clone(),
1135 tools: plan.tool_specs.clone(),
1136 previous_response_id: plan.previous_response_id.clone(),
1137 new_messages_start: plan.new_messages_start,
1138 };
1139
1140 let provider_stream = provider.stream(request, turn_cancel.child_token()).await?;
1141 let mut materialized = materialize_assistant_message(provider_stream, &model).await?;
1142 accumulate_usage(&mut state.usage_so_far, &materialized.usage);
1143 accumulate_usage(&mut turn_usage, &materialized.usage);
1144 debug!(
1145 session_id = %self.session_id,
1146 turn_id = %turn.id,
1147 model_id = %model.id,
1148 subagent_model_id = %subagent_model.id,
1149 assistant_part_count = materialized.message.parts.len(),
1150 stop_reason = ?materialized.message.stop_reason,
1151 input_tokens = materialized.usage.input_tokens,
1152 output_tokens = materialized.usage.output_tokens,
1153 "materialized assistant message"
1154 );
1155
1156 let (deduped_parts, duplicate_tool_calls) =
1157 dedupe_assistant_tool_call_parts(std::mem::take(&mut materialized.message.parts));
1158 if duplicate_tool_calls > 0 {
1159 warn!(
1160 session_id = %self.session_id,
1161 turn_id = %turn.id,
1162 duplicate_tool_call_count = duplicate_tool_calls,
1163 "deduped duplicate tool calls from provider output"
1164 );
1165 materialized.message.parts = deduped_parts;
1166 } else {
1167 materialized.message.parts = deduped_parts;
1168 }
1169
1170 let assistant_message = Message::Assistant(materialized.message.clone());
1171 state.messages.push(assistant_message.clone());
1172
1173 if let Some(ref resp_id) = materialized.response_id {
1175 state.last_response_id = Some(resp_id.clone());
1176 state.messages_seen_by_provider = state.messages.len();
1177 }
1178
1179 for payload in materialized.events {
1180 self.push_event(&mut events, payload);
1181 }
1182 self.push_event(
1183 &mut events,
1184 SessionEventPayload::MessageItem {
1185 message: assistant_message,
1186 },
1187 );
1188
1189 let tool_calls = assistant_tool_calls(&materialized.message);
1190 if tool_calls.is_empty() {
1191 let stop_dispatch = run_stop(
1192 self,
1193 &fired_hook_ids,
1194 HookInvocationContext {
1195 turn_id: &turn.id,
1196 model: &model.id,
1197 working_dir: &stored.blueprint.working_dir,
1198 },
1199 Some(&materialized.message),
1200 true,
1201 )
1202 .await?;
1203 track_fired_hook_ids(&mut fired_hook_ids, &stop_dispatch);
1204 self.record_hook_dispatch(&mut events, &stop_dispatch);
1205 for message in apply_hook_side_effects(&mut state, &stop_dispatch) {
1206 self.push_event(&mut events, SessionEventPayload::MessageItem { message });
1207 }
1208 if let Some(reason) = stop_dispatch.merged.block_reason.clone() {
1209 let continuation = Message::User(halter_protocol::UserMessage::text(
1210 if reason.trim().is_empty() {
1211 "Continue."
1212 } else {
1213 &reason
1214 },
1215 ));
1216 state.messages.push(continuation.clone());
1217 self.push_event(
1218 &mut events,
1219 SessionEventPayload::MessageItem {
1220 message: continuation,
1221 },
1222 );
1223 self.flush_turn_progress(
1224 snapshot.clone(),
1225 &mut expected_state,
1226 &state,
1227 &mut events,
1228 live,
1229 )
1230 .await?;
1231 continue;
1232 }
1233
1234 info!(
1235 session_id = %self.session_id,
1236 turn_id = %turn.id,
1237 input_tokens = turn_usage.input_tokens,
1238 output_tokens = turn_usage.output_tokens,
1239 "turn completed without tool calls"
1240 );
1241 events.push(self.make_event(SessionEventPayload::TurnCompleted {
1242 turn_id: turn.id,
1243 usage: turn_usage,
1244 }));
1245 return Ok(TurnCommit {
1246 expected_state,
1247 snapshot,
1248 state,
1249 events,
1250 });
1251 }
1252
1253 info!(
1254 session_id = %self.session_id,
1255 turn_id = %turn.id,
1256 tool_call_count = tool_calls.len(),
1257 "assistant requested tool calls"
1258 );
1259
1260 let tool_events = self
1261 .execute_tool_calls(
1262 &stored.blueprint,
1263 snapshot.clone(),
1264 turn_cancel.child_token(),
1265 &selected_models.default_model,
1266 &selected_models.subagent_model,
1267 &turn.id,
1268 &mut fired_hook_ids,
1269 &mut state,
1270 tool_calls,
1271 )
1272 .await?;
1273 events.extend(tool_events);
1274 self.flush_turn_progress(
1275 snapshot.clone(),
1276 &mut expected_state,
1277 &state,
1278 &mut events,
1279 live,
1280 )
1281 .await?;
1282 }
1283 }
1284
1285 #[expect(clippy::too_many_arguments)]
1286 async fn execute_tool_calls(
1287 &self,
1288 blueprint: &SessionBlueprint,
1289 snapshot: Arc<ResourceSnapshot>,
1290 cancel: CancellationToken,
1291 effective_model: &ModelId,
1292 effective_subagent_model: &ModelId,
1293 turn_id: &halter_protocol::TurnId,
1294 fired_hook_ids: &mut BTreeSet<String>,
1295 state: &mut SessionState,
1296 tool_calls: Vec<ToolCall>,
1297 ) -> anyhow::Result<Vec<PendingEvent>> {
1298 let mut events = Vec::new();
1299
1300 let tools = self.services.tools.clone();
1301 for batch in
1302 batch_tool_calls_by_concurrency(|name| tools.concurrency_for(&name.0), tool_calls)
1303 {
1304 let mut prepared: Vec<PreparedToolCall> = Vec::with_capacity(batch.len());
1307 for mut call in batch {
1308 let pre_dispatch = run_pre_tool_use(
1309 self,
1310 fired_hook_ids,
1311 HookInvocationContext {
1312 turn_id,
1313 model: effective_model,
1314 working_dir: &blueprint.working_dir,
1315 },
1316 &call,
1317 )
1318 .await?;
1319 track_fired_hook_ids(fired_hook_ids, &pre_dispatch);
1320 self.record_hook_dispatch(&mut events, &pre_dispatch);
1321 for message in apply_hook_side_effects(state, &pre_dispatch) {
1322 self.push_event(&mut events, SessionEventPayload::MessageItem { message });
1323 }
1324 if let Some(updated_input) = pre_dispatch.merged.updated_input.clone() {
1325 call.arguments = updated_input;
1326 }
1327 info!(
1328 session_id = %self.session_id,
1329 tool_call_id = %call.id,
1330 tool_name = %call.name,
1331 "executing tool call"
1332 );
1333 self.push_event(
1334 &mut events,
1335 SessionEventPayload::ToolExecutionStarted { call: call.clone() },
1336 );
1337
1338 if let Some(reason) = pre_dispatch.merged.block_reason.clone() {
1339 let error = ToolError::new(reason);
1340 let outcome = ToolExecutionOutcome {
1341 call: call.clone(),
1342 result: Err(error.clone()),
1343 };
1344 let message = Message::Tool(ToolResultMessage {
1345 id: MessageId::new(),
1346 call_id: call.id.clone(),
1347 content: ToolResult::Empty,
1348 error: Some(error),
1349 created_at: Utc::now(),
1350 });
1351 state.messages.push(message.clone());
1352 self.push_event(
1353 &mut events,
1354 SessionEventPayload::ToolExecutionCompleted { outcome },
1355 );
1356 self.push_event(&mut events, SessionEventPayload::MessageItem { message });
1357 continue;
1358 }
1359
1360 let (emit, tool_event_drain) = self.spawn_tool_event_sink();
1361 let context = halter_tools::ToolContext {
1362 session_id: self.session_id.clone(),
1363 working_dir: blueprint.working_dir.clone(),
1364 path_locks: self.services.path_locks.clone(),
1365 tool_sessions: self.services.tool_sessions.clone(),
1366 snapshot: snapshot.clone(),
1367 cancel: cancel.child_token(),
1368 emit,
1369 policy: self.services.policy.clone(),
1370 shell_timeout_secs: self.services.shell_timeout_secs,
1371 subagent_parent: Some(Arc::new(SubagentParentContext {
1372 blueprint: blueprint.clone(),
1373 state: state.clone(),
1374 snapshot: snapshot.clone(),
1375 subagent_model: effective_subagent_model.clone(),
1376 })),
1377 };
1378 state.pending_tool_calls.insert(
1379 call.id.clone(),
1380 PendingToolCall {
1381 call: call.clone(),
1382 submitted_at: Utc::now(),
1383 },
1384 );
1385
1386 prepared.push(PreparedToolCall {
1387 call,
1388 context,
1389 tool_event_drain,
1390 });
1391 }
1392
1393 let tools = self.services.tools.clone();
1396 let executions: Vec<anyhow::Result<ToolResult>> =
1397 futures::future::join_all(prepared.iter().map(|p| {
1398 let tools = tools.clone();
1399 let context = p.context.clone();
1400 let args = p.call.arguments.clone();
1401 let name = p.call.name.0.clone();
1402 async move { tools.execute(&name, context, args).await }
1403 }))
1404 .await;
1405
1406 for (prep, execution) in prepared.into_iter().zip(executions) {
1408 let PreparedToolCall {
1409 call,
1410 context,
1411 tool_event_drain,
1412 } = prep;
1413 drop(context);
1414 for payload in tool_event_drain
1415 .into_events()
1416 .into_iter()
1417 .filter_map(|event| tool_runtime_event_payload(&call.id, event))
1418 {
1419 self.push_event(&mut events, payload);
1420 }
1421
1422 let (mut content, error) = match execution {
1423 Ok(result) => {
1424 debug!(
1425 session_id = %self.session_id,
1426 tool_call_id = %call.id,
1427 tool_name = %call.name,
1428 result_kind = tool_result_kind(&result),
1429 "tool call completed"
1430 );
1431 (result, None)
1432 }
1433 Err(error) => {
1434 warn!(
1435 session_id = %self.session_id,
1436 tool_call_id = %call.id,
1437 tool_name = %call.name,
1438 error = %error,
1439 "tool call failed"
1440 );
1441 (ToolResult::Empty, Some(ToolError::new(error.to_string())))
1442 }
1443 };
1444 if error.is_none() {
1445 let post_dispatch = run_post_tool_use(
1446 self,
1447 fired_hook_ids,
1448 HookInvocationContext {
1449 turn_id,
1450 model: effective_model,
1451 working_dir: &blueprint.working_dir,
1452 },
1453 &call,
1454 &content,
1455 )
1456 .await?;
1457 track_fired_hook_ids(fired_hook_ids, &post_dispatch);
1458 self.record_hook_dispatch(&mut events, &post_dispatch);
1459 for message in apply_hook_side_effects(state, &post_dispatch) {
1460 self.push_event(&mut events, SessionEventPayload::MessageItem { message });
1461 }
1462 if let Some(updated_output) = post_dispatch.merged.updated_output {
1463 content = tool_result_from_hook_value(updated_output);
1464 }
1465 } else if let Some(tool_error) = error.as_ref() {
1466 let post_dispatch = run_post_tool_use_failure(
1467 self,
1468 fired_hook_ids,
1469 HookInvocationContext {
1470 turn_id,
1471 model: effective_model,
1472 working_dir: &blueprint.working_dir,
1473 },
1474 &call,
1475 tool_error,
1476 )
1477 .await?;
1478 track_fired_hook_ids(fired_hook_ids, &post_dispatch);
1479 self.record_hook_dispatch(&mut events, &post_dispatch);
1480 for message in apply_hook_side_effects(state, &post_dispatch) {
1481 self.push_event(&mut events, SessionEventPayload::MessageItem { message });
1482 }
1483 }
1484 let outcome = ToolExecutionOutcome {
1485 call: call.clone(),
1486 result: error.clone().map_or_else(|| Ok(content.clone()), Err),
1487 };
1488 let message = Message::Tool(ToolResultMessage {
1489 id: MessageId::new(),
1490 call_id: call.id.clone(),
1491 content,
1492 error,
1493 created_at: Utc::now(),
1494 });
1495
1496 state.pending_tool_calls.shift_remove(&call.id);
1497 state.messages.push(message.clone());
1498 self.push_event(
1499 &mut events,
1500 SessionEventPayload::ToolExecutionCompleted { outcome },
1501 );
1502 self.push_event(&mut events, SessionEventPayload::MessageItem { message });
1503 }
1504 }
1505
1506 Ok(events)
1507 }
1508
1509 async fn commit_and_publish(
1510 &self,
1511 snapshot: Option<Arc<ResourceSnapshot>>,
1512 expected_state: Option<SessionState>,
1513 state: Option<SessionState>,
1514 events: Vec<PendingEvent>,
1515 live: Option<&LiveTurnStream>,
1516 ) -> anyhow::Result<Vec<SessionEvent>> {
1517 debug!(
1518 session_id = %self.session_id,
1519 event_count = events.len(),
1520 replace_snapshot = snapshot.is_some(),
1521 check_expected_state = expected_state.is_some(),
1522 replace_state = state.is_some(),
1523 "committing session events"
1524 );
1525 let committed = self
1526 .services
1527 .sessions
1528 .commit(&self.session_id, snapshot, expected_state, state, events)
1529 .await?;
1530 let forwarding_ancestors =
1531 forwarding_ancestors_for_session(&self.services, &self.session_id).await;
1532 for event in &committed {
1537 if let Some(live) = live {
1538 live.emit_committed(event.clone());
1539 }
1540 self.services
1541 .parent_streams
1542 .forward_to_ancestors(&forwarding_ancestors, event);
1543 self.services.event_bus.publish(event.clone());
1544 if let Some(recorder) = &self.services.trace_recorder {
1545 recorder.record(event);
1546 }
1547 }
1548 Ok(committed)
1549 }
1550
1551 async fn flush_turn_progress(
1552 &self,
1553 snapshot: Arc<ResourceSnapshot>,
1554 expected_state: &mut SessionState,
1555 state: &SessionState,
1556 events: &mut Vec<PendingEvent>,
1557 live: &LiveTurnStream,
1558 ) -> anyhow::Result<()> {
1559 if events.is_empty() && state == expected_state {
1560 return Ok(());
1561 }
1562
1563 self.commit_and_publish(
1564 Some(snapshot),
1565 Some(expected_state.clone()),
1566 Some(state.clone()),
1567 events.clone(),
1568 Some(live),
1569 )
1570 .await?;
1571 events.clear();
1572 *expected_state = state.clone();
1573 Ok(())
1574 }
1575
1576 async fn commit_turn_failure(
1577 &self,
1578 failure_events: Vec<PendingEvent>,
1579 live: &LiveTurnStream,
1580 ) -> anyhow::Result<()> {
1581 let stored = self
1582 .services
1583 .sessions
1584 .load_session(&self.session_id)
1585 .await?
1586 .with_context(|| {
1587 format!(
1588 "failed to commit failed turn: unknown session '{}'",
1589 self.session_id.0
1590 )
1591 })?;
1592 self.commit_and_publish(None, Some(stored.state), None, failure_events, Some(live))
1593 .await?;
1594 Ok(())
1595 }
1596
1597 fn spawn_tool_event_sink(&self) -> (Arc<dyn ToolEventSink>, ToolEventDrain) {
1598 let buffer = Arc::new(Mutex::new(ToolEventBuffer::default()));
1599 (
1600 Arc::new(SessionToolEventSink {
1601 buffer: buffer.clone(),
1602 }) as Arc<dyn ToolEventSink>,
1603 ToolEventDrain { buffer },
1604 )
1605 }
1606
1607 fn push_event(&self, events: &mut Vec<PendingEvent>, payload: SessionEventPayload) {
1608 events.push(self.make_event(payload));
1609 }
1610
1611 fn record_hook_dispatch(
1612 &self,
1613 events: &mut Vec<PendingEvent>,
1614 dispatch: &ExecutedHookDispatch,
1615 ) {
1616 for run in &dispatch.preview_runs {
1617 self.push_event(
1618 events,
1619 SessionEventPayload::HookStarted { run: run.clone() },
1620 );
1621 }
1622 for run in &dispatch.completed_runs {
1623 self.push_event(
1624 events,
1625 SessionEventPayload::HookCompleted { run: run.clone() },
1626 );
1627 }
1628 }
1629
1630 fn make_event(&self, payload: SessionEventPayload) -> PendingEvent {
1631 let pending = PendingEvent::new(self.session_id.clone(), Delivery::Lossless, payload);
1632 if let Some(recorder) = &self.services.trace_recorder {
1637 recorder.record_pending(&pending);
1638 }
1639 pending
1640 }
1641}
1642
1643fn tool_runtime_event_payload(
1644 call_id: &halter_protocol::ToolCallId,
1645 event: ToolRuntimeEvent,
1646) -> Option<SessionEventPayload> {
1647 match event {
1648 ToolRuntimeEvent::ToolOutput { tool_name, chunk } => {
1649 Some(SessionEventPayload::ToolOutput {
1650 call_id: call_id.clone(),
1651 tool_name: tool_name.into(),
1652 chunk,
1653 })
1654 }
1655 ToolRuntimeEvent::Started { .. } | ToolRuntimeEvent::Completed { .. } => None,
1656 }
1657}
1658
1659fn tool_runtime_event_tool_name(event: &ToolRuntimeEvent) -> &str {
1660 match event {
1661 ToolRuntimeEvent::Started { tool_name }
1662 | ToolRuntimeEvent::Completed { tool_name }
1663 | ToolRuntimeEvent::ToolOutput { tool_name, .. } => tool_name,
1664 }
1665}
1666
1667fn tool_runtime_event_bytes(event: &ToolRuntimeEvent) -> usize {
1668 match event {
1669 ToolRuntimeEvent::Started { tool_name } | ToolRuntimeEvent::Completed { tool_name } => {
1670 tool_name.len()
1671 }
1672 ToolRuntimeEvent::ToolOutput { tool_name, chunk } => tool_name.len() + chunk.len(),
1673 }
1674}
1675
1676#[derive(Debug)]
1677struct TurnCommit {
1678 expected_state: SessionState,
1679 snapshot: Arc<ResourceSnapshot>,
1680 state: SessionState,
1681 events: Vec<PendingEvent>,
1682}
1683
1684#[derive(Debug, Default)]
1685struct PendingThinkingBlock {
1686 text: String,
1687 signature: Option<String>,
1688}
1689
1690#[derive(Debug)]
1691struct PendingToolCallBlock {
1692 tool_call_id: halter_protocol::ToolCallId,
1693 name: halter_protocol::ToolName,
1694 arguments: String,
1695}
1696
1697#[derive(Debug)]
1698pub(crate) struct MaterializedAssistantMessage {
1699 pub(crate) message: AssistantMessage,
1700 pub(crate) usage: Usage,
1701 pub(crate) events: Vec<SessionEventPayload>,
1702 pub(crate) response_id: Option<String>,
1704}
1705
1706pub(crate) async fn materialize_assistant_message(
1707 mut provider_stream: BoxStream<'static, Result<StreamEvent, ProviderError>>,
1708 model: &halter_protocol::ResolvedModel,
1709) -> anyhow::Result<MaterializedAssistantMessage> {
1710 debug!(provider = %model.provider, model = %model.model, "materializing provider stream");
1711 let mut message_id = MessageId::new();
1712 let mut usage = Usage::default();
1713 let mut stop_reason = StopReason::EndTurn;
1714 let mut parts = Vec::new();
1715 let mut text_buffer = String::new();
1716 let mut delta_events = Vec::new();
1717 let mut accumulated_output_bytes = 0usize;
1718 let mut thinking_block: Option<PendingThinkingBlock> = None;
1719 let mut tool_call_blocks: std::collections::BTreeMap<BlockId, PendingToolCallBlock> =
1720 std::collections::BTreeMap::new();
1721 let mut captured_response_id: Option<String> = None;
1722
1723 while let Some(item) = provider_stream.next().await {
1724 match item {
1725 Ok(StreamEvent::MessageStart { id }) => {
1726 message_id = id;
1727 }
1728 Ok(StreamEvent::TextStart { .. }) => {}
1729 Ok(StreamEvent::TextDelta { delta, .. }) => {
1730 append_provider_stream_chunk(
1731 "text",
1732 &mut text_buffer,
1733 &delta,
1734 &mut accumulated_output_bytes,
1735 )?;
1736 if delta_events.len() >= PROVIDER_STREAM_EVENT_CAP {
1737 anyhow::bail!(
1738 "provider stream text exceeded event cap: {} events",
1739 PROVIDER_STREAM_EVENT_CAP
1740 );
1741 }
1742 delta_events.push(SessionEventPayload::DeltaItem {
1743 delta: halter_protocol::DeltaItem { text: delta },
1744 });
1745 }
1746 Ok(StreamEvent::TextEnd { .. }) => {
1747 flush_text_buffer(&mut parts, &mut text_buffer);
1748 }
1749 Ok(StreamEvent::ThinkingStart { .. }) => {
1750 thinking_block = Some(PendingThinkingBlock::default());
1751 }
1752 Ok(StreamEvent::ThinkingDelta { delta, .. }) => {
1753 let thinking = thinking_block.get_or_insert_with(PendingThinkingBlock::default);
1754 append_provider_stream_chunk(
1755 "thinking",
1756 &mut thinking.text,
1757 &delta,
1758 &mut accumulated_output_bytes,
1759 )?;
1760 }
1761 Ok(StreamEvent::ThinkingEnd { signature, .. }) => {
1762 if let Some(mut thinking) = thinking_block.take() {
1763 thinking.signature = signature;
1764 parts.push(AssistantPart::Thinking(halter_protocol::ThinkingBlock {
1765 text: thinking.text,
1766 signature: thinking.signature,
1767 }));
1768 }
1769 }
1770 Ok(StreamEvent::ToolCallStart {
1771 id,
1772 tool_call_id,
1773 name,
1774 }) => {
1775 flush_text_buffer(&mut parts, &mut text_buffer);
1776 tool_call_blocks.insert(
1777 id,
1778 PendingToolCallBlock {
1779 tool_call_id,
1780 name,
1781 arguments: String::new(),
1782 },
1783 );
1784 }
1785 Ok(StreamEvent::ToolArgsDelta { id, delta }) => {
1786 let pending = tool_call_blocks.get_mut(&id).with_context(|| {
1787 format!("failed to materialize tool call: missing block '{}'", id)
1788 })?;
1789 append_provider_stream_chunk(
1790 "tool arguments",
1791 &mut pending.arguments,
1792 &delta,
1793 &mut accumulated_output_bytes,
1794 )?;
1795 }
1796 Ok(StreamEvent::ToolCallEnd { id }) => {
1797 let pending = tool_call_blocks.remove(&id).with_context(|| {
1798 format!("failed to materialize tool call: missing block '{}'", id)
1799 })?;
1800 let arguments = parse_tool_call_arguments(&pending.arguments)?;
1801 parts.push(AssistantPart::ToolCall(ToolCall {
1802 id: pending.tool_call_id,
1803 name: pending.name,
1804 arguments,
1805 }));
1806 }
1807 Ok(StreamEvent::UsageUpdate { usage: updated }) => {
1808 usage = updated;
1809 }
1810 Ok(StreamEvent::MessageEnd {
1811 stop_reason: ended_reason,
1812 response_id: resp_id,
1813 ..
1814 }) => {
1815 stop_reason = ended_reason;
1816 if resp_id.is_some() {
1817 captured_response_id = resp_id;
1818 }
1819 }
1820 Ok(StreamEvent::ProviderWarning { message }) => {
1821 warn!(provider = %model.provider, message = %message, "provider emitted warning");
1822 }
1823 Ok(StreamEvent::Error { error }) | Err(error) => {
1824 error!(provider = %model.provider, error = %error.message, "provider stream failed");
1825 return Err(anyhow::Error::new(error));
1826 }
1827 }
1828 }
1829
1830 flush_text_buffer(&mut parts, &mut text_buffer);
1831 for (block_id, pending) in std::mem::take(&mut tool_call_blocks) {
1843 let arguments = match parse_tool_call_arguments(&pending.arguments) {
1844 Ok(value) => value,
1845 Err(error) => {
1846 warn!(
1847 provider = %model.provider,
1848 model = %model.model,
1849 tool_call_id = %pending.tool_call_id,
1850 block_id = %block_id,
1851 raw_arguments = %pending.arguments,
1852 %error,
1853 "stream ended with unterminated tool call whose arguments failed to parse; substituting empty object"
1854 );
1855 serde_json::json!({})
1856 }
1857 };
1858 warn!(
1859 provider = %model.provider,
1860 model = %model.model,
1861 tool_call_id = %pending.tool_call_id,
1862 block_id = %block_id,
1863 tool_name = %pending.name,
1864 "stream ended without ToolCallEnd; auto-closing tool call block"
1865 );
1866 parts.push(AssistantPart::ToolCall(ToolCall {
1867 id: pending.tool_call_id,
1868 name: pending.name,
1869 arguments,
1870 }));
1871 }
1872 debug!(
1873 provider = %model.provider,
1874 model = %model.model,
1875 message_id = %message_id,
1876 part_count = parts.len(),
1877 stop_reason = ?stop_reason,
1878 "finished materializing provider stream"
1879 );
1880
1881 Ok(MaterializedAssistantMessage {
1882 message: AssistantMessage {
1883 id: message_id,
1884 created_at: Utc::now(),
1885 parts,
1886 stop_reason: Some(stop_reason),
1887 usage: Some(usage.clone()),
1888 replay_meta: ReplayMeta {
1889 provider_name: Some(model.provider.clone()),
1890 model: Some(model.id.clone()),
1891 },
1892 },
1893 usage,
1894 events: delta_events,
1895 response_id: captured_response_id,
1896 })
1897}
1898
1899fn flush_text_buffer(parts: &mut Vec<AssistantPart>, text_buffer: &mut String) {
1900 if text_buffer.is_empty() {
1901 return;
1902 }
1903
1904 parts.push(AssistantPart::Text {
1905 text: std::mem::take(text_buffer),
1906 });
1907}
1908
1909fn parse_tool_call_arguments(arguments: &str) -> anyhow::Result<serde_json::Value> {
1910 if arguments.trim().is_empty() {
1911 return Ok(serde_json::json!({}));
1912 }
1913
1914 serde_json::from_str(arguments)
1915 .with_context(|| "failed to materialize tool call: invalid json arguments")
1916}
1917
1918fn append_provider_stream_chunk(
1919 label: &str,
1920 target: &mut String,
1921 chunk: &str,
1922 accumulated_output_bytes: &mut usize,
1923) -> anyhow::Result<()> {
1924 let observed = accumulated_output_bytes.saturating_add(chunk.len());
1925 if observed > PROVIDER_STREAM_OUTPUT_CAP_BYTES {
1926 anyhow::bail!(
1927 "provider stream {label} exceeded output cap: {observed} bytes (cap {PROVIDER_STREAM_OUTPUT_CAP_BYTES})"
1928 );
1929 }
1930 target.push_str(chunk);
1931 *accumulated_output_bytes = observed;
1932 Ok(())
1933}
1934
1935fn assistant_tool_calls(message: &AssistantMessage) -> Vec<ToolCall> {
1936 message
1937 .parts
1938 .iter()
1939 .filter_map(|part| match part {
1940 AssistantPart::ToolCall(call) => Some(call.clone()),
1941 AssistantPart::Text { .. } | AssistantPart::Thinking(_) => None,
1942 })
1943 .collect()
1944}
1945
1946fn dedupe_assistant_tool_call_parts(parts: Vec<AssistantPart>) -> (Vec<AssistantPart>, usize) {
1947 let mut deduped = Vec::with_capacity(parts.len());
1948 let mut seen_tool_call_ids = BTreeSet::new();
1949 let mut duplicate_count = 0;
1950
1951 for part in parts {
1952 match &part {
1953 AssistantPart::ToolCall(call) => {
1954 if !seen_tool_call_ids.insert(call.id.clone()) {
1955 duplicate_count += 1;
1956 continue;
1957 }
1958 }
1959 AssistantPart::Text { .. } | AssistantPart::Thinking(_) => {}
1960 }
1961 deduped.push(part);
1962 }
1963
1964 (deduped, duplicate_count)
1965}
1966
1967fn ensure_provider_iteration_allowed(
1968 max_turns: Option<u32>,
1969 completed_iterations: u32,
1970) -> anyhow::Result<()> {
1971 if let Some(max_turns) = max_turns
1972 && completed_iterations >= max_turns
1973 {
1974 anyhow::bail!(
1975 "failed to run turn: max_turns {max_turns} exhausted before provider iteration {}",
1976 completed_iterations.saturating_add(1)
1977 );
1978 }
1979 Ok(())
1980}
1981
1982fn accumulate_usage(total: &mut Usage, delta: &Usage) {
1983 total.input_tokens += delta.input_tokens;
1984 total.output_tokens += delta.output_tokens;
1985 total.cache_creation_input_tokens += delta.cache_creation_input_tokens;
1986 total.cache_read_input_tokens += delta.cache_read_input_tokens;
1987}
1988
1989fn tool_result_kind(result: &ToolResult) -> &'static str {
1990 match result {
1991 ToolResult::Empty => "empty",
1992 ToolResult::Text { .. } => "text",
1993 ToolResult::Json { .. } => "json",
1994 }
1995}
1996
1997pub(crate) fn apply_hook_side_effects(
1998 state: &mut SessionState,
1999 dispatch: &ExecutedHookDispatch,
2000) -> Vec<Message> {
2001 for fired_hook_id in &dispatch.fired_hook_ids {
2002 if !state
2003 .fired_hook_ids
2004 .iter()
2005 .any(|seen| seen == fired_hook_id)
2006 {
2007 state.fired_hook_ids.push(fired_hook_id.clone());
2008 }
2009 }
2010
2011 for context in &dispatch.merged.additional_context {
2012 state
2013 .appended_prompt_segments
2014 .push(build_hook_prompt_segment(context));
2015 }
2016
2017 let mut messages = Vec::new();
2018 for text in &dispatch.merged.system_messages {
2019 let message = Message::System(SystemMessage {
2020 id: MessageId::new(),
2021 created_at: Utc::now(),
2022 text: text.clone(),
2023 });
2024 state.messages.push(message.clone());
2025 messages.push(message);
2026 }
2027
2028 messages
2029}
2030
2031fn build_hook_prompt_segment(text: &str) -> PromptSegment {
2032 PromptSegment {
2033 id: PromptSegmentId::new(),
2034 text: text.to_owned(),
2035 volatility: Volatility::TurnDynamic,
2036 cache_scope: CacheScope::Dynamic,
2037 content_hash: hash_text(text),
2038 kind: PromptSegmentKind::Append,
2039 }
2040}
2041
2042fn hash_text(text: &str) -> ContentHash {
2043 let mut hasher = Sha256::new();
2044 hasher.update(text.as_bytes());
2045 format!("{:x}", hasher.finalize())
2046}
2047
2048fn tool_result_from_hook_value(value: serde_json::Value) -> ToolResult {
2049 match value {
2050 serde_json::Value::Null => ToolResult::Empty,
2051 serde_json::Value::String(text) => ToolResult::Text { text },
2052 other => ToolResult::Json { value: other },
2053 }
2054}
2055
2056async fn forwarding_ancestors_for_session(
2057 services: &Arc<RuntimeServices>,
2058 session_id: &SessionId,
2059) -> Vec<SessionId> {
2060 let stored = match services.sessions.load_session(session_id).await {
2061 Ok(Some(stored)) => stored,
2062 Ok(None) => return Vec::new(),
2063 Err(error) => {
2064 warn!(
2065 session_id = %session_id,
2066 error = %error,
2067 "failed to load session for subagent event forwarding"
2068 );
2069 return Vec::new();
2070 }
2071 };
2072
2073 forwarding_ancestors_for_blueprint(services, &stored.blueprint).await
2074}
2075
2076async fn forwarding_ancestors_for_blueprint(
2077 services: &Arc<RuntimeServices>,
2078 blueprint: &SessionBlueprint,
2079) -> Vec<SessionId> {
2080 if !blueprint.subagent_event_forwarding.is_enabled() {
2081 return Vec::new();
2082 }
2083
2084 let mut ancestors = Vec::new();
2085 let mut seen = BTreeSet::new();
2086 let mut next = blueprint.parent_session_id.clone();
2087 while let Some(session_id) = next {
2088 if !seen.insert(session_id.clone()) {
2089 warn!(
2090 session_id = %blueprint.session_id,
2091 ancestor_session_id = %session_id,
2092 "stopped subagent event forwarding ancestor walk at cycle"
2093 );
2094 break;
2095 }
2096
2097 ancestors.push(session_id.clone());
2098 next = match services.sessions.load_session(&session_id).await {
2099 Ok(Some(stored)) => stored.blueprint.parent_session_id,
2100 Ok(None) => {
2101 warn!(
2102 session_id = %blueprint.session_id,
2103 ancestor_session_id = %session_id,
2104 "stopped subagent event forwarding ancestor walk at missing parent"
2105 );
2106 break;
2107 }
2108 Err(error) => {
2109 warn!(
2110 session_id = %blueprint.session_id,
2111 ancestor_session_id = %session_id,
2112 error = %error,
2113 "stopped subagent event forwarding ancestor walk after load failure"
2114 );
2115 break;
2116 }
2117 };
2118 }
2119
2120 ancestors
2121}
2122
2123pub(crate) async fn create_session_seeded(
2124 services: Arc<RuntimeServices>,
2125 init: SessionInit,
2126 mut initial_state: SessionState,
2127 snapshot: Arc<ResourceSnapshot>,
2128) -> anyhow::Result<HalterSession> {
2129 let default_registry_model = services.models.default_model()?;
2130 let subagent_registry_model = services.models.subagent_model()?;
2131 let selected_models = select_models(
2132 &default_registry_model.id,
2133 &subagent_registry_model.id,
2134 init.default_model.as_ref(),
2135 init.subagent_model.as_ref(),
2136 );
2137 services.models.model(&selected_models.default_model)?;
2138 services.models.model(&selected_models.subagent_model)?;
2139 let session_id = init.session_id.unwrap_or_default();
2140 let subagent_event_forwarding = init
2141 .subagent_event_forwarding
2142 .unwrap_or(services.subagent_event_forwarding);
2143 let blueprint = SessionBlueprint {
2144 session_id: session_id.clone(),
2145 parent_session_id: init.parent_session_id,
2146 default_model: selected_models.default_model,
2147 subagent_model: selected_models.subagent_model,
2148 subagent_event_forwarding,
2149 snapshot_revision: snapshot.revision.clone(),
2150 working_dir: init.working_dir,
2151 system_prompt_seed: init.system_prompt_seed,
2152 max_turns: init.max_turns,
2153 subagent_depth: init.subagent_depth,
2154 };
2155 info!(
2156 session_id = %session_id,
2157 default_model = %blueprint.default_model,
2158 subagent_model = %blueprint.subagent_model,
2159 subagent_event_forwarding = ?blueprint.subagent_event_forwarding,
2160 working_dir = %blueprint.working_dir.display(),
2161 snapshot_revision = %blueprint.snapshot_revision,
2162 "created session blueprint"
2163 );
2164
2165 if initial_state.pending_session_start_source.is_none() {
2166 initial_state.pending_session_start_source = Some(HookSessionStartSource::Startup);
2167 }
2168 if initial_state.pending_warning_messages.is_empty() {
2169 initial_state.pending_warning_messages =
2170 services.resources.hook_warnings().as_ref().clone();
2171 }
2172
2173 services
2174 .sessions
2175 .create_session(StoredSession {
2176 blueprint: blueprint.clone(),
2177 state: initial_state,
2178 snapshot,
2179 })
2180 .await?;
2181
2182 if let Some(recorder) = &services.trace_recorder {
2183 recorder.open_session(
2184 &session_id,
2185 blueprint.parent_session_id.as_ref(),
2186 &blueprint,
2187 )?;
2188 }
2189
2190 let started = PendingEvent::new(
2191 session_id.clone(),
2192 Delivery::Lossless,
2193 SessionEventPayload::SessionStarted,
2194 );
2195 let committed = services
2196 .sessions
2197 .commit(&session_id, None, None, None, vec![started])
2198 .await?;
2199 let forwarding_ancestors = forwarding_ancestors_for_blueprint(&services, &blueprint).await;
2200 for event in committed {
2201 if let Some(recorder) = &services.trace_recorder {
2202 recorder.record(&event);
2203 }
2204 services
2205 .parent_streams
2206 .forward_to_ancestors(&forwarding_ancestors, &event);
2207 services.event_bus.publish(event);
2208 }
2209
2210 HalterSession::new(services, session_id)
2211}
2212
2213fn lookup_or_create_session_hooks(
2214 services: &Arc<RuntimeServices>,
2215 session_id: &SessionId,
2216) -> anyhow::Result<Arc<Hooks>> {
2217 let hooks = Arc::new(services.registered_hooks.instantiate()?);
2218 match services.session_hook_store.lock() {
2219 Ok(mut store) => {
2220 if let Some(existing) = store.get(session_id) {
2221 return Ok(existing.clone());
2222 }
2223 store.insert(session_id.clone(), hooks.clone());
2224 Ok(hooks)
2225 }
2226 Err(_) => {
2227 error!(
2228 session_id = %session_id,
2229 "session hook store lock poisoned; rebuilding uncached session hooks"
2230 );
2231 Ok(hooks)
2232 }
2233 }
2234}
2235
2236fn evict_session_hooks(services: &Arc<RuntimeServices>, session_id: &SessionId) {
2237 match services.session_hook_store.lock() {
2238 Ok(mut store) => {
2239 store.remove(session_id);
2240 }
2241 Err(_) => {
2242 error!(
2243 session_id = %session_id,
2244 "session hook store lock poisoned; skipping session hook eviction"
2245 );
2246 }
2247 }
2248}
2249
2250fn format_hook_warning(warning: &HookWarning) -> String {
2251 let mut prefix = String::new();
2252 if let Some(plugin_name) = warning.plugin_name.as_deref() {
2253 prefix.push_str(&format!("plugin '{plugin_name}' "));
2254 }
2255 prefix.push_str("hook warning");
2256 if !warning.category.trim().is_empty() {
2257 prefix.push_str(&format!(" [{}]", warning.category));
2258 }
2259 if let Some(source_path) = warning.source_path.as_ref() {
2260 prefix.push_str(&format!(" at {}", source_path.display()));
2261 }
2262 format!("{prefix}: {}", warning.message)
2263}
2264
2265fn observe_state(working_dir: PathBuf) -> ObservedState {
2266 let (git_branch, git_dirty) = probe_git(&working_dir);
2267 ObservedState {
2268 cwd: working_dir,
2269 git_branch,
2270 git_dirty,
2271 now_utc: Utc::now(),
2272 env_facts: Default::default(),
2273 }
2274}
2275
2276struct PreparedToolCall {
2277 call: ToolCall,
2278 context: halter_tools::ToolContext,
2279 tool_event_drain: ToolEventDrain,
2280}
2281
2282fn batch_tool_calls_by_concurrency(
2294 mut resolve: impl FnMut(&halter_protocol::ToolName) -> Option<halter_protocol::ToolConcurrency>,
2295 tool_calls: Vec<ToolCall>,
2296) -> Vec<Vec<ToolCall>> {
2297 use halter_protocol::ToolConcurrency;
2298
2299 let mut concurrency_of = |call: &ToolCall| -> ToolConcurrency {
2300 resolve(&call.name).unwrap_or(ToolConcurrency::Exclusive)
2301 };
2302
2303 let mut batches: Vec<Vec<ToolCall>> = Vec::new();
2304 let mut current: Vec<ToolCall> = Vec::new();
2305
2306 for call in tool_calls {
2307 let concurrency = concurrency_of(&call);
2308 if matches!(concurrency, ToolConcurrency::Exclusive) {
2309 if !current.is_empty() {
2310 batches.push(std::mem::take(&mut current));
2311 }
2312 batches.push(vec![call]);
2313 } else {
2314 current.push(call);
2315 }
2316 }
2317 if !current.is_empty() {
2318 batches.push(current);
2319 }
2320 batches
2321}
2322
2323fn probe_git(working_dir: &std::path::Path) -> (Option<String>, Option<bool>) {
2330 use std::process::Command;
2331
2332 let branch_output = Command::new("git")
2333 .args(["rev-parse", "--abbrev-ref", "HEAD"])
2334 .current_dir(working_dir)
2335 .output();
2336 let branch = match branch_output {
2337 Ok(out) if out.status.success() => {
2338 let name = String::from_utf8_lossy(&out.stdout).trim().to_owned();
2339 if name.is_empty() {
2340 return (None, None);
2341 }
2342 if name == "HEAD" {
2343 Command::new("git")
2345 .args(["rev-parse", "--short", "HEAD"])
2346 .current_dir(working_dir)
2347 .output()
2348 .ok()
2349 .filter(|out| out.status.success())
2350 .map(|out| String::from_utf8_lossy(&out.stdout).trim().to_owned())
2351 .filter(|s| !s.is_empty())
2352 } else {
2353 Some(name)
2354 }
2355 }
2356 _ => return (None, None),
2357 };
2358
2359 let dirty = Command::new("git")
2360 .args(["status", "--porcelain"])
2361 .current_dir(working_dir)
2362 .output()
2363 .ok()
2364 .filter(|out| out.status.success())
2365 .map(|out| !out.stdout.is_empty());
2366
2367 (branch, dirty)
2368}
2369
2370#[cfg(test)]
2371impl Default for RuntimeServices {
2372 fn default() -> Self {
2373 let snapshot = ResourceSnapshot::empty();
2374 Self {
2375 resources: Arc::new(ResourceHandle::new(
2376 snapshot,
2377 Arc::new(Hooks::default()),
2378 Vec::new(),
2379 )),
2380 registered_hooks: Arc::new(RegisteredHooks::default()),
2381 session_hook_store: Arc::new(Mutex::new(HashMap::new())),
2382 models: Arc::new(ModelRegistry::new()),
2383 tools: Arc::new(ToolRuntime::new()),
2384 path_locks: Arc::new(PathLockMap::default()),
2385 tool_sessions: Arc::new(ToolSessionStore::default()),
2386 sessions: Arc::new(halter_session::InMemorySessionStore::default()),
2387 policy: Arc::new(halter_tools::DefaultToolPolicy::new(Default::default())),
2388 prompt_assembler: Arc::new(crate::DefaultPromptAssembler),
2389 context_manager: Arc::new(DefaultContextManager::default()),
2390 event_bus: Arc::new(EventBus::default()),
2391 parent_streams: Arc::new(ParentStreamRegistry::default()),
2392 turn_registry: Arc::new(TurnRegistry::new()),
2393 subagent_event_forwarding: SubagentEventForwarding::Off,
2394 subagent_event_forwarding_cap: 100_000,
2395 shell_timeout_secs: 30,
2396 trace_recorder: None,
2397 }
2398 }
2399}
2400
2401#[cfg(test)]
2402mod tests {
2403 use std::sync::{Arc, Mutex};
2404 use std::time::Duration;
2405
2406 use async_trait::async_trait;
2407 use futures::stream::{self, BoxStream};
2408 use futures::{StreamExt, TryStreamExt};
2409 use halter_hooks::{
2410 Hook, HookEventName, HookResponse, HooksFile, RegisteredHookPriority, RegisteredHooks,
2411 };
2412 use halter_protocol::{
2413 ApiKind, BlockId, HookHandlerType, HookRunStatus, HookSessionStartSource, Message, ModelId,
2414 ModelRole, PluginId, ProviderCapabilities, ProviderError, ProviderKind, ProviderName,
2415 ProviderRequest, ResolvedModel, StopReason, StreamEvent, ToolCallId, ToolCapabilities,
2416 ToolConcurrency, ToolName, ToolResult, ToolSpec, Turn,
2417 };
2418 use halter_providers::{FakeProvider, Provider};
2419 use halter_tools::{
2420 DefaultToolPolicy, PolicySettings, Tool, ToolContext, register_builtin_tools,
2421 register_subagent_tools,
2422 };
2423 use serde_json::json;
2424 use tokio::sync::Notify;
2425
2426 use super::*;
2427 use test_support::{
2428 configured_services, empty_hooks, install_file_hooks, new_session, resolved_test_model,
2429 };
2430
2431 #[test]
2432 fn session_init_default_uses_embedded_system_prompt_seed() {
2433 let init = SessionInit::default();
2434
2435 assert_eq!(init.system_prompt_seed.len(), 1);
2436 assert_eq!(
2437 init.system_prompt_seed[0].text,
2438 crate::prompt::default_system_prompt_text()
2439 );
2440 }
2441
2442 #[test]
2451 fn dropping_temporary_session_handle_does_not_close_trace_writer() {
2452 use halter_protocol::{
2453 Delivery, ModelId, PendingEvent, Revision, SessionBlueprint, SessionEventPayload,
2454 SessionId, SubagentEventForwarding,
2455 };
2456
2457 let temp = tempfile::tempdir().expect("tempdir");
2458 let recorder =
2459 Arc::new(crate::TraceRecorder::open(temp.path().to_path_buf()).expect("recorder"));
2460 let services = RuntimeServices {
2461 trace_recorder: Some(recorder.clone()),
2462 ..RuntimeServices::default()
2463 };
2464 let services = Arc::new(services);
2465
2466 let session_id = SessionId::from("regression-trace");
2467 let blueprint = SessionBlueprint {
2468 session_id: session_id.clone(),
2469 parent_session_id: None,
2470 default_model: ModelId::from("default"),
2471 subagent_model: ModelId::from("subagent"),
2472 subagent_event_forwarding: SubagentEventForwarding::Off,
2473 snapshot_revision: Revision::from("rev-1".to_owned()),
2474 working_dir: temp.path().to_path_buf(),
2475 system_prompt_seed: Vec::new(),
2476 max_turns: None,
2477 subagent_depth: 0,
2478 };
2479 recorder
2480 .open_session(&session_id, None, &blueprint)
2481 .expect("open session");
2482
2483 let primary =
2488 HalterSession::new(services.clone(), session_id.clone()).expect("primary handle");
2489 {
2490 let _temporary =
2491 HalterSession::new(services.clone(), session_id.clone()).expect("temporary handle");
2492 } let pending = PendingEvent::new(
2498 session_id.clone(),
2499 Delivery::Lossless,
2500 SessionEventPayload::Warning {
2501 message: "post-drop".to_owned(),
2502 },
2503 );
2504 recorder.record(&pending.into_committed(1));
2505 drop(primary);
2507
2508 let path = temp.path().join(format!("{}.txt", session_id.0));
2509 let contents = std::fs::read_to_string(&path).expect("trace contents");
2510 assert!(
2511 contents.contains("post-drop"),
2512 "trace did not capture post-drop event:\n{contents}"
2513 );
2514 }
2515
2516 #[tokio::test]
2524 async fn materialize_handles_unterminated_tool_call_block() {
2525 let model = ResolvedModel {
2526 role: ModelRole::default(),
2527 id: ModelId::from("default"),
2528 provider: ProviderName::from("fake"),
2529 provider_kind: ProviderKind::Fake,
2530 api_kind: ApiKind::Fake,
2531 model: "halter/fake".to_owned(),
2532 max_input_tokens: Some(32_000),
2533 max_output_tokens: Some(4_096),
2534 reasoning: None,
2535 tokens_per_minute: None,
2536 };
2537 let message_id = halter_protocol::MessageId::new();
2538 let block_id = BlockId::new();
2539 let tool_call_id = ToolCallId::from("call-truncated");
2540 let stream: BoxStream<'static, Result<StreamEvent, ProviderError>> = stream::iter(vec![
2543 Ok(StreamEvent::MessageStart {
2544 id: message_id.clone(),
2545 }),
2546 Ok(StreamEvent::ToolCallStart {
2547 id: block_id.clone(),
2548 tool_call_id: tool_call_id.clone(),
2549 name: ToolName::from("write"),
2550 }),
2551 Ok(StreamEvent::ToolArgsDelta {
2552 id: block_id.clone(),
2553 delta: r#"{"path":"x.txt","content":"hi"}"#.to_owned(),
2554 }),
2555 Ok(StreamEvent::MessageEnd {
2556 id: message_id.clone(),
2557 stop_reason: StopReason::ToolUse,
2558 response_id: None,
2559 }),
2560 ])
2561 .boxed();
2562
2563 let materialized = super::materialize_assistant_message(stream, &model)
2564 .await
2565 .expect("materialize must recover from unterminated tool call");
2566 let tool_calls: Vec<_> = materialized
2567 .message
2568 .parts
2569 .iter()
2570 .filter_map(|part| match part {
2571 AssistantPart::ToolCall(call) => Some(call.clone()),
2572 _ => None,
2573 })
2574 .collect();
2575 assert_eq!(tool_calls.len(), 1, "one synthetic tool call expected");
2576 assert_eq!(tool_calls[0].id, tool_call_id);
2577 assert_eq!(tool_calls[0].name.0, "write");
2578 assert_eq!(
2579 tool_calls[0].arguments,
2580 serde_json::json!({"path": "x.txt", "content": "hi"})
2581 );
2582 }
2583
2584 #[tokio::test]
2589 async fn materialize_handles_unterminated_tool_call_with_invalid_json() {
2590 let model = ResolvedModel {
2591 role: ModelRole::default(),
2592 id: ModelId::from("default"),
2593 provider: ProviderName::from("fake"),
2594 provider_kind: ProviderKind::Fake,
2595 api_kind: ApiKind::Fake,
2596 model: "halter/fake".to_owned(),
2597 max_input_tokens: Some(32_000),
2598 max_output_tokens: Some(4_096),
2599 reasoning: None,
2600 tokens_per_minute: None,
2601 };
2602 let message_id = halter_protocol::MessageId::new();
2603 let block_id = BlockId::new();
2604 let tool_call_id = ToolCallId::from("call-bad-json");
2605 let stream: BoxStream<'static, Result<StreamEvent, ProviderError>> = stream::iter(vec![
2606 Ok(StreamEvent::MessageStart {
2607 id: message_id.clone(),
2608 }),
2609 Ok(StreamEvent::ToolCallStart {
2610 id: block_id.clone(),
2611 tool_call_id: tool_call_id.clone(),
2612 name: ToolName::from("write"),
2613 }),
2614 Ok(StreamEvent::ToolArgsDelta {
2615 id: block_id.clone(),
2616 delta: r#"{"path":"x.txt","content":"hel"#.to_owned(),
2618 }),
2619 Ok(StreamEvent::MessageEnd {
2620 id: message_id.clone(),
2621 stop_reason: StopReason::ToolUse,
2622 response_id: None,
2623 }),
2624 ])
2625 .boxed();
2626
2627 let materialized = super::materialize_assistant_message(stream, &model)
2628 .await
2629 .expect("materialize must recover even with unparsable arguments");
2630 let tool_calls: Vec<_> = materialized
2631 .message
2632 .parts
2633 .iter()
2634 .filter_map(|part| match part {
2635 AssistantPart::ToolCall(call) => Some(call.clone()),
2636 _ => None,
2637 })
2638 .collect();
2639 assert_eq!(tool_calls.len(), 1);
2640 assert_eq!(tool_calls[0].id, tool_call_id);
2641 assert_eq!(tool_calls[0].arguments, serde_json::json!({}));
2642 }
2643
2644 #[tokio::test]
2645 async fn materialize_rejects_oversized_provider_text() {
2646 let model = resolved_test_model("default", "fake", "halter/fake");
2647 let block_id = BlockId::new();
2648 let oversized = "x".repeat(PROVIDER_STREAM_OUTPUT_CAP_BYTES + 1);
2649 let stream: BoxStream<'static, Result<StreamEvent, ProviderError>> = stream::iter(vec![
2650 Ok(StreamEvent::MessageStart {
2651 id: halter_protocol::MessageId::new(),
2652 }),
2653 Ok(StreamEvent::TextStart {
2654 id: block_id.clone(),
2655 }),
2656 Ok(StreamEvent::TextDelta {
2657 id: block_id,
2658 delta: oversized,
2659 }),
2660 ])
2661 .boxed();
2662
2663 let error = super::materialize_assistant_message(stream, &model)
2664 .await
2665 .expect_err("oversized provider text should fail");
2666 assert!(error.to_string().contains("exceeded output cap"));
2667 }
2668
2669 #[test]
2670 fn batch_tool_calls_isolates_exclusive_tools() {
2671 use halter_protocol::ToolCallId;
2672
2673 let mut declared: HashMap<String, ToolConcurrency> = HashMap::new();
2674 for (name, concurrency) in [
2675 ("read_a", ToolConcurrency::ReadOnly),
2676 ("read_b", ToolConcurrency::ReadOnly),
2677 ("exclusive", ToolConcurrency::Exclusive),
2678 ("parallel", ToolConcurrency::ParallelSafe),
2679 ] {
2680 declared.insert(name.into(), concurrency);
2681 }
2682
2683 let mk = |tool: &str, id: &str| ToolCall {
2684 id: ToolCallId::from(id),
2685 name: ToolName::from(tool),
2686 arguments: serde_json::Value::Null,
2687 };
2688 let batches = super::batch_tool_calls_by_concurrency(
2689 |name| declared.get(&name.0).copied(),
2690 vec![
2691 mk("read_a", "1"),
2692 mk("read_b", "2"),
2693 mk("exclusive", "3"),
2694 mk("parallel", "4"),
2695 mk("read_a", "5"),
2696 ],
2697 );
2698 assert_eq!(batches.len(), 3, "three batches: [r,r], [excl], [p,r]");
2699 assert_eq!(batches[0].len(), 2);
2700 assert_eq!(batches[1].len(), 1);
2701 assert_eq!(batches[1][0].name.0, "exclusive");
2702 assert_eq!(batches[2].len(), 2);
2703 }
2704
2705 #[test]
2706 fn batch_tool_calls_treats_unknown_tools_as_exclusive() {
2707 use halter_protocol::ToolCallId;
2708
2709 let mk = |name: &str, id: &str| ToolCall {
2710 id: ToolCallId::from(id),
2711 name: ToolName::from(name),
2712 arguments: serde_json::Value::Null,
2713 };
2714 let batches = super::batch_tool_calls_by_concurrency(
2715 |_| None,
2716 vec![mk("mystery_a", "1"), mk("mystery_b", "2")],
2717 );
2718 assert_eq!(batches.len(), 2, "unknown tools must be exclusive");
2719 }
2720
2721 #[test]
2722 fn probe_git_returns_none_outside_working_tree() {
2723 let tmp = tempfile::tempdir().expect("tempdir");
2724 let (branch, dirty) = super::probe_git(tmp.path());
2725 assert_eq!(branch, None);
2726 assert_eq!(dirty, None);
2727 }
2728
2729 #[test]
2730 fn probe_git_reports_branch_and_dirty_flag() {
2731 let tmp = tempfile::tempdir().expect("tempdir");
2732 let root = tmp.path();
2733 let git = |args: &[&str]| {
2734 let status = std::process::Command::new("git")
2735 .args(args)
2736 .current_dir(root)
2737 .output()
2738 .expect("git command");
2739 assert!(status.status.success(), "git {:?} failed", args);
2740 };
2741
2742 git(&["init", "--initial-branch=trunk"]);
2743 git(&["config", "user.email", "test@example.com"]);
2744 git(&["config", "user.name", "Test User"]);
2745 std::fs::write(root.join("seed.txt"), b"initial").expect("write seed");
2746 git(&["add", "seed.txt"]);
2747 git(&["commit", "-m", "seed"]);
2748
2749 let (branch, dirty) = super::probe_git(root);
2750 assert_eq!(branch.as_deref(), Some("trunk"));
2751 assert_eq!(dirty, Some(false));
2752
2753 std::fs::write(root.join("dirty.txt"), b"unstaged").expect("write dirty");
2754 let (branch, dirty) = super::probe_git(root);
2755 assert_eq!(branch.as_deref(), Some("trunk"));
2756 assert_eq!(dirty, Some(true));
2757 }
2758
2759 mod test_support {
2760 use std::path::Path;
2761 use std::sync::Arc;
2762
2763 use halter_hooks::{HookRegistrySource, Hooks, HooksFile};
2764 use halter_protocol::{
2765 ApiKind, ModelId, ModelRole, PluginId, ProviderKind, ProviderName, ResolvedModel,
2766 ResourceSnapshot, SubagentEventForwarding,
2767 };
2768 use halter_providers::Provider;
2769 use halter_tools::{DefaultToolPolicy, PolicySettings};
2770
2771 use super::{HalterSession, ModelRegistry, RuntimeServices, SessionInit, SessionRuntime};
2772
2773 pub(super) fn configured_services(
2774 provider: Arc<dyn Provider>,
2775 working_dir: &Path,
2776 ) -> Arc<RuntimeServices> {
2777 configured_services_with_runtime(
2778 provider,
2779 working_dir,
2780 SubagentEventForwarding::Off,
2781 100_000,
2782 )
2783 }
2784
2785 pub(super) fn configured_services_with_runtime(
2786 provider: Arc<dyn Provider>,
2787 working_dir: &Path,
2788 subagent_event_forwarding: SubagentEventForwarding,
2789 subagent_event_forwarding_cap: u64,
2790 ) -> Arc<RuntimeServices> {
2791 configured_services_with_runtime_and_trace(
2792 provider,
2793 working_dir,
2794 subagent_event_forwarding,
2795 subagent_event_forwarding_cap,
2796 None,
2797 )
2798 }
2799
2800 pub(super) fn configured_services_with_runtime_and_trace(
2801 provider: Arc<dyn Provider>,
2802 working_dir: &Path,
2803 subagent_event_forwarding: SubagentEventForwarding,
2804 subagent_event_forwarding_cap: u64,
2805 trace_recorder: Option<Arc<crate::TraceRecorder>>,
2806 ) -> Arc<RuntimeServices> {
2807 let mut services = RuntimeServices::default();
2808 let mut models = ModelRegistry::new();
2809 models.set_default_model(ResolvedModel {
2810 role: ModelRole::default(),
2811 id: ModelId::from("default"),
2812 provider: ProviderName::from("fake"),
2813 provider_kind: ProviderKind::Fake,
2814 api_kind: ApiKind::Fake,
2815 model: "halter/fake".to_owned(),
2816 max_input_tokens: Some(32_000),
2817 max_output_tokens: Some(4_096),
2818 reasoning: None,
2819 tokens_per_minute: None,
2820 });
2821 models.set_subagent_model(ResolvedModel {
2822 role: ModelRole::subagent(),
2823 id: ModelId::from("subagent"),
2824 provider: ProviderName::from("fake"),
2825 provider_kind: ProviderKind::Fake,
2826 api_kind: ApiKind::Fake,
2827 model: "halter/fake".to_owned(),
2828 max_input_tokens: Some(32_000),
2829 max_output_tokens: Some(4_096),
2830 reasoning: None,
2831 tokens_per_minute: None,
2832 });
2833 models.set_small_model(ResolvedModel {
2834 role: ModelRole::small(),
2835 id: ModelId::from("small"),
2836 provider: ProviderName::from("fake"),
2837 provider_kind: ProviderKind::Fake,
2838 api_kind: ApiKind::Fake,
2839 model: "halter/fake-small".to_owned(),
2840 max_input_tokens: Some(32_000),
2841 max_output_tokens: Some(4_096),
2842 reasoning: None,
2843 tokens_per_minute: None,
2844 });
2845 models.register_provider(ProviderName::from("fake"), provider);
2846 services.models = Arc::new(models);
2847 services.policy = Arc::new(DefaultToolPolicy::new(PolicySettings {
2848 allowed_write_roots: vec![working_dir.to_path_buf()],
2849 ..PolicySettings::default()
2850 }));
2851 services.subagent_event_forwarding = subagent_event_forwarding;
2852 services.subagent_event_forwarding_cap = subagent_event_forwarding_cap;
2853 services.trace_recorder = trace_recorder;
2854 Arc::new(services)
2855 }
2856
2857 pub(super) async fn new_session(
2858 runtime: &SessionRuntime,
2859 working_dir: &Path,
2860 ) -> HalterSession {
2861 runtime
2862 .new_session(SessionInit {
2863 working_dir: working_dir.to_path_buf(),
2864 ..SessionInit::default()
2865 })
2866 .await
2867 .expect("session")
2868 }
2869
2870 pub(super) fn install_file_hooks(
2871 services: &Arc<RuntimeServices>,
2872 working_dir: &Path,
2873 hooks_file: HooksFile,
2874 ) {
2875 services.resources.replace(
2876 ResourceSnapshot::empty(),
2877 Arc::new(Hooks::from_sources(vec![HookRegistrySource {
2878 plugin_id: PluginId::from("test-plugin"),
2879 plugin_root: working_dir.to_path_buf(),
2880 source_path: working_dir.join("hooks/hooks.json"),
2881 allowed_http_hosts: Vec::new(),
2882 allowed_env_vars: Vec::new(),
2883 file: hooks_file,
2884 }])),
2885 Vec::new(),
2886 );
2887 }
2888
2889 pub(super) fn empty_hooks() -> Arc<Hooks> {
2890 Arc::new(Hooks::default())
2891 }
2892
2893 pub(super) fn resolved_test_model(id: &str, provider: &str, model: &str) -> ResolvedModel {
2894 ResolvedModel {
2895 role: if id == "subagent" {
2896 ModelRole::subagent()
2897 } else {
2898 ModelRole::default()
2899 },
2900 id: ModelId::from(id),
2901 provider: ProviderName::from(provider),
2902 provider_kind: ProviderKind::Fake,
2903 api_kind: ApiKind::Fake,
2904 model: model.to_owned(),
2905 max_input_tokens: Some(32_000),
2906 max_output_tokens: Some(4_096),
2907 reasoning: None,
2908 tokens_per_minute: None,
2909 }
2910 }
2911 }
2912
2913 #[tokio::test]
2914 async fn fake_provider_turn_produces_canonical_events() {
2915 let mut services = RuntimeServices::default();
2916 let mut models = ModelRegistry::new();
2917 models.set_default_model(ResolvedModel {
2918 role: ModelRole::default(),
2919 id: ModelId::from("default"),
2920 provider: ProviderName::from("fake"),
2921 provider_kind: ProviderKind::Fake,
2922 api_kind: ApiKind::Fake,
2923 model: "halter/fake".to_owned(),
2924 max_input_tokens: Some(32_000),
2925 max_output_tokens: Some(4_096),
2926 reasoning: None,
2927 tokens_per_minute: None,
2928 });
2929 models.set_subagent_model(ResolvedModel {
2930 role: ModelRole::subagent(),
2931 id: ModelId::from("subagent"),
2932 provider: ProviderName::from("fake"),
2933 provider_kind: ProviderKind::Fake,
2934 api_kind: ApiKind::Fake,
2935 model: "halter/fake".to_owned(),
2936 max_input_tokens: Some(32_000),
2937 max_output_tokens: Some(4_096),
2938 reasoning: None,
2939 tokens_per_minute: None,
2940 });
2941 models.register_provider(
2942 ProviderName::from("fake"),
2943 Arc::new(FakeProvider::default()),
2944 );
2945 services.models = Arc::new(models);
2946
2947 let runtime = SessionRuntime::new(Arc::new(services));
2948 let session = runtime
2949 .new_session(SessionInit::default())
2950 .await
2951 .expect("session");
2952 let events = session
2953 .submit_turn(Turn::user("hello runtime"))
2954 .await
2955 .expect("submit turn")
2956 .try_collect::<Vec<_>>()
2957 .await
2958 .expect("collect events");
2959
2960 assert!(
2961 events
2962 .iter()
2963 .any(|event| matches!(event.payload, SessionEventPayload::TurnCompleted { .. }))
2964 );
2965 }
2966
2967 #[tokio::test]
2968 async fn submit_turn_executes_tool_calls_until_completion() {
2969 let temp = tempfile::tempdir().expect("tempdir");
2970 let services = configured_services(Arc::new(ToolLoopProvider), temp.path());
2971 register_builtin_tools(&services.tools, &[]);
2972 let runtime = SessionRuntime::new(services.clone());
2973 let session = new_session(&runtime, temp.path()).await;
2974
2975 let events = session
2976 .submit_turn(Turn::user("write a note"))
2977 .await
2978 .expect("submit turn")
2979 .try_collect::<Vec<_>>()
2980 .await
2981 .expect("collect events");
2982
2983 assert!(temp.path().join("note.txt").exists());
2984 assert!(events.iter().any(|event| matches!(
2985 event.payload,
2986 SessionEventPayload::ToolExecutionStarted { .. }
2987 )));
2988 assert!(events.iter().any(|event| matches!(
2989 event.payload,
2990 SessionEventPayload::ToolExecutionCompleted { .. }
2991 )));
2992 assert!(events.iter().any(|event| matches!(
2993 &event.payload,
2994 SessionEventPayload::MessageItem {
2995 message: Message::Assistant(assistant),
2996 } if assistant.parts.iter().any(|part| matches!(
2997 part,
2998 AssistantPart::Text { text } if text.contains("tool completed")
2999 ))
3000 )));
3001 }
3002
3003 #[tokio::test]
3004 async fn submit_turn_dedupes_duplicate_tool_call_ids() {
3005 let temp = tempfile::tempdir().expect("tempdir");
3006 let executions = Arc::new(Mutex::new(0usize));
3007 let services = configured_services(Arc::new(DuplicateToolCallProvider), temp.path());
3008 services
3009 .tools
3010 .register(Arc::new(CountingTool::new(executions.clone())));
3011 let runtime = SessionRuntime::new(services.clone());
3012 let session = new_session(&runtime, temp.path()).await;
3013
3014 let events = session
3015 .submit_turn(Turn::user("dedupe tool calls"))
3016 .await
3017 .expect("submit turn")
3018 .try_collect::<Vec<_>>()
3019 .await
3020 .expect("collect events");
3021
3022 assert_eq!(*executions.lock().expect("executions"), 1);
3023 assert!(events.iter().any(|event| matches!(
3024 &event.payload,
3025 SessionEventPayload::MessageItem {
3026 message: Message::Assistant(assistant),
3027 } if assistant.parts.iter().filter(|part| matches!(part, AssistantPart::ToolCall(_))).count() == 1
3028 )));
3029 }
3030
3031 #[tokio::test]
3032 async fn pre_tool_use_hook_can_block_tool_execution() {
3033 let temp = tempfile::tempdir().expect("tempdir");
3034 let services = configured_services(Arc::new(ToolLoopProvider), temp.path());
3035 register_builtin_tools(&services.tools, &[]);
3036 let (hooks_file, warnings) = HooksFile::from_json_bytes(
3037 br#"{
3038 "hooks": {
3039 "PreToolUse": [
3040 {
3041 "matcher": "write",
3042 "hooks": [
3043 {
3044 "type": "command",
3045 "command": "echo blocked by hook >&2; exit 2"
3046 }
3047 ]
3048 }
3049 ]
3050 }
3051 }"#,
3052 )
3053 .expect("parse hooks");
3054 assert!(warnings.is_empty());
3055 install_file_hooks(&services, temp.path(), hooks_file);
3056
3057 let runtime = SessionRuntime::new(services.clone());
3058 let session = new_session(&runtime, temp.path()).await;
3059
3060 let events = session
3061 .submit_turn(Turn::user("write a note"))
3062 .await
3063 .expect("submit turn")
3064 .try_collect::<Vec<_>>()
3065 .await
3066 .expect("collect events");
3067
3068 assert!(!temp.path().join("note.txt").exists());
3069 assert!(
3070 events
3071 .iter()
3072 .any(|event| matches!(event.payload, SessionEventPayload::HookStarted { .. }))
3073 );
3074 assert!(events.iter().any(|event| matches!(
3075 &event.payload,
3076 SessionEventPayload::HookCompleted { run } if run.status == HookRunStatus::Blocked
3077 )));
3078 assert!(events.iter().any(|event| matches!(
3079 &event.payload,
3080 SessionEventPayload::MessageItem {
3081 message: Message::Tool(tool),
3082 } if tool
3083 .error
3084 .as_ref()
3085 .is_some_and(|error| error.message.contains("blocked by hook"))
3086 )));
3087
3088 let stored = services
3091 .sessions
3092 .load_session(session.session_id())
3093 .await
3094 .expect("load session")
3095 .expect("session present");
3096 assert!(
3097 stored.state.pending_tool_calls.is_empty(),
3098 "blocked-tool path left pending_tool_calls populated: {} entries",
3099 stored.state.pending_tool_calls.len()
3100 );
3101 }
3102
3103 #[tokio::test]
3104 async fn ac2_1_dropping_a_clone_does_not_evict_hooks_from_the_original_handle() {
3105 let temp = tempfile::tempdir().expect("tempdir");
3106 let services = configured_services(Arc::new(ToolLoopProvider), temp.path());
3107 let runtime = SessionRuntime::new(services.clone());
3108 let session = new_session(&runtime, temp.path()).await;
3109 let session_id = session.session_id().clone();
3110
3111 {
3113 let store = services
3114 .session_hook_store
3115 .lock()
3116 .expect("session hook store lock");
3117 assert!(
3118 store.contains_key(&session_id),
3119 "session hooks should be registered after new_session"
3120 );
3121 }
3122
3123 let cloned = session.clone();
3128 drop(cloned);
3129
3130 {
3131 let store = services
3132 .session_hook_store
3133 .lock()
3134 .expect("session hook store lock");
3135 assert!(
3136 store.contains_key(&session_id),
3137 "dropping a clone evicted hooks; the original handle must keep them alive"
3138 );
3139 }
3140
3141 drop(session);
3142
3143 {
3144 let store = services
3145 .session_hook_store
3146 .lock()
3147 .expect("session hook store lock");
3148 assert!(
3149 !store.contains_key(&session_id),
3150 "session hooks should be evicted once the last handle is dropped"
3151 );
3152 }
3153 }
3154
3155 #[tokio::test]
3156 async fn sdk_callback_hook_can_block_tool_execution() {
3157 let temp = tempfile::tempdir().expect("tempdir");
3158 let mut services = configured_services(Arc::new(ToolLoopProvider), temp.path());
3159 register_builtin_tools(&services.tools, &[]);
3160
3161 let mut registered = RegisteredHooks::default();
3162 registered.register(
3163 PluginId::from("internal"),
3164 RegisteredHookPriority::AfterPlugins,
3165 Hook::callback(HookEventName::PreToolUse, |input| async move {
3166 if input.tool_name() == Some("write") {
3167 HookResponse::block("blocked by callback hook")
3168 } else {
3169 HookResponse::passthrough()
3170 }
3171 }),
3172 );
3173 Arc::get_mut(&mut services)
3174 .expect("unique services")
3175 .registered_hooks = Arc::new(registered);
3176
3177 let runtime = SessionRuntime::new(services.clone());
3178 let session = new_session(&runtime, temp.path()).await;
3179
3180 let events = session
3181 .submit_turn(Turn::user("write a note"))
3182 .await
3183 .expect("submit turn")
3184 .try_collect::<Vec<_>>()
3185 .await
3186 .expect("collect events");
3187
3188 assert!(!temp.path().join("note.txt").exists());
3189 assert!(events.iter().any(|event| matches!(
3190 &event.payload,
3191 SessionEventPayload::HookCompleted { run }
3192 if run.status == HookRunStatus::Blocked
3193 && run.handler_type == HookHandlerType::Callback
3194 )));
3195 assert!(events.iter().any(|event| matches!(
3196 &event.payload,
3197 SessionEventPayload::MessageItem {
3198 message: Message::Tool(tool),
3199 } if tool
3200 .error
3201 .as_ref()
3202 .is_some_and(|error| error.message.contains("blocked by callback hook"))
3203 )));
3204 }
3205
3206 #[tokio::test]
3207 async fn prompt_hook_uses_small_model_and_blocks_turn() {
3208 let temp = tempfile::tempdir().expect("tempdir");
3209 let requests = Arc::new(Mutex::new(Vec::<ProviderRequest>::new()));
3210 let mut services = RuntimeServices::default();
3211 let mut models = ModelRegistry::new();
3212 models.set_default_model(resolved_test_model("default", "fake", "default/model"));
3213 models.set_small_model(ResolvedModel {
3214 role: ModelRole::small(),
3215 id: ModelId::from("small"),
3216 provider: ProviderName::from("fake"),
3217 provider_kind: ProviderKind::Fake,
3218 api_kind: ApiKind::Fake,
3219 model: "small/model".to_owned(),
3220 max_input_tokens: Some(32_000),
3221 max_output_tokens: Some(4_096),
3222 reasoning: None,
3223 tokens_per_minute: None,
3224 });
3225 models.set_subagent_model(resolved_test_model("subagent", "fake", "subagent/model"));
3226 models.register_provider(
3227 ProviderName::from("fake"),
3228 Arc::new(JsonHookProvider::new(requests.clone())),
3229 );
3230 services.models = Arc::new(models);
3231 services.policy = Arc::new(DefaultToolPolicy::new(PolicySettings {
3232 allowed_write_roots: vec![temp.path().to_path_buf()],
3233 ..PolicySettings::default()
3234 }));
3235
3236 let (hooks_file, warnings) = HooksFile::from_json_bytes(
3237 br#"{
3238 "hooks": {
3239 "UserPromptSubmit": [
3240 {
3241 "hooks": [
3242 {
3243 "type": "prompt",
3244 "prompt": "HOOK_PROMPT $ARGUMENTS"
3245 }
3246 ]
3247 }
3248 ]
3249 }
3250 }"#,
3251 )
3252 .expect("parse hooks");
3253 assert!(warnings.is_empty());
3254 let services = Arc::new(services);
3255 install_file_hooks(&services, temp.path(), hooks_file);
3256
3257 let runtime = SessionRuntime::new(services);
3258 let session = new_session(&runtime, temp.path()).await;
3259
3260 let events = session
3261 .submit_turn(Turn::user("blocked prompt"))
3262 .await
3263 .expect("submit turn")
3264 .try_collect::<Vec<_>>()
3265 .await
3266 .expect("collect events");
3267
3268 let requests = requests.lock().expect("requests");
3269 assert_eq!(requests.len(), 1);
3270 assert_eq!(requests[0].model.id, ModelId::from("small"));
3271 assert!(events.iter().any(|event| matches!(
3272 &event.payload,
3273 SessionEventPayload::HookCompleted { run } if run.status == HookRunStatus::Blocked
3274 )));
3275 assert!(events.iter().any(|event| matches!(
3276 &event.payload,
3277 SessionEventPayload::MessageItem {
3278 message: Message::System(system),
3279 } if system.text.contains("blocked by prompt hook")
3280 )));
3281 }
3282
3283 #[tokio::test]
3284 async fn sdk_function_hooks_keep_state_per_session() {
3285 let temp = tempfile::tempdir().expect("tempdir");
3286 let mut services = configured_services(Arc::new(FakeProvider::default()), temp.path());
3287 let mut registered = RegisteredHooks::default();
3288 registered.register(
3289 PluginId::from("internal"),
3290 RegisteredHookPriority::AfterPlugins,
3291 Hook::function(HookEventName::UserPromptSubmit, || {
3292 let seen = Arc::new(Mutex::new(0usize));
3293 move |_input| {
3294 let seen = seen.clone();
3295 async move {
3296 let mut seen = seen.lock().expect("seen");
3297 *seen += 1;
3298 HookResponse::passthrough()
3299 .with_system_message(format!("function count {}", *seen))
3300 }
3301 }
3302 }),
3303 );
3304 Arc::get_mut(&mut services)
3305 .expect("unique services")
3306 .registered_hooks = Arc::new(registered);
3307
3308 let runtime = SessionRuntime::new(services.clone());
3309 let session_a = new_session(&runtime, temp.path()).await;
3310 let session_b = new_session(&runtime, temp.path()).await;
3311
3312 let events_a1 = session_a
3313 .submit_turn(Turn::user("first"))
3314 .await
3315 .expect("submit turn")
3316 .try_collect::<Vec<_>>()
3317 .await
3318 .expect("collect events");
3319 let events_a2 = session_a
3320 .submit_turn(Turn::user("second"))
3321 .await
3322 .expect("submit turn")
3323 .try_collect::<Vec<_>>()
3324 .await
3325 .expect("collect events");
3326 let events_b1 = session_b
3327 .submit_turn(Turn::user("third"))
3328 .await
3329 .expect("submit turn")
3330 .try_collect::<Vec<_>>()
3331 .await
3332 .expect("collect events");
3333
3334 assert!(events_a1.iter().any(|event| matches!(
3335 &event.payload,
3336 SessionEventPayload::HookCompleted { run }
3337 if run.handler_type == HookHandlerType::Function
3338 )));
3339 assert!(events_a1.iter().any(|event| matches!(
3340 &event.payload,
3341 SessionEventPayload::MessageItem {
3342 message: Message::System(system),
3343 } if system.text.contains("function count 1")
3344 )));
3345 assert!(events_a2.iter().any(|event| matches!(
3346 &event.payload,
3347 SessionEventPayload::MessageItem {
3348 message: Message::System(system),
3349 } if system.text.contains("function count 2")
3350 )));
3351 assert!(events_b1.iter().any(|event| matches!(
3352 &event.payload,
3353 SessionEventPayload::MessageItem {
3354 message: Message::System(system),
3355 } if system.text.contains("function count 1")
3356 )));
3357 }
3358
3359 #[tokio::test]
3360 async fn hook_warnings_emit_warning_events_on_next_turn() {
3361 let temp = tempfile::tempdir().expect("tempdir");
3362 let services = configured_services(Arc::new(FakeProvider::default()), temp.path());
3363 services.resources.replace(
3364 ResourceSnapshot::empty(),
3365 empty_hooks(),
3366 vec![halter_protocol::HookWarning {
3367 category: "test".to_owned(),
3368 message: "hook warning".to_owned(),
3369 ..halter_protocol::HookWarning::default()
3370 }],
3371 );
3372 let runtime = SessionRuntime::new(services);
3373 let session = new_session(&runtime, temp.path()).await;
3374
3375 let events = session
3376 .submit_turn(Turn::user("hello"))
3377 .await
3378 .expect("submit turn")
3379 .try_collect::<Vec<_>>()
3380 .await
3381 .expect("collect events");
3382
3383 assert!(events.iter().any(|event| matches!(
3384 &event.payload,
3385 SessionEventPayload::Warning { message } if message.contains("hook warning")
3386 )));
3387 }
3388
3389 #[tokio::test]
3390 async fn hook_events_delivered_after_turn_commits() {
3391 let temp = tempfile::tempdir().expect("tempdir");
3395 let mut services = RuntimeServices::default();
3396 let mut hooks = RegisteredHooks::default();
3397 hooks.register(
3398 PluginId::from("test-plugin"),
3399 RegisteredHookPriority::AfterPlugins,
3400 Hook::callback(HookEventName::UserPromptSubmit, move |_input| async move {
3401 HookResponse::passthrough()
3402 }),
3403 );
3404 services.registered_hooks = Arc::new(hooks);
3405 let mut models = ModelRegistry::new();
3406 models.set_default_model(resolved_test_model("default", "fake", "default/model"));
3407 models.set_subagent_model(resolved_test_model("subagent", "fake", "subagent/model"));
3408 models.register_provider(
3409 ProviderName::from("fake"),
3410 Arc::new(FakeProvider::default()),
3411 );
3412 services.models = Arc::new(models);
3413 services.policy = Arc::new(DefaultToolPolicy::new(PolicySettings {
3414 allowed_write_roots: vec![temp.path().to_path_buf()],
3415 ..PolicySettings::default()
3416 }));
3417
3418 let runtime = SessionRuntime::new(Arc::new(services));
3419 let session = new_session(&runtime, temp.path()).await;
3420
3421 let events = session
3422 .submit_turn(Turn::user("hello"))
3423 .await
3424 .expect("submit turn")
3425 .try_collect::<Vec<_>>()
3426 .await
3427 .expect("collect events");
3428
3429 assert!(
3430 events
3431 .iter()
3432 .any(|event| matches!(event.payload, SessionEventPayload::HookStarted { .. }))
3433 );
3434 assert!(
3435 events
3436 .iter()
3437 .any(|event| matches!(event.payload, SessionEventPayload::HookCompleted { .. }))
3438 );
3439 }
3440
3441 #[tokio::test]
3442 async fn notify_runs_notification_hooks() {
3443 let temp = tempfile::tempdir().expect("tempdir");
3444 let services = configured_services(Arc::new(FakeProvider::default()), temp.path());
3445 let (hooks_file, warnings) = HooksFile::from_json_bytes(
3446 br#"{
3447 "hooks": {
3448 "Notification": [
3449 {
3450 "matcher": "policy",
3451 "hooks": [
3452 {
3453 "type": "command",
3454 "command": "printf '{\"systemMessage\":\"notification seen\"}'"
3455 }
3456 ]
3457 }
3458 ]
3459 }
3460 }"#,
3461 )
3462 .expect("parse hooks");
3463 assert!(warnings.is_empty());
3464 install_file_hooks(&services, temp.path(), hooks_file);
3465 let runtime = SessionRuntime::new(services.clone());
3466 let session = new_session(&runtime, temp.path()).await;
3467
3468 session
3469 .notify("policy", "denied")
3470 .await
3471 .expect("notify succeeds");
3472
3473 let replay = session.replay().await.expect("replay");
3474 assert!(replay.iter().any(|event| matches!(
3475 &event.payload,
3476 SessionEventPayload::HookCompleted { run } if run.event_name == "Notification"
3477 )));
3478 }
3479
3480 #[tokio::test]
3481 async fn compact_summarizes_older_messages_and_sets_session_start_latch() {
3482 let temp = tempfile::tempdir().expect("tempdir");
3483 let services = configured_services(Arc::new(FakeProvider::default()), temp.path());
3484 let runtime = SessionRuntime::new(services.clone());
3485 let session = new_session(&runtime, temp.path()).await;
3486
3487 let mut stored = services
3488 .sessions
3489 .load_session(session.session_id())
3490 .await
3491 .expect("load session")
3492 .expect("session exists");
3493 stored.state.messages = (0..12)
3494 .map(|index| Message::User(halter_protocol::UserMessage::text(format!("msg {index}"))))
3495 .collect();
3496 let _ = services
3497 .sessions
3498 .commit(
3499 session.session_id(),
3500 None,
3501 None,
3502 Some(stored.state),
3503 Vec::new(),
3504 )
3505 .await
3506 .expect("commit state");
3507
3508 session
3509 .compact("manual", Some("Focus on decisions"))
3510 .await
3511 .expect("compact succeeds");
3512
3513 let stored = services
3514 .sessions
3515 .load_session(session.session_id())
3516 .await
3517 .expect("load compacted session")
3518 .expect("session exists");
3519 assert!(!stored.state.compacted_prefix.is_empty());
3520 assert!(stored.state.messages.is_empty());
3521 assert_eq!(
3522 stored.state.pending_session_start_source,
3523 Some(HookSessionStartSource::Compact)
3524 );
3525
3526 let replay = session.replay().await.expect("replay");
3527 assert!(
3528 replay
3529 .iter()
3530 .any(|event| matches!(event.payload, SessionEventPayload::ContextCompacted { .. }))
3531 );
3532 }
3533
3534 #[tokio::test]
3535 async fn submit_turn_compacts_immediately_after_response_when_threshold_is_reached() {
3536 let temp = tempfile::tempdir().expect("tempdir");
3537 let mut services = configured_services(Arc::new(FakeProvider::default()), temp.path());
3538 Arc::get_mut(&mut services)
3539 .expect("unique services")
3540 .context_manager = Arc::new(DefaultContextManager::new(
3541 150,
3542 0,
3543 halter_protocol::PruneSignalThreshold::VeryLow,
3544 ));
3545 let runtime = SessionRuntime::new(services.clone());
3546 let session = new_session(&runtime, temp.path()).await;
3547
3548 session
3549 .submit_turn(Turn::user("x".repeat(150)))
3550 .await
3551 .expect("submit turn")
3552 .try_collect::<Vec<_>>()
3553 .await
3554 .expect("collect events");
3555
3556 let stored = services
3557 .sessions
3558 .load_session(session.session_id())
3559 .await
3560 .expect("load session")
3561 .expect("session exists");
3562 assert!(!stored.state.compacted_prefix.is_empty());
3563 assert_eq!(stored.state.messages.len(), 1);
3564 assert!(matches!(stored.state.messages[0], Message::Assistant(_)));
3565 }
3566
3567 #[tokio::test]
3568 async fn submit_turn_failure_preserves_valid_transcript() {
3569 let temp = tempfile::tempdir().expect("tempdir");
3570 let services = configured_services(Arc::new(FailingProvider), temp.path());
3571 let runtime = SessionRuntime::new(services.clone());
3572 let session = new_session(&runtime, temp.path()).await;
3573
3574 let events = session
3575 .submit_turn(Turn::user("will fail"))
3576 .await
3577 .expect("submit turn")
3578 .try_collect::<Vec<_>>()
3579 .await
3580 .expect("collect events");
3581
3582 let stored = services
3583 .sessions
3584 .load_session(session.session_id())
3585 .await
3586 .expect("load session")
3587 .expect("session exists");
3588 assert_eq!(stored.state.messages.len(), 1);
3589 assert!(matches!(
3590 &stored.state.messages[0],
3591 Message::User(user) if user.plain_text() == "will fail"
3592 ));
3593 assert!(
3594 events
3595 .iter()
3596 .any(|event| matches!(event.payload, SessionEventPayload::TurnFailed { .. }))
3597 );
3598
3599 let events = services
3600 .sessions
3601 .replay(session.session_id())
3602 .await
3603 .expect("replay");
3604 assert!(
3605 events
3606 .iter()
3607 .any(|event| matches!(event.payload, SessionEventPayload::TurnFailed { .. }))
3608 );
3609 assert!(events.iter().any(|event| matches!(
3610 &event.payload,
3611 SessionEventPayload::MessageItem {
3612 message: Message::User(user),
3613 } if user.plain_text() == "will fail"
3614 )));
3615 }
3616
3617 #[tokio::test]
3618 async fn later_turns_commit_latest_resource_snapshot() {
3619 let temp = tempfile::tempdir().expect("tempdir");
3620 let services = configured_services(Arc::new(FakeProvider::default()), temp.path());
3621 let runtime = SessionRuntime::new(services.clone());
3622 let session = new_session(&runtime, temp.path()).await;
3623
3624 let mut reloaded = ResourceSnapshot::empty();
3625 reloaded.revision = halter_protocol::Revision::from("reloaded");
3626 runtime.replace_resources(reloaded, empty_hooks(), Vec::new());
3627
3628 session
3629 .submit_turn(Turn::user("after reload"))
3630 .await
3631 .expect("submit turn")
3632 .try_collect::<Vec<_>>()
3633 .await
3634 .expect("collect events");
3635
3636 let stored = services
3637 .sessions
3638 .load_session(session.session_id())
3639 .await
3640 .expect("load session")
3641 .expect("session exists");
3642 assert_eq!(stored.snapshot.revision.0, "reloaded");
3643 }
3644
3645 #[tokio::test]
3646 async fn session_init_can_override_subagent_model() {
3647 let temp = tempfile::tempdir().expect("tempdir");
3648 let services = configured_services(Arc::new(FakeProvider::default()), temp.path());
3649 let runtime = SessionRuntime::new(services.clone());
3650 let session = runtime
3651 .new_session(SessionInit::default().with_subagent_model("default"))
3652 .await
3653 .expect("session");
3654
3655 let stored = services
3656 .sessions
3657 .load_session(session.session_id())
3658 .await
3659 .expect("load session")
3660 .expect("session exists");
3661
3662 assert_eq!(stored.blueprint.default_model, ModelId::from("default"));
3663 assert_eq!(stored.blueprint.subagent_model, ModelId::from("default"));
3664 }
3665
3666 #[tokio::test]
3667 async fn turn_default_model_override_selects_overridden_provider() {
3668 let temp = tempfile::tempdir().expect("tempdir");
3669 let default_requests = Arc::new(Mutex::new(Vec::<ProviderRequest>::new()));
3670 let subagent_requests = Arc::new(Mutex::new(Vec::<ProviderRequest>::new()));
3671 let mut services = RuntimeServices::default();
3672 let mut models = ModelRegistry::new();
3673 models.set_default_model(resolved_test_model(
3674 "default",
3675 "default-provider",
3676 "default/model",
3677 ));
3678 models.set_subagent_model(resolved_test_model(
3679 "subagent",
3680 "subagent-provider",
3681 "subagent/model",
3682 ));
3683 models.register_provider(
3684 ProviderName::from("default-provider"),
3685 Arc::new(RecordingProvider::new(
3686 default_requests.clone(),
3687 "default provider reply",
3688 )),
3689 );
3690 models.register_provider(
3691 ProviderName::from("subagent-provider"),
3692 Arc::new(RecordingProvider::new(
3693 subagent_requests.clone(),
3694 "subagent provider reply",
3695 )),
3696 );
3697 services.models = Arc::new(models);
3698 services.policy = Arc::new(DefaultToolPolicy::new(PolicySettings {
3699 allowed_write_roots: vec![temp.path().to_path_buf()],
3700 ..PolicySettings::default()
3701 }));
3702
3703 let runtime = SessionRuntime::new(Arc::new(services));
3704 let session = new_session(&runtime, temp.path()).await;
3705
3706 let events = session
3707 .submit_turn(Turn::user("hello").with_default_model("subagent"))
3708 .await
3709 .expect("submit turn")
3710 .try_collect::<Vec<_>>()
3711 .await
3712 .expect("collect events");
3713
3714 assert!(default_requests.lock().expect("requests").is_empty());
3715 let subagent_requests = subagent_requests.lock().expect("requests");
3716 assert_eq!(subagent_requests.len(), 1);
3717 assert_eq!(subagent_requests[0].model.id, ModelId::from("subagent"));
3718 assert_eq!(
3719 subagent_requests[0].model.provider,
3720 ProviderName::from("subagent-provider")
3721 );
3722 assert!(events.iter().any(|event| matches!(
3723 &event.payload,
3724 SessionEventPayload::MessageItem {
3725 message: Message::Assistant(assistant),
3726 } if assistant.parts.iter().any(|part| matches!(
3727 part,
3728 AssistantPart::Text { text } if text.contains("subagent provider reply")
3729 ))
3730 )));
3731 }
3732
3733 #[tokio::test]
3734 async fn turn_rejects_unknown_subagent_model_override() {
3735 let temp = tempfile::tempdir().expect("tempdir");
3736 let services = configured_services(Arc::new(FakeProvider::default()), temp.path());
3737 let runtime = SessionRuntime::new(services.clone());
3738 let session = new_session(&runtime, temp.path()).await;
3739
3740 let events = session
3741 .submit_turn(Turn::user("hello").with_subagent_model("missing"))
3742 .await
3743 .expect("submit turn")
3744 .try_collect::<Vec<_>>()
3745 .await
3746 .expect("collect events");
3747
3748 assert!(events.iter().any(|event| matches!(
3749 &event.payload,
3750 SessionEventPayload::TurnFailed { error, .. } if error.contains("unknown model 'missing'")
3751 )));
3752 }
3753
3754 #[tokio::test]
3755 async fn submit_turn_delivers_tool_output_after_turn_commits() {
3756 let temp = tempfile::tempdir().expect("tempdir");
3761 let services = configured_services(Arc::new(StreamingToolProvider), temp.path());
3762 services.tools.register(Arc::new(StreamingTestTool));
3763 let runtime = SessionRuntime::new(services);
3764 let session = new_session(&runtime, temp.path()).await;
3765
3766 let events = session
3767 .submit_turn(Turn::user("stream tool output"))
3768 .await
3769 .expect("submit turn")
3770 .try_collect::<Vec<_>>()
3771 .await
3772 .expect("collect events");
3773
3774 let tool_output = events
3775 .iter()
3776 .find_map(|event| match &event.payload {
3777 SessionEventPayload::ToolOutput { chunk, .. } => Some(chunk.clone()),
3778 _ => None,
3779 })
3780 .expect("tool output chunk present in committed events");
3781 assert_eq!(tool_output, "streamed chunk");
3782 assert!(
3783 events
3784 .iter()
3785 .any(|event| matches!(event.payload, SessionEventPayload::TurnCompleted { .. }))
3786 );
3787 }
3788
3789 #[tokio::test]
3790 async fn mutating_tool_result_survives_later_provider_failure() {
3791 let temp = tempfile::tempdir().expect("tempdir");
3792 let services = configured_services(Arc::new(ToolThenFailProvider), temp.path());
3793 register_builtin_tools(&services.tools, &[]);
3794 let runtime = SessionRuntime::new(services.clone());
3795 let session = new_session(&runtime, temp.path()).await;
3796
3797 let events = session
3798 .submit_turn(Turn::user("write then fail"))
3799 .await
3800 .expect("submit turn")
3801 .try_collect::<Vec<_>>()
3802 .await
3803 .expect("collect events");
3804
3805 let written = std::fs::read_to_string(temp.path().join("note.txt")).expect("written file");
3806 assert_eq!(written, "hello from tool");
3807 assert!(events.iter().any(|event| matches!(
3808 &event.payload,
3809 SessionEventPayload::ToolExecutionCompleted { outcome }
3810 if outcome.call.name.0 == "write" && outcome.result.is_ok()
3811 )));
3812 assert!(events.iter().any(|event| matches!(
3813 &event.payload,
3814 SessionEventPayload::TurnFailed { error, .. }
3815 if error.contains("provider failed after tool")
3816 )));
3817
3818 let replayed = session.replay().await.expect("replay events");
3819 assert!(replayed.iter().any(|event| matches!(
3820 &event.payload,
3821 SessionEventPayload::ToolExecutionCompleted { outcome }
3822 if outcome.call.name.0 == "write" && outcome.result.is_ok()
3823 )));
3824 assert!(replayed.iter().any(|event| matches!(
3825 &event.payload,
3826 SessionEventPayload::TurnFailed { error, .. }
3827 if error.contains("provider failed after tool")
3828 )));
3829
3830 let stored = services
3831 .sessions
3832 .load_session(session.session_id())
3833 .await
3834 .expect("load session")
3835 .expect("session exists");
3836 assert!(stored.state.messages.iter().any(|message| {
3837 matches!(
3838 message,
3839 Message::Tool(tool) if tool.error.is_none()
3840 )
3841 }));
3842 }
3843
3844 #[tokio::test]
3845 async fn max_turns_caps_provider_iterations() {
3846 let temp = tempfile::tempdir().expect("tempdir");
3847 let services = configured_services(Arc::new(ToolLoopProvider), temp.path());
3848 register_builtin_tools(&services.tools, &[]);
3849 let runtime = SessionRuntime::new(services);
3850 let session = runtime
3851 .new_session(SessionInit {
3852 working_dir: temp.path().to_path_buf(),
3853 max_turns: Some(1),
3854 ..SessionInit::default()
3855 })
3856 .await
3857 .expect("session");
3858
3859 let events = session
3860 .submit_turn(Turn::user("one iteration only"))
3861 .await
3862 .expect("submit turn")
3863 .try_collect::<Vec<_>>()
3864 .await
3865 .expect("collect events");
3866
3867 assert_eq!(
3868 std::fs::read_to_string(temp.path().join("note.txt")).expect("written file"),
3869 "hello from tool"
3870 );
3871 assert!(events.iter().any(|event| matches!(
3872 &event.payload,
3873 SessionEventPayload::TurnFailed { error, .. } if error.contains("max_turns 1")
3874 )));
3875 assert!(
3876 !events
3877 .iter()
3878 .any(|event| matches!(event.payload, SessionEventPayload::TurnCompleted { .. }))
3879 );
3880 }
3881
3882 #[tokio::test]
3883 async fn subagent_event_forwarding_defaults_off() {
3884 let (parent_id, events) =
3885 run_subagent_firehose_turn(SubagentEventForwarding::Off, None, 100_000, "single").await;
3886
3887 assert!(
3888 events.iter().all(|event| event.session_id == parent_id),
3889 "default-off parent stream should contain only parent session events: {events:?}"
3890 );
3891 }
3892
3893 #[tokio::test]
3894 async fn traces_record_subagent_events_when_forwarding_is_off() {
3895 let temp = tempfile::tempdir().expect("tempdir");
3896 let traces_dir = temp.path().join("traces");
3897 let trace_recorder =
3898 Arc::new(crate::TraceRecorder::open(traces_dir.clone()).expect("trace recorder"));
3899 let provider = Arc::new(SubagentFirehoseProvider);
3900 let services = test_support::configured_services_with_runtime_and_trace(
3901 provider,
3902 temp.path(),
3903 SubagentEventForwarding::Off,
3904 100_000,
3905 Some(trace_recorder.clone()),
3906 );
3907 let runtime = SessionRuntime::new(services.clone());
3908 install_subagent_tools(&runtime, &services);
3909 let session = runtime
3910 .new_session(SessionInit {
3911 working_dir: temp.path().to_path_buf(),
3912 ..SessionInit::default()
3913 })
3914 .await
3915 .expect("session");
3916 let parent_id = session.session_id().clone();
3917
3918 let events = session
3919 .submit_turn(Turn::user("single"))
3920 .await
3921 .expect("turn")
3922 .try_collect::<Vec<_>>()
3923 .await
3924 .expect("events");
3925 assert!(
3926 events.iter().all(|event| event.session_id == parent_id),
3927 "forwarding is off, so live stream should stay parent-only: {events:?}"
3928 );
3929
3930 trace_recorder.close_session(&parent_id);
3931 let contents = std::fs::read_to_string(traces_dir.join(format!("{}.txt", parent_id.0)))
3932 .expect("read trace");
3933 let lines = contents
3934 .lines()
3935 .map(|line| serde_json::from_str::<serde_json::Value>(line).expect("trace json"))
3936 .collect::<Vec<_>>();
3937
3938 assert!(
3939 lines.iter().any(|line| {
3940 line.get("kind").and_then(serde_json::Value::as_str) == Some("subagent_header")
3941 }),
3942 "trace should include a subagent header even when forwarding is off:\n{contents}"
3943 );
3944 assert!(
3945 lines.iter().any(|line| {
3946 line.get("sequence").is_some()
3947 && line.get("session_id").and_then(serde_json::Value::as_str)
3948 != Some(parent_id.0.as_str())
3949 && line
3950 .pointer("/payload/kind")
3951 .and_then(serde_json::Value::as_str)
3952 == Some("delta_item")
3953 && line
3954 .pointer("/payload/delta/text")
3955 .and_then(serde_json::Value::as_str)
3956 .is_some_and(|text| text.contains("child done"))
3957 }),
3958 "trace should include committed subagent delta events even when forwarding is off:\n{contents}"
3959 );
3960 }
3961
3962 #[tokio::test]
3963 async fn subagent_event_forwarding_includes_single_level_events() {
3964 let (parent_id, events) =
3965 run_subagent_firehose_turn(SubagentEventForwarding::All, None, 100_000, "single").await;
3966
3967 assert!(
3968 events.iter().any(|event| event.session_id == parent_id),
3969 "parent events should still be present"
3970 );
3971 assert!(
3972 forwarded_events(&events, &parent_id)
3973 .iter()
3974 .any(|event| event_has_delta_text(event, "child done")),
3975 "parent stream should include child session deltas: {events:?}"
3976 );
3977 }
3978
3979 #[tokio::test]
3980 async fn subagent_event_forwarding_includes_recursive_events() {
3981 let (parent_id, events) =
3982 run_subagent_firehose_turn(SubagentEventForwarding::All, None, 100_000, "recursive")
3983 .await;
3984 let forwarded = forwarded_events(&events, &parent_id);
3985 let forwarded_session_ids = forwarded
3986 .iter()
3987 .map(|event| event.session_id.clone())
3988 .collect::<BTreeSet<_>>();
3989
3990 assert!(
3991 forwarded_session_ids.len() >= 2,
3992 "recursive forwarding should include child and grandchild sessions: {events:?}"
3993 );
3994 assert!(
3995 forwarded
3996 .iter()
3997 .any(|event| event_has_delta_text(event, "grandchild done")),
3998 "top-level stream should include grandchild deltas: {events:?}"
3999 );
4000 }
4001
4002 #[tokio::test]
4003 async fn session_init_can_override_subagent_event_forwarding() {
4004 let temp = tempfile::tempdir().expect("tempdir");
4005 let provider = Arc::new(SubagentFirehoseProvider);
4006 let services = test_support::configured_services_with_runtime(
4007 provider,
4008 temp.path(),
4009 SubagentEventForwarding::Off,
4010 100_000,
4011 );
4012 let runtime = SessionRuntime::new(services.clone());
4013 install_subagent_tools(&runtime, &services);
4014
4015 let enabled_session = runtime
4016 .new_session(SessionInit {
4017 working_dir: temp.path().to_path_buf(),
4018 subagent_event_forwarding: Some(SubagentEventForwarding::All),
4019 ..SessionInit::default()
4020 })
4021 .await
4022 .expect("enabled session");
4023 let enabled_parent_id = enabled_session.session_id().clone();
4024 let enabled_events = enabled_session
4025 .submit_turn(Turn::user("single"))
4026 .await
4027 .expect("enabled turn")
4028 .try_collect::<Vec<_>>()
4029 .await
4030 .expect("enabled events");
4031
4032 let default_session = runtime
4033 .new_session(SessionInit {
4034 working_dir: temp.path().to_path_buf(),
4035 ..SessionInit::default()
4036 })
4037 .await
4038 .expect("default session");
4039 let default_parent_id = default_session.session_id().clone();
4040 let default_events = default_session
4041 .submit_turn(Turn::user("single"))
4042 .await
4043 .expect("default turn")
4044 .try_collect::<Vec<_>>()
4045 .await
4046 .expect("default events");
4047
4048 assert!(
4049 !forwarded_events(&enabled_events, &enabled_parent_id).is_empty(),
4050 "per-session override should enable forwarding"
4051 );
4052 assert!(
4053 forwarded_events(&default_events, &default_parent_id).is_empty(),
4054 "harness default off should still apply to other sessions"
4055 );
4056 }
4057
4058 #[tokio::test]
4059 async fn subagent_event_forwarding_cap_emits_lagged_and_stops_forwarding() {
4060 let (parent_id, events) =
4061 run_subagent_firehose_turn(SubagentEventForwarding::All, None, 2, "many child events")
4062 .await;
4063
4064 assert!(
4065 events.iter().any(|event| {
4066 event.session_id.0 == crate::event_bus::BUS_SESSION_ID
4067 && matches!(
4068 event.payload,
4069 SessionEventPayload::Lagged { dropped_events: 1 }
4070 )
4071 }),
4072 "cap should emit a synthetic lagged event: {events:?}"
4073 );
4074 assert_eq!(
4075 forwarded_events(&events, &parent_id).len(),
4076 2,
4077 "forwarded child events should stop at the configured cap: {events:?}"
4078 );
4079 }
4080
4081 async fn run_subagent_firehose_turn(
4082 default_forwarding: SubagentEventForwarding,
4083 session_forwarding: Option<SubagentEventForwarding>,
4084 cap: u64,
4085 prompt: &str,
4086 ) -> (SessionId, Vec<SessionEvent>) {
4087 let temp = tempfile::tempdir().expect("tempdir");
4088 let provider = Arc::new(SubagentFirehoseProvider);
4089 let services = test_support::configured_services_with_runtime(
4090 provider,
4091 temp.path(),
4092 default_forwarding,
4093 cap,
4094 );
4095 let runtime = SessionRuntime::new(services.clone());
4096 install_subagent_tools(&runtime, &services);
4097 let session = runtime
4098 .new_session(SessionInit {
4099 working_dir: temp.path().to_path_buf(),
4100 subagent_event_forwarding: session_forwarding,
4101 ..SessionInit::default()
4102 })
4103 .await
4104 .expect("session");
4105 let parent_id = session.session_id().clone();
4106 let events = session
4107 .submit_turn(Turn::user(prompt))
4108 .await
4109 .expect("turn")
4110 .try_collect::<Vec<_>>()
4111 .await
4112 .expect("events");
4113 (parent_id, events)
4114 }
4115
4116 fn install_subagent_tools(runtime: &SessionRuntime, services: &Arc<RuntimeServices>) {
4117 let snapshot = services.resources.snapshot();
4118 let available_model_ids = services
4119 .models
4120 .model_ids()
4121 .into_iter()
4122 .map(|model_id| model_id.0)
4123 .collect::<Vec<_>>();
4124 register_subagent_tools(
4125 &services.tools,
4126 runtime.subagent_control(),
4127 &[],
4128 snapshot.as_ref(),
4129 &available_model_ids,
4130 );
4131 }
4132
4133 fn forwarded_events<'a>(
4134 events: &'a [SessionEvent],
4135 parent_id: &SessionId,
4136 ) -> Vec<&'a SessionEvent> {
4137 events
4138 .iter()
4139 .filter(|event| {
4140 &event.session_id != parent_id
4141 && event.session_id.0 != crate::event_bus::BUS_SESSION_ID
4142 })
4143 .collect()
4144 }
4145
4146 fn event_has_delta_text(event: &SessionEvent, needle: &str) -> bool {
4147 matches!(
4148 &event.payload,
4149 SessionEventPayload::DeltaItem { delta } if delta.text.contains(needle)
4150 )
4151 }
4152
4153 #[derive(Debug)]
4154 struct ToolLoopProvider;
4155
4156 #[derive(Debug)]
4157 struct DuplicateToolCallProvider;
4158
4159 #[derive(Debug, Default)]
4160 struct SubagentFirehoseProvider;
4161
4162 #[derive(Debug)]
4163 struct JsonHookProvider {
4164 requests: Arc<Mutex<Vec<ProviderRequest>>>,
4165 }
4166
4167 impl JsonHookProvider {
4168 fn new(requests: Arc<Mutex<Vec<ProviderRequest>>>) -> Self {
4169 Self { requests }
4170 }
4171 }
4172
4173 #[async_trait]
4174 impl Provider for JsonHookProvider {
4175 fn capabilities(&self) -> ProviderCapabilities {
4176 ProviderCapabilities::default()
4177 }
4178
4179 async fn stream(
4180 &self,
4181 request: ProviderRequest,
4182 _cancel: CancellationToken,
4183 ) -> anyhow::Result<BoxStream<'static, Result<StreamEvent, ProviderError>>> {
4184 self.requests
4185 .lock()
4186 .expect("requests")
4187 .push(request.clone());
4188 let latest_user_text = request
4189 .messages
4190 .iter()
4191 .rev()
4192 .find_map(|message| match message {
4193 Message::User(user) => Some(user.plain_text()),
4194 Message::System(_) | Message::Assistant(_) | Message::Tool(_) => None,
4195 })
4196 .unwrap_or_default();
4197 let reply = if latest_user_text.starts_with("HOOK_PROMPT ") {
4198 "{\"decision\":\"block\",\"reason\":\"blocked by prompt hook\"}".to_owned()
4199 } else {
4200 "normal reply".to_owned()
4201 };
4202 let message_id = halter_protocol::MessageId::new();
4203 let block_id = BlockId::new();
4204 Ok(stream::iter(vec![
4205 Ok(StreamEvent::MessageStart {
4206 id: message_id.clone(),
4207 }),
4208 Ok(StreamEvent::TextStart {
4209 id: block_id.clone(),
4210 }),
4211 Ok(StreamEvent::TextDelta {
4212 id: block_id.clone(),
4213 delta: reply,
4214 }),
4215 Ok(StreamEvent::TextEnd { id: block_id }),
4216 Ok(StreamEvent::UsageUpdate {
4217 usage: Usage {
4218 input_tokens: 4,
4219 output_tokens: 4,
4220 cache_creation_input_tokens: 0,
4221 cache_read_input_tokens: 0,
4222 },
4223 }),
4224 Ok(StreamEvent::MessageEnd {
4225 id: message_id,
4226 stop_reason: StopReason::EndTurn,
4227 response_id: None,
4228 }),
4229 ])
4230 .boxed())
4231 }
4232 }
4233
4234 #[async_trait]
4235 impl Provider for SubagentFirehoseProvider {
4236 fn capabilities(&self) -> ProviderCapabilities {
4237 ProviderCapabilities::default()
4238 }
4239
4240 async fn stream(
4241 &self,
4242 request: ProviderRequest,
4243 _cancel: CancellationToken,
4244 ) -> anyhow::Result<BoxStream<'static, Result<StreamEvent, ProviderError>>> {
4245 let latest_user_text = latest_user_text(&request.messages);
4246 let has_wait_response = has_wait_response(&request.messages);
4247 let latest_agent_id = latest_spawn_agent_id(&request.messages);
4248
4249 if request.model.id.0 == "subagent" {
4250 if latest_user_text.contains("spawn grandchild") {
4251 if has_wait_response {
4252 return Ok(text_stream(vec!["child done"]));
4253 }
4254 if let Some(agent_id) = latest_agent_id {
4255 return Ok(tool_call_stream(
4256 "wait_agent",
4257 json!({ "targets": [agent_id], "timeout_ms": 5_000 }),
4258 ));
4259 }
4260 return Ok(tool_call_stream(
4261 "spawn_agent",
4262 json!({ "message": "grandchild task", "fork_context": false }),
4263 ));
4264 }
4265
4266 if latest_user_text.contains("grandchild") {
4267 return Ok(text_stream(vec!["grandchild done"]));
4268 }
4269 if latest_user_text.contains("many child events") {
4270 return Ok(text_stream(vec![
4271 "child ", "done ", "with ", "many ", "events",
4272 ]));
4273 }
4274 return Ok(text_stream(vec!["child done"]));
4275 }
4276
4277 if has_wait_response {
4278 return Ok(text_stream(vec!["parent done"]));
4279 }
4280 if let Some(agent_id) = latest_agent_id {
4281 return Ok(tool_call_stream(
4282 "wait_agent",
4283 json!({ "targets": [agent_id], "timeout_ms": 5_000 }),
4284 ));
4285 }
4286
4287 let child_task = if latest_user_text.contains("recursive") {
4288 "spawn grandchild"
4289 } else if latest_user_text.contains("many child events") {
4290 "many child events"
4291 } else {
4292 "child task"
4293 };
4294 Ok(tool_call_stream(
4295 "spawn_agent",
4296 json!({ "message": child_task, "fork_context": false }),
4297 ))
4298 }
4299 }
4300
4301 fn latest_user_text(messages: &[Message]) -> String {
4302 messages
4303 .iter()
4304 .rev()
4305 .find_map(|message| match message {
4306 Message::User(user) => Some(user.plain_text()),
4307 Message::System(_) | Message::Assistant(_) | Message::Tool(_) => None,
4308 })
4309 .unwrap_or_default()
4310 }
4311
4312 fn latest_spawn_agent_id(messages: &[Message]) -> Option<String> {
4313 messages.iter().rev().find_map(|message| match message {
4314 Message::Tool(tool) => match &tool.content {
4315 ToolResult::Json { value } => value
4316 .get("agent_id")
4317 .and_then(serde_json::Value::as_str)
4318 .map(ToOwned::to_owned),
4319 ToolResult::Empty | ToolResult::Text { .. } => None,
4320 },
4321 Message::System(_) | Message::User(_) | Message::Assistant(_) => None,
4322 })
4323 }
4324
4325 fn has_wait_response(messages: &[Message]) -> bool {
4326 messages.iter().any(|message| {
4327 matches!(
4328 message,
4329 Message::Tool(tool)
4330 if matches!(&tool.content, ToolResult::Json { value } if value.get("timed_out").is_some())
4331 )
4332 })
4333 }
4334
4335 fn text_stream(
4336 chunks: Vec<&'static str>,
4337 ) -> BoxStream<'static, Result<StreamEvent, ProviderError>> {
4338 let message_id = halter_protocol::MessageId::new();
4339 let block_id = BlockId::new();
4340 let mut events = vec![
4341 Ok(StreamEvent::MessageStart {
4342 id: message_id.clone(),
4343 }),
4344 Ok(StreamEvent::TextStart {
4345 id: block_id.clone(),
4346 }),
4347 ];
4348 for chunk in chunks {
4349 events.push(Ok(StreamEvent::TextDelta {
4350 id: block_id.clone(),
4351 delta: chunk.to_owned(),
4352 }));
4353 }
4354 events.extend([
4355 Ok(StreamEvent::TextEnd { id: block_id }),
4356 Ok(StreamEvent::UsageUpdate {
4357 usage: Usage {
4358 input_tokens: 1,
4359 output_tokens: 1,
4360 cache_creation_input_tokens: 0,
4361 cache_read_input_tokens: 0,
4362 },
4363 }),
4364 Ok(StreamEvent::MessageEnd {
4365 id: message_id,
4366 stop_reason: StopReason::EndTurn,
4367 response_id: None,
4368 }),
4369 ]);
4370 stream::iter(events).boxed()
4371 }
4372
4373 fn tool_call_stream(
4374 name: &'static str,
4375 arguments: serde_json::Value,
4376 ) -> BoxStream<'static, Result<StreamEvent, ProviderError>> {
4377 let message_id = halter_protocol::MessageId::new();
4378 let block_id = BlockId::new();
4379 let tool_call_id = ToolCallId::new();
4380 stream::iter(vec![
4381 Ok(StreamEvent::MessageStart {
4382 id: message_id.clone(),
4383 }),
4384 Ok(StreamEvent::ToolCallStart {
4385 id: block_id.clone(),
4386 tool_call_id,
4387 name: ToolName::from(name),
4388 }),
4389 Ok(StreamEvent::ToolArgsDelta {
4390 id: block_id.clone(),
4391 delta: arguments.to_string(),
4392 }),
4393 Ok(StreamEvent::ToolCallEnd { id: block_id }),
4394 Ok(StreamEvent::UsageUpdate {
4395 usage: Usage {
4396 input_tokens: 1,
4397 output_tokens: 1,
4398 cache_creation_input_tokens: 0,
4399 cache_read_input_tokens: 0,
4400 },
4401 }),
4402 Ok(StreamEvent::MessageEnd {
4403 id: message_id,
4404 stop_reason: StopReason::ToolUse,
4405 response_id: None,
4406 }),
4407 ])
4408 .boxed()
4409 }
4410
4411 #[async_trait]
4412 impl Provider for ToolLoopProvider {
4413 fn capabilities(&self) -> ProviderCapabilities {
4414 ProviderCapabilities::default()
4415 }
4416
4417 async fn stream(
4418 &self,
4419 request: ProviderRequest,
4420 _cancel: CancellationToken,
4421 ) -> anyhow::Result<BoxStream<'static, Result<StreamEvent, ProviderError>>> {
4422 if request
4423 .messages
4424 .iter()
4425 .any(|message| matches!(message, Message::Tool(_)))
4426 {
4427 return Ok(stream::iter(vec![
4428 Ok(StreamEvent::MessageStart {
4429 id: halter_protocol::MessageId::new(),
4430 }),
4431 Ok(StreamEvent::TextStart { id: BlockId::new() }),
4432 Ok(StreamEvent::TextDelta {
4433 id: BlockId::new(),
4434 delta: "tool completed".to_owned(),
4435 }),
4436 Ok(StreamEvent::TextEnd { id: BlockId::new() }),
4437 Ok(StreamEvent::UsageUpdate {
4438 usage: Usage {
4439 input_tokens: 10,
4440 output_tokens: 2,
4441 cache_creation_input_tokens: 0,
4442 cache_read_input_tokens: 0,
4443 },
4444 }),
4445 Ok(StreamEvent::MessageEnd {
4446 id: halter_protocol::MessageId::new(),
4447 stop_reason: StopReason::EndTurn,
4448 response_id: None,
4449 }),
4450 ])
4451 .boxed());
4452 }
4453
4454 let block_id = BlockId::new();
4455 let message_id = halter_protocol::MessageId::new();
4456 let tool_call_id = ToolCallId::new();
4457 Ok(stream::iter(vec![
4458 Ok(StreamEvent::MessageStart {
4459 id: message_id.clone(),
4460 }),
4461 Ok(StreamEvent::ToolCallStart {
4462 id: block_id.clone(),
4463 tool_call_id,
4464 name: halter_protocol::ToolName::from("write"),
4465 }),
4466 Ok(StreamEvent::ToolArgsDelta {
4467 id: block_id.clone(),
4468 delta: serde_json::json!({
4469 "path": "note.txt",
4470 "content": "hello from tool"
4471 })
4472 .to_string(),
4473 }),
4474 Ok(StreamEvent::ToolCallEnd { id: block_id }),
4475 Ok(StreamEvent::UsageUpdate {
4476 usage: Usage {
4477 input_tokens: 8,
4478 output_tokens: 0,
4479 cache_creation_input_tokens: 0,
4480 cache_read_input_tokens: 0,
4481 },
4482 }),
4483 Ok(StreamEvent::MessageEnd {
4484 id: message_id,
4485 stop_reason: StopReason::ToolUse,
4486 response_id: None,
4487 }),
4488 ])
4489 .boxed())
4490 }
4491 }
4492
4493 struct ToolThenFailProvider;
4494
4495 #[async_trait]
4496 impl Provider for ToolThenFailProvider {
4497 fn capabilities(&self) -> ProviderCapabilities {
4498 ProviderCapabilities::default()
4499 }
4500
4501 async fn stream(
4502 &self,
4503 request: ProviderRequest,
4504 _cancel: CancellationToken,
4505 ) -> anyhow::Result<BoxStream<'static, Result<StreamEvent, ProviderError>>> {
4506 if request
4507 .messages
4508 .iter()
4509 .any(|message| matches!(message, Message::Tool(_)))
4510 {
4511 return Ok(stream::iter(vec![Err(ProviderError::new(
4512 "provider failed after tool",
4513 false,
4514 ))])
4515 .boxed());
4516 }
4517
4518 let block_id = BlockId::new();
4519 let message_id = halter_protocol::MessageId::new();
4520 Ok(stream::iter(vec![
4521 Ok(StreamEvent::MessageStart {
4522 id: message_id.clone(),
4523 }),
4524 Ok(StreamEvent::ToolCallStart {
4525 id: block_id.clone(),
4526 tool_call_id: ToolCallId::new(),
4527 name: ToolName::from("write"),
4528 }),
4529 Ok(StreamEvent::ToolArgsDelta {
4530 id: block_id.clone(),
4531 delta: serde_json::json!({
4532 "path": "note.txt",
4533 "content": "hello from tool"
4534 })
4535 .to_string(),
4536 }),
4537 Ok(StreamEvent::ToolCallEnd { id: block_id }),
4538 Ok(StreamEvent::MessageEnd {
4539 id: message_id,
4540 stop_reason: StopReason::ToolUse,
4541 response_id: None,
4542 }),
4543 ])
4544 .boxed())
4545 }
4546 }
4547
4548 #[async_trait]
4549 impl Provider for DuplicateToolCallProvider {
4550 fn capabilities(&self) -> ProviderCapabilities {
4551 ProviderCapabilities::default()
4552 }
4553
4554 async fn stream(
4555 &self,
4556 request: ProviderRequest,
4557 _cancel: CancellationToken,
4558 ) -> anyhow::Result<BoxStream<'static, Result<StreamEvent, ProviderError>>> {
4559 if request
4560 .messages
4561 .iter()
4562 .any(|message| matches!(message, Message::Tool(_)))
4563 {
4564 return Ok(stream::iter(vec![
4565 Ok(StreamEvent::MessageStart {
4566 id: halter_protocol::MessageId::new(),
4567 }),
4568 Ok(StreamEvent::TextStart { id: BlockId::new() }),
4569 Ok(StreamEvent::TextDelta {
4570 id: BlockId::new(),
4571 delta: "tool completed".to_owned(),
4572 }),
4573 Ok(StreamEvent::TextEnd { id: BlockId::new() }),
4574 Ok(StreamEvent::MessageEnd {
4575 id: halter_protocol::MessageId::new(),
4576 stop_reason: StopReason::EndTurn,
4577 response_id: None,
4578 }),
4579 ])
4580 .boxed());
4581 }
4582
4583 let first_block_id = BlockId::new();
4584 let second_block_id = BlockId::new();
4585 let tool_call_id = ToolCallId::from("call-dedupe");
4586 Ok(stream::iter(vec![
4587 Ok(StreamEvent::MessageStart {
4588 id: halter_protocol::MessageId::new(),
4589 }),
4590 Ok(StreamEvent::ToolCallStart {
4591 id: first_block_id.clone(),
4592 tool_call_id: tool_call_id.clone(),
4593 name: ToolName::from("count_test"),
4594 }),
4595 Ok(StreamEvent::ToolArgsDelta {
4596 id: first_block_id.clone(),
4597 delta: json!({ "value": 1 }).to_string(),
4598 }),
4599 Ok(StreamEvent::ToolCallEnd { id: first_block_id }),
4600 Ok(StreamEvent::ToolCallStart {
4601 id: second_block_id.clone(),
4602 tool_call_id,
4603 name: ToolName::from("count_test"),
4604 }),
4605 Ok(StreamEvent::ToolArgsDelta {
4606 id: second_block_id.clone(),
4607 delta: json!({ "value": 1 }).to_string(),
4608 }),
4609 Ok(StreamEvent::ToolCallEnd {
4610 id: second_block_id,
4611 }),
4612 Ok(StreamEvent::MessageEnd {
4613 id: halter_protocol::MessageId::new(),
4614 stop_reason: StopReason::ToolUse,
4615 response_id: None,
4616 }),
4617 ])
4618 .boxed())
4619 }
4620 }
4621
4622 #[derive(Debug)]
4623 struct FailingProvider;
4624
4625 #[async_trait]
4626 impl Provider for FailingProvider {
4627 fn capabilities(&self) -> ProviderCapabilities {
4628 ProviderCapabilities::default()
4629 }
4630
4631 async fn stream(
4632 &self,
4633 _request: ProviderRequest,
4634 _cancel: CancellationToken,
4635 ) -> anyhow::Result<BoxStream<'static, Result<StreamEvent, ProviderError>>> {
4636 Ok(stream::iter(vec![Err(ProviderError::new("provider exploded", false))]).boxed())
4637 }
4638 }
4639
4640 #[derive(Debug)]
4643 struct RetryableFailingProvider;
4644
4645 #[async_trait]
4646 impl Provider for RetryableFailingProvider {
4647 fn capabilities(&self) -> ProviderCapabilities {
4648 ProviderCapabilities::default()
4649 }
4650
4651 async fn stream(
4652 &self,
4653 _request: ProviderRequest,
4654 _cancel: CancellationToken,
4655 ) -> anyhow::Result<BoxStream<'static, Result<StreamEvent, ProviderError>>> {
4656 Ok(stream::iter(vec![Err(ProviderError::new("rate limited", true))]).boxed())
4657 }
4658 }
4659
4660 #[derive(Debug)]
4664 struct CancellableBlockingProvider {
4665 started: Arc<Notify>,
4666 }
4667
4668 #[async_trait]
4669 impl Provider for CancellableBlockingProvider {
4670 fn capabilities(&self) -> ProviderCapabilities {
4671 ProviderCapabilities::default()
4672 }
4673
4674 async fn stream(
4675 &self,
4676 _request: ProviderRequest,
4677 cancel: CancellationToken,
4678 ) -> anyhow::Result<BoxStream<'static, Result<StreamEvent, ProviderError>>> {
4679 let started = self.started.clone();
4680 let s = futures::stream::unfold(Some((cancel, started, false)), |state| async move {
4681 let (cancel, started, emitted) = state?;
4682 if !emitted {
4683 started.notify_one();
4684 return Some((
4685 Ok(StreamEvent::MessageStart {
4686 id: halter_protocol::MessageId::new(),
4687 }),
4688 Some((cancel, started, true)),
4689 ));
4690 }
4691 cancel.cancelled().await;
4692 Some((Err(ProviderError::new("cancelled", false)), None))
4693 });
4694 Ok(s.boxed())
4695 }
4696 }
4697
4698 #[derive(Debug)]
4702 struct UncancellableBlockingProvider {
4703 started: Arc<Notify>,
4704 }
4705
4706 #[async_trait]
4707 impl Provider for UncancellableBlockingProvider {
4708 fn capabilities(&self) -> ProviderCapabilities {
4709 ProviderCapabilities::default()
4710 }
4711
4712 async fn stream(
4713 &self,
4714 _request: ProviderRequest,
4715 _cancel: CancellationToken,
4716 ) -> anyhow::Result<BoxStream<'static, Result<StreamEvent, ProviderError>>> {
4717 let started = self.started.clone();
4718 let s = futures::stream::unfold(Some(started), |state| async move {
4719 let started = state?;
4720 started.notify_one();
4721 std::future::pending::<()>().await;
4722 #[allow(unreachable_code)]
4723 Some((
4724 Ok(StreamEvent::MessageStart {
4725 id: halter_protocol::MessageId::new(),
4726 }),
4727 None,
4728 ))
4729 });
4730 Ok(s.boxed())
4731 }
4732 }
4733
4734 #[tokio::test]
4735 async fn turn_failed_carries_retryable_flag_from_provider_error() {
4736 let temp = tempfile::tempdir().expect("tempdir");
4739 let services = configured_services(Arc::new(RetryableFailingProvider), temp.path());
4740 let runtime = SessionRuntime::new(services.clone());
4741 let session = new_session(&runtime, temp.path()).await;
4742
4743 let events = session
4744 .submit_turn(Turn::user("retryable failure"))
4745 .await
4746 .expect("submit turn")
4747 .try_collect::<Vec<_>>()
4748 .await
4749 .expect("collect events");
4750
4751 let live_failure = events
4752 .iter()
4753 .find_map(|event| match &event.payload {
4754 SessionEventPayload::TurnFailed { retryable, .. } => Some(*retryable),
4755 _ => None,
4756 })
4757 .expect("TurnFailed in live stream");
4758 assert!(live_failure, "live TurnFailed must preserve retryable=true");
4759
4760 let replay = services
4761 .sessions
4762 .replay(session.session_id())
4763 .await
4764 .expect("replay");
4765 let persisted = replay
4766 .iter()
4767 .find_map(|event| match &event.payload {
4768 SessionEventPayload::TurnFailed { retryable, .. } => Some(*retryable),
4769 _ => None,
4770 })
4771 .expect("TurnFailed in persisted store");
4772 assert!(
4773 persisted,
4774 "persisted TurnFailed must preserve retryable=true"
4775 );
4776 }
4777
4778 #[tokio::test]
4779 async fn turn_failed_non_retryable_default_path() {
4780 let temp = tempfile::tempdir().expect("tempdir");
4783 let services = configured_services(Arc::new(FailingProvider), temp.path());
4784 let runtime = SessionRuntime::new(services.clone());
4785 let session = new_session(&runtime, temp.path()).await;
4786
4787 let events = session
4788 .submit_turn(Turn::user("non-retryable failure"))
4789 .await
4790 .expect("submit turn")
4791 .try_collect::<Vec<_>>()
4792 .await
4793 .expect("collect events");
4794
4795 let retryable = events
4796 .iter()
4797 .find_map(|event| match &event.payload {
4798 SessionEventPayload::TurnFailed { retryable, .. } => Some(*retryable),
4799 _ => None,
4800 })
4801 .expect("TurnFailed payload");
4802 assert!(!retryable);
4803 }
4804
4805 #[tokio::test]
4806 async fn turn_failed_carries_originating_turn_id() {
4807 let temp = tempfile::tempdir().expect("tempdir");
4810 let services = configured_services(Arc::new(FailingProvider), temp.path());
4811 let runtime = SessionRuntime::new(services.clone());
4812 let session = new_session(&runtime, temp.path()).await;
4813
4814 let turn = Turn::user("track id");
4815 let expected_id = turn.id.clone();
4816 let events = session
4817 .submit_turn(turn)
4818 .await
4819 .expect("submit turn")
4820 .try_collect::<Vec<_>>()
4821 .await
4822 .expect("collect events");
4823
4824 let live_id = events
4825 .iter()
4826 .find_map(|event| match &event.payload {
4827 SessionEventPayload::TurnFailed { turn_id, .. } => Some(turn_id.clone()),
4828 _ => None,
4829 })
4830 .expect("TurnFailed in stream");
4831 assert_eq!(live_id, expected_id, "stream turn_id must match submission");
4832
4833 let replay = services
4834 .sessions
4835 .replay(session.session_id())
4836 .await
4837 .expect("replay");
4838 let persisted_id = replay
4839 .iter()
4840 .find_map(|event| match &event.payload {
4841 SessionEventPayload::TurnFailed { turn_id, .. } => Some(turn_id.clone()),
4842 _ => None,
4843 })
4844 .expect("TurnFailed in store");
4845 assert_eq!(
4846 persisted_id, expected_id,
4847 "persisted turn_id must match submission"
4848 );
4849 }
4850
4851 #[tokio::test]
4852 async fn shutdown_drains_cooperative_in_flight_turn() {
4853 let temp = tempfile::tempdir().expect("tempdir");
4857 let started = Arc::new(Notify::new());
4858 let services = configured_services(
4859 Arc::new(CancellableBlockingProvider {
4860 started: started.clone(),
4861 }),
4862 temp.path(),
4863 );
4864 let runtime = SessionRuntime::new(services.clone());
4865 let session = new_session(&runtime, temp.path()).await;
4866
4867 let _stream = session
4868 .submit_turn(Turn::user("blocking turn"))
4869 .await
4870 .expect("submit turn");
4871
4872 started.notified().await;
4876
4877 assert_eq!(services.turn_registry.in_flight_count(), 1);
4878
4879 let report = runtime.shutdown(Duration::from_secs(2)).await;
4880 assert!(!report.timed_out, "cooperative drain must not time out");
4881 assert_eq!(report.turns_drained, 1);
4882 assert_eq!(report.turns_aborted, 0);
4883 assert!(services.turn_registry.is_shutting_down());
4884 }
4885
4886 #[tokio::test]
4887 async fn shutdown_aborts_uncooperative_turn_after_deadline() {
4888 let temp = tempfile::tempdir().expect("tempdir");
4892 let started = Arc::new(Notify::new());
4893 let services = configured_services(
4894 Arc::new(UncancellableBlockingProvider {
4895 started: started.clone(),
4896 }),
4897 temp.path(),
4898 );
4899 let runtime = SessionRuntime::new(services.clone());
4900 let session = new_session(&runtime, temp.path()).await;
4901
4902 let _stream = session
4903 .submit_turn(Turn::user("stuck turn"))
4904 .await
4905 .expect("submit turn");
4906 started.notified().await;
4907
4908 let report = runtime.shutdown(Duration::from_millis(100)).await;
4909 assert!(report.timed_out, "uncooperative drain must time out");
4910 assert!(
4911 report.turns_aborted >= 1,
4912 "at least one task must be aborted, got {report:?}"
4913 );
4914 }
4915
4916 #[tokio::test]
4917 async fn submit_turn_after_shutdown_is_rejected() {
4918 let temp = tempfile::tempdir().expect("tempdir");
4922 let services = configured_services(Arc::new(FakeProvider::default()), temp.path());
4923 let runtime = SessionRuntime::new(services.clone());
4924 let session = new_session(&runtime, temp.path()).await;
4925
4926 let _ = runtime.shutdown(Duration::from_millis(0)).await;
4927
4928 let err = match session.submit_turn(Turn::user("late submission")).await {
4929 Ok(_) => panic!("must reject post-shutdown submission"),
4930 Err(e) => e,
4931 };
4932 assert!(
4933 err.to_string().contains("runtime is shutting down"),
4934 "unexpected error: {err}"
4935 );
4936 }
4937
4938 #[derive(Debug)]
4939 struct RecordingProvider {
4940 requests: Arc<Mutex<Vec<ProviderRequest>>>,
4941 reply: &'static str,
4942 }
4943
4944 impl RecordingProvider {
4945 fn new(requests: Arc<Mutex<Vec<ProviderRequest>>>, reply: &'static str) -> Self {
4946 Self { requests, reply }
4947 }
4948 }
4949
4950 #[async_trait]
4951 impl Provider for RecordingProvider {
4952 fn capabilities(&self) -> ProviderCapabilities {
4953 ProviderCapabilities::default()
4954 }
4955
4956 async fn stream(
4957 &self,
4958 request: ProviderRequest,
4959 _cancel: CancellationToken,
4960 ) -> anyhow::Result<BoxStream<'static, Result<StreamEvent, ProviderError>>> {
4961 self.requests.lock().expect("requests").push(request);
4962 let block_id = BlockId::new();
4963 let message_id = halter_protocol::MessageId::new();
4964 Ok(stream::iter(vec![
4965 Ok(StreamEvent::MessageStart {
4966 id: message_id.clone(),
4967 }),
4968 Ok(StreamEvent::TextStart {
4969 id: block_id.clone(),
4970 }),
4971 Ok(StreamEvent::TextDelta {
4972 id: block_id.clone(),
4973 delta: self.reply.to_owned(),
4974 }),
4975 Ok(StreamEvent::TextEnd { id: block_id }),
4976 Ok(StreamEvent::MessageEnd {
4977 id: message_id,
4978 stop_reason: StopReason::EndTurn,
4979 response_id: None,
4980 }),
4981 ])
4982 .boxed())
4983 }
4984 }
4985
4986 #[derive(Debug)]
4987 struct StreamingToolProvider;
4988
4989 #[async_trait]
4990 impl Provider for StreamingToolProvider {
4991 fn capabilities(&self) -> ProviderCapabilities {
4992 ProviderCapabilities::default()
4993 }
4994
4995 async fn stream(
4996 &self,
4997 request: ProviderRequest,
4998 _cancel: CancellationToken,
4999 ) -> anyhow::Result<BoxStream<'static, Result<StreamEvent, ProviderError>>> {
5000 if request
5001 .messages
5002 .iter()
5003 .any(|message| matches!(message, Message::Tool(_)))
5004 {
5005 return Ok(stream::iter(vec![
5006 Ok(StreamEvent::MessageStart {
5007 id: halter_protocol::MessageId::new(),
5008 }),
5009 Ok(StreamEvent::TextStart { id: BlockId::new() }),
5010 Ok(StreamEvent::TextDelta {
5011 id: BlockId::new(),
5012 delta: "stream done".to_owned(),
5013 }),
5014 Ok(StreamEvent::TextEnd { id: BlockId::new() }),
5015 Ok(StreamEvent::MessageEnd {
5016 id: halter_protocol::MessageId::new(),
5017 stop_reason: StopReason::EndTurn,
5018 response_id: None,
5019 }),
5020 ])
5021 .boxed());
5022 }
5023
5024 let block_id = BlockId::new();
5025 Ok(stream::iter(vec![
5026 Ok(StreamEvent::MessageStart {
5027 id: halter_protocol::MessageId::new(),
5028 }),
5029 Ok(StreamEvent::ToolCallStart {
5030 id: block_id.clone(),
5031 tool_call_id: ToolCallId::new(),
5032 name: ToolName::from("stream_test"),
5033 }),
5034 Ok(StreamEvent::ToolArgsDelta {
5035 id: block_id.clone(),
5036 delta: json!({}).to_string(),
5037 }),
5038 Ok(StreamEvent::ToolCallEnd { id: block_id }),
5039 Ok(StreamEvent::MessageEnd {
5040 id: halter_protocol::MessageId::new(),
5041 stop_reason: StopReason::ToolUse,
5042 response_id: None,
5043 }),
5044 ])
5045 .boxed())
5046 }
5047 }
5048
5049 #[derive(Debug)]
5050 struct StreamingTestTool;
5051
5052 #[derive(Debug)]
5053 struct CountingTool {
5054 executions: Arc<Mutex<usize>>,
5055 }
5056
5057 impl CountingTool {
5058 fn new(executions: Arc<Mutex<usize>>) -> Self {
5059 Self { executions }
5060 }
5061 }
5062
5063 #[async_trait]
5064 impl Tool for CountingTool {
5065 fn spec(&self) -> ToolSpec {
5066 ToolSpec {
5067 name: ToolName::from("count_test"),
5068 description: "Count tool executions".to_owned(),
5069 input_schema: json!({
5070 "type": "object",
5071 "properties": {}
5072 }),
5073 concurrency: ToolConcurrency::Exclusive,
5074 capabilities: ToolCapabilities::default(),
5075 provider_aliases: Default::default(),
5076 }
5077 }
5078
5079 async fn execute(
5080 &self,
5081 _context: ToolContext,
5082 _input: serde_json::Value,
5083 ) -> anyhow::Result<ToolResult> {
5084 *self.executions.lock().expect("executions") += 1;
5085 Ok(ToolResult::Json {
5086 value: json!({ "ok": true }),
5087 })
5088 }
5089 }
5090
5091 #[derive(Debug)]
5092 struct ParallelBatchProvider {
5093 tool_name: &'static str,
5094 }
5095
5096 #[async_trait]
5097 impl Provider for ParallelBatchProvider {
5098 fn capabilities(&self) -> ProviderCapabilities {
5099 ProviderCapabilities::default()
5100 }
5101
5102 async fn stream(
5103 &self,
5104 request: ProviderRequest,
5105 _cancel: CancellationToken,
5106 ) -> anyhow::Result<BoxStream<'static, Result<StreamEvent, ProviderError>>> {
5107 if request
5108 .messages
5109 .iter()
5110 .any(|message| matches!(message, Message::Tool(_)))
5111 {
5112 return Ok(stream::iter(vec![
5113 Ok(StreamEvent::MessageStart {
5114 id: halter_protocol::MessageId::new(),
5115 }),
5116 Ok(StreamEvent::TextStart { id: BlockId::new() }),
5117 Ok(StreamEvent::TextDelta {
5118 id: BlockId::new(),
5119 delta: "done".to_owned(),
5120 }),
5121 Ok(StreamEvent::TextEnd { id: BlockId::new() }),
5122 Ok(StreamEvent::MessageEnd {
5123 id: halter_protocol::MessageId::new(),
5124 stop_reason: StopReason::EndTurn,
5125 response_id: None,
5126 }),
5127 ])
5128 .boxed());
5129 }
5130 let first_block = BlockId::new();
5131 let second_block = BlockId::new();
5132 Ok(stream::iter(vec![
5133 Ok(StreamEvent::MessageStart {
5134 id: halter_protocol::MessageId::new(),
5135 }),
5136 Ok(StreamEvent::ToolCallStart {
5137 id: first_block.clone(),
5138 tool_call_id: ToolCallId::new(),
5139 name: ToolName::from(self.tool_name),
5140 }),
5141 Ok(StreamEvent::ToolArgsDelta {
5142 id: first_block.clone(),
5143 delta: json!({}).to_string(),
5144 }),
5145 Ok(StreamEvent::ToolCallEnd { id: first_block }),
5146 Ok(StreamEvent::ToolCallStart {
5147 id: second_block.clone(),
5148 tool_call_id: ToolCallId::new(),
5149 name: ToolName::from(self.tool_name),
5150 }),
5151 Ok(StreamEvent::ToolArgsDelta {
5152 id: second_block.clone(),
5153 delta: json!({}).to_string(),
5154 }),
5155 Ok(StreamEvent::ToolCallEnd { id: second_block }),
5156 Ok(StreamEvent::MessageEnd {
5157 id: halter_protocol::MessageId::new(),
5158 stop_reason: StopReason::ToolUse,
5159 response_id: None,
5160 }),
5161 ])
5162 .boxed())
5163 }
5164 }
5165
5166 #[derive(Debug)]
5167 struct BarrierTool {
5168 barrier: Arc<tokio::sync::Barrier>,
5169 concurrency: ToolConcurrency,
5170 name: &'static str,
5171 }
5172
5173 #[async_trait]
5174 impl Tool for BarrierTool {
5175 fn spec(&self) -> ToolSpec {
5176 ToolSpec {
5177 name: ToolName::from(self.name),
5178 description: "barrier-synchronized tool".to_owned(),
5179 input_schema: json!({ "type": "object", "properties": {} }),
5180 concurrency: self.concurrency,
5181 capabilities: ToolCapabilities::default(),
5182 provider_aliases: Default::default(),
5183 }
5184 }
5185
5186 async fn execute(
5187 &self,
5188 _context: ToolContext,
5189 _input: serde_json::Value,
5190 ) -> anyhow::Result<ToolResult> {
5191 self.barrier.wait().await;
5192 Ok(ToolResult::Empty)
5193 }
5194 }
5195
5196 #[tokio::test]
5197 async fn parallel_safe_tools_execute_concurrently() {
5198 let temp = tempfile::tempdir().expect("tempdir");
5202 let services = configured_services(
5203 Arc::new(ParallelBatchProvider {
5204 tool_name: "parallel_barrier",
5205 }),
5206 temp.path(),
5207 );
5208 let barrier = Arc::new(tokio::sync::Barrier::new(2));
5209 services.tools.register(Arc::new(BarrierTool {
5210 barrier: barrier.clone(),
5211 concurrency: ToolConcurrency::ParallelSafe,
5212 name: "parallel_barrier",
5213 }));
5214 let runtime = SessionRuntime::new(services.clone());
5215 let session = new_session(&runtime, temp.path()).await;
5216
5217 let result = tokio::time::timeout(
5218 Duration::from_secs(5),
5219 session
5220 .submit_turn(Turn::user("run in parallel"))
5221 .await
5222 .expect("submit turn")
5223 .try_collect::<Vec<_>>(),
5224 )
5225 .await
5226 .expect("parallel-safe tools must not deadlock the barrier")
5227 .expect("collect events");
5228
5229 let completed = result
5230 .iter()
5231 .filter(|event| {
5232 matches!(
5233 event.payload,
5234 SessionEventPayload::ToolExecutionCompleted { .. }
5235 )
5236 })
5237 .count();
5238 assert_eq!(
5239 completed, 2,
5240 "both parallel-safe tool executions must complete"
5241 );
5242 }
5243
5244 #[tokio::test]
5245 async fn exclusive_tools_run_serially_in_batch() {
5246 let temp = tempfile::tempdir().expect("tempdir");
5251 let services = configured_services(
5252 Arc::new(ParallelBatchProvider {
5253 tool_name: "exclusive_barrier",
5254 }),
5255 temp.path(),
5256 );
5257 let barrier = Arc::new(tokio::sync::Barrier::new(2));
5258 services.tools.register(Arc::new(BarrierTool {
5259 barrier: barrier.clone(),
5260 concurrency: ToolConcurrency::Exclusive,
5261 name: "exclusive_barrier",
5262 }));
5263 let runtime = SessionRuntime::new(services.clone());
5264 let session = new_session(&runtime, temp.path()).await;
5265
5266 let timed = tokio::time::timeout(
5267 Duration::from_millis(500),
5268 session
5269 .submit_turn(Turn::user("serialize exclusive"))
5270 .await
5271 .expect("submit turn")
5272 .try_collect::<Vec<_>>(),
5273 )
5274 .await;
5275 assert!(
5276 timed.is_err(),
5277 "exclusive tools must serialize; barrier should deadlock"
5278 );
5279 }
5280
5281 #[async_trait]
5282 impl Tool for StreamingTestTool {
5283 fn spec(&self) -> ToolSpec {
5284 ToolSpec {
5285 name: ToolName::from("stream_test"),
5286 description: "Emit output before completing".to_owned(),
5287 input_schema: json!({
5288 "type": "object",
5289 "properties": {}
5290 }),
5291 concurrency: ToolConcurrency::Exclusive,
5292 capabilities: ToolCapabilities {
5293 mutating: false,
5294 requires_approval: false,
5295 cancellable: false,
5296 long_running: true,
5297 },
5298 provider_aliases: Default::default(),
5299 }
5300 }
5301
5302 async fn execute(
5303 &self,
5304 context: ToolContext,
5305 _input: serde_json::Value,
5306 ) -> anyhow::Result<ToolResult> {
5307 context.emit.emit(ToolRuntimeEvent::ToolOutput {
5308 tool_name: "stream_test".to_owned(),
5309 chunk: "streamed chunk".to_owned(),
5310 });
5311 tokio::time::sleep(Duration::from_millis(300)).await;
5312 Ok(ToolResult::Json {
5313 value: json!({ "ok": true }),
5314 })
5315 }
5316 }
5317}