1use std::sync::Arc;
13
14use tokio_util::sync::CancellationToken;
15
16use crate::error::{CoreError, Result};
17use crate::event::{RunEvent, RunHooks};
18use crate::message::{AgentMessage, ContentBlock, Role};
19use crate::model::{Model, ModelProvider, ModelRequest, StreamEvent};
20use crate::policy::{PolicyVerdict, ToolPolicy};
21use crate::thinking::ThinkingLevel;
22use crate::tool::{InvokeContext, Tool, ToolCall, ToolResult};
23
24#[derive(Debug, Clone)]
26pub struct RunConfig {
27 pub max_turns: usize,
29 pub thinking: ThinkingLevel,
31 pub turn_timeout_ms: Option<u64>,
34 pub max_tool_calls_per_turn: usize,
37 pub tool_concurrency: usize,
41}
42
43impl Default for RunConfig {
44 fn default() -> Self {
45 Self {
46 max_turns: 12,
47 thinking: ThinkingLevel::default(),
48 turn_timeout_ms: Some(120_000),
49 max_tool_calls_per_turn: 10,
50 tool_concurrency: 1,
51 }
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct RunOutcome {
58 pub turns: usize,
60 pub final_text: String,
63}
64
65#[async_trait::async_trait]
80pub trait TurnSink: Send + Sync {
81 async fn after_turn(&self, turn: usize, messages: &[AgentMessage]) -> Result<()>;
84}
85
86pub struct FanoutTurnSink {
100 sinks: Vec<Box<dyn TurnSink>>,
101}
102
103impl FanoutTurnSink {
104 #[must_use]
106 pub fn new() -> Self {
107 Self { sinks: Vec::new() }
108 }
109
110 #[must_use]
112 pub fn push(mut self, sink: Box<dyn TurnSink>) -> Self {
113 self.sinks.push(sink);
114 self
115 }
116
117 #[must_use]
119 pub fn len(&self) -> usize {
120 self.sinks.len()
121 }
122
123 #[must_use]
125 pub fn is_empty(&self) -> bool {
126 self.sinks.is_empty()
127 }
128}
129
130impl Default for FanoutTurnSink {
131 fn default() -> Self {
132 Self::new()
133 }
134}
135
136#[async_trait::async_trait]
137impl TurnSink for FanoutTurnSink {
138 async fn after_turn(&self, turn: usize, messages: &[AgentMessage]) -> Result<()> {
139 for sink in &self.sinks {
140 sink.after_turn(turn, messages).await?;
141 }
142 Ok(())
143 }
144}
145
146#[cfg(test)]
147mod fanout_tests {
148 use super::*;
149 use std::sync::{Arc, Mutex};
150
151 struct RecordingSink {
153 calls: Arc<Mutex<Vec<usize>>>,
154 fail_at: Option<usize>,
155 }
156
157 #[async_trait::async_trait]
158 impl TurnSink for RecordingSink {
159 async fn after_turn(&self, turn: usize, _messages: &[AgentMessage]) -> Result<()> {
160 self.calls.lock().expect("lock poisoned").push(turn);
161 if self.fail_at == Some(turn) {
162 return Err(crate::error::CoreError::Transport(format!(
163 "injected failure at turn {turn}"
164 )));
165 }
166 Ok(())
167 }
168 }
169
170 #[tokio::test]
171 async fn fanout_calls_sinks_in_order() {
172 let calls_a = Arc::new(Mutex::new(Vec::new()));
173 let calls_b = Arc::new(Mutex::new(Vec::new()));
174 let fanout = FanoutTurnSink::new()
175 .push(Box::new(RecordingSink {
176 calls: calls_a.clone(),
177 fail_at: None,
178 }))
179 .push(Box::new(RecordingSink {
180 calls: calls_b.clone(),
181 fail_at: None,
182 }));
183 TurnSink::after_turn(&fanout, 1, &[]).await.unwrap();
184 assert_eq!(*calls_a.lock().unwrap(), vec![1]);
185 assert_eq!(*calls_b.lock().unwrap(), vec![1]);
186 }
187
188 #[tokio::test]
189 async fn fanout_propagates_error_and_stops() {
190 let calls_a = Arc::new(Mutex::new(Vec::new()));
192 let calls_b = Arc::new(Mutex::new(Vec::new()));
193 let fanout = FanoutTurnSink::new()
194 .push(Box::new(RecordingSink {
195 calls: calls_a.clone(),
196 fail_at: Some(2),
197 }))
198 .push(Box::new(RecordingSink {
199 calls: calls_b.clone(),
200 fail_at: None,
201 }));
202 let _ = TurnSink::after_turn(&fanout, 2, &[]).await;
203 assert_eq!(*calls_a.lock().unwrap(), vec![2]);
204 assert!(calls_b.lock().unwrap().is_empty(), "second sink ran");
205 }
206}
207
208pub async fn run_agent(
226 provider: &dyn ModelProvider,
227 tools: &[Arc<dyn Tool>],
228 messages: &mut Vec<AgentMessage>,
229 model: &Model,
230 config: &RunConfig,
231 cancel: &CancellationToken,
232 hooks: &RunHooks<'_>,
233) -> Result<RunOutcome> {
234 hooks.emit_event(|sid| RunEvent::SessionStarted { session: sid });
235 let mut turns = 0usize;
236 loop {
237 if cancel.is_cancelled() {
238 hooks.emit_event(|sid| crate::event::run_failed(sid, "cancelled"));
239 return Err(CoreError::Cancelled("agent run cancelled".into()));
240 }
241 if turns >= config.max_turns {
242 let msg = format!(
243 "max_turns ({}) exceeded — the model kept calling tools",
244 config.max_turns
245 );
246 hooks.emit_event(|sid| crate::event::run_failed(sid, msg.clone()));
247 return Err(CoreError::ModelResponse(msg));
248 }
249 turns += 1;
250 hooks.emit_event(|sid| RunEvent::TurnStarted {
251 session: sid,
252 turn: turns,
253 });
254
255 let request = ModelRequest {
256 model: model.clone(),
257 messages: messages.clone(),
258 tools: tools.iter().map(|t| t.definition()).collect(),
259 thinking: config.thinking,
260 params: Default::default(),
261 };
262 hooks.emit_event(|sid| RunEvent::ModelStarted {
263 session: sid,
264 turn: turns,
265 model: model.id.clone(),
266 });
267 let response =
269 match invoke_with_budget(provider, request, config.turn_timeout_ms, cancel).await {
270 Ok(r) => r,
271 Err(e) => {
272 hooks.emit_event(|sid| crate::event::run_failed(sid, e.to_string()));
273 return Err(e);
274 }
275 };
276 hooks.emit_event(|sid| RunEvent::ModelFinished {
277 session: sid,
278 turn: turns,
279 });
280 let tool_calls: Vec<(String, ToolCall)> = response
282 .messages
283 .iter()
284 .flat_map(|m| m.content.iter())
285 .filter_map(|block| {
286 if let ContentBlock::ToolUse { id, call } = block {
287 Some((id.clone(), call.clone()))
288 } else {
289 None
290 }
291 })
292 .collect();
293 messages.extend(response.messages);
295
296 if tool_calls.is_empty() {
297 let final_text = extract_final_text(messages);
299 if let Some(sink) = hooks.turn_sink {
301 sink.after_turn(turns, messages).await?;
302 }
303 hooks.emit_event(|sid| RunEvent::TurnFinished {
304 session: sid,
305 turn: turns,
306 });
307 return Ok(RunOutcome { turns, final_text });
308 }
309
310 if tool_calls.len() > config.max_tool_calls_per_turn {
312 let msg = format!(
313 "model issued {} tool calls in one turn (max {})",
314 tool_calls.len(),
315 config.max_tool_calls_per_turn
316 );
317 hooks.emit_event(|sid| crate::event::run_failed(sid, msg.clone()));
318 return Err(CoreError::ModelResponse(msg));
319 }
320
321 for (id, call) in &tool_calls {
323 hooks.emit_event(|sid| RunEvent::ToolStarted {
324 session: sid,
325 turn: turns,
326 tool: call.name.clone(),
327 call_id: id.clone(),
328 });
329 }
330
331 let results = execute_tool_calls(
335 tools,
336 &tool_calls,
337 cancel,
338 config.tool_concurrency,
339 hooks.policy,
340 )
341 .await;
342
343 for (i, (id, call)) in tool_calls.iter().enumerate() {
345 let result = &results[i];
346 let ok = tool_result_ok(result);
347 hooks.emit_event(|sid| RunEvent::ToolFinished {
348 session: sid,
349 turn: turns,
350 tool: call.name.clone(),
351 call_id: id.clone(),
352 ok,
353 });
354 let tool_msg = AgentMessage {
355 role: Role::Tool,
356 content: vec![ContentBlock::ToolResult {
357 tool_use_id: id.clone(),
358 content: serde_json::to_value(result)
359 .unwrap_or_else(|_| serde_json::json!({ "error": "serialize failed" })),
360 }],
361 };
362 messages.push(tool_msg);
363 }
364 if let Some(sink) = hooks.turn_sink {
368 sink.after_turn(turns, messages).await?;
369 }
370 hooks.emit_event(|sid| RunEvent::TurnFinished {
371 session: sid,
372 turn: turns,
373 });
374 }
375}
376
377#[derive(Debug, Clone, Default)]
380struct StreamedTurn {
381 text: String,
382 thinking: String,
383 tool_calls: Vec<(String, ToolCall)>,
384}
385
386async fn collect_streamed_turn(
396 stream: crate::model::StreamEventStream,
397 on_event: &mut (dyn FnMut(&StreamEvent) + Send),
398 turn_timeout_ms: Option<u64>,
399 cancel: &CancellationToken,
400) -> Result<StreamedTurn> {
401 use futures::StreamExt;
402 let mut turn = StreamedTurn::default();
403 let mut s = stream;
404 let deadline = turn_timeout_ms
409 .map(|ms| tokio::time::Instant::now() + std::time::Duration::from_millis(ms));
410 loop {
411 if cancel.is_cancelled() {
413 return Err(CoreError::Cancelled("turn cancelled during stream".into()));
414 }
415 let next = async { s.next().await };
421 let item = match deadline {
422 Some(deadline) => {
423 let to = tokio::time::timeout_at(deadline, next);
424 tokio::select! {
425 biased;
426 _ = cancel.cancelled() => {
427 return Err(CoreError::Cancelled(
428 "turn cancelled during stream".into(),
429 ));
430 }
431 res = to => match res {
432 Ok(Some(item)) => item,
433 Ok(None) => break, Err(_) => {
435 let ms = turn_timeout_ms.unwrap_or(0);
436 return Err(CoreError::Cancelled(format!(
437 "turn timed out after {ms}ms"
438 )));
439 }
440 },
441 }
442 }
443 None => {
444 tokio::select! {
445 biased;
446 _ = cancel.cancelled() => {
447 return Err(CoreError::Cancelled(
448 "turn cancelled during stream".into(),
449 ));
450 }
451 item = next => match item {
452 Some(item) => item,
453 None => break, },
455 }
456 }
457 };
458 match item {
459 Ok(StreamEvent::TextDelta(t)) => {
460 on_event(&StreamEvent::TextDelta(t.clone()));
461 turn.text.push_str(&t);
462 }
463 Ok(StreamEvent::ThinkingDelta(t)) => {
464 turn.thinking.push_str(&t);
465 }
466 Ok(StreamEvent::ToolCall(call)) => {
467 let id = format!("call_{}", turn.tool_calls.len());
471 turn.tool_calls.push((id, call));
472 }
473 Ok(StreamEvent::Done) => break,
474 Err(e) => return Err(e),
475 }
476 }
477 Ok(turn)
478}
479
480#[allow(clippy::too_many_arguments)]
488pub async fn run_agent_streaming(
489 provider: &dyn ModelProvider,
490 tools: &[Arc<dyn Tool>],
491 messages: &mut Vec<AgentMessage>,
492 model: &Model,
493 config: &RunConfig,
494 cancel: &CancellationToken,
495 on_event: &mut (dyn FnMut(&StreamEvent) + Send),
496 hooks: &RunHooks<'_>,
497) -> Result<RunOutcome> {
498 hooks.emit_event(|sid| RunEvent::SessionStarted { session: sid });
499 let mut turns = 0usize;
500 loop {
501 if cancel.is_cancelled() {
502 hooks.emit_event(|sid| crate::event::run_failed(sid, "cancelled"));
503 return Err(CoreError::Cancelled("agent run cancelled".into()));
504 }
505 if turns >= config.max_turns {
506 let msg = format!(
507 "max_turns ({}) exceeded — the model kept calling tools",
508 config.max_turns
509 );
510 hooks.emit_event(|sid| crate::event::run_failed(sid, msg.clone()));
511 return Err(CoreError::ModelResponse(msg));
512 }
513 turns += 1;
514 hooks.emit_event(|sid| RunEvent::TurnStarted {
515 session: sid,
516 turn: turns,
517 });
518
519 let request = ModelRequest {
520 model: model.clone(),
521 messages: messages.clone(),
522 tools: tools.iter().map(|t| t.definition()).collect(),
523 thinking: config.thinking,
524 params: Default::default(),
525 };
526 hooks.emit_event(|sid| RunEvent::ModelStarted {
527 session: sid,
528 turn: turns,
529 model: model.id.clone(),
530 });
531 let stream = provider.stream(request);
533 let turn =
534 match collect_streamed_turn(stream, on_event, config.turn_timeout_ms, cancel).await {
535 Ok(t) => t,
536 Err(e) => {
537 hooks.emit_event(|sid| crate::event::run_failed(sid, e.to_string()));
538 return Err(e);
539 }
540 };
541 hooks.emit_event(|sid| RunEvent::ModelFinished {
542 session: sid,
543 turn: turns,
544 });
545
546 let mut content: Vec<ContentBlock> = Vec::new();
548 if !turn.text.is_empty() {
549 content.push(ContentBlock::Text { text: turn.text });
550 }
551 for (id, call) in &turn.tool_calls {
552 content.push(ContentBlock::ToolUse {
553 id: id.clone(),
554 call: call.clone(),
555 });
556 }
557 messages.push(AgentMessage {
558 role: Role::Assistant,
559 content,
560 });
561
562 if turn.tool_calls.is_empty() {
563 let final_text = extract_final_text(messages);
564 if let Some(sink) = hooks.turn_sink {
566 sink.after_turn(turns, messages).await?;
567 }
568 hooks.emit_event(|sid| RunEvent::TurnFinished {
569 session: sid,
570 turn: turns,
571 });
572 return Ok(RunOutcome { turns, final_text });
573 }
574 if turn.tool_calls.len() > config.max_tool_calls_per_turn {
575 let msg = format!(
576 "model issued {} tool calls in one turn (max {})",
577 turn.tool_calls.len(),
578 config.max_tool_calls_per_turn
579 );
580 hooks.emit_event(|sid| crate::event::run_failed(sid, msg.clone()));
581 return Err(CoreError::ModelResponse(msg));
582 }
583
584 let owned_calls: Vec<(String, ToolCall)> = turn.tool_calls.clone();
585
586 for (id, call) in &owned_calls {
588 hooks.emit_event(|sid| RunEvent::ToolStarted {
589 session: sid,
590 turn: turns,
591 tool: call.name.clone(),
592 call_id: id.clone(),
593 });
594 }
595
596 let results = execute_tool_calls(
597 tools,
598 &owned_calls,
599 cancel,
600 config.tool_concurrency,
601 hooks.policy,
602 )
603 .await;
604
605 for (i, (id, call)) in owned_calls.iter().enumerate() {
607 let result = &results[i];
608 let ok = tool_result_ok(result);
609 hooks.emit_event(|sid| RunEvent::ToolFinished {
610 session: sid,
611 turn: turns,
612 tool: call.name.clone(),
613 call_id: id.clone(),
614 ok,
615 });
616 let tool_msg = AgentMessage {
617 role: Role::Tool,
618 content: vec![ContentBlock::ToolResult {
619 tool_use_id: id.clone(),
620 content: serde_json::to_value(result)
621 .unwrap_or_else(|_| serde_json::json!({ "error": "serialize failed" })),
622 }],
623 };
624 messages.push(tool_msg);
625 }
626 if let Some(sink) = hooks.turn_sink {
628 sink.after_turn(turns, messages).await?;
629 }
630 hooks.emit_event(|sid| RunEvent::TurnFinished {
631 session: sid,
632 turn: turns,
633 });
634 }
635}
636
637const PANIC_SUMMARY_MAX_CHARS: usize = 200;
640
641fn summarize_panic(payload: &Box<dyn std::any::Any + Send>) -> String {
646 let raw = payload
647 .downcast_ref::<&'static str>()
648 .map(std::string::ToString::to_string)
649 .or_else(|| payload.downcast_ref::<String>().cloned())
650 .unwrap_or_else(|| "<non-string panic payload>".to_string());
651 let chars: Vec<char> = raw.chars().collect();
652 if chars.len() <= PANIC_SUMMARY_MAX_CHARS {
653 raw
654 } else {
655 let truncated: String = chars
656 .into_iter()
657 .take(PANIC_SUMMARY_MAX_CHARS - 1)
658 .collect();
659 format!("{truncated}…")
660 }
661}
662
663async fn execute_tool_call(
673 tools: &[Arc<dyn Tool>],
674 id: &str,
675 call: &ToolCall,
676 cancel: &CancellationToken,
677) -> ToolResult {
678 let Some(tool) = tools.iter().find(|t| t.definition().name == call.name) else {
679 return error_result(&format!("unknown tool: `{}`", call.name));
680 };
681 let ctx = InvokeContext {
682 tool_call_id: id.to_string(),
683 cancel: cancel.clone(),
684 };
685 use futures::FutureExt;
686 use std::panic::AssertUnwindSafe;
687 match AssertUnwindSafe(tool.execute(ctx, call.input.clone()))
688 .catch_unwind()
689 .await
690 {
691 Ok(Ok(result)) => result,
692 Ok(Err(err)) => error_result(&err.to_string()),
693 Err(payload) => {
694 let summary = summarize_panic(&payload);
695 tracing::warn!(
700 tool = %call.name,
701 call_id = %id,
702 "tool panicked; converted to model-visible error result"
703 );
704 error_result(&format!("tool `{}` panicked: {summary}", call.name))
705 }
706 }
707}
708
709async fn invoke_with_budget(
712 provider: &dyn ModelProvider,
713 request: ModelRequest,
714 turn_timeout_ms: Option<u64>,
715 cancel: &CancellationToken,
716) -> Result<crate::model::ModelResponse> {
717 if cancel.is_cancelled() {
719 return Err(CoreError::Cancelled("turn cancelled before invoke".into()));
720 }
721 let invoke_fut = provider.invoke(request);
722 match turn_timeout_ms {
723 Some(ms) => {
724 let timeout = tokio::time::timeout(std::time::Duration::from_millis(ms), invoke_fut);
725 tokio::select! {
726 biased;
727 _ = cancel.cancelled() => {
728 Err(CoreError::Cancelled("turn cancelled during invoke".into()))
729 }
730 res = timeout => {
731 res.map_err(|_| {
732 CoreError::Cancelled(format!(
733 "turn timed out after {ms}ms"
734 ))
735 })?
736 }
737 }
738 }
739 None => {
740 tokio::select! {
741 biased;
742 _ = cancel.cancelled() => {
743 Err(CoreError::Cancelled("turn cancelled during invoke".into()))
744 }
745 res = invoke_fut => res,
746 }
747 }
748 }
749}
750
751enum PolicyOutcome {
753 Execute,
755 Denied(ToolResult),
758}
759
760async fn policy_check(
766 policy: Option<&dyn ToolPolicy>,
767 id: &str,
768 call: &ToolCall,
769 cancel: &CancellationToken,
770) -> PolicyOutcome {
771 let Some(policy) = policy else {
772 return PolicyOutcome::Execute;
773 };
774 let ctx = InvokeContext {
775 tool_call_id: id.to_string(),
776 cancel: cancel.clone(),
777 };
778 match policy.check(&call.name, &call.input, &ctx).await {
779 PolicyVerdict::Allow => PolicyOutcome::Execute,
780 PolicyVerdict::Confirm(reason) => {
781 tracing::info!(
782 tool = %call.name,
783 call_id = %id,
784 "tool policy returned Confirm; treating as Allow for this run: {reason}"
785 );
786 PolicyOutcome::Execute
787 }
788 PolicyVerdict::Deny(reason) => {
789 PolicyOutcome::Denied(error_result(&format!("denied by policy: {reason}")))
790 }
791 }
792}
793
794async fn execute_tool_calls(
805 tools: &[Arc<dyn Tool>],
806 calls: &[(String, ToolCall)],
807 cancel: &CancellationToken,
808 tool_concurrency: usize,
809 policy: Option<&dyn ToolPolicy>,
810) -> Vec<ToolResult> {
811 if tool_concurrency <= 1 {
812 let mut out = Vec::with_capacity(calls.len());
813 for (id, call) in calls {
814 let result = match policy_check(policy, id, call, cancel).await {
815 PolicyOutcome::Execute => execute_tool_call(tools, id, call, cancel).await,
816 PolicyOutcome::Denied(result) => result,
817 };
818 out.push(result);
819 }
820 return out;
821 }
822
823 use tokio::task::JoinSet;
825 let mut indexed: Vec<Option<ToolResult>> = (0..calls.len()).map(|_| None).collect();
830 let mut set: JoinSet<(usize, ToolResult)> = JoinSet::new();
831 for (i, (id, call)) in calls.iter().enumerate() {
832 if let PolicyOutcome::Denied(result) = policy_check(policy, id, call, cancel).await {
837 if let Some(slot) = indexed.get_mut(i) {
838 *slot = Some(result);
839 }
840 continue;
841 }
842 let tool = tools
844 .iter()
845 .find(|t| t.definition().name == call.name)
846 .cloned();
847 let ctx_cancel = cancel.child_token();
848 let ctx = InvokeContext {
849 tool_call_id: id.clone(),
850 cancel: ctx_cancel,
851 };
852 let input = call.input.clone();
853 let id_owned = id.clone();
854 let call_name = call.name.clone();
855 set.spawn(async move {
856 let result = match tool {
862 Some(t) => {
863 use futures::FutureExt;
864 use std::panic::AssertUnwindSafe;
865 match AssertUnwindSafe(t.execute(ctx, input)).catch_unwind().await {
866 Ok(Ok(r)) => r,
867 Ok(Err(err)) => error_result(&err.to_string()),
868 Err(payload) => {
869 let summary = summarize_panic(&payload);
870 tracing::warn!(
871 tool = %call_name,
872 call_id = %id_owned,
873 "tool panicked; converted to model-visible error result"
874 );
875 error_result(&format!("tool `{call_name}` panicked: {summary}"))
876 }
877 }
878 }
879 None => error_result(&format!("unknown tool: `{id_owned}`")),
880 };
881 (i, result)
882 });
883 while set.len() >= tool_concurrency {
887 let res = set.join_next().await;
888 if res.is_none() {
889 break; }
891 record_join_result(res, &mut indexed);
892 }
893 }
894 while let Some(res) = set.join_next().await {
901 record_join_result(Some(res), &mut indexed);
902 }
903 indexed
904 .into_iter()
905 .map(|opt| opt.unwrap_or_else(|| error_result("tool task produced no result")))
906 .collect()
908}
909
910fn record_join_result(
915 res: Option<std::result::Result<(usize, ToolResult), tokio::task::JoinError>>,
916 indexed: &mut [Option<ToolResult>],
917) {
918 match res {
919 Some(Ok((i, result))) => {
920 if let Some(slot) = indexed.get_mut(i) {
921 *slot = Some(result);
922 }
923 }
924 Some(Err(join_err)) => {
925 let slot = indexed.iter().position(Option::is_none).unwrap_or(0);
926 if let Some(s) = indexed.get_mut(slot) {
927 *s = Some(error_result(&format!("tool task failed: {join_err}")));
928 }
929 }
930 None => {}
931 }
932}
933
934fn error_result(message: &str) -> ToolResult {
936 ToolResult {
937 content: vec![serde_json::json!({ "type": "text", "text": format!("Error: {message}") })],
938 details: None,
939 }
940}
941
942fn tool_result_ok(result: &ToolResult) -> bool {
946 !result.content.iter().any(|c| {
947 c.get("text")
948 .and_then(|t| t.as_str())
949 .is_some_and(|t| t.starts_with("Error:"))
950 })
951}
952
953fn extract_final_text(messages: &[AgentMessage]) -> String {
955 messages
956 .iter()
957 .rev()
958 .find(|m| m.role == Role::Assistant)
959 .map(|m| {
960 m.content
961 .iter()
962 .filter_map(|b| {
963 if let ContentBlock::Text { text } = b {
964 Some(text.as_str())
965 } else {
966 None
967 }
968 })
969 .collect::<Vec<_>>()
970 .join("")
971 })
972 .unwrap_or_default()
973}
974
975#[cfg(test)]
976mod tests {
977 use super::*;
983 use crate::model::ModelResponse;
984 use async_trait::async_trait;
985 use serde_json::{json, Value};
986
987 struct MockProvider {
989 responses: std::sync::Mutex<std::collections::VecDeque<ModelResponse>>,
990 }
991
992 impl MockProvider {
993 fn new(responses: Vec<Vec<AgentMessage>>) -> Self {
994 let responses = responses
995 .into_iter()
996 .map(|msgs| ModelResponse { messages: msgs })
997 .collect();
998 Self {
999 responses: std::sync::Mutex::new(responses),
1000 }
1001 }
1002 }
1003
1004 #[async_trait]
1005 impl ModelProvider for MockProvider {
1006 async fn invoke(&self, _request: ModelRequest) -> Result<ModelResponse> {
1007 let next = self
1008 .responses
1009 .lock()
1010 .unwrap()
1011 .pop_front()
1012 .unwrap_or(ModelResponse { messages: vec![] });
1013 Ok(next)
1014 }
1015 }
1016
1017 struct EchoTool;
1019
1020 #[async_trait]
1021 impl Tool for EchoTool {
1022 fn definition(&self) -> crate::tool::ToolDefinition {
1023 crate::tool::ToolDefinition {
1024 name: "echo".into(),
1025 label: "Echo".into(),
1026 description: "Echo back the provided text.".into(),
1027 parameters: crate::tool::ParameterSchema::default(),
1028 }
1029 }
1030
1031 async fn execute(&self, _ctx: InvokeContext, input: Value) -> Result<ToolResult> {
1032 let text = input
1033 .get("text")
1034 .and_then(Value::as_str)
1035 .unwrap_or("(no text)")
1036 .to_string();
1037 Ok(ToolResult {
1038 content: vec![json!({ "type": "text", "text": format!("echo: {text}") })],
1039 details: None,
1040 })
1041 }
1042 }
1043
1044 fn assistant_text(t: &str) -> AgentMessage {
1045 AgentMessage {
1046 role: Role::Assistant,
1047 content: vec![ContentBlock::Text { text: t.into() }],
1048 }
1049 }
1050
1051 fn assistant_tool_use(id: &str, name: &str, input: Value) -> AgentMessage {
1052 AgentMessage {
1053 role: Role::Assistant,
1054 content: vec![ContentBlock::ToolUse {
1055 id: id.into(),
1056 call: ToolCall {
1057 name: name.into(),
1058 input,
1059 },
1060 }],
1061 }
1062 }
1063
1064 fn user(t: &str) -> AgentMessage {
1065 AgentMessage {
1066 role: Role::User,
1067 content: vec![ContentBlock::Text { text: t.into() }],
1068 }
1069 }
1070
1071 #[tokio::test]
1072 async fn loop_runs_tool_then_finishes() {
1073 let provider = MockProvider::new(vec![
1075 vec![assistant_tool_use(
1076 "call_1",
1077 "echo",
1078 json!({ "text": "hello" }),
1079 )],
1080 vec![assistant_text("done")],
1081 ]);
1082 let tools: Vec<Arc<dyn Tool>> = vec![Arc::new(EchoTool)];
1083 let model = Model::new("mock/test");
1084 let mut messages = vec![user("please echo hello then say done")];
1085
1086 let outcome = run_agent(
1087 &provider,
1088 &tools,
1089 &mut messages,
1090 &model,
1091 &RunConfig::default(),
1092 &CancellationToken::new(),
1093 &RunHooks::default(),
1094 )
1095 .await
1096 .expect("loop should complete");
1097
1098 assert_eq!(outcome.turns, 2);
1099 assert_eq!(outcome.final_text, "done");
1100
1101 assert_eq!(messages.len(), 4);
1103 assert_eq!(messages[2].role, Role::Tool);
1104 match &messages[2].content[0] {
1106 ContentBlock::ToolResult { content, .. } => {
1107 let s = serde_json::to_string(content).unwrap_or_default();
1108 assert!(s.contains("echo: hello"), "tool result was: {s}");
1109 }
1110 other => panic!("expected ToolResult, got {other:?}"),
1111 }
1112 }
1113
1114 #[tokio::test]
1115 async fn loop_stops_when_no_tool_calls() {
1116 let provider = MockProvider::new(vec![vec![assistant_text("just text, no tools")]]);
1117 let tools: Vec<Arc<dyn Tool>> = vec![];
1118 let model = Model::new("mock/test");
1119 let mut messages = vec![user("hi")];
1120
1121 let outcome = run_agent(
1122 &provider,
1123 &tools,
1124 &mut messages,
1125 &model,
1126 &RunConfig::default(),
1127 &CancellationToken::new(),
1128 &RunHooks::default(),
1129 )
1130 .await
1131 .expect("loop should complete");
1132
1133 assert_eq!(outcome.turns, 1);
1134 assert_eq!(outcome.final_text, "just text, no tools");
1135 }
1136
1137 #[tokio::test]
1138 async fn loop_recovers_from_unknown_tool() {
1139 let provider = MockProvider::new(vec![
1142 vec![assistant_tool_use("c1", "nonexistent", json!({}))],
1143 vec![assistant_text("recovered")],
1144 ]);
1145 let tools: Vec<Arc<dyn Tool>> = vec![Arc::new(EchoTool)];
1146 let model = Model::new("mock/test");
1147 let mut messages = vec![user("call a missing tool")];
1148
1149 let outcome = run_agent(
1150 &provider,
1151 &tools,
1152 &mut messages,
1153 &model,
1154 &RunConfig::default(),
1155 &CancellationToken::new(),
1156 &RunHooks::default(),
1157 )
1158 .await
1159 .expect("loop should recover");
1160
1161 assert_eq!(outcome.final_text, "recovered");
1162 let tool_msg = &messages[2];
1163 assert_eq!(tool_msg.role, Role::Tool);
1164 }
1165
1166 #[tokio::test]
1167 async fn loop_aborts_on_max_turns() {
1168 let repeat = || vec![assistant_tool_use("c", "echo", json!({ "text": "x" }))];
1170 let provider = MockProvider::new(vec![repeat(), repeat(), repeat(), repeat()]);
1171 let tools: Vec<Arc<dyn Tool>> = vec![Arc::new(EchoTool)];
1172 let model = Model::new("mock/test");
1173 let mut messages = vec![user("loop forever")];
1174
1175 let result = run_agent(
1176 &provider,
1177 &tools,
1178 &mut messages,
1179 &model,
1180 &RunConfig {
1181 max_turns: 3,
1182 ..RunConfig::default()
1183 },
1184 &CancellationToken::new(),
1185 &RunHooks::default(),
1186 )
1187 .await;
1188
1189 assert!(result.is_err(), "must abort on max_turns");
1190 }
1191
1192 struct DenyAllPolicy;
1194
1195 #[async_trait]
1196 impl ToolPolicy for DenyAllPolicy {
1197 async fn check(&self, _tool: &str, _input: &Value, _ctx: &InvokeContext) -> PolicyVerdict {
1198 PolicyVerdict::Deny("blocked in test".into())
1199 }
1200 }
1201
1202 #[tokio::test]
1203 async fn policy_deny_blocks_tool_but_run_continues() {
1204 let provider = MockProvider::new(vec![
1208 vec![assistant_tool_use(
1209 "c1",
1210 "echo",
1211 json!({ "text": "secret" }),
1212 )],
1213 vec![assistant_text("done")],
1214 ]);
1215 let tools: Vec<Arc<dyn Tool>> = vec![Arc::new(EchoTool)];
1216 let model = Model::new("mock/test");
1217 let mut messages = vec![user("call echo")];
1218 let policy = DenyAllPolicy;
1219 let hooks = RunHooks {
1220 policy: Some(&policy),
1221 ..RunHooks::default()
1222 };
1223
1224 let outcome = run_agent(
1225 &provider,
1226 &tools,
1227 &mut messages,
1228 &model,
1229 &RunConfig::default(),
1230 &CancellationToken::new(),
1231 &hooks,
1232 )
1233 .await
1234 .expect("loop completes despite denial");
1235
1236 assert_eq!(outcome.final_text, "done");
1237 let s = match &messages[2].content[0] {
1239 ContentBlock::ToolResult { content, .. } => content.to_string(),
1240 other => panic!("expected ToolResult, got {other:?}"),
1241 };
1242 assert!(s.contains("denied by policy"), "expected denial, got: {s}");
1243 assert!(
1244 !s.contains("echo: secret"),
1245 "denied tool must NOT have executed: {s}"
1246 );
1247 }
1248
1249 #[tokio::test]
1250 async fn policy_none_is_allow_all() {
1251 let provider = MockProvider::new(vec![
1254 vec![assistant_tool_use("c1", "echo", json!({ "text": "hi" }))],
1255 vec![assistant_text("done")],
1256 ]);
1257 let tools: Vec<Arc<dyn Tool>> = vec![Arc::new(EchoTool)];
1258 let model = Model::new("mock/test");
1259 let mut messages = vec![user("call echo")];
1260
1261 let outcome = run_agent(
1262 &provider,
1263 &tools,
1264 &mut messages,
1265 &model,
1266 &RunConfig::default(),
1267 &CancellationToken::new(),
1268 &RunHooks::default(),
1269 )
1270 .await
1271 .expect("loop completes");
1272 assert_eq!(outcome.final_text, "done");
1273 let s = match &messages[2].content[0] {
1274 ContentBlock::ToolResult { content, .. } => content.to_string(),
1275 other => panic!("expected ToolResult, got {other:?}"),
1276 };
1277 assert!(s.contains("echo: hi"), "tool should have run: {s}");
1278 }
1279
1280 struct RecordingTool {
1285 name: String,
1286 log: Arc<std::sync::Mutex<Vec<String>>>,
1287 }
1288
1289 #[async_trait]
1290 impl Tool for RecordingTool {
1291 fn definition(&self) -> crate::tool::ToolDefinition {
1292 crate::tool::ToolDefinition {
1293 name: self.name.clone(),
1294 label: "Recording".into(),
1295 description: "Records each execution.".into(),
1296 parameters: crate::tool::ParameterSchema::default(),
1297 }
1298 }
1299
1300 async fn execute(&self, _ctx: InvokeContext, input: Value) -> Result<ToolResult> {
1301 let tag = input
1302 .get("tag")
1303 .and_then(Value::as_str)
1304 .unwrap_or("?")
1305 .to_string();
1306 self.log.lock().expect("lock poisoned").push(tag);
1307 Ok(ToolResult {
1308 content: vec![json!({ "type": "text", "text": "ran" })],
1309 details: None,
1310 })
1311 }
1312 }
1313
1314 #[tokio::test]
1324 async fn policy_deny_blocks_tools_on_the_parallel_path() {
1325 let log = Arc::new(std::sync::Mutex::new(Vec::new()));
1328 let tools: Vec<Arc<dyn Tool>> = vec![Arc::new(RecordingTool {
1329 name: "rec".into(),
1330 log: log.clone(),
1331 })];
1332 let turn = vec![
1333 assistant_tool_use("c1", "rec", json!({ "tag": "one" })),
1334 assistant_tool_use("c2", "rec", json!({ "tag": "two" })),
1335 assistant_tool_use("c3", "rec", json!({ "tag": "three" })),
1336 ];
1337 let provider = MockProvider::new(vec![turn, vec![assistant_text("done")]]);
1338 let model = Model::new("mock/test");
1339 let mut messages = vec![user("call all three")];
1340 let config = RunConfig {
1341 tool_concurrency: 4,
1342 ..RunConfig::default()
1343 };
1344 let policy = DenyAllPolicy;
1345 let hooks = RunHooks {
1346 policy: Some(&policy),
1347 ..RunHooks::default()
1348 };
1349
1350 let outcome = run_agent(
1351 &provider,
1352 &tools,
1353 &mut messages,
1354 &model,
1355 &config,
1356 &CancellationToken::new(),
1357 &hooks,
1358 )
1359 .await
1360 .expect("loop completes despite denials");
1361
1362 assert_eq!(outcome.final_text, "done");
1363
1364 let executed = log.lock().expect("lock poisoned").clone();
1368 assert!(
1369 executed.is_empty(),
1370 "denied tools must NOT execute on the parallel path: ran {executed:?}"
1371 );
1372
1373 let results: Vec<String> = messages
1376 .iter()
1377 .filter(|m| m.role == Role::Tool)
1378 .filter_map(|m| match &m.content[0] {
1379 ContentBlock::ToolResult {
1380 tool_use_id,
1381 content,
1382 ..
1383 } => {
1384 let text = content.to_string();
1385 Some(format!("{tool_use_id}:{text}"))
1386 }
1387 _ => None,
1388 })
1389 .collect();
1390 assert_eq!(
1391 results.len(),
1392 3,
1393 "all 3 denied calls must produce a result slot: {results:?}"
1394 );
1395 for r in &results {
1396 assert!(
1397 r.contains("denied by policy"),
1398 "parallel-path denial must surface to the model: {r}"
1399 );
1400 }
1401 assert!(
1404 results[0].starts_with("c1:")
1405 && results[1].starts_with("c2:")
1406 && results[2].starts_with("c3:"),
1407 "denial slots must preserve issued order: {results:?}"
1408 );
1409 }
1410
1411 #[tokio::test]
1412 async fn loop_respects_cancellation() {
1413 let provider = MockProvider::new(vec![vec![assistant_text("never reached")]]);
1415 let tools: Vec<Arc<dyn Tool>> = vec![];
1416 let model = Model::new("mock/test");
1417 let mut messages = vec![user("hi")];
1418 let cancel = CancellationToken::new();
1419 cancel.cancel();
1420
1421 let result = run_agent(
1422 &provider,
1423 &tools,
1424 &mut messages,
1425 &model,
1426 &RunConfig::default(),
1427 &cancel,
1428 &RunHooks::default(),
1429 )
1430 .await;
1431
1432 assert!(matches!(result, Err(CoreError::Cancelled(_))));
1433 }
1434
1435 struct SlowProvider {
1438 delay_ms: u64,
1439 responses: std::sync::Mutex<std::collections::VecDeque<ModelResponse>>,
1440 }
1441
1442 impl SlowProvider {
1443 fn new(delay_ms: u64, responses: Vec<Vec<AgentMessage>>) -> Self {
1444 let responses = responses
1445 .into_iter()
1446 .map(|m| ModelResponse { messages: m })
1447 .collect();
1448 Self {
1449 delay_ms,
1450 responses: std::sync::Mutex::new(responses),
1451 }
1452 }
1453 }
1454
1455 #[async_trait]
1456 impl ModelProvider for SlowProvider {
1457 async fn invoke(&self, _request: ModelRequest) -> Result<ModelResponse> {
1458 tokio::time::sleep(std::time::Duration::from_millis(self.delay_ms)).await;
1459 let next = self
1460 .responses
1461 .lock()
1462 .unwrap()
1463 .pop_front()
1464 .unwrap_or(ModelResponse { messages: vec![] });
1465 Ok(next)
1466 }
1467 }
1468
1469 struct SlowStreamingProvider {
1473 delay_ms: u64,
1474 }
1475
1476 #[async_trait]
1477 impl ModelProvider for SlowStreamingProvider {
1478 async fn invoke(&self, _request: ModelRequest) -> Result<ModelResponse> {
1479 Ok(ModelResponse { messages: vec![] })
1481 }
1482 fn stream(&self, _request: ModelRequest) -> crate::model::StreamEventStream {
1483 use futures::stream::StreamExt as _;
1484 let delay = self.delay_ms;
1485 Box::pin(
1486 futures::stream::once(async move {
1487 tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
1488 Ok(StreamEvent::TextDelta("finally".to_string()))
1489 })
1490 .chain(futures::stream::once(async { Ok(StreamEvent::Done) })),
1491 )
1492 }
1493 }
1494
1495 #[tokio::test]
1496 async fn streaming_turn_times_out_on_slow_provider() {
1497 let provider = SlowStreamingProvider { delay_ms: 500 };
1501 let model = Model::new("mock/test");
1502 let mut messages = vec![user("hi")];
1503 let config = RunConfig {
1504 turn_timeout_ms: Some(100),
1505 ..RunConfig::default()
1506 };
1507 let mut events: Vec<StreamEvent> = Vec::new();
1508 let mut on_event = |ev: &StreamEvent| {
1509 events.push(ev.clone());
1510 };
1511 let result = run_agent_streaming(
1512 &provider,
1513 &[],
1514 &mut messages,
1515 &model,
1516 &config,
1517 &CancellationToken::new(),
1518 &mut on_event,
1519 &RunHooks::default(),
1520 )
1521 .await;
1522 assert!(
1523 matches!(result, Err(CoreError::Cancelled(_))),
1524 "expected a streaming timeout cancellation, got {result:?}"
1525 );
1526 }
1527
1528 #[tokio::test]
1529 async fn streaming_turn_aborts_on_cancel() {
1530 let provider = SlowStreamingProvider { delay_ms: 60_000 };
1533 let model = Model::new("mock/test");
1534 let mut messages = vec![user("hi")];
1535 let cancel = CancellationToken::new();
1536 let cancel_for_run = cancel.clone();
1537 tokio::spawn(async move {
1539 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1540 cancel_for_run.cancel();
1541 });
1542 let result = run_agent_streaming(
1543 &provider,
1544 &[],
1545 &mut messages,
1546 &model,
1547 &RunConfig::default(),
1548 &cancel,
1549 &mut |_| {},
1550 &RunHooks::default(),
1551 )
1552 .await;
1553 assert!(
1554 matches!(result, Err(CoreError::Cancelled(_))),
1555 "expected cancellation to abort the streaming run, got {result:?}"
1556 );
1557 }
1558
1559 #[tokio::test]
1560 async fn turn_timeout_aborts_slow_provider() {
1561 let provider = SlowProvider::new(500, vec![vec![assistant_text("too slow")]]);
1563 let model = Model::new("mock/test");
1564 let mut messages = vec![user("hi")];
1565 let config = RunConfig {
1566 turn_timeout_ms: Some(100),
1567 ..RunConfig::default()
1568 };
1569
1570 let result = run_agent(
1571 &provider,
1572 &[],
1573 &mut messages,
1574 &model,
1575 &config,
1576 &CancellationToken::new(),
1577 &RunHooks::default(),
1578 )
1579 .await;
1580
1581 assert!(
1582 matches!(result, Err(CoreError::Cancelled(_))),
1583 "expected cancelled, got {result:?}"
1584 );
1585 }
1586
1587 #[tokio::test]
1588 async fn max_tool_calls_per_turn_rejects_runaway_response() {
1589 let runaway: Vec<AgentMessage> = (0..5)
1591 .map(|i| assistant_tool_use(&format!("c{i}"), "echo", json!({ "text": "x" })))
1592 .collect();
1593 let provider = MockProvider::new(vec![runaway]);
1594 let tools: Vec<Arc<dyn Tool>> = vec![Arc::new(EchoTool)];
1595 let model = Model::new("mock/test");
1596 let mut messages = vec![user("call many tools")];
1597 let config = RunConfig {
1598 max_tool_calls_per_turn: 2,
1599 ..RunConfig::default()
1600 };
1601
1602 let result = run_agent(
1603 &provider,
1604 &tools,
1605 &mut messages,
1606 &model,
1607 &config,
1608 &CancellationToken::new(),
1609 &RunHooks::default(),
1610 )
1611 .await;
1612
1613 assert!(result.is_err(), "runaway tool calls must be rejected");
1614 let err = result.unwrap_err().to_string();
1615 assert!(err.contains("max"), "error should mention the cap: {err}");
1616 }
1617
1618 struct OrderingTool {
1620 name: String,
1621 delay_ms: u64,
1622 log: Arc<std::sync::Mutex<Vec<String>>>,
1623 }
1624
1625 #[async_trait]
1626 impl Tool for OrderingTool {
1627 fn definition(&self) -> crate::tool::ToolDefinition {
1628 crate::tool::ToolDefinition {
1629 name: self.name.clone(),
1630 label: "Ordering".into(),
1631 description: "Records completion order.".into(),
1632 parameters: crate::tool::ParameterSchema::default(),
1633 }
1634 }
1635
1636 async fn execute(&self, _ctx: InvokeContext, input: Value) -> Result<ToolResult> {
1637 tokio::time::sleep(std::time::Duration::from_millis(self.delay_ms)).await;
1638 self.log.lock().unwrap().push(
1639 input
1640 .get("tag")
1641 .and_then(Value::as_str)
1642 .unwrap_or("?")
1643 .to_string(),
1644 );
1645 Ok(ToolResult {
1646 content: vec![json!({ "type": "text", "text": "ok" })],
1647 details: None,
1648 })
1649 }
1650 }
1651
1652 #[tokio::test]
1653 async fn parallel_tool_calls_preserve_result_order() {
1654 let log = Arc::new(std::sync::Mutex::new(Vec::new()));
1658 let tools: Vec<Arc<dyn Tool>> = vec![
1659 Arc::new(OrderingTool {
1660 name: "slow".into(),
1661 delay_ms: 60,
1662 log: log.clone(),
1663 }),
1664 Arc::new(OrderingTool {
1665 name: "fast".into(),
1666 delay_ms: 5,
1667 log: log.clone(),
1668 }),
1669 ];
1670 let turn = vec![
1671 assistant_tool_use("c1", "slow", json!({ "tag": "slow" })),
1672 assistant_tool_use("c2", "fast", json!({ "tag": "fast" })),
1673 ];
1674 let provider = MockProvider::new(vec![turn, vec![assistant_text("done")]]);
1675 let model = Model::new("mock/test");
1676 let mut messages = vec![user("call both")];
1677 let config = RunConfig {
1678 tool_concurrency: 4,
1679 ..RunConfig::default()
1680 };
1681
1682 let outcome = run_agent(
1683 &provider,
1684 &tools,
1685 &mut messages,
1686 &model,
1687 &config,
1688 &CancellationToken::new(),
1689 &RunHooks::default(),
1690 )
1691 .await
1692 .expect("loop should complete");
1693
1694 assert_eq!(outcome.final_text, "done");
1695
1696 let completed = log.lock().unwrap().clone();
1699 assert_eq!(
1700 completed,
1701 vec!["fast", "slow"],
1702 "tools must have run concurrently: {completed:?}"
1703 );
1704
1705 let tool_ids: Vec<String> = messages
1707 .iter()
1708 .filter(|m| m.role == Role::Tool)
1709 .filter_map(|m| match &m.content[0] {
1710 ContentBlock::ToolResult { tool_use_id, .. } => Some(tool_use_id.clone()),
1711 _ => None,
1712 })
1713 .collect();
1714 assert_eq!(
1715 tool_ids,
1716 vec!["c1", "c2"],
1717 "results must be appended in issued order: {tool_ids:?}"
1718 );
1719 }
1720
1721 struct PanickingTool;
1724
1725 #[async_trait]
1726 impl Tool for PanickingTool {
1727 fn definition(&self) -> crate::tool::ToolDefinition {
1728 crate::tool::ToolDefinition {
1729 name: "boom".into(),
1730 label: "Boom".into(),
1731 description: "Always panics.".into(),
1732 parameters: crate::tool::ParameterSchema::default(),
1733 }
1734 }
1735
1736 async fn execute(&self, _ctx: InvokeContext, _input: Value) -> Result<ToolResult> {
1737 panic!("deliberate tool panic");
1738 }
1739 }
1740
1741 #[tokio::test]
1742 async fn parallel_path_survives_a_task_panic() {
1743 let turn = vec![
1746 assistant_tool_use("c1", "boom", json!({})),
1747 assistant_tool_use("c2", "echo", json!({ "text": "survived" })),
1748 ];
1749 let provider = MockProvider::new(vec![turn, vec![assistant_text("done")]]);
1750 let tools: Vec<Arc<dyn Tool>> = vec![Arc::new(PanickingTool), Arc::new(EchoTool)];
1751 let model = Model::new("mock/test");
1752 let mut messages = vec![user("call both")];
1753 let config = RunConfig {
1754 tool_concurrency: 4,
1755 ..RunConfig::default()
1756 };
1757
1758 let outcome = run_agent(
1760 &provider,
1761 &tools,
1762 &mut messages,
1763 &model,
1764 &config,
1765 &CancellationToken::new(),
1766 &RunHooks::default(),
1767 )
1768 .await
1769 .expect("loop must survive a tool panic");
1770
1771 assert_eq!(outcome.final_text, "done");
1772 let tool_ids: Vec<String> = messages
1774 .iter()
1775 .filter(|m| m.role == Role::Tool)
1776 .filter_map(|m| match &m.content[0] {
1777 ContentBlock::ToolResult { tool_use_id, .. } => Some(tool_use_id.clone()),
1778 _ => None,
1779 })
1780 .collect();
1781 assert_eq!(
1782 tool_ids,
1783 vec!["c1", "c2"],
1784 "both results must be present despite the panic: {tool_ids:?}"
1785 );
1786 }
1787
1788 #[tokio::test]
1794 async fn sequential_path_survives_a_tool_panic() {
1795 let turn = vec![
1796 assistant_tool_use("c1", "boom", json!({})),
1797 assistant_tool_use("c2", "echo", json!({ "text": "survived" })),
1798 ];
1799 let provider = MockProvider::new(vec![turn, vec![assistant_text("done")]]);
1800 let tools: Vec<Arc<dyn Tool>> = vec![Arc::new(PanickingTool), Arc::new(EchoTool)];
1801 let model = Model::new("mock/test");
1802 let mut messages = vec![user("call both")];
1803 let config = RunConfig::default();
1805
1806 let outcome = run_agent(
1808 &provider,
1809 &tools,
1810 &mut messages,
1811 &model,
1812 &config,
1813 &CancellationToken::new(),
1814 &RunHooks::default(),
1815 )
1816 .await
1817 .expect("sequential path must survive a tool panic");
1818 assert_eq!(outcome.final_text, "done");
1819
1820 let results: Vec<&ContentBlock> = messages
1822 .iter()
1823 .filter(|m| m.role == Role::Tool)
1824 .flat_map(|m| m.content.iter())
1825 .collect();
1826 assert_eq!(results.len(), 2, "both results appended");
1827 let c1_str = match &results[0] {
1829 ContentBlock::ToolResult { content, .. } => content.to_string(),
1830 _ => String::new(),
1831 };
1832 assert!(
1833 c1_str.contains("Error:"),
1834 "panic must surface as an Error: result, got: {c1_str}"
1835 );
1836 assert!(
1837 c1_str.contains("panicked"),
1838 "error result should mention the panic: {c1_str}"
1839 );
1840 }
1841
1842 #[tokio::test]
1848 async fn parallel_path_panic_preserves_call_id_and_summary() {
1849 let turn = vec![
1851 assistant_tool_use("c1", "boom", json!({})),
1852 assistant_tool_use("c2", "echo", json!({ "text": "ok" })),
1853 ];
1854 let provider = MockProvider::new(vec![turn, vec![assistant_text("done")]]);
1855 let tools: Vec<Arc<dyn Tool>> = vec![Arc::new(PanickingTool), Arc::new(EchoTool)];
1856 let model = Model::new("mock/test");
1857 let mut messages = vec![user("call both")];
1858 let config = RunConfig {
1859 tool_concurrency: 4,
1860 ..RunConfig::default()
1861 };
1862
1863 let outcome = run_agent(
1864 &provider,
1865 &tools,
1866 &mut messages,
1867 &model,
1868 &config,
1869 &CancellationToken::new(),
1870 &RunHooks::default(),
1871 )
1872 .await
1873 .expect("run survives parallel panic");
1874 assert_eq!(outcome.final_text, "done");
1875
1876 let tool_msgs: Vec<(&String, String)> = messages
1879 .iter()
1880 .filter(|m| m.role == Role::Tool)
1881 .flat_map(|m| m.content.iter())
1882 .filter_map(|b| match b {
1883 ContentBlock::ToolResult {
1884 tool_use_id,
1885 content,
1886 } => Some((tool_use_id, content.to_string())),
1887 _ => None,
1888 })
1889 .collect();
1890 assert_eq!(tool_msgs.len(), 2, "both results present");
1891 assert_eq!(tool_msgs[0].0, "c1", "c1 attributed correctly");
1893 assert_eq!(tool_msgs[1].0, "c2", "c2 attributed correctly");
1894 assert!(
1896 tool_msgs[0].1.contains("panicked"),
1897 "parallel panic should carry bounded summary, got: {}",
1898 tool_msgs[0].1
1899 );
1900 assert!(
1901 tool_msgs[0].1.contains("Error:"),
1902 "should be an Error: result, got: {}",
1903 tool_msgs[0].1
1904 );
1905 }
1906
1907 #[tokio::test]
1914 async fn parallel_path_keeps_all_results_under_throttling() {
1915 let turn = vec![
1919 assistant_tool_use("c1", "echo", json!({ "text": "one" })),
1920 assistant_tool_use("c2", "echo", json!({ "text": "two" })),
1921 assistant_tool_use("c3", "echo", json!({ "text": "three" })),
1922 ];
1923 let provider = MockProvider::new(vec![turn, vec![assistant_text("done")]]);
1924 let tools: Vec<Arc<dyn Tool>> = vec![Arc::new(EchoTool)];
1925 let model = Model::new("mock/test");
1926 let mut messages = vec![user("call all three")];
1927 let config = RunConfig {
1928 tool_concurrency: 2,
1929 ..RunConfig::default()
1930 };
1931
1932 let outcome = run_agent(
1933 &provider,
1934 &tools,
1935 &mut messages,
1936 &model,
1937 &config,
1938 &CancellationToken::new(),
1939 &RunHooks::default(),
1940 )
1941 .await
1942 .expect("run completes");
1943 assert_eq!(outcome.final_text, "done");
1944
1945 let results: Vec<String> = messages
1949 .iter()
1950 .filter(|m| m.role == Role::Tool)
1951 .flat_map(|m| m.content.iter())
1952 .filter_map(|b| match b {
1953 ContentBlock::ToolResult {
1954 tool_use_id,
1955 content,
1956 } => {
1957 let text = content
1958 .get("content")
1959 .and_then(|c| c.get(0))
1960 .and_then(|c| c.get("text"))
1961 .and_then(|t| t.as_str())
1962 .unwrap_or("<missing>");
1963 Some(format!("{tool_use_id}={text}"))
1964 }
1965 _ => None,
1966 })
1967 .collect();
1968 assert_eq!(
1969 results,
1970 vec!["c1=echo: one", "c2=echo: two", "c3=echo: three"],
1971 "all 3 results must survive throttling, in order, with correct text: {results:?}"
1972 );
1973 }
1974
1975 #[test]
1976 fn summarize_panic_handles_string_payloads() {
1977 let p: Box<dyn std::any::Any + Send> = Box::new("boom!".to_string());
1978 assert_eq!(summarize_panic(&p), "boom!");
1979 }
1980
1981 #[test]
1982 fn summarize_panic_handles_str_payloads() {
1983 let s: &'static str = "static boom";
1984 let p: Box<dyn std::any::Any + Send> = Box::new(s);
1985 assert_eq!(summarize_panic(&p), "static boom");
1986 }
1987
1988 #[test]
1989 fn summarize_panic_bounds_huge_payloads() {
1990 let huge = "x".repeat(10_000);
1991 let p: Box<dyn std::any::Any + Send> = Box::new(huge);
1992 let summary = summarize_panic(&p);
1993 assert!(
1994 summary.chars().count() <= PANIC_SUMMARY_MAX_CHARS,
1995 "summary not bounded: {} chars",
1996 summary.chars().count()
1997 );
1998 assert!(
1999 summary.ends_with('…'),
2000 "should end with ellipsis: {summary}"
2001 );
2002 }
2003
2004 #[test]
2005 fn summarize_panic_falls_back_for_non_string_payloads() {
2006 let p: Box<dyn std::any::Any + Send> = Box::new(42_i32);
2007 let summary = summarize_panic(&p);
2008 assert!(
2009 summary.contains("non-string"),
2010 "expected fallback marker: {summary}"
2011 );
2012 }
2013
2014 use crate::event::{EventSink, RunEvent};
2017 use std::sync::{Arc, Mutex};
2018 use uuid::Uuid;
2019
2020 struct RecordingSink {
2022 events: Arc<Mutex<Vec<RunEvent>>>,
2023 }
2024
2025 impl EventSink for RecordingSink {
2026 fn emit(&self, event: RunEvent) {
2027 self.events.lock().expect("lock poisoned").push(event);
2028 }
2029 }
2030
2031 #[tokio::test]
2032 async fn text_only_run_emits_complete_event_sequence() {
2033 let provider = MockProvider::new(vec![vec![assistant_text("hello")]]);
2034 let tools: Vec<Arc<dyn Tool>> = vec![];
2035 let model = Model::new("mock/test");
2036 let mut messages = vec![user("hi")];
2037 let sink = Arc::new(Mutex::new(Vec::new()));
2038 let hooks = RunHooks {
2039 session_id: Some(Uuid::nil()),
2040 turn_sink: None,
2041 event_sink: Some(&RecordingSink {
2042 events: sink.clone(),
2043 } as &dyn EventSink),
2044 policy: None,
2045 };
2046
2047 run_agent(
2048 &provider,
2049 &tools,
2050 &mut messages,
2051 &model,
2052 &RunConfig::default(),
2053 &CancellationToken::new(),
2054 &hooks,
2055 )
2056 .await
2057 .expect("run");
2058
2059 let events = sink.lock().expect("lock poisoned").clone();
2060 assert!(events
2062 .iter()
2063 .any(|e| matches!(e, RunEvent::SessionStarted { .. })));
2064 assert!(events
2065 .iter()
2066 .any(|e| matches!(e, RunEvent::TurnStarted { turn: 1, .. })));
2067 assert!(events.iter().any(
2068 |e| matches!(e, RunEvent::ModelStarted { turn: 1, model, .. } if model == "mock/test")
2069 ));
2070 assert!(events
2071 .iter()
2072 .any(|e| matches!(e, RunEvent::ModelFinished { turn: 1, .. })));
2073 assert!(events
2074 .iter()
2075 .any(|e| matches!(e, RunEvent::TurnFinished { turn: 1, .. })));
2076 assert!(!events
2078 .iter()
2079 .any(|e| matches!(e, RunEvent::ToolStarted { .. })));
2080 }
2081
2082 #[tokio::test]
2083 async fn tool_run_emits_tool_started_finished() {
2084 let echo_tool = Arc::new(EchoTool) as Arc<dyn Tool>;
2085 let tools = vec![echo_tool.clone()];
2086 let provider = MockProvider::new(vec![
2087 vec![assistant_tool_use(
2088 "call-1",
2089 "echo",
2090 json!({ "text": "hi" }),
2091 )],
2092 vec![assistant_text("done")],
2093 ]);
2094 let model = Model::new("mock/test");
2095 let mut messages = vec![user("echo hi")];
2096 let sink = Arc::new(Mutex::new(Vec::new()));
2097 let hooks = RunHooks {
2098 session_id: Some(Uuid::nil()),
2099 turn_sink: None,
2100 event_sink: Some(&RecordingSink {
2101 events: sink.clone(),
2102 } as &dyn EventSink),
2103 policy: None,
2104 };
2105
2106 run_agent(
2107 &provider,
2108 &tools,
2109 &mut messages,
2110 &model,
2111 &RunConfig::default(),
2112 &CancellationToken::new(),
2113 &hooks,
2114 )
2115 .await
2116 .expect("run");
2117
2118 let events = sink.lock().expect("lock poisoned").clone();
2119 assert!(
2121 events.iter().any(|e| matches!(e, RunEvent::ToolStarted { turn: 1, tool, call_id, .. } if tool == "echo" && call_id == "call-1")),
2122 "missing ToolStarted for echo/call-1"
2123 );
2124 assert!(
2125 events.iter().any(|e| matches!(e, RunEvent::ToolFinished { turn: 1, tool, call_id, ok: true, .. } if tool == "echo" && call_id == "call-1")),
2126 "missing ToolFinished for echo/call-1"
2127 );
2128 assert!(events
2130 .iter()
2131 .any(|e| matches!(e, RunEvent::TurnFinished { turn: 2, .. })));
2132 }
2133
2134 #[tokio::test]
2135 async fn no_events_when_session_id_is_none() {
2136 let provider = MockProvider::new(vec![vec![assistant_text("hello")]]);
2137 let tools: Vec<Arc<dyn Tool>> = vec![];
2138 let model = Model::new("mock/test");
2139 let mut messages = vec![user("hi")];
2140 let sink = Arc::new(Mutex::new(Vec::new()));
2141 let hooks = RunHooks {
2142 session_id: None, turn_sink: None,
2144 event_sink: Some(&RecordingSink {
2145 events: sink.clone(),
2146 } as &dyn EventSink),
2147 policy: None,
2148 };
2149
2150 run_agent(
2151 &provider,
2152 &tools,
2153 &mut messages,
2154 &model,
2155 &RunConfig::default(),
2156 &CancellationToken::new(),
2157 &hooks,
2158 )
2159 .await
2160 .expect("run");
2161
2162 assert!(
2163 sink.lock().expect("lock poisoned").is_empty(),
2164 "events emitted with no session_id"
2165 );
2166 }
2167}