1use std::sync::Arc;
13
14use tokio_util::sync::CancellationToken;
15
16use crate::error::{CoreError, Result};
17use crate::event::{RunEvent, RunHooks};
18use crate::message::{AgentMessage, ContentBlock, Role};
19use crate::model::{Model, ModelProvider, ModelRequest, StreamEvent};
20use crate::policy::{PolicyVerdict, ToolPolicy};
21use crate::thinking::ThinkingLevel;
22use crate::tool::{InvokeContext, Tool, ToolCall, ToolResult};
23
24#[derive(Debug, Clone)]
26pub struct RunConfig {
27 pub max_turns: usize,
29 pub thinking: ThinkingLevel,
31 pub turn_timeout_ms: Option<u64>,
34 pub max_tool_calls_per_turn: usize,
37 pub tool_concurrency: usize,
41}
42
43impl Default for RunConfig {
44 fn default() -> Self {
45 Self {
46 max_turns: 12,
47 thinking: ThinkingLevel::default(),
48 turn_timeout_ms: Some(120_000),
49 max_tool_calls_per_turn: 10,
50 tool_concurrency: 1,
51 }
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct RunOutcome {
58 pub turns: usize,
60 pub final_text: String,
63}
64
65#[async_trait::async_trait]
80pub trait TurnSink: Send + Sync {
81 async fn after_turn(&self, turn: usize, messages: &[AgentMessage]) -> Result<()>;
84}
85
86pub struct FanoutTurnSink {
100 sinks: Vec<Box<dyn TurnSink>>,
101}
102
103impl FanoutTurnSink {
104 #[must_use]
106 pub fn new() -> Self {
107 Self { sinks: Vec::new() }
108 }
109
110 #[must_use]
112 pub fn push(mut self, sink: Box<dyn TurnSink>) -> Self {
113 self.sinks.push(sink);
114 self
115 }
116
117 #[must_use]
119 pub fn len(&self) -> usize {
120 self.sinks.len()
121 }
122
123 #[must_use]
125 pub fn is_empty(&self) -> bool {
126 self.sinks.is_empty()
127 }
128}
129
130impl Default for FanoutTurnSink {
131 fn default() -> Self {
132 Self::new()
133 }
134}
135
136#[async_trait::async_trait]
137impl TurnSink for FanoutTurnSink {
138 async fn after_turn(&self, turn: usize, messages: &[AgentMessage]) -> Result<()> {
139 for sink in &self.sinks {
140 sink.after_turn(turn, messages).await?;
141 }
142 Ok(())
143 }
144}
145
146#[cfg(test)]
147mod fanout_tests {
148 use super::*;
149 use std::sync::{Arc, Mutex};
150
151 struct RecordingSink {
153 calls: Arc<Mutex<Vec<usize>>>,
154 fail_at: Option<usize>,
155 }
156
157 #[async_trait::async_trait]
158 impl TurnSink for RecordingSink {
159 async fn after_turn(&self, turn: usize, _messages: &[AgentMessage]) -> Result<()> {
160 self.calls.lock().expect("lock poisoned").push(turn);
161 if self.fail_at == Some(turn) {
162 return Err(crate::error::CoreError::Transport(format!(
163 "injected failure at turn {turn}"
164 )));
165 }
166 Ok(())
167 }
168 }
169
170 #[tokio::test]
171 async fn fanout_calls_sinks_in_order() {
172 let calls_a = Arc::new(Mutex::new(Vec::new()));
173 let calls_b = Arc::new(Mutex::new(Vec::new()));
174 let fanout = FanoutTurnSink::new()
175 .push(Box::new(RecordingSink {
176 calls: calls_a.clone(),
177 fail_at: None,
178 }))
179 .push(Box::new(RecordingSink {
180 calls: calls_b.clone(),
181 fail_at: None,
182 }));
183 TurnSink::after_turn(&fanout, 1, &[]).await.unwrap();
184 assert_eq!(*calls_a.lock().unwrap(), vec![1]);
185 assert_eq!(*calls_b.lock().unwrap(), vec![1]);
186 }
187
188 #[tokio::test]
189 async fn fanout_propagates_error_and_stops() {
190 let calls_a = Arc::new(Mutex::new(Vec::new()));
192 let calls_b = Arc::new(Mutex::new(Vec::new()));
193 let fanout = FanoutTurnSink::new()
194 .push(Box::new(RecordingSink {
195 calls: calls_a.clone(),
196 fail_at: Some(2),
197 }))
198 .push(Box::new(RecordingSink {
199 calls: calls_b.clone(),
200 fail_at: None,
201 }));
202 let _ = TurnSink::after_turn(&fanout, 2, &[]).await;
203 assert_eq!(*calls_a.lock().unwrap(), vec![2]);
204 assert!(calls_b.lock().unwrap().is_empty(), "second sink ran");
205 }
206}
207
208pub async fn run_agent(
226 provider: &dyn ModelProvider,
227 tools: &[Arc<dyn Tool>],
228 messages: &mut Vec<AgentMessage>,
229 model: &Model,
230 config: &RunConfig,
231 cancel: &CancellationToken,
232 hooks: &RunHooks<'_>,
233) -> Result<RunOutcome> {
234 hooks.emit_event(|sid| RunEvent::SessionStarted { session: sid });
235 let mut turns = 0usize;
236 loop {
237 if cancel.is_cancelled() {
238 hooks.emit_event(|sid| crate::event::run_failed(sid, "cancelled"));
239 return Err(CoreError::Cancelled("agent run cancelled".into()));
240 }
241 if turns >= config.max_turns {
242 let msg = format!(
243 "max_turns ({}) exceeded — the model kept calling tools",
244 config.max_turns
245 );
246 hooks.emit_event(|sid| crate::event::run_failed(sid, msg.clone()));
247 return Err(CoreError::ModelResponse(msg));
248 }
249 turns += 1;
250 hooks.emit_event(|sid| RunEvent::TurnStarted {
251 session: sid,
252 turn: turns,
253 });
254
255 let request = ModelRequest {
256 model: model.clone(),
257 messages: messages.clone(),
258 tools: tools.iter().map(|t| t.definition()).collect(),
259 thinking: config.thinking,
260 params: Default::default(),
261 };
262 hooks.emit_event(|sid| RunEvent::ModelStarted {
263 session: sid,
264 turn: turns,
265 model: model.id.clone(),
266 });
267 let response =
269 match invoke_with_budget(provider, request, config.turn_timeout_ms, cancel).await {
270 Ok(r) => r,
271 Err(e) => {
272 hooks.emit_event(|sid| crate::event::run_failed(sid, e.to_string()));
273 return Err(e);
274 }
275 };
276 hooks.emit_event(|sid| RunEvent::ModelFinished {
277 session: sid,
278 turn: turns,
279 });
280 let tool_calls: Vec<(String, ToolCall)> = response
282 .messages
283 .iter()
284 .flat_map(|m| m.content.iter())
285 .filter_map(|block| {
286 if let ContentBlock::ToolUse { id, call } = block {
287 Some((id.clone(), call.clone()))
288 } else {
289 None
290 }
291 })
292 .collect();
293 messages.extend(response.messages);
295
296 if tool_calls.is_empty() {
297 let final_text = extract_final_text(messages);
299 if let Some(sink) = hooks.turn_sink {
301 sink.after_turn(turns, messages).await?;
302 }
303 hooks.emit_event(|sid| RunEvent::TurnFinished {
304 session: sid,
305 turn: turns,
306 });
307 return Ok(RunOutcome { turns, final_text });
308 }
309
310 if tool_calls.len() > config.max_tool_calls_per_turn {
312 let msg = format!(
313 "model issued {} tool calls in one turn (max {})",
314 tool_calls.len(),
315 config.max_tool_calls_per_turn
316 );
317 hooks.emit_event(|sid| crate::event::run_failed(sid, msg.clone()));
318 return Err(CoreError::ModelResponse(msg));
319 }
320
321 for (id, call) in &tool_calls {
323 hooks.emit_event(|sid| RunEvent::ToolStarted {
324 session: sid,
325 turn: turns,
326 tool: call.name.clone(),
327 call_id: id.clone(),
328 });
329 }
330
331 let results = execute_tool_calls(
335 tools,
336 &tool_calls,
337 cancel,
338 config.tool_concurrency,
339 hooks.policy,
340 )
341 .await;
342
343 for (i, (id, call)) in tool_calls.iter().enumerate() {
345 let result = &results[i];
346 let ok = tool_result_ok(result);
347 hooks.emit_event(|sid| RunEvent::ToolFinished {
348 session: sid,
349 turn: turns,
350 tool: call.name.clone(),
351 call_id: id.clone(),
352 ok,
353 });
354 let tool_msg = AgentMessage {
355 role: Role::Tool,
356 content: vec![ContentBlock::ToolResult {
357 tool_use_id: id.clone(),
358 content: serde_json::to_value(result)
359 .unwrap_or_else(|_| serde_json::json!({ "error": "serialize failed" })),
360 }],
361 };
362 messages.push(tool_msg);
363 }
364 if let Some(sink) = hooks.turn_sink {
368 sink.after_turn(turns, messages).await?;
369 }
370 hooks.emit_event(|sid| RunEvent::TurnFinished {
371 session: sid,
372 turn: turns,
373 });
374 }
375}
376
377#[derive(Debug, Clone, Default)]
380struct StreamedTurn {
381 text: String,
382 thinking: String,
383 tool_calls: Vec<(String, ToolCall)>,
384}
385
386async fn collect_streamed_turn(
396 stream: crate::model::StreamEventStream,
397 on_event: &mut (dyn FnMut(&StreamEvent) + Send),
398 turn_timeout_ms: Option<u64>,
399 cancel: &CancellationToken,
400) -> Result<StreamedTurn> {
401 use futures::StreamExt;
402 let mut turn = StreamedTurn::default();
403 let mut s = stream;
404 let deadline = turn_timeout_ms
409 .map(|ms| tokio::time::Instant::now() + std::time::Duration::from_millis(ms));
410 loop {
411 if cancel.is_cancelled() {
413 return Err(CoreError::Cancelled("turn cancelled during stream".into()));
414 }
415 let next = async { s.next().await };
421 let item = match deadline {
422 Some(deadline) => {
423 let to = tokio::time::timeout_at(deadline, next);
424 tokio::select! {
425 biased;
426 _ = cancel.cancelled() => {
427 return Err(CoreError::Cancelled(
428 "turn cancelled during stream".into(),
429 ));
430 }
431 res = to => match res {
432 Ok(Some(item)) => item,
433 Ok(None) => break, Err(_) => {
435 let ms = turn_timeout_ms.unwrap_or(0);
436 return Err(CoreError::TurnTimeout { ms });
437 }
438 },
439 }
440 }
441 None => {
442 tokio::select! {
443 biased;
444 _ = cancel.cancelled() => {
445 return Err(CoreError::Cancelled(
446 "turn cancelled during stream".into(),
447 ));
448 }
449 item = next => match item {
450 Some(item) => item,
451 None => break, },
453 }
454 }
455 };
456 match item {
457 Ok(StreamEvent::TextDelta(t)) => {
458 on_event(&StreamEvent::TextDelta(t.clone()));
459 turn.text.push_str(&t);
460 }
461 Ok(StreamEvent::ThinkingDelta(t)) => {
462 turn.thinking.push_str(&t);
463 }
464 Ok(StreamEvent::ToolCall(call)) => {
465 let id = format!("call_{}", turn.tool_calls.len());
469 turn.tool_calls.push((id, call));
470 }
471 Ok(StreamEvent::Done) => break,
472 Err(e) => return Err(e),
473 }
474 }
475 Ok(turn)
476}
477
478#[allow(clippy::too_many_arguments)]
486pub async fn run_agent_streaming(
487 provider: &dyn ModelProvider,
488 tools: &[Arc<dyn Tool>],
489 messages: &mut Vec<AgentMessage>,
490 model: &Model,
491 config: &RunConfig,
492 cancel: &CancellationToken,
493 on_event: &mut (dyn FnMut(&StreamEvent) + Send),
494 hooks: &RunHooks<'_>,
495) -> Result<RunOutcome> {
496 hooks.emit_event(|sid| RunEvent::SessionStarted { session: sid });
497 let mut turns = 0usize;
498 loop {
499 if cancel.is_cancelled() {
500 hooks.emit_event(|sid| crate::event::run_failed(sid, "cancelled"));
501 return Err(CoreError::Cancelled("agent run cancelled".into()));
502 }
503 if turns >= config.max_turns {
504 let msg = format!(
505 "max_turns ({}) exceeded — the model kept calling tools",
506 config.max_turns
507 );
508 hooks.emit_event(|sid| crate::event::run_failed(sid, msg.clone()));
509 return Err(CoreError::ModelResponse(msg));
510 }
511 turns += 1;
512 hooks.emit_event(|sid| RunEvent::TurnStarted {
513 session: sid,
514 turn: turns,
515 });
516
517 let request = ModelRequest {
518 model: model.clone(),
519 messages: messages.clone(),
520 tools: tools.iter().map(|t| t.definition()).collect(),
521 thinking: config.thinking,
522 params: Default::default(),
523 };
524 hooks.emit_event(|sid| RunEvent::ModelStarted {
525 session: sid,
526 turn: turns,
527 model: model.id.clone(),
528 });
529 let stream = provider.stream(request);
531 let turn =
532 match collect_streamed_turn(stream, on_event, config.turn_timeout_ms, cancel).await {
533 Ok(t) => t,
534 Err(e) => {
535 hooks.emit_event(|sid| crate::event::run_failed(sid, e.to_string()));
536 return Err(e);
537 }
538 };
539 hooks.emit_event(|sid| RunEvent::ModelFinished {
540 session: sid,
541 turn: turns,
542 });
543
544 let mut content: Vec<ContentBlock> = Vec::new();
546 if !turn.text.is_empty() {
547 content.push(ContentBlock::Text { text: turn.text });
548 }
549 for (id, call) in &turn.tool_calls {
550 content.push(ContentBlock::ToolUse {
551 id: id.clone(),
552 call: call.clone(),
553 });
554 }
555 messages.push(AgentMessage {
556 role: Role::Assistant,
557 content,
558 });
559
560 if turn.tool_calls.is_empty() {
561 let final_text = extract_final_text(messages);
562 if let Some(sink) = hooks.turn_sink {
564 sink.after_turn(turns, messages).await?;
565 }
566 hooks.emit_event(|sid| RunEvent::TurnFinished {
567 session: sid,
568 turn: turns,
569 });
570 return Ok(RunOutcome { turns, final_text });
571 }
572 if turn.tool_calls.len() > config.max_tool_calls_per_turn {
573 let msg = format!(
574 "model issued {} tool calls in one turn (max {})",
575 turn.tool_calls.len(),
576 config.max_tool_calls_per_turn
577 );
578 hooks.emit_event(|sid| crate::event::run_failed(sid, msg.clone()));
579 return Err(CoreError::ModelResponse(msg));
580 }
581
582 let owned_calls: Vec<(String, ToolCall)> = turn.tool_calls.clone();
583
584 for (id, call) in &owned_calls {
586 hooks.emit_event(|sid| RunEvent::ToolStarted {
587 session: sid,
588 turn: turns,
589 tool: call.name.clone(),
590 call_id: id.clone(),
591 });
592 }
593
594 let results = execute_tool_calls(
595 tools,
596 &owned_calls,
597 cancel,
598 config.tool_concurrency,
599 hooks.policy,
600 )
601 .await;
602
603 for (i, (id, call)) in owned_calls.iter().enumerate() {
605 let result = &results[i];
606 let ok = tool_result_ok(result);
607 hooks.emit_event(|sid| RunEvent::ToolFinished {
608 session: sid,
609 turn: turns,
610 tool: call.name.clone(),
611 call_id: id.clone(),
612 ok,
613 });
614 let tool_msg = AgentMessage {
615 role: Role::Tool,
616 content: vec![ContentBlock::ToolResult {
617 tool_use_id: id.clone(),
618 content: serde_json::to_value(result)
619 .unwrap_or_else(|_| serde_json::json!({ "error": "serialize failed" })),
620 }],
621 };
622 messages.push(tool_msg);
623 }
624 if let Some(sink) = hooks.turn_sink {
626 sink.after_turn(turns, messages).await?;
627 }
628 hooks.emit_event(|sid| RunEvent::TurnFinished {
629 session: sid,
630 turn: turns,
631 });
632 }
633}
634
635const PANIC_SUMMARY_MAX_CHARS: usize = 200;
638
639fn summarize_panic(payload: &Box<dyn std::any::Any + Send>) -> String {
644 let raw = payload
645 .downcast_ref::<&'static str>()
646 .map(std::string::ToString::to_string)
647 .or_else(|| payload.downcast_ref::<String>().cloned())
648 .unwrap_or_else(|| "<non-string panic payload>".to_string());
649 let chars: Vec<char> = raw.chars().collect();
650 if chars.len() <= PANIC_SUMMARY_MAX_CHARS {
651 raw
652 } else {
653 let truncated: String = chars
654 .into_iter()
655 .take(PANIC_SUMMARY_MAX_CHARS - 1)
656 .collect();
657 format!("{truncated}…")
658 }
659}
660
661async fn execute_tool_call(
671 tools: &[Arc<dyn Tool>],
672 id: &str,
673 call: &ToolCall,
674 cancel: &CancellationToken,
675) -> ToolResult {
676 let Some(tool) = tools.iter().find(|t| t.definition().name == call.name) else {
677 return error_result(&format!("unknown tool: `{}`", call.name));
678 };
679 let ctx = InvokeContext {
680 tool_call_id: id.to_string(),
681 cancel: cancel.clone(),
682 };
683 use futures::FutureExt;
684 use std::panic::AssertUnwindSafe;
685 match AssertUnwindSafe(tool.execute(ctx, call.input.clone()))
686 .catch_unwind()
687 .await
688 {
689 Ok(Ok(result)) => result,
690 Ok(Err(err)) => error_result(&err.to_string()),
691 Err(payload) => {
692 let summary = summarize_panic(&payload);
693 tracing::warn!(
698 tool = %call.name,
699 call_id = %id,
700 "tool panicked; converted to model-visible error result"
701 );
702 error_result(&format!("tool `{}` panicked: {summary}", call.name))
703 }
704 }
705}
706
707async fn invoke_with_budget(
710 provider: &dyn ModelProvider,
711 request: ModelRequest,
712 turn_timeout_ms: Option<u64>,
713 cancel: &CancellationToken,
714) -> Result<crate::model::ModelResponse> {
715 if cancel.is_cancelled() {
717 return Err(CoreError::Cancelled("turn cancelled before invoke".into()));
718 }
719 let invoke_fut = provider.invoke(request);
720 match turn_timeout_ms {
721 Some(ms) => {
722 let timeout = tokio::time::timeout(std::time::Duration::from_millis(ms), invoke_fut);
723 tokio::select! {
724 biased;
725 _ = cancel.cancelled() => {
726 Err(CoreError::Cancelled("turn cancelled during invoke".into()))
727 }
728 res = timeout => {
729 res.map_err(|_| CoreError::TurnTimeout { ms })?
730 }
731 }
732 }
733 None => {
734 tokio::select! {
735 biased;
736 _ = cancel.cancelled() => {
737 Err(CoreError::Cancelled("turn cancelled during invoke".into()))
738 }
739 res = invoke_fut => res,
740 }
741 }
742 }
743}
744
745enum PolicyOutcome {
747 Execute,
749 Denied(ToolResult),
752}
753
754async fn policy_check(
760 policy: Option<&dyn ToolPolicy>,
761 id: &str,
762 call: &ToolCall,
763 cancel: &CancellationToken,
764) -> PolicyOutcome {
765 let Some(policy) = policy else {
766 return PolicyOutcome::Execute;
767 };
768 let ctx = InvokeContext {
769 tool_call_id: id.to_string(),
770 cancel: cancel.clone(),
771 };
772 match policy.check(&call.name, &call.input, &ctx).await {
773 PolicyVerdict::Allow => PolicyOutcome::Execute,
774 PolicyVerdict::Confirm(reason) => {
775 tracing::info!(
776 tool = %call.name,
777 call_id = %id,
778 "tool policy returned Confirm; treating as Allow for this run: {reason}"
779 );
780 PolicyOutcome::Execute
781 }
782 PolicyVerdict::Deny(reason) => {
783 PolicyOutcome::Denied(error_result(&format!("denied by policy: {reason}")))
784 }
785 }
786}
787
788async fn execute_tool_calls(
799 tools: &[Arc<dyn Tool>],
800 calls: &[(String, ToolCall)],
801 cancel: &CancellationToken,
802 tool_concurrency: usize,
803 policy: Option<&dyn ToolPolicy>,
804) -> Vec<ToolResult> {
805 if tool_concurrency <= 1 {
806 let mut out = Vec::with_capacity(calls.len());
807 for (id, call) in calls {
808 let result = match policy_check(policy, id, call, cancel).await {
809 PolicyOutcome::Execute => execute_tool_call(tools, id, call, cancel).await,
810 PolicyOutcome::Denied(result) => result,
811 };
812 out.push(result);
813 }
814 return out;
815 }
816
817 use tokio::task::JoinSet;
819 let mut indexed: Vec<Option<ToolResult>> = (0..calls.len()).map(|_| None).collect();
824 let mut set: JoinSet<(usize, ToolResult)> = JoinSet::new();
825 for (i, (id, call)) in calls.iter().enumerate() {
826 if let PolicyOutcome::Denied(result) = policy_check(policy, id, call, cancel).await {
831 if let Some(slot) = indexed.get_mut(i) {
832 *slot = Some(result);
833 }
834 continue;
835 }
836 let tool = tools
838 .iter()
839 .find(|t| t.definition().name == call.name)
840 .cloned();
841 let ctx_cancel = cancel.child_token();
842 let ctx = InvokeContext {
843 tool_call_id: id.clone(),
844 cancel: ctx_cancel,
845 };
846 let input = call.input.clone();
847 let id_owned = id.clone();
848 let call_name = call.name.clone();
849 set.spawn(async move {
850 let result = match tool {
856 Some(t) => {
857 use futures::FutureExt;
858 use std::panic::AssertUnwindSafe;
859 match AssertUnwindSafe(t.execute(ctx, input)).catch_unwind().await {
860 Ok(Ok(r)) => r,
861 Ok(Err(err)) => error_result(&err.to_string()),
862 Err(payload) => {
863 let summary = summarize_panic(&payload);
864 tracing::warn!(
865 tool = %call_name,
866 call_id = %id_owned,
867 "tool panicked; converted to model-visible error result"
868 );
869 error_result(&format!("tool `{call_name}` panicked: {summary}"))
870 }
871 }
872 }
873 None => error_result(&format!("unknown tool: `{id_owned}`")),
874 };
875 (i, result)
876 });
877 while set.len() >= tool_concurrency {
881 let res = set.join_next().await;
882 if res.is_none() {
883 break; }
885 record_join_result(res, &mut indexed);
886 }
887 }
888 while let Some(res) = set.join_next().await {
895 record_join_result(Some(res), &mut indexed);
896 }
897 indexed
898 .into_iter()
899 .map(|opt| opt.unwrap_or_else(|| error_result("tool task produced no result")))
900 .collect()
902}
903
904fn record_join_result(
909 res: Option<std::result::Result<(usize, ToolResult), tokio::task::JoinError>>,
910 indexed: &mut [Option<ToolResult>],
911) {
912 match res {
913 Some(Ok((i, result))) => {
914 if let Some(slot) = indexed.get_mut(i) {
915 *slot = Some(result);
916 }
917 }
918 Some(Err(join_err)) => {
919 let slot = indexed.iter().position(Option::is_none).unwrap_or(0);
920 if let Some(s) = indexed.get_mut(slot) {
921 *s = Some(error_result(&format!("tool task failed: {join_err}")));
922 }
923 }
924 None => {}
925 }
926}
927
928fn error_result(message: &str) -> ToolResult {
930 ToolResult {
931 content: vec![serde_json::json!({ "type": "text", "text": format!("Error: {message}") })],
932 details: None,
933 }
934}
935
936fn tool_result_ok(result: &ToolResult) -> bool {
940 !result.content.iter().any(|c| {
941 c.get("text")
942 .and_then(|t| t.as_str())
943 .is_some_and(|t| t.starts_with("Error:"))
944 })
945}
946
947fn extract_final_text(messages: &[AgentMessage]) -> String {
949 messages
950 .iter()
951 .rev()
952 .find(|m| m.role == Role::Assistant)
953 .map(|m| {
954 m.content
955 .iter()
956 .filter_map(|b| {
957 if let ContentBlock::Text { text } = b {
958 Some(text.as_str())
959 } else {
960 None
961 }
962 })
963 .collect::<Vec<_>>()
964 .join("")
965 })
966 .unwrap_or_default()
967}
968
969#[cfg(test)]
970mod tests {
971 use super::*;
977 use crate::model::ModelResponse;
978 use async_trait::async_trait;
979 use serde_json::{json, Value};
980
981 struct MockProvider {
983 responses: std::sync::Mutex<std::collections::VecDeque<ModelResponse>>,
984 }
985
986 impl MockProvider {
987 fn new(responses: Vec<Vec<AgentMessage>>) -> Self {
988 let responses = responses
989 .into_iter()
990 .map(|msgs| ModelResponse { messages: msgs })
991 .collect();
992 Self {
993 responses: std::sync::Mutex::new(responses),
994 }
995 }
996 }
997
998 #[async_trait]
999 impl ModelProvider for MockProvider {
1000 async fn invoke(&self, _request: ModelRequest) -> Result<ModelResponse> {
1001 let next = self
1002 .responses
1003 .lock()
1004 .unwrap()
1005 .pop_front()
1006 .unwrap_or(ModelResponse { messages: vec![] });
1007 Ok(next)
1008 }
1009 }
1010
1011 struct EchoTool;
1013
1014 #[async_trait]
1015 impl Tool for EchoTool {
1016 fn definition(&self) -> crate::tool::ToolDefinition {
1017 crate::tool::ToolDefinition {
1018 name: "echo".into(),
1019 label: "Echo".into(),
1020 description: "Echo back the provided text.".into(),
1021 parameters: crate::tool::ParameterSchema::default(),
1022 }
1023 }
1024
1025 async fn execute(&self, _ctx: InvokeContext, input: Value) -> Result<ToolResult> {
1026 let text = input
1027 .get("text")
1028 .and_then(Value::as_str)
1029 .unwrap_or("(no text)")
1030 .to_string();
1031 Ok(ToolResult {
1032 content: vec![json!({ "type": "text", "text": format!("echo: {text}") })],
1033 details: None,
1034 })
1035 }
1036 }
1037
1038 fn assistant_text(t: &str) -> AgentMessage {
1039 AgentMessage {
1040 role: Role::Assistant,
1041 content: vec![ContentBlock::Text { text: t.into() }],
1042 }
1043 }
1044
1045 fn assistant_tool_use(id: &str, name: &str, input: Value) -> AgentMessage {
1046 AgentMessage {
1047 role: Role::Assistant,
1048 content: vec![ContentBlock::ToolUse {
1049 id: id.into(),
1050 call: ToolCall {
1051 name: name.into(),
1052 input,
1053 },
1054 }],
1055 }
1056 }
1057
1058 fn user(t: &str) -> AgentMessage {
1059 AgentMessage {
1060 role: Role::User,
1061 content: vec![ContentBlock::Text { text: t.into() }],
1062 }
1063 }
1064
1065 #[tokio::test]
1066 async fn loop_runs_tool_then_finishes() {
1067 let provider = MockProvider::new(vec![
1069 vec![assistant_tool_use(
1070 "call_1",
1071 "echo",
1072 json!({ "text": "hello" }),
1073 )],
1074 vec![assistant_text("done")],
1075 ]);
1076 let tools: Vec<Arc<dyn Tool>> = vec![Arc::new(EchoTool)];
1077 let model = Model::new("mock/test");
1078 let mut messages = vec![user("please echo hello then say done")];
1079
1080 let outcome = run_agent(
1081 &provider,
1082 &tools,
1083 &mut messages,
1084 &model,
1085 &RunConfig::default(),
1086 &CancellationToken::new(),
1087 &RunHooks::default(),
1088 )
1089 .await
1090 .expect("loop should complete");
1091
1092 assert_eq!(outcome.turns, 2);
1093 assert_eq!(outcome.final_text, "done");
1094
1095 assert_eq!(messages.len(), 4);
1097 assert_eq!(messages[2].role, Role::Tool);
1098 match &messages[2].content[0] {
1100 ContentBlock::ToolResult { content, .. } => {
1101 let s = serde_json::to_string(content).unwrap_or_default();
1102 assert!(s.contains("echo: hello"), "tool result was: {s}");
1103 }
1104 other => panic!("expected ToolResult, got {other:?}"),
1105 }
1106 }
1107
1108 #[tokio::test]
1109 async fn loop_stops_when_no_tool_calls() {
1110 let provider = MockProvider::new(vec![vec![assistant_text("just text, no tools")]]);
1111 let tools: Vec<Arc<dyn Tool>> = vec![];
1112 let model = Model::new("mock/test");
1113 let mut messages = vec![user("hi")];
1114
1115 let outcome = run_agent(
1116 &provider,
1117 &tools,
1118 &mut messages,
1119 &model,
1120 &RunConfig::default(),
1121 &CancellationToken::new(),
1122 &RunHooks::default(),
1123 )
1124 .await
1125 .expect("loop should complete");
1126
1127 assert_eq!(outcome.turns, 1);
1128 assert_eq!(outcome.final_text, "just text, no tools");
1129 }
1130
1131 #[tokio::test]
1132 async fn loop_recovers_from_unknown_tool() {
1133 let provider = MockProvider::new(vec![
1136 vec![assistant_tool_use("c1", "nonexistent", json!({}))],
1137 vec![assistant_text("recovered")],
1138 ]);
1139 let tools: Vec<Arc<dyn Tool>> = vec![Arc::new(EchoTool)];
1140 let model = Model::new("mock/test");
1141 let mut messages = vec![user("call a missing tool")];
1142
1143 let outcome = run_agent(
1144 &provider,
1145 &tools,
1146 &mut messages,
1147 &model,
1148 &RunConfig::default(),
1149 &CancellationToken::new(),
1150 &RunHooks::default(),
1151 )
1152 .await
1153 .expect("loop should recover");
1154
1155 assert_eq!(outcome.final_text, "recovered");
1156 let tool_msg = &messages[2];
1157 assert_eq!(tool_msg.role, Role::Tool);
1158 }
1159
1160 #[tokio::test]
1161 async fn loop_aborts_on_max_turns() {
1162 let repeat = || vec![assistant_tool_use("c", "echo", json!({ "text": "x" }))];
1164 let provider = MockProvider::new(vec![repeat(), repeat(), repeat(), repeat()]);
1165 let tools: Vec<Arc<dyn Tool>> = vec![Arc::new(EchoTool)];
1166 let model = Model::new("mock/test");
1167 let mut messages = vec![user("loop forever")];
1168
1169 let result = run_agent(
1170 &provider,
1171 &tools,
1172 &mut messages,
1173 &model,
1174 &RunConfig {
1175 max_turns: 3,
1176 ..RunConfig::default()
1177 },
1178 &CancellationToken::new(),
1179 &RunHooks::default(),
1180 )
1181 .await;
1182
1183 assert!(result.is_err(), "must abort on max_turns");
1184 }
1185
1186 struct DenyAllPolicy;
1188
1189 #[async_trait]
1190 impl ToolPolicy for DenyAllPolicy {
1191 async fn check(&self, _tool: &str, _input: &Value, _ctx: &InvokeContext) -> PolicyVerdict {
1192 PolicyVerdict::Deny("blocked in test".into())
1193 }
1194 }
1195
1196 #[tokio::test]
1197 async fn policy_deny_blocks_tool_but_run_continues() {
1198 let provider = MockProvider::new(vec![
1202 vec![assistant_tool_use(
1203 "c1",
1204 "echo",
1205 json!({ "text": "secret" }),
1206 )],
1207 vec![assistant_text("done")],
1208 ]);
1209 let tools: Vec<Arc<dyn Tool>> = vec![Arc::new(EchoTool)];
1210 let model = Model::new("mock/test");
1211 let mut messages = vec![user("call echo")];
1212 let policy = DenyAllPolicy;
1213 let hooks = RunHooks {
1214 policy: Some(&policy),
1215 ..RunHooks::default()
1216 };
1217
1218 let outcome = run_agent(
1219 &provider,
1220 &tools,
1221 &mut messages,
1222 &model,
1223 &RunConfig::default(),
1224 &CancellationToken::new(),
1225 &hooks,
1226 )
1227 .await
1228 .expect("loop completes despite denial");
1229
1230 assert_eq!(outcome.final_text, "done");
1231 let s = match &messages[2].content[0] {
1233 ContentBlock::ToolResult { content, .. } => content.to_string(),
1234 other => panic!("expected ToolResult, got {other:?}"),
1235 };
1236 assert!(s.contains("denied by policy"), "expected denial, got: {s}");
1237 assert!(
1238 !s.contains("echo: secret"),
1239 "denied tool must NOT have executed: {s}"
1240 );
1241 }
1242
1243 #[tokio::test]
1244 async fn policy_none_is_allow_all() {
1245 let provider = MockProvider::new(vec![
1248 vec![assistant_tool_use("c1", "echo", json!({ "text": "hi" }))],
1249 vec![assistant_text("done")],
1250 ]);
1251 let tools: Vec<Arc<dyn Tool>> = vec![Arc::new(EchoTool)];
1252 let model = Model::new("mock/test");
1253 let mut messages = vec![user("call echo")];
1254
1255 let outcome = run_agent(
1256 &provider,
1257 &tools,
1258 &mut messages,
1259 &model,
1260 &RunConfig::default(),
1261 &CancellationToken::new(),
1262 &RunHooks::default(),
1263 )
1264 .await
1265 .expect("loop completes");
1266 assert_eq!(outcome.final_text, "done");
1267 let s = match &messages[2].content[0] {
1268 ContentBlock::ToolResult { content, .. } => content.to_string(),
1269 other => panic!("expected ToolResult, got {other:?}"),
1270 };
1271 assert!(s.contains("echo: hi"), "tool should have run: {s}");
1272 }
1273
1274 struct RecordingTool {
1279 name: String,
1280 log: Arc<std::sync::Mutex<Vec<String>>>,
1281 }
1282
1283 #[async_trait]
1284 impl Tool for RecordingTool {
1285 fn definition(&self) -> crate::tool::ToolDefinition {
1286 crate::tool::ToolDefinition {
1287 name: self.name.clone(),
1288 label: "Recording".into(),
1289 description: "Records each execution.".into(),
1290 parameters: crate::tool::ParameterSchema::default(),
1291 }
1292 }
1293
1294 async fn execute(&self, _ctx: InvokeContext, input: Value) -> Result<ToolResult> {
1295 let tag = input
1296 .get("tag")
1297 .and_then(Value::as_str)
1298 .unwrap_or("?")
1299 .to_string();
1300 self.log.lock().expect("lock poisoned").push(tag);
1301 Ok(ToolResult {
1302 content: vec![json!({ "type": "text", "text": "ran" })],
1303 details: None,
1304 })
1305 }
1306 }
1307
1308 #[tokio::test]
1318 async fn policy_deny_blocks_tools_on_the_parallel_path() {
1319 let log = Arc::new(std::sync::Mutex::new(Vec::new()));
1322 let tools: Vec<Arc<dyn Tool>> = vec![Arc::new(RecordingTool {
1323 name: "rec".into(),
1324 log: log.clone(),
1325 })];
1326 let turn = vec![
1327 assistant_tool_use("c1", "rec", json!({ "tag": "one" })),
1328 assistant_tool_use("c2", "rec", json!({ "tag": "two" })),
1329 assistant_tool_use("c3", "rec", json!({ "tag": "three" })),
1330 ];
1331 let provider = MockProvider::new(vec![turn, vec![assistant_text("done")]]);
1332 let model = Model::new("mock/test");
1333 let mut messages = vec![user("call all three")];
1334 let config = RunConfig {
1335 tool_concurrency: 4,
1336 ..RunConfig::default()
1337 };
1338 let policy = DenyAllPolicy;
1339 let hooks = RunHooks {
1340 policy: Some(&policy),
1341 ..RunHooks::default()
1342 };
1343
1344 let outcome = run_agent(
1345 &provider,
1346 &tools,
1347 &mut messages,
1348 &model,
1349 &config,
1350 &CancellationToken::new(),
1351 &hooks,
1352 )
1353 .await
1354 .expect("loop completes despite denials");
1355
1356 assert_eq!(outcome.final_text, "done");
1357
1358 let executed = log.lock().expect("lock poisoned").clone();
1362 assert!(
1363 executed.is_empty(),
1364 "denied tools must NOT execute on the parallel path: ran {executed:?}"
1365 );
1366
1367 let results: Vec<String> = messages
1370 .iter()
1371 .filter(|m| m.role == Role::Tool)
1372 .filter_map(|m| match &m.content[0] {
1373 ContentBlock::ToolResult {
1374 tool_use_id,
1375 content,
1376 ..
1377 } => {
1378 let text = content.to_string();
1379 Some(format!("{tool_use_id}:{text}"))
1380 }
1381 _ => None,
1382 })
1383 .collect();
1384 assert_eq!(
1385 results.len(),
1386 3,
1387 "all 3 denied calls must produce a result slot: {results:?}"
1388 );
1389 for r in &results {
1390 assert!(
1391 r.contains("denied by policy"),
1392 "parallel-path denial must surface to the model: {r}"
1393 );
1394 }
1395 assert!(
1398 results[0].starts_with("c1:")
1399 && results[1].starts_with("c2:")
1400 && results[2].starts_with("c3:"),
1401 "denial slots must preserve issued order: {results:?}"
1402 );
1403 }
1404
1405 #[tokio::test]
1406 async fn loop_respects_cancellation() {
1407 let provider = MockProvider::new(vec![vec![assistant_text("never reached")]]);
1409 let tools: Vec<Arc<dyn Tool>> = vec![];
1410 let model = Model::new("mock/test");
1411 let mut messages = vec![user("hi")];
1412 let cancel = CancellationToken::new();
1413 cancel.cancel();
1414
1415 let result = run_agent(
1416 &provider,
1417 &tools,
1418 &mut messages,
1419 &model,
1420 &RunConfig::default(),
1421 &cancel,
1422 &RunHooks::default(),
1423 )
1424 .await;
1425
1426 assert!(matches!(result, Err(CoreError::Cancelled(_))));
1427 }
1428
1429 struct SlowProvider {
1432 delay_ms: u64,
1433 responses: std::sync::Mutex<std::collections::VecDeque<ModelResponse>>,
1434 }
1435
1436 impl SlowProvider {
1437 fn new(delay_ms: u64, responses: Vec<Vec<AgentMessage>>) -> Self {
1438 let responses = responses
1439 .into_iter()
1440 .map(|m| ModelResponse { messages: m })
1441 .collect();
1442 Self {
1443 delay_ms,
1444 responses: std::sync::Mutex::new(responses),
1445 }
1446 }
1447 }
1448
1449 #[async_trait]
1450 impl ModelProvider for SlowProvider {
1451 async fn invoke(&self, _request: ModelRequest) -> Result<ModelResponse> {
1452 tokio::time::sleep(std::time::Duration::from_millis(self.delay_ms)).await;
1453 let next = self
1454 .responses
1455 .lock()
1456 .unwrap()
1457 .pop_front()
1458 .unwrap_or(ModelResponse { messages: vec![] });
1459 Ok(next)
1460 }
1461 }
1462
1463 struct SlowStreamingProvider {
1467 delay_ms: u64,
1468 }
1469
1470 #[async_trait]
1471 impl ModelProvider for SlowStreamingProvider {
1472 async fn invoke(&self, _request: ModelRequest) -> Result<ModelResponse> {
1473 Ok(ModelResponse { messages: vec![] })
1475 }
1476 fn stream(&self, _request: ModelRequest) -> crate::model::StreamEventStream {
1477 use futures::stream::StreamExt as _;
1478 let delay = self.delay_ms;
1479 Box::pin(
1480 futures::stream::once(async move {
1481 tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
1482 Ok(StreamEvent::TextDelta("finally".to_string()))
1483 })
1484 .chain(futures::stream::once(async { Ok(StreamEvent::Done) })),
1485 )
1486 }
1487 }
1488
1489 #[tokio::test]
1490 async fn streaming_turn_times_out_on_slow_provider() {
1491 let provider = SlowStreamingProvider { delay_ms: 500 };
1495 let model = Model::new("mock/test");
1496 let mut messages = vec![user("hi")];
1497 let config = RunConfig {
1498 turn_timeout_ms: Some(100),
1499 ..RunConfig::default()
1500 };
1501 let mut events: Vec<StreamEvent> = Vec::new();
1502 let mut on_event = |ev: &StreamEvent| {
1503 events.push(ev.clone());
1504 };
1505 let result = run_agent_streaming(
1506 &provider,
1507 &[],
1508 &mut messages,
1509 &model,
1510 &config,
1511 &CancellationToken::new(),
1512 &mut on_event,
1513 &RunHooks::default(),
1514 )
1515 .await;
1516 assert!(
1517 matches!(result, Err(CoreError::TurnTimeout { ms: 100 })),
1518 "expected a streaming turn timeout, got {result:?}"
1519 );
1520 }
1521
1522 #[tokio::test]
1523 async fn streaming_turn_aborts_on_cancel() {
1524 let provider = SlowStreamingProvider { delay_ms: 60_000 };
1527 let model = Model::new("mock/test");
1528 let mut messages = vec![user("hi")];
1529 let cancel = CancellationToken::new();
1530 let cancel_for_run = cancel.clone();
1531 tokio::spawn(async move {
1533 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1534 cancel_for_run.cancel();
1535 });
1536 let result = run_agent_streaming(
1537 &provider,
1538 &[],
1539 &mut messages,
1540 &model,
1541 &RunConfig::default(),
1542 &cancel,
1543 &mut |_| {},
1544 &RunHooks::default(),
1545 )
1546 .await;
1547 assert!(
1548 matches!(result, Err(CoreError::Cancelled(_))),
1549 "expected cancellation to abort the streaming run, got {result:?}"
1550 );
1551 }
1552
1553 #[tokio::test]
1554 async fn turn_timeout_aborts_slow_provider() {
1555 let provider = SlowProvider::new(500, vec![vec![assistant_text("too slow")]]);
1557 let model = Model::new("mock/test");
1558 let mut messages = vec![user("hi")];
1559 let config = RunConfig {
1560 turn_timeout_ms: Some(100),
1561 ..RunConfig::default()
1562 };
1563
1564 let result = run_agent(
1565 &provider,
1566 &[],
1567 &mut messages,
1568 &model,
1569 &config,
1570 &CancellationToken::new(),
1571 &RunHooks::default(),
1572 )
1573 .await;
1574
1575 assert!(
1576 matches!(result, Err(CoreError::TurnTimeout { ms: 100 })),
1577 "expected turn timeout, got {result:?}"
1578 );
1579 }
1580
1581 #[tokio::test]
1582 async fn max_tool_calls_per_turn_rejects_runaway_response() {
1583 let runaway: Vec<AgentMessage> = (0..5)
1585 .map(|i| assistant_tool_use(&format!("c{i}"), "echo", json!({ "text": "x" })))
1586 .collect();
1587 let provider = MockProvider::new(vec![runaway]);
1588 let tools: Vec<Arc<dyn Tool>> = vec![Arc::new(EchoTool)];
1589 let model = Model::new("mock/test");
1590 let mut messages = vec![user("call many tools")];
1591 let config = RunConfig {
1592 max_tool_calls_per_turn: 2,
1593 ..RunConfig::default()
1594 };
1595
1596 let result = run_agent(
1597 &provider,
1598 &tools,
1599 &mut messages,
1600 &model,
1601 &config,
1602 &CancellationToken::new(),
1603 &RunHooks::default(),
1604 )
1605 .await;
1606
1607 assert!(result.is_err(), "runaway tool calls must be rejected");
1608 let err = result.unwrap_err().to_string();
1609 assert!(err.contains("max"), "error should mention the cap: {err}");
1610 }
1611
1612 struct OrderingTool {
1614 name: String,
1615 delay_ms: u64,
1616 log: Arc<std::sync::Mutex<Vec<String>>>,
1617 }
1618
1619 #[async_trait]
1620 impl Tool for OrderingTool {
1621 fn definition(&self) -> crate::tool::ToolDefinition {
1622 crate::tool::ToolDefinition {
1623 name: self.name.clone(),
1624 label: "Ordering".into(),
1625 description: "Records completion order.".into(),
1626 parameters: crate::tool::ParameterSchema::default(),
1627 }
1628 }
1629
1630 async fn execute(&self, _ctx: InvokeContext, input: Value) -> Result<ToolResult> {
1631 tokio::time::sleep(std::time::Duration::from_millis(self.delay_ms)).await;
1632 self.log.lock().unwrap().push(
1633 input
1634 .get("tag")
1635 .and_then(Value::as_str)
1636 .unwrap_or("?")
1637 .to_string(),
1638 );
1639 Ok(ToolResult {
1640 content: vec![json!({ "type": "text", "text": "ok" })],
1641 details: None,
1642 })
1643 }
1644 }
1645
1646 #[tokio::test]
1647 async fn parallel_tool_calls_preserve_result_order() {
1648 let log = Arc::new(std::sync::Mutex::new(Vec::new()));
1652 let tools: Vec<Arc<dyn Tool>> = vec![
1653 Arc::new(OrderingTool {
1654 name: "slow".into(),
1655 delay_ms: 60,
1656 log: log.clone(),
1657 }),
1658 Arc::new(OrderingTool {
1659 name: "fast".into(),
1660 delay_ms: 5,
1661 log: log.clone(),
1662 }),
1663 ];
1664 let turn = vec![
1665 assistant_tool_use("c1", "slow", json!({ "tag": "slow" })),
1666 assistant_tool_use("c2", "fast", json!({ "tag": "fast" })),
1667 ];
1668 let provider = MockProvider::new(vec![turn, vec![assistant_text("done")]]);
1669 let model = Model::new("mock/test");
1670 let mut messages = vec![user("call both")];
1671 let config = RunConfig {
1672 tool_concurrency: 4,
1673 ..RunConfig::default()
1674 };
1675
1676 let outcome = run_agent(
1677 &provider,
1678 &tools,
1679 &mut messages,
1680 &model,
1681 &config,
1682 &CancellationToken::new(),
1683 &RunHooks::default(),
1684 )
1685 .await
1686 .expect("loop should complete");
1687
1688 assert_eq!(outcome.final_text, "done");
1689
1690 let completed = log.lock().unwrap().clone();
1693 assert_eq!(
1694 completed,
1695 vec!["fast", "slow"],
1696 "tools must have run concurrently: {completed:?}"
1697 );
1698
1699 let tool_ids: Vec<String> = messages
1701 .iter()
1702 .filter(|m| m.role == Role::Tool)
1703 .filter_map(|m| match &m.content[0] {
1704 ContentBlock::ToolResult { tool_use_id, .. } => Some(tool_use_id.clone()),
1705 _ => None,
1706 })
1707 .collect();
1708 assert_eq!(
1709 tool_ids,
1710 vec!["c1", "c2"],
1711 "results must be appended in issued order: {tool_ids:?}"
1712 );
1713 }
1714
1715 struct PanickingTool;
1718
1719 #[async_trait]
1720 impl Tool for PanickingTool {
1721 fn definition(&self) -> crate::tool::ToolDefinition {
1722 crate::tool::ToolDefinition {
1723 name: "boom".into(),
1724 label: "Boom".into(),
1725 description: "Always panics.".into(),
1726 parameters: crate::tool::ParameterSchema::default(),
1727 }
1728 }
1729
1730 async fn execute(&self, _ctx: InvokeContext, _input: Value) -> Result<ToolResult> {
1731 panic!("deliberate tool panic");
1732 }
1733 }
1734
1735 #[tokio::test]
1736 async fn parallel_path_survives_a_task_panic() {
1737 let turn = vec![
1740 assistant_tool_use("c1", "boom", json!({})),
1741 assistant_tool_use("c2", "echo", json!({ "text": "survived" })),
1742 ];
1743 let provider = MockProvider::new(vec![turn, vec![assistant_text("done")]]);
1744 let tools: Vec<Arc<dyn Tool>> = vec![Arc::new(PanickingTool), Arc::new(EchoTool)];
1745 let model = Model::new("mock/test");
1746 let mut messages = vec![user("call both")];
1747 let config = RunConfig {
1748 tool_concurrency: 4,
1749 ..RunConfig::default()
1750 };
1751
1752 let outcome = run_agent(
1754 &provider,
1755 &tools,
1756 &mut messages,
1757 &model,
1758 &config,
1759 &CancellationToken::new(),
1760 &RunHooks::default(),
1761 )
1762 .await
1763 .expect("loop must survive a tool panic");
1764
1765 assert_eq!(outcome.final_text, "done");
1766 let tool_ids: Vec<String> = messages
1768 .iter()
1769 .filter(|m| m.role == Role::Tool)
1770 .filter_map(|m| match &m.content[0] {
1771 ContentBlock::ToolResult { tool_use_id, .. } => Some(tool_use_id.clone()),
1772 _ => None,
1773 })
1774 .collect();
1775 assert_eq!(
1776 tool_ids,
1777 vec!["c1", "c2"],
1778 "both results must be present despite the panic: {tool_ids:?}"
1779 );
1780 }
1781
1782 #[tokio::test]
1788 async fn sequential_path_survives_a_tool_panic() {
1789 let turn = vec![
1790 assistant_tool_use("c1", "boom", json!({})),
1791 assistant_tool_use("c2", "echo", json!({ "text": "survived" })),
1792 ];
1793 let provider = MockProvider::new(vec![turn, vec![assistant_text("done")]]);
1794 let tools: Vec<Arc<dyn Tool>> = vec![Arc::new(PanickingTool), Arc::new(EchoTool)];
1795 let model = Model::new("mock/test");
1796 let mut messages = vec![user("call both")];
1797 let config = RunConfig::default();
1799
1800 let outcome = run_agent(
1802 &provider,
1803 &tools,
1804 &mut messages,
1805 &model,
1806 &config,
1807 &CancellationToken::new(),
1808 &RunHooks::default(),
1809 )
1810 .await
1811 .expect("sequential path must survive a tool panic");
1812 assert_eq!(outcome.final_text, "done");
1813
1814 let results: Vec<&ContentBlock> = messages
1816 .iter()
1817 .filter(|m| m.role == Role::Tool)
1818 .flat_map(|m| m.content.iter())
1819 .collect();
1820 assert_eq!(results.len(), 2, "both results appended");
1821 let c1_str = match &results[0] {
1823 ContentBlock::ToolResult { content, .. } => content.to_string(),
1824 _ => String::new(),
1825 };
1826 assert!(
1827 c1_str.contains("Error:"),
1828 "panic must surface as an Error: result, got: {c1_str}"
1829 );
1830 assert!(
1831 c1_str.contains("panicked"),
1832 "error result should mention the panic: {c1_str}"
1833 );
1834 }
1835
1836 #[tokio::test]
1842 async fn parallel_path_panic_preserves_call_id_and_summary() {
1843 let turn = vec![
1845 assistant_tool_use("c1", "boom", json!({})),
1846 assistant_tool_use("c2", "echo", json!({ "text": "ok" })),
1847 ];
1848 let provider = MockProvider::new(vec![turn, vec![assistant_text("done")]]);
1849 let tools: Vec<Arc<dyn Tool>> = vec![Arc::new(PanickingTool), Arc::new(EchoTool)];
1850 let model = Model::new("mock/test");
1851 let mut messages = vec![user("call both")];
1852 let config = RunConfig {
1853 tool_concurrency: 4,
1854 ..RunConfig::default()
1855 };
1856
1857 let outcome = run_agent(
1858 &provider,
1859 &tools,
1860 &mut messages,
1861 &model,
1862 &config,
1863 &CancellationToken::new(),
1864 &RunHooks::default(),
1865 )
1866 .await
1867 .expect("run survives parallel panic");
1868 assert_eq!(outcome.final_text, "done");
1869
1870 let tool_msgs: Vec<(&String, String)> = messages
1873 .iter()
1874 .filter(|m| m.role == Role::Tool)
1875 .flat_map(|m| m.content.iter())
1876 .filter_map(|b| match b {
1877 ContentBlock::ToolResult {
1878 tool_use_id,
1879 content,
1880 } => Some((tool_use_id, content.to_string())),
1881 _ => None,
1882 })
1883 .collect();
1884 assert_eq!(tool_msgs.len(), 2, "both results present");
1885 assert_eq!(tool_msgs[0].0, "c1", "c1 attributed correctly");
1887 assert_eq!(tool_msgs[1].0, "c2", "c2 attributed correctly");
1888 assert!(
1890 tool_msgs[0].1.contains("panicked"),
1891 "parallel panic should carry bounded summary, got: {}",
1892 tool_msgs[0].1
1893 );
1894 assert!(
1895 tool_msgs[0].1.contains("Error:"),
1896 "should be an Error: result, got: {}",
1897 tool_msgs[0].1
1898 );
1899 }
1900
1901 #[tokio::test]
1908 async fn parallel_path_keeps_all_results_under_throttling() {
1909 let turn = vec![
1913 assistant_tool_use("c1", "echo", json!({ "text": "one" })),
1914 assistant_tool_use("c2", "echo", json!({ "text": "two" })),
1915 assistant_tool_use("c3", "echo", json!({ "text": "three" })),
1916 ];
1917 let provider = MockProvider::new(vec![turn, vec![assistant_text("done")]]);
1918 let tools: Vec<Arc<dyn Tool>> = vec![Arc::new(EchoTool)];
1919 let model = Model::new("mock/test");
1920 let mut messages = vec![user("call all three")];
1921 let config = RunConfig {
1922 tool_concurrency: 2,
1923 ..RunConfig::default()
1924 };
1925
1926 let outcome = run_agent(
1927 &provider,
1928 &tools,
1929 &mut messages,
1930 &model,
1931 &config,
1932 &CancellationToken::new(),
1933 &RunHooks::default(),
1934 )
1935 .await
1936 .expect("run completes");
1937 assert_eq!(outcome.final_text, "done");
1938
1939 let results: Vec<String> = messages
1943 .iter()
1944 .filter(|m| m.role == Role::Tool)
1945 .flat_map(|m| m.content.iter())
1946 .filter_map(|b| match b {
1947 ContentBlock::ToolResult {
1948 tool_use_id,
1949 content,
1950 } => {
1951 let text = content
1952 .get("content")
1953 .and_then(|c| c.get(0))
1954 .and_then(|c| c.get("text"))
1955 .and_then(|t| t.as_str())
1956 .unwrap_or("<missing>");
1957 Some(format!("{tool_use_id}={text}"))
1958 }
1959 _ => None,
1960 })
1961 .collect();
1962 assert_eq!(
1963 results,
1964 vec!["c1=echo: one", "c2=echo: two", "c3=echo: three"],
1965 "all 3 results must survive throttling, in order, with correct text: {results:?}"
1966 );
1967 }
1968
1969 #[test]
1970 fn summarize_panic_handles_string_payloads() {
1971 let p: Box<dyn std::any::Any + Send> = Box::new("boom!".to_string());
1972 assert_eq!(summarize_panic(&p), "boom!");
1973 }
1974
1975 #[test]
1976 fn summarize_panic_handles_str_payloads() {
1977 let s: &'static str = "static boom";
1978 let p: Box<dyn std::any::Any + Send> = Box::new(s);
1979 assert_eq!(summarize_panic(&p), "static boom");
1980 }
1981
1982 #[test]
1983 fn summarize_panic_bounds_huge_payloads() {
1984 let huge = "x".repeat(10_000);
1985 let p: Box<dyn std::any::Any + Send> = Box::new(huge);
1986 let summary = summarize_panic(&p);
1987 assert!(
1988 summary.chars().count() <= PANIC_SUMMARY_MAX_CHARS,
1989 "summary not bounded: {} chars",
1990 summary.chars().count()
1991 );
1992 assert!(
1993 summary.ends_with('…'),
1994 "should end with ellipsis: {summary}"
1995 );
1996 }
1997
1998 #[test]
1999 fn summarize_panic_falls_back_for_non_string_payloads() {
2000 let p: Box<dyn std::any::Any + Send> = Box::new(42_i32);
2001 let summary = summarize_panic(&p);
2002 assert!(
2003 summary.contains("non-string"),
2004 "expected fallback marker: {summary}"
2005 );
2006 }
2007
2008 use crate::event::{EventSink, RunEvent};
2011 use std::sync::{Arc, Mutex};
2012 use uuid::Uuid;
2013
2014 struct RecordingSink {
2016 events: Arc<Mutex<Vec<RunEvent>>>,
2017 }
2018
2019 impl EventSink for RecordingSink {
2020 fn emit(&self, event: RunEvent) {
2021 self.events.lock().expect("lock poisoned").push(event);
2022 }
2023 }
2024
2025 #[tokio::test]
2026 async fn text_only_run_emits_complete_event_sequence() {
2027 let provider = MockProvider::new(vec![vec![assistant_text("hello")]]);
2028 let tools: Vec<Arc<dyn Tool>> = vec![];
2029 let model = Model::new("mock/test");
2030 let mut messages = vec![user("hi")];
2031 let sink = Arc::new(Mutex::new(Vec::new()));
2032 let hooks = RunHooks {
2033 session_id: Some(Uuid::nil()),
2034 turn_sink: None,
2035 event_sink: Some(&RecordingSink {
2036 events: sink.clone(),
2037 } as &dyn EventSink),
2038 policy: None,
2039 };
2040
2041 run_agent(
2042 &provider,
2043 &tools,
2044 &mut messages,
2045 &model,
2046 &RunConfig::default(),
2047 &CancellationToken::new(),
2048 &hooks,
2049 )
2050 .await
2051 .expect("run");
2052
2053 let events = sink.lock().expect("lock poisoned").clone();
2054 assert!(events
2056 .iter()
2057 .any(|e| matches!(e, RunEvent::SessionStarted { .. })));
2058 assert!(events
2059 .iter()
2060 .any(|e| matches!(e, RunEvent::TurnStarted { turn: 1, .. })));
2061 assert!(events.iter().any(
2062 |e| matches!(e, RunEvent::ModelStarted { turn: 1, model, .. } if model == "mock/test")
2063 ));
2064 assert!(events
2065 .iter()
2066 .any(|e| matches!(e, RunEvent::ModelFinished { turn: 1, .. })));
2067 assert!(events
2068 .iter()
2069 .any(|e| matches!(e, RunEvent::TurnFinished { turn: 1, .. })));
2070 assert!(!events
2072 .iter()
2073 .any(|e| matches!(e, RunEvent::ToolStarted { .. })));
2074 }
2075
2076 #[tokio::test]
2077 async fn tool_run_emits_tool_started_finished() {
2078 let echo_tool = Arc::new(EchoTool) as Arc<dyn Tool>;
2079 let tools = vec![echo_tool.clone()];
2080 let provider = MockProvider::new(vec![
2081 vec![assistant_tool_use(
2082 "call-1",
2083 "echo",
2084 json!({ "text": "hi" }),
2085 )],
2086 vec![assistant_text("done")],
2087 ]);
2088 let model = Model::new("mock/test");
2089 let mut messages = vec![user("echo hi")];
2090 let sink = Arc::new(Mutex::new(Vec::new()));
2091 let hooks = RunHooks {
2092 session_id: Some(Uuid::nil()),
2093 turn_sink: None,
2094 event_sink: Some(&RecordingSink {
2095 events: sink.clone(),
2096 } as &dyn EventSink),
2097 policy: None,
2098 };
2099
2100 run_agent(
2101 &provider,
2102 &tools,
2103 &mut messages,
2104 &model,
2105 &RunConfig::default(),
2106 &CancellationToken::new(),
2107 &hooks,
2108 )
2109 .await
2110 .expect("run");
2111
2112 let events = sink.lock().expect("lock poisoned").clone();
2113 assert!(
2115 events.iter().any(|e| matches!(e, RunEvent::ToolStarted { turn: 1, tool, call_id, .. } if tool == "echo" && call_id == "call-1")),
2116 "missing ToolStarted for echo/call-1"
2117 );
2118 assert!(
2119 events.iter().any(|e| matches!(e, RunEvent::ToolFinished { turn: 1, tool, call_id, ok: true, .. } if tool == "echo" && call_id == "call-1")),
2120 "missing ToolFinished for echo/call-1"
2121 );
2122 assert!(events
2124 .iter()
2125 .any(|e| matches!(e, RunEvent::TurnFinished { turn: 2, .. })));
2126 }
2127
2128 #[tokio::test]
2129 async fn no_events_when_session_id_is_none() {
2130 let provider = MockProvider::new(vec![vec![assistant_text("hello")]]);
2131 let tools: Vec<Arc<dyn Tool>> = vec![];
2132 let model = Model::new("mock/test");
2133 let mut messages = vec![user("hi")];
2134 let sink = Arc::new(Mutex::new(Vec::new()));
2135 let hooks = RunHooks {
2136 session_id: None, turn_sink: None,
2138 event_sink: Some(&RecordingSink {
2139 events: sink.clone(),
2140 } as &dyn EventSink),
2141 policy: None,
2142 };
2143
2144 run_agent(
2145 &provider,
2146 &tools,
2147 &mut messages,
2148 &model,
2149 &RunConfig::default(),
2150 &CancellationToken::new(),
2151 &hooks,
2152 )
2153 .await
2154 .expect("run");
2155
2156 assert!(
2157 sink.lock().expect("lock poisoned").is_empty(),
2158 "events emitted with no session_id"
2159 );
2160 }
2161}