1use crate::agent::extension::{Cancel, Extension, ToolOutput};
2use crate::agent::provider::{Provider, StopReason, StreamEvent, ToolDef};
3use crate::agent::types::{AgentMessage, PendingMessageQueue, Role, ToolCall, ToolExecutionMode};
4use futures::future::join_all;
5
6pub fn collect_tool_defs(extensions: &[Box<dyn Extension>]) -> Vec<ToolDef> {
8 let mut defs = Vec::new();
9 for ext in extensions {
10 for tool in ext.tools() {
11 if !defs.iter().any(|d: &ToolDef| d.name == tool.name()) {
12 defs.push(ToolDef {
13 name: tool.name().to_string(),
14 description: tool.description().to_string(),
15 parameters: tool.parameters(),
16 });
17 }
18 }
19 }
20 defs
21}
22
23#[derive(Debug, Clone)]
25#[allow(dead_code)]
26pub enum AgentEvent {
27 AgentStart,
28 TurnStart,
29 TextDelta {
30 delta: String,
31 },
32 ThinkingDelta {
33 delta: String,
34 },
35 ToolCall {
36 id: String,
37 name: String,
38 args: serde_json::Value,
39 },
40 ToolCallArgsUpdate {
42 id: String,
43 args: serde_json::Value,
44 },
45 ToolResult {
46 id: String,
47 name: String,
48 content: String,
49 compact: Option<String>,
50 is_error: bool,
51 },
52 ToolProgress {
54 content: String,
55 is_error: bool,
56 },
57 Aborted {
59 reason: String,
60 },
61 UserMessage {
63 content: String,
64 },
65 TurnEnd,
66 AgentEnd {
67 messages: Vec<AgentMessage>,
68 },
69}
70
71pub type TransformFn = Box<dyn Fn(&[AgentMessage]) -> Vec<AgentMessage> + Send + Sync>;
73
74pub type PrepareNextTurnFn = Box<dyn Fn(&[AgentMessage]) -> Option<TurnUpdate> + Send + Sync>;
76
77pub type ShouldStopFn = Box<dyn Fn(&[AgentMessage]) -> bool + Send + Sync>;
79
80pub struct TurnUpdate {
82 pub context: Option<Vec<AgentMessage>>,
84}
85
86pub struct LoopConfig<'a> {
88 pub model: String,
89 pub system_prompt: String,
90 pub tools: Vec<ToolDef>,
91 pub agent_tools: &'a [Box<dyn crate::agent::extension::AgentTool>],
92 pub extensions: &'a [Box<dyn Extension>],
93 pub tool_execution: ToolExecutionMode,
95 pub steering_queue: Option<&'a std::sync::Mutex<PendingMessageQueue>>,
98 pub follow_up_queue: Option<&'a std::sync::Mutex<PendingMessageQueue>>,
101 pub transform_context: Option<TransformFn>,
105 pub prepare_next_turn: Option<PrepareNextTurnFn>,
109 pub should_stop_after_turn: Option<ShouldStopFn>,
113}
114
115fn find_tool<'a>(
117 tools: &'a [Box<dyn crate::agent::extension::AgentTool>],
118 name: &str,
119) -> Option<&'a dyn crate::agent::extension::AgentTool> {
120 tools.iter().find(|t| t.name() == name).map(|t| t.as_ref())
121}
122
123const MAX_TOOL_ITERATIONS: usize = 25;
125
126struct ToolExecOutcome {
128 id: String,
129 name: String,
130 content: String,
131 compact: Option<String>,
132 is_error: bool,
133 terminate: bool,
135}
136
137pub async fn run_agent_loop(
143 prompts: Vec<AgentMessage>,
144 history: Vec<AgentMessage>,
145 config: &LoopConfig<'_>,
146 provider: &dyn Provider,
147 emit: &mut (dyn FnMut(AgentEvent) + Send),
148) -> anyhow::Result<Vec<AgentMessage>> {
149 let mut messages: Vec<AgentMessage> = Vec::new();
150 messages.extend(history);
151 messages.extend(prompts.clone());
152
153 let mut new_messages: Vec<AgentMessage> = prompts.clone();
154
155 emit(AgentEvent::AgentStart);
156 emit(AgentEvent::TurnStart);
157
158 let mut iteration_count: usize = 0;
159
160 loop {
163 let mut has_more_tool_calls = true;
165
166 while has_more_tool_calls {
167 iteration_count += 1;
168 if iteration_count > MAX_TOOL_ITERATIONS {
169 let msg = format!(
170 "Agent loop exceeded maximum iterations ({}). Last response may be incomplete.",
171 MAX_TOOL_ITERATIONS
172 );
173 emit(AgentEvent::Aborted {
174 reason: msg.clone(),
175 });
176 emit(AgentEvent::AgentEnd {
177 messages: new_messages.clone(),
178 });
179 return Ok(new_messages);
180 }
181
182 drain_steering(config, &mut messages, &mut new_messages, emit);
186
187 let llm_messages: &[AgentMessage] = &messages;
191 let _transformed_holder;
192 let llm_messages = if let Some(ref transform) = config.transform_context {
193 _transformed_holder = transform(llm_messages);
194 &_transformed_holder
195 } else {
196 llm_messages
197 };
198 let mut stream = provider
199 .stream(
200 &config.model,
201 &config.system_prompt,
202 llm_messages,
203 &config.tools,
204 )
205 .await?;
206
207 let mut response_text = String::new();
209 let mut tool_calls: Vec<ToolCall> = Vec::new();
210 let mut stop_reason = StopReason::EndTurn;
211
212 while let Some(event) = futures::StreamExt::next(&mut stream).await {
213 match event {
214 StreamEvent::TextDelta { text } => {
215 response_text.push_str(&text);
216 emit(AgentEvent::TextDelta { delta: text });
217 }
218 StreamEvent::ThinkingDelta { text } => {
219 emit(AgentEvent::ThinkingDelta { delta: text });
220 }
221 StreamEvent::ToolCall {
222 id,
223 name,
224 arguments,
225 } => {
226 let args: serde_json::Value = serde_json::from_str(&arguments)
227 .unwrap_or(serde_json::Value::String(arguments.clone()));
228
229 if let Some(existing) = tool_calls.iter_mut().find(|tc| tc.id == id) {
230 existing.arguments = args;
231 } else {
232 tool_calls.push(ToolCall {
233 id,
234 name,
235 arguments: args,
236 });
237 }
238 }
239 StreamEvent::Done {
240 text,
241 stop_reason: sr,
242 tool_calls: tcs,
243 ..
244 } => {
245 if response_text.is_empty() && !text.is_empty() {
246 emit(AgentEvent::TextDelta {
247 delta: text.clone(),
248 });
249 }
250 response_text = text;
251 stop_reason = sr;
252 if !tcs.is_empty() {
253 tool_calls = tcs;
254 }
255 }
256 StreamEvent::Error { message } => {
257 emit(AgentEvent::Aborted {
258 reason: message.clone(),
259 });
260 emit(AgentEvent::ToolResult {
261 id: String::new(),
262 name: String::new(),
263 content: message.clone(),
264 compact: None,
265 is_error: true,
266 });
267 let error_msg =
268 AgentMessage::tool_result(String::new(), message.clone(), true);
269 new_messages.push(error_msg);
270 emit(AgentEvent::AgentEnd {
271 messages: new_messages.clone(),
272 });
273 return Ok(new_messages);
274 }
275 }
276 }
277
278 let assistant_msg = AgentMessage {
280 id: uuid::Uuid::new_v4().to_string(),
281 parent_id: None,
282 role: Role::Assistant,
283 content: response_text.clone(),
284 tool_calls: tool_calls.clone(),
285 tool_call_id: None,
286 usage: None,
287 is_error: false,
288 timestamp: chrono::Utc::now().timestamp_millis(),
289 };
290
291 messages.push(assistant_msg.clone());
292 new_messages.push(assistant_msg);
293
294 if stop_reason == StopReason::Error {
296 emit(AgentEvent::AgentEnd {
297 messages: new_messages.clone(),
298 });
299 return Ok(new_messages);
300 }
301
302 if !tool_calls.is_empty() {
304 let has_sequential_tool = tool_calls.iter().any(|tc| {
307 config
308 .agent_tools
309 .iter()
310 .find(|t| t.name() == tc.name)
311 .map(|t| t.execution_mode() == ToolExecutionMode::Sequential)
312 .unwrap_or(false)
313 });
314
315 let effective_mode = if has_sequential_tool {
316 ToolExecutionMode::Sequential
317 } else {
318 config.tool_execution
319 };
320
321 let outcomes = match effective_mode {
322 ToolExecutionMode::Parallel => {
323 execute_tool_calls_parallel(&tool_calls, config, emit).await
324 }
325 ToolExecutionMode::Sequential => {
326 execute_tool_calls_sequential(&tool_calls, config, emit).await
327 }
328 };
329
330 let all_terminate = !outcomes.is_empty() && outcomes.iter().all(|o| o.terminate);
331
332 for outcome in outcomes {
333 let msg =
334 AgentMessage::tool_result(&outcome.id, &outcome.content, outcome.is_error);
335 emit(AgentEvent::ToolResult {
336 id: outcome.id,
337 name: outcome.name,
338 content: outcome.content,
339 compact: outcome.compact,
340 is_error: outcome.is_error,
341 });
342 messages.push(msg.clone());
343 new_messages.push(msg);
344 }
345
346 apply_prepare_next_turn(config, &mut messages, &new_messages);
349
350 if all_terminate {
351 emit(AgentEvent::TurnEnd);
353 break;
354 }
355
356 continue;
358 }
359
360 has_more_tool_calls = false;
362 emit(AgentEvent::TurnEnd);
363
364 apply_prepare_next_turn(config, &mut messages, &new_messages);
366
367 if apply_should_stop_after_turn(config, &new_messages) {
369 emit(AgentEvent::AgentEnd {
370 messages: new_messages.clone(),
371 });
372 return Ok(new_messages);
373 }
374 }
375
376 if !drain_follow_up(config, &mut messages, &mut new_messages, emit) {
379 break;
380 }
381 }
382
383 emit(AgentEvent::AgentEnd {
384 messages: new_messages.clone(),
385 });
386 Ok(new_messages)
387}
388
389fn drain_steering(
391 config: &LoopConfig<'_>,
392 messages: &mut Vec<AgentMessage>,
393 new_messages: &mut Vec<AgentMessage>,
394 emit: &mut (dyn FnMut(AgentEvent) + Send),
395) -> bool {
396 let Some(queue) = config.steering_queue else {
397 return false;
398 };
399 let drained = queue.lock().unwrap().drain();
400 if drained.is_empty() {
401 return false;
402 }
403 for msg in drained {
404 emit(AgentEvent::UserMessage {
405 content: msg.content.clone(),
406 });
407 messages.push(msg.clone());
408 new_messages.push(msg);
409 }
410 true
411}
412
413fn drain_follow_up(
416 config: &LoopConfig<'_>,
417 messages: &mut Vec<AgentMessage>,
418 new_messages: &mut Vec<AgentMessage>,
419 emit: &mut (dyn FnMut(AgentEvent) + Send),
420) -> bool {
421 let Some(queue) = config.follow_up_queue else {
422 return false;
423 };
424 let drained = queue.lock().unwrap().drain();
425 if drained.is_empty() {
426 return false;
427 }
428 for msg in drained {
429 emit(AgentEvent::UserMessage {
430 content: msg.content.clone(),
431 });
432 messages.push(msg.clone());
433 new_messages.push(msg);
434 }
435 true
436}
437
438fn apply_prepare_next_turn(
441 config: &LoopConfig<'_>,
442 messages: &mut Vec<AgentMessage>,
443 new_messages: &[AgentMessage],
444) {
445 if let Some(ref prepare) = config.prepare_next_turn
446 && let Some(update) = prepare(new_messages)
447 && let Some(ctx) = update.context
448 {
449 *messages = ctx;
450 }
451}
452
453fn apply_should_stop_after_turn(config: &LoopConfig<'_>, new_messages: &[AgentMessage]) -> bool {
456 config
457 .should_stop_after_turn
458 .as_ref()
459 .map(|stop| stop(new_messages))
460 .unwrap_or(false)
461}
462
463async fn execute_tool_calls_sequential(
465 tool_calls: &[ToolCall],
466 config: &LoopConfig<'_>,
467 emit: &mut (dyn FnMut(AgentEvent) + Send),
468) -> Vec<ToolExecOutcome> {
469 let mut outcomes = Vec::new();
470
471 for tc in tool_calls {
472 emit(AgentEvent::ToolCall {
473 id: tc.id.clone(),
474 name: tc.name.clone(),
475 args: tc.arguments.clone(),
476 });
477
478 let mut blocked = false;
480 for ext in config.extensions {
481 if let Some(reason) = ext.before_tool_call(tc).await {
482 outcomes.push(ToolExecOutcome {
483 id: tc.id.clone(),
484 name: tc.name.clone(),
485 content: format!("Tool execution blocked: {:?}", reason),
486 compact: None,
487 is_error: true,
488 terminate: false,
489 });
490 blocked = true;
491 break;
492 }
493 }
494 if blocked {
495 continue;
496 }
497
498 let outcome = execute_single_tool(
500 tc,
501 config.agent_tools,
502 config.extensions,
503 None, )
505 .await;
506 outcomes.push(outcome);
507 }
508
509 outcomes
510}
511
512async fn execute_tool_calls_parallel(
517 tool_calls: &[ToolCall],
518 config: &LoopConfig<'_>,
519 emit: &mut (dyn FnMut(AgentEvent) + Send),
520) -> Vec<ToolExecOutcome> {
521 let mut outcomes: Vec<ToolExecOutcome> = Vec::with_capacity(tool_calls.len());
522 let mut futures: Vec<
523 std::pin::Pin<Box<dyn std::future::Future<Output = ToolExecOutcome> + Send + '_>>,
524 > = Vec::new();
525
526 for tc in tool_calls {
535 emit(AgentEvent::ToolCall {
536 id: tc.id.clone(),
537 name: tc.name.clone(),
538 args: tc.arguments.clone(),
539 });
540
541 let mut blocked = false;
542 for ext in config.extensions {
543 if let Some(reason) = ext.before_tool_call(tc).await {
544 outcomes.push(ToolExecOutcome {
545 id: tc.id.clone(),
546 name: tc.name.clone(),
547 content: format!("Tool execution blocked: {:?}", reason),
548 compact: None,
549 is_error: true,
550 terminate: false,
551 });
552 blocked = true;
553 break;
554 }
555 }
556 if blocked {
557 continue;
558 }
559
560 let tc_clone = tc.clone();
565 futures.push(Box::pin(async move {
566 execute_single_tool(
567 &tc_clone,
568 config.agent_tools,
569 config.extensions,
570 None, )
572 .await
573 }));
574 }
575
576 if !futures.is_empty() {
578 let results = join_all(futures).await;
579 outcomes.extend(results);
580 }
581
582 outcomes
583}
584
585async fn execute_single_tool(
589 tc: &ToolCall,
590 agent_tools: &[Box<dyn crate::agent::extension::AgentTool>],
591 extensions: &[Box<dyn Extension>],
592 progress_tx: Option<tokio::sync::mpsc::UnboundedSender<AgentEvent>>,
593) -> ToolExecOutcome {
594 let cancel = Cancel::new();
595
596 if let Some(tool) = find_tool(agent_tools, &tc.name) {
597 let args = tool.prepare_arguments(tc.arguments.clone());
599
600 let on_update = progress_tx.as_ref().map(|_| {
603 let (tool_tx, mut tool_rx) = tokio::sync::mpsc::unbounded_channel::<ToolOutput>();
604 if let Some(ref tx) = progress_tx {
605 let tx = tx.clone();
606 tokio::spawn(async move {
607 while let Some(output) = tool_rx.recv().await {
608 let _ = tx.send(AgentEvent::ToolProgress {
609 content: output.content,
610 is_error: output.is_error,
611 });
612 }
613 });
614 }
615 tool_tx
616 });
617
618 match tool.execute(tc.id.clone(), args, cancel, on_update).await {
619 Ok(output) => {
620 let mut final_result = output.content.clone();
622 for ext in extensions {
623 if let Some(overridden) = ext.after_tool_call(tc, &final_result).await {
624 final_result = overridden;
625 }
626 }
627
628 ToolExecOutcome {
629 id: tc.id.clone(),
630 name: tc.name.clone(),
631 content: final_result,
632 compact: output.compact,
633 is_error: false,
634 terminate: output.terminate,
635 }
636 }
637 Err(e) => ToolExecOutcome {
638 id: tc.id.clone(),
639 name: tc.name.clone(),
640 content: format!("{:#}", e),
641 compact: None,
642 is_error: true,
643 terminate: false,
644 },
645 }
646 } else {
647 ToolExecOutcome {
648 id: tc.id.clone(),
649 name: tc.name.clone(),
650 content: format!("Tool '{}' not found", tc.name),
651 compact: None,
652 is_error: true,
653 terminate: false,
654 }
655 }
656}
657
658#[cfg(test)]
659mod tests {
660 use super::*;
661 use crate::agent::extension::{AgentTool, BlockReason, Cancel, ToolOutput};
662 use crate::agent::provider::StreamEvent;
663 use crate::agent::types::{
664 AgentMessage, PendingMessageQueue, QueueMode, Role, ToolCall, ToolExecutionMode,
665 };
666 use async_trait::async_trait;
667 use futures::Stream;
668 use std::pin::Pin;
669 use std::sync::Arc;
670
671 struct MockProvider {
673 responses: Arc<std::sync::Mutex<Vec<MockResponse>>>,
674 sent_messages: Arc<std::sync::Mutex<Vec<Vec<AgentMessage>>>>,
676 }
677
678 struct MockResponse {
679 text: String,
680 tool_calls: Vec<ToolCall>,
681 stop_reason: StopReason,
682 thinking: String,
683 }
684
685 impl MockProvider {
686 fn new() -> Self {
687 Self {
688 responses: Arc::new(std::sync::Mutex::new(Vec::new())),
689 sent_messages: Arc::new(std::sync::Mutex::new(Vec::new())),
690 }
691 }
692
693 fn add_response(&self, text: &str) {
694 self.responses.lock().unwrap().push(MockResponse {
695 text: text.to_string(),
696 tool_calls: vec![],
697 stop_reason: StopReason::EndTurn,
698 thinking: String::new(),
699 });
700 }
701
702 fn add_tool_call_response(&self, text: &str, tool_calls: Vec<ToolCall>) {
703 self.responses.lock().unwrap().push(MockResponse {
704 text: text.to_string(),
705 tool_calls,
706 stop_reason: StopReason::ToolUse,
707 thinking: String::new(),
708 });
709 }
710
711 #[allow(dead_code)]
712 fn sent_message_count(&self) -> usize {
713 self.sent_messages.lock().unwrap().len()
714 }
715
716 #[allow(dead_code)]
717 fn last_sent_message_count(&self) -> usize {
718 let msgs = self.sent_messages.lock().unwrap();
719 msgs.last().map(|m| m.len()).unwrap_or(0)
720 }
721 }
722
723 #[async_trait]
724 impl Provider for MockProvider {
725 async fn stream(
726 &self,
727 _model: &str,
728 _system: &str,
729 messages: &[AgentMessage],
730 _tools: &[ToolDef],
731 ) -> anyhow::Result<Pin<Box<dyn Stream<Item = StreamEvent> + Send>>> {
732 self.sent_messages.lock().unwrap().push(messages.to_vec());
734
735 let mut resp = self.responses.lock().unwrap();
736 let response = if resp.is_empty() {
737 MockResponse {
739 text: String::new(),
740 tool_calls: vec![],
741 stop_reason: StopReason::EndTurn,
742 thinking: String::new(),
743 }
744 } else {
745 resp.remove(0)
746 };
747 drop(resp);
748
749 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
750
751 if !response.thinking.is_empty() {
753 let _ = tx.send(StreamEvent::ThinkingDelta {
754 text: response.thinking.clone(),
755 });
756 }
757
758 if !response.text.is_empty() {
760 let _ = tx.send(StreamEvent::TextDelta {
761 text: response.text.clone(),
762 });
763 }
764
765 let _ = tx.send(StreamEvent::Done {
767 text: response.text,
768 usage: crate::agent::types::Usage::default(),
769 stop_reason: response.stop_reason,
770 tool_calls: response.tool_calls,
771 });
772
773 use futures::stream::unfold;
775 let stream = unfold(rx, |mut rx| async move {
776 rx.recv().await.map(|event| (event, rx))
777 });
778 Ok(Box::pin(stream))
779 }
780 }
781
782 struct MockTool {
784 name: String,
785 execution_mode: ToolExecutionMode,
786 execute_delay: std::time::Duration,
787 executed: Arc<std::sync::Mutex<Vec<String>>>,
788 terminate: bool,
789 }
790
791 impl MockTool {
792 fn new(name: &str) -> Self {
793 Self {
794 name: name.to_string(),
795 execution_mode: ToolExecutionMode::Parallel,
796 execute_delay: std::time::Duration::ZERO,
797 executed: Arc::new(std::sync::Mutex::new(Vec::new())),
798 terminate: false,
799 }
800 }
801
802 #[allow(dead_code)]
803 fn with_sequential(mut self) -> Self {
804 self.execution_mode = ToolExecutionMode::Sequential;
805 self
806 }
807
808 fn with_delay(mut self, delay: std::time::Duration) -> Self {
809 self.execute_delay = delay;
810 self
811 }
812
813 fn with_terminate(mut self) -> Self {
814 self.terminate = true;
815 self
816 }
817 }
818
819 #[async_trait]
820 impl AgentTool for MockTool {
821 fn name(&self) -> &str {
822 &self.name
823 }
824 fn description(&self) -> &str {
825 "mock tool"
826 }
827 fn parameters(&self) -> serde_json::Value {
828 serde_json::json!({})
829 }
830 fn label(&self) -> &str {
831 &self.name
832 }
833 fn execution_mode(&self) -> ToolExecutionMode {
834 self.execution_mode
835 }
836
837 async fn execute(
838 &self,
839 tool_call_id: String,
840 _args: serde_json::Value,
841 _cancel: Cancel,
842 _on_update: Option<tokio::sync::mpsc::UnboundedSender<ToolOutput>>,
843 ) -> anyhow::Result<ToolOutput> {
844 self.executed.lock().unwrap().push(tool_call_id.clone());
845
846 if self.execute_delay > std::time::Duration::ZERO {
847 tokio::time::sleep(self.execute_delay).await;
848 }
849
850 Ok(ToolOutput {
851 content: format!("executed: {}", tool_call_id),
852 compact: None,
853 is_error: false,
854 terminate: self.terminate,
855 })
856 }
857 }
858
859 #[derive(Debug, Clone)]
861 struct EventRecorder {
862 events: Arc<std::sync::Mutex<Vec<AgentEvent>>>,
863 }
864
865 impl EventRecorder {
866 fn new() -> Self {
867 Self {
868 events: Arc::new(std::sync::Mutex::new(Vec::new())),
869 }
870 }
871
872 fn record(&self, event: AgentEvent) {
873 self.events.lock().unwrap().push(event);
874 }
875
876 fn events(&self) -> Vec<AgentEvent> {
877 self.events.lock().unwrap().clone()
878 }
879
880 fn event_types(&self) -> Vec<String> {
881 self.events()
882 .iter()
883 .map(|e| match e {
884 AgentEvent::AgentStart => "agent_start".to_string(),
885 AgentEvent::TurnStart => "turn_start".to_string(),
886 AgentEvent::TextDelta { .. } => "text_delta".to_string(),
887 AgentEvent::ThinkingDelta { .. } => "thinking_delta".to_string(),
888 AgentEvent::ToolCall { .. } => "tool_call".to_string(),
889 AgentEvent::ToolCallArgsUpdate { .. } => "tool_call_args_update".to_string(),
890 AgentEvent::ToolResult { .. } => "tool_result".to_string(),
891 AgentEvent::ToolProgress { .. } => "tool_progress".to_string(),
892 AgentEvent::Aborted { .. } => "aborted".to_string(),
893 AgentEvent::UserMessage { .. } => "user_message".to_string(),
894 AgentEvent::TurnEnd => "turn_end".to_string(),
895 AgentEvent::AgentEnd { .. } => "agent_end".to_string(),
896 })
897 .collect()
898 }
899
900 fn text_deltas(&self) -> Vec<String> {
901 self.events()
902 .iter()
903 .filter_map(|e| {
904 if let AgentEvent::TextDelta { delta } = e {
905 Some(delta.clone())
906 } else {
907 None
908 }
909 })
910 .collect()
911 }
912 }
913
914 #[tokio::test]
918 async fn test_basic_text_response() {
919 let provider = MockProvider::new();
920 provider.add_response("Hello, world!");
921
922 let recorder = EventRecorder::new();
923 let mut emit = |e: AgentEvent| recorder.record(e);
924
925 let config = LoopConfig {
926 model: "test".to_string(),
927 system_prompt: "You are helpful.".to_string(),
928 tools: vec![],
929 agent_tools: &[],
930 extensions: &[],
931 tool_execution: ToolExecutionMode::Parallel,
932 steering_queue: None,
933 follow_up_queue: None,
934 transform_context: None,
935 prepare_next_turn: None,
936 should_stop_after_turn: None,
937 };
938
939 let prompt = AgentMessage::user("Hi");
940 let result = run_agent_loop(vec![prompt], vec![], &config, &provider, &mut emit)
941 .await
942 .unwrap();
943
944 assert_eq!(result.len(), 2);
946 assert_eq!(result[0].role, Role::User);
947 assert_eq!(result[1].role, Role::Assistant);
948
949 let types = recorder.event_types();
951 assert!(types.contains(&"agent_start".to_string()));
952 assert!(types.contains(&"text_delta".to_string()));
953 assert!(types.contains(&"turn_end".to_string()));
954 assert!(types.contains(&"agent_end".to_string()));
955
956 let texts = recorder.text_deltas();
958 assert!(texts.iter().any(|t| t == "Hello, world!"));
959 }
960
961 #[tokio::test]
963 async fn test_sequential_tool_execution() {
964 let tool = MockTool::new("echo");
965 let tool_executed = Arc::clone(&tool.executed);
966 let agent_tools: Vec<Box<dyn AgentTool>> = vec![Box::new(tool)];
967
968 let provider = MockProvider::new();
969 provider.add_tool_call_response(
970 "",
971 vec![
972 ToolCall {
973 id: "call-1".to_string(),
974 name: "echo".to_string(),
975 arguments: serde_json::json!({}),
976 },
977 ToolCall {
978 id: "call-2".to_string(),
979 name: "echo".to_string(),
980 arguments: serde_json::json!({}),
981 },
982 ],
983 );
984 provider.add_response("Done after tools.");
985
986 let recorder = EventRecorder::new();
987 let mut emit = |e: AgentEvent| recorder.record(e);
988
989 let config = LoopConfig {
990 model: "test".to_string(),
991 system_prompt: "".to_string(),
992 tools: vec![],
993 agent_tools: &agent_tools,
994 extensions: &[],
995 tool_execution: ToolExecutionMode::Sequential,
996 steering_queue: None,
997 follow_up_queue: None,
998 transform_context: None,
999 prepare_next_turn: None,
1000 should_stop_after_turn: None,
1001 };
1002
1003 let result = run_agent_loop(
1004 vec![AgentMessage::user("run tools")],
1005 vec![],
1006 &config,
1007 &provider,
1008 &mut emit,
1009 )
1010 .await
1011 .unwrap();
1012
1013 assert_eq!(result.len(), 5);
1015
1016 let executed = tool_executed.lock().unwrap().clone();
1017 assert_eq!(executed.len(), 2);
1018 assert_eq!(executed[0], "call-1");
1019 assert_eq!(executed[1], "call-2");
1020
1021 let types = recorder.event_types();
1023 assert!(types.contains(&"tool_call".to_string()));
1024 assert!(types.contains(&"tool_result".to_string()));
1025 }
1026
1027 #[tokio::test]
1029 async fn test_parallel_tool_execution() {
1030 let fast_tool =
1031 Arc::new(MockTool::new("fast").with_delay(std::time::Duration::from_millis(50)));
1032 let slow_tool =
1033 Arc::new(MockTool::new("slow").with_delay(std::time::Duration::from_millis(100)));
1034 let _fast_executed = Arc::clone(&fast_tool.executed);
1035 let _slow_executed = Arc::clone(&slow_tool.executed);
1036
1037 let start_times: Arc<std::sync::Mutex<Vec<(String, std::time::Instant)>>> =
1039 Arc::new(std::sync::Mutex::new(Vec::new()));
1040 let start_times_clone = Arc::clone(&start_times);
1041
1042 struct TrackingTool {
1043 inner: MockTool,
1044 start_times: Arc<std::sync::Mutex<Vec<(String, std::time::Instant)>>>,
1045 }
1046 #[async_trait]
1047 impl AgentTool for TrackingTool {
1048 fn name(&self) -> &str {
1049 self.inner.name()
1050 }
1051 fn description(&self) -> &str {
1052 "tracking"
1053 }
1054 fn parameters(&self) -> serde_json::Value {
1055 serde_json::json!({})
1056 }
1057 fn label(&self) -> &str {
1058 self.inner.name()
1059 }
1060 async fn execute(
1061 &self,
1062 tool_call_id: String,
1063 args: serde_json::Value,
1064 cancel: Cancel,
1065 on_update: Option<tokio::sync::mpsc::UnboundedSender<ToolOutput>>,
1066 ) -> anyhow::Result<ToolOutput> {
1067 self.start_times
1068 .lock()
1069 .unwrap()
1070 .push((tool_call_id.clone(), std::time::Instant::now()));
1071 self.inner
1072 .execute(tool_call_id, args, cancel, on_update)
1073 .await
1074 }
1075 }
1076
1077 let agent_tools: Vec<Box<dyn AgentTool>> = vec![
1078 Box::new(TrackingTool {
1079 inner: MockTool::new("slow").with_delay(std::time::Duration::from_millis(100)),
1080 start_times: Arc::clone(&start_times),
1081 }),
1082 Box::new(TrackingTool {
1083 inner: MockTool::new("fast").with_delay(std::time::Duration::from_millis(50)),
1084 start_times: Arc::clone(&start_times_clone),
1085 }),
1086 ];
1087
1088 let provider = MockProvider::new();
1089 provider.add_tool_call_response(
1090 "",
1091 vec![
1092 ToolCall {
1093 id: "slow-1".to_string(),
1094 name: "slow".to_string(),
1095 arguments: serde_json::json!({}),
1096 },
1097 ToolCall {
1098 id: "fast-1".to_string(),
1099 name: "fast".to_string(),
1100 arguments: serde_json::json!({}),
1101 },
1102 ],
1103 );
1104 provider.add_response("All tools done.");
1105
1106 let recorder = EventRecorder::new();
1107 let mut emit = |e: AgentEvent| recorder.record(e);
1108
1109 let config = LoopConfig {
1110 model: "test".to_string(),
1111 system_prompt: "".to_string(),
1112 tools: vec![],
1113 agent_tools: &agent_tools,
1114 extensions: &[],
1115 tool_execution: ToolExecutionMode::Parallel,
1116 steering_queue: None,
1117 follow_up_queue: None,
1118 transform_context: None,
1119 prepare_next_turn: None,
1120 should_stop_after_turn: None,
1121 };
1122
1123 run_agent_loop(
1124 vec![AgentMessage::user("run tools")],
1125 vec![],
1126 &config,
1127 &provider,
1128 &mut emit,
1129 )
1130 .await
1131 .unwrap();
1132
1133 let times = start_times.lock().unwrap();
1134 assert_eq!(times.len(), 2, "both tools should have started");
1135
1136 let names: Vec<&str> = times.iter().map(|(n, _)| n.as_str()).collect();
1139 assert!(names.contains(&"slow-1"));
1140 assert!(names.contains(&"fast-1"));
1141 }
1142
1143 #[tokio::test]
1145 async fn test_per_tool_sequential_mode() {
1146 let executed = Arc::new(std::sync::Mutex::new(Vec::new()));
1147 {
1148 let _seq_exec = Arc::clone(&executed);
1150 let _par_exec = Arc::clone(&executed);
1151
1152 struct SeqTool;
1153 #[async_trait]
1154 impl AgentTool for SeqTool {
1155 fn name(&self) -> &str {
1156 "sequential_tool"
1157 }
1158 fn description(&self) -> &str {
1159 ""
1160 }
1161 fn parameters(&self) -> serde_json::Value {
1162 serde_json::json!({})
1163 }
1164 fn label(&self) -> &str {
1165 "sequential_tool"
1166 }
1167 fn execution_mode(&self) -> ToolExecutionMode {
1168 ToolExecutionMode::Sequential
1169 }
1170 async fn execute(
1171 &self,
1172 tool_call_id: String,
1173 _args: serde_json::Value,
1174 _cancel: Cancel,
1175 _on_update: Option<tokio::sync::mpsc::UnboundedSender<ToolOutput>>,
1176 ) -> anyhow::Result<ToolOutput> {
1177 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
1179 Ok(ToolOutput::ok(format!("done: {}", tool_call_id)))
1180 }
1181 }
1182
1183 struct ParTool {
1184 executed: Arc<std::sync::Mutex<Vec<String>>>,
1185 }
1186 #[async_trait]
1187 impl AgentTool for ParTool {
1188 fn name(&self) -> &str {
1189 "parallel_tool"
1190 }
1191 fn description(&self) -> &str {
1192 ""
1193 }
1194 fn parameters(&self) -> serde_json::Value {
1195 serde_json::json!({})
1196 }
1197 fn label(&self) -> &str {
1198 "parallel_tool"
1199 }
1200 async fn execute(
1201 &self,
1202 tool_call_id: String,
1203 _args: serde_json::Value,
1204 _cancel: Cancel,
1205 _on_update: Option<tokio::sync::mpsc::UnboundedSender<ToolOutput>>,
1206 ) -> anyhow::Result<ToolOutput> {
1207 self.executed.lock().unwrap().push(tool_call_id.clone());
1208 Ok(ToolOutput::ok(format!("done: {}", tool_call_id)))
1209 }
1210 }
1211
1212 let agent_tools: Vec<Box<dyn AgentTool>> = vec![
1213 Box::new(SeqTool),
1214 Box::new(ParTool {
1215 executed: Arc::clone(&executed),
1216 }),
1217 ];
1218
1219 let provider = MockProvider::new();
1220 provider.add_tool_call_response(
1221 "",
1222 vec![
1223 ToolCall {
1224 id: "seq-1".to_string(),
1225 name: "sequential_tool".to_string(),
1226 arguments: serde_json::json!({}),
1227 },
1228 ToolCall {
1229 id: "par-1".to_string(),
1230 name: "parallel_tool".to_string(),
1231 arguments: serde_json::json!({}),
1232 },
1233 ],
1234 );
1235 provider.add_response("Done.");
1236
1237 let recorder = EventRecorder::new();
1238 let mut emit = |e: AgentEvent| recorder.record(e);
1239
1240 let config = LoopConfig {
1241 model: "test".to_string(),
1242 system_prompt: "".to_string(),
1243 tools: vec![],
1244 agent_tools: &agent_tools,
1245 extensions: &[],
1246 tool_execution: ToolExecutionMode::Parallel,
1247 steering_queue: None,
1248 follow_up_queue: None,
1249 transform_context: None,
1250 prepare_next_turn: None,
1251 should_stop_after_turn: None,
1252 };
1253
1254 run_agent_loop(
1255 vec![AgentMessage::user("run")],
1256 vec![],
1257 &config,
1258 &provider,
1259 &mut emit,
1260 )
1261 .await
1262 .unwrap();
1263
1264 let exec_order = executed.lock().unwrap().clone();
1266 assert_eq!(
1267 exec_order.len(),
1268 1,
1269 "only parallel_tool records in executed"
1270 );
1271 }
1272 }
1273
1274 #[tokio::test]
1276 async fn test_terminate_stops_loop() {
1277 let agent_tools: Vec<Box<dyn AgentTool>> =
1278 vec![Box::new(MockTool::new("final").with_terminate())];
1279
1280 let provider = MockProvider::new();
1281 provider.add_tool_call_response(
1282 "",
1283 vec![ToolCall {
1284 id: "final-1".to_string(),
1285 name: "final".to_string(),
1286 arguments: serde_json::json!({}),
1287 }],
1288 );
1289 let recorder = EventRecorder::new();
1292 let mut emit = |e: AgentEvent| recorder.record(e);
1293
1294 let config = LoopConfig {
1295 model: "test".to_string(),
1296 system_prompt: "".to_string(),
1297 tools: vec![],
1298 agent_tools: &agent_tools,
1299 extensions: &[],
1300 tool_execution: ToolExecutionMode::Parallel,
1301 steering_queue: None,
1302 follow_up_queue: None,
1303 transform_context: None,
1304 prepare_next_turn: None,
1305 should_stop_after_turn: None,
1306 };
1307
1308 let result = run_agent_loop(
1309 vec![AgentMessage::user("final")],
1310 vec![],
1311 &config,
1312 &provider,
1313 &mut emit,
1314 )
1315 .await
1316 .unwrap();
1317
1318 assert_eq!(
1321 result.len(),
1322 3,
1323 "should stop after terminate without second LLM call"
1324 );
1325
1326 let types = recorder.event_types();
1327 assert!(types.contains(&"turn_end".to_string()));
1328 assert!(types.contains(&"agent_end".to_string()));
1329 }
1330
1331 #[tokio::test]
1333 async fn test_transform_context() {
1334 let provider = MockProvider::new();
1335 provider.add_response("Response");
1336
1337 let transform_called = Arc::new(std::sync::Mutex::new(false));
1338 let transform_called_clone = Arc::clone(&transform_called);
1339
1340 let config = LoopConfig {
1341 model: "test".to_string(),
1342 system_prompt: "".to_string(),
1343 tools: vec![],
1344 agent_tools: &[],
1345 extensions: &[],
1346 tool_execution: ToolExecutionMode::Parallel,
1347 steering_queue: None,
1348 follow_up_queue: None,
1349 transform_context: Some(Box::new(move |msgs| {
1350 *transform_called_clone.lock().unwrap() = true;
1351 msgs.to_vec()
1352 })),
1353 prepare_next_turn: None,
1354 should_stop_after_turn: None,
1355 };
1356
1357 let mut emit = |_: AgentEvent| {};
1358 run_agent_loop(
1359 vec![AgentMessage::user("hi")],
1360 vec![],
1361 &config,
1362 &provider,
1363 &mut emit,
1364 )
1365 .await
1366 .unwrap();
1367
1368 assert!(
1369 *transform_called.lock().unwrap(),
1370 "transform_context should be called"
1371 );
1372 }
1373
1374 #[tokio::test]
1376 async fn test_prepare_next_turn() {
1377 let agent_tools: Vec<Box<dyn AgentTool>> = vec![Box::new(MockTool::new("echo"))];
1378 let provider = MockProvider::new();
1379 provider.add_tool_call_response(
1380 "",
1381 vec![ToolCall {
1382 id: "tool-1".to_string(),
1383 name: "echo".to_string(),
1384 arguments: serde_json::json!({}),
1385 }],
1386 );
1387 provider.add_response("After prepare.");
1388
1389 let prepare_called = Arc::new(std::sync::Mutex::new(false));
1390 let prepare_called_clone = Arc::clone(&prepare_called);
1391
1392 let config = LoopConfig {
1393 model: "test".to_string(),
1394 system_prompt: "".to_string(),
1395 tools: vec![],
1396 agent_tools: &agent_tools,
1397 extensions: &[],
1398 tool_execution: ToolExecutionMode::Sequential,
1399 steering_queue: None,
1400 follow_up_queue: None,
1401 transform_context: None,
1402 prepare_next_turn: Some(Box::new(move |_new_msgs| {
1403 *prepare_called_clone.lock().unwrap() = true;
1404 None })),
1406 should_stop_after_turn: None,
1407 };
1408
1409 let mut emit = |_: AgentEvent| {};
1410 run_agent_loop(
1411 vec![AgentMessage::user("run")],
1412 vec![],
1413 &config,
1414 &provider,
1415 &mut emit,
1416 )
1417 .await
1418 .unwrap();
1419
1420 assert!(
1421 *prepare_called.lock().unwrap(),
1422 "prepare_next_turn should be called"
1423 );
1424 }
1425
1426 #[tokio::test]
1428 async fn test_should_stop_after_turn() {
1429 let provider = MockProvider::new();
1430 provider.add_response("First turn.");
1431
1432 let stop = Arc::new(std::sync::Mutex::new(true));
1433 let stop_clone = Arc::clone(&stop);
1434
1435 let config = LoopConfig {
1436 model: "test".to_string(),
1437 system_prompt: "".to_string(),
1438 tools: vec![],
1439 agent_tools: &[],
1440 extensions: &[],
1441 tool_execution: ToolExecutionMode::Parallel,
1442 steering_queue: None,
1443 follow_up_queue: None,
1444 transform_context: None,
1445 prepare_next_turn: None,
1446 should_stop_after_turn: Some(Box::new(move |_| *stop_clone.lock().unwrap())),
1447 };
1448
1449 let recorder = EventRecorder::new();
1450 let mut emit = |e: AgentEvent| recorder.record(e);
1451 run_agent_loop(
1452 vec![AgentMessage::user("hi")],
1453 vec![],
1454 &config,
1455 &provider,
1456 &mut emit,
1457 )
1458 .await
1459 .unwrap();
1460
1461 let types = recorder.event_types();
1463 let agent_end_count = types.iter().filter(|t| *t == "agent_end").count();
1464 assert_eq!(agent_end_count, 1, "should end exactly once");
1465 }
1466
1467 #[tokio::test]
1469 async fn test_steering_queue() {
1470 let agent_tools: Vec<Box<dyn AgentTool>> = vec![Box::new(MockTool::new("echo"))];
1471 let provider = MockProvider::new();
1472 provider.add_tool_call_response(
1473 "",
1474 vec![ToolCall {
1475 id: "tool-1".to_string(),
1476 name: "echo".to_string(),
1477 arguments: serde_json::json!({}),
1478 }],
1479 );
1480 provider.add_response("After tool.");
1481 provider.add_response("After steering.");
1482
1483 let steering_queue = std::sync::Mutex::new(PendingMessageQueue::new(QueueMode::OneAtATime));
1484 steering_queue
1486 .lock()
1487 .unwrap()
1488 .enqueue(AgentMessage::user("steer here"));
1489
1490 let recorder = EventRecorder::new();
1491 let mut emit = |e: AgentEvent| recorder.record(e);
1492
1493 let config = LoopConfig {
1494 model: "test".to_string(),
1495 system_prompt: "".to_string(),
1496 tools: vec![],
1497 agent_tools: &agent_tools,
1498 extensions: &[],
1499 tool_execution: ToolExecutionMode::Sequential,
1500 steering_queue: Some(&steering_queue),
1501 follow_up_queue: None,
1502 transform_context: None,
1503 prepare_next_turn: None,
1504 should_stop_after_turn: None,
1505 };
1506
1507 let result = run_agent_loop(
1508 vec![AgentMessage::user("run")],
1509 vec![],
1510 &config,
1511 &provider,
1512 &mut emit,
1513 )
1514 .await
1515 .unwrap();
1516
1517 let types = recorder.event_types();
1520 let user_msg_count = types.iter().filter(|t| *t == "user_message").count();
1521 assert!(
1522 user_msg_count >= 1,
1523 "steering should produce at least one user_message event, got {}",
1524 user_msg_count
1525 );
1526
1527 let user_messages: Vec<&AgentMessage> =
1529 result.iter().filter(|m| m.role == Role::User).collect();
1530 assert_eq!(
1531 user_messages.len(),
1532 2,
1533 "should have original prompt + steering message"
1534 );
1535 }
1536
1537 #[tokio::test]
1539 async fn test_follow_up_queue() {
1540 let provider = MockProvider::new();
1541 provider.add_response("First response.");
1542 provider.add_response("Follow-up response.");
1543
1544 let follow_up_queue =
1545 std::sync::Mutex::new(PendingMessageQueue::new(QueueMode::OneAtATime));
1546 follow_up_queue
1547 .lock()
1548 .unwrap()
1549 .enqueue(AgentMessage::user("follow up"));
1550
1551 let recorder = EventRecorder::new();
1552 let mut emit = |e: AgentEvent| recorder.record(e);
1553
1554 let config = LoopConfig {
1555 model: "test".to_string(),
1556 system_prompt: "".to_string(),
1557 tools: vec![],
1558 agent_tools: &[],
1559 extensions: &[],
1560 tool_execution: ToolExecutionMode::Parallel,
1561 steering_queue: None,
1562 follow_up_queue: Some(&follow_up_queue),
1563 transform_context: None,
1564 prepare_next_turn: None,
1565 should_stop_after_turn: None,
1566 };
1567
1568 let result = run_agent_loop(
1569 vec![AgentMessage::user("first")],
1570 vec![],
1571 &config,
1572 &provider,
1573 &mut emit,
1574 )
1575 .await
1576 .unwrap();
1577
1578 assert_eq!(
1580 result.len(),
1581 4,
1582 "follow-up should add another user+assistant pair"
1583 );
1584 assert_eq!(
1585 result[2].content, "follow up",
1586 "third message should be the injected follow-up"
1587 );
1588
1589 let types = recorder.event_types();
1590 assert!(types.contains(&"user_message".to_string()));
1591 }
1592
1593 #[tokio::test]
1595 async fn test_message_queue_modes() {
1596 let mut queue = PendingMessageQueue::new(QueueMode::OneAtATime);
1598 queue.enqueue(AgentMessage::user("msg1"));
1599 queue.enqueue(AgentMessage::user("msg2"));
1600
1601 let batch1 = queue.drain();
1602 assert_eq!(batch1.len(), 1, "OneAtATime should drain 1");
1603 assert_eq!(batch1[0].content, "msg1");
1604
1605 let batch2 = queue.drain();
1606 assert_eq!(batch2.len(), 1, "OneAtATime should drain 1 on second call");
1607 assert_eq!(batch2[0].content, "msg2");
1608
1609 assert!(
1610 queue.drain().is_empty(),
1611 "should be empty after both drained"
1612 );
1613
1614 let mut queue = PendingMessageQueue::new(QueueMode::All);
1616 queue.enqueue(AgentMessage::user("a"));
1617 queue.enqueue(AgentMessage::user("b"));
1618
1619 let all = queue.drain();
1620 assert_eq!(all.len(), 2, "All mode should drain both");
1621 assert!(queue.drain().is_empty(), "should be empty after drain");
1622
1623 let mut queue = PendingMessageQueue::new(QueueMode::OneAtATime);
1625 queue.enqueue(AgentMessage::user("x"));
1626 queue.clear();
1627 assert!(queue.is_empty());
1628 }
1629
1630 #[tokio::test]
1632 async fn test_prepare_arguments() {
1633 struct PrepTool;
1634 #[async_trait]
1635 impl AgentTool for PrepTool {
1636 fn name(&self) -> &str {
1637 "prep_tool"
1638 }
1639 fn description(&self) -> &str {
1640 ""
1641 }
1642 fn parameters(&self) -> serde_json::Value {
1643 serde_json::json!({})
1644 }
1645 fn label(&self) -> &str {
1646 "prep_tool"
1647 }
1648 fn prepare_arguments(&self, args: serde_json::Value) -> serde_json::Value {
1649 let mut m = serde_json::Map::new();
1650 m.insert("prepared".to_string(), serde_json::json!(true));
1651 if let Some(obj) = args.as_object() {
1652 for (k, v) in obj {
1653 m.insert(k.clone(), v.clone());
1654 }
1655 }
1656 serde_json::Value::Object(m)
1657 }
1658 async fn execute(
1659 &self,
1660 _tool_call_id: String,
1661 args: serde_json::Value,
1662 _cancel: Cancel,
1663 _on_update: Option<tokio::sync::mpsc::UnboundedSender<ToolOutput>>,
1664 ) -> anyhow::Result<ToolOutput> {
1665 assert_eq!(args.get("prepared").and_then(|v| v.as_bool()), Some(true));
1667 Ok(ToolOutput::ok("prepared ok"))
1668 }
1669 }
1670
1671 let agent_tools: Vec<Box<dyn AgentTool>> = vec![Box::new(PrepTool)];
1672 let provider = MockProvider::new();
1673 provider.add_tool_call_response(
1674 "",
1675 vec![ToolCall {
1676 id: "tool-1".to_string(),
1677 name: "prep_tool".to_string(),
1678 arguments: serde_json::json!({"original": "value"}),
1679 }],
1680 );
1681 provider.add_response("Done.");
1682
1683 let config = LoopConfig {
1684 model: "test".to_string(),
1685 system_prompt: "".to_string(),
1686 tools: vec![],
1687 agent_tools: &agent_tools,
1688 extensions: &[],
1689 tool_execution: ToolExecutionMode::Sequential,
1690 steering_queue: None,
1691 follow_up_queue: None,
1692 transform_context: None,
1693 prepare_next_turn: None,
1694 should_stop_after_turn: None,
1695 };
1696
1697 let mut emit = |_: AgentEvent| {};
1698 let result = run_agent_loop(
1699 vec![AgentMessage::user("prep")],
1700 vec![],
1701 &config,
1702 &provider,
1703 &mut emit,
1704 )
1705 .await;
1706
1707 assert!(
1708 result.is_ok(),
1709 "prepare_arguments should work without error"
1710 );
1711 }
1712
1713 #[tokio::test]
1715 async fn test_before_tool_call_blocks() {
1716 struct BlockingExt;
1717 #[async_trait]
1718 impl Extension for BlockingExt {
1719 fn name(&self) -> std::borrow::Cow<'static, str> {
1720 std::borrow::Cow::Borrowed("blocker")
1721 }
1722 async fn before_tool_call(&self, _tc: &ToolCall) -> Option<BlockReason> {
1723 Some(BlockReason::Security("blocked for test".into()))
1724 }
1725 }
1726
1727 let agent_tools: Vec<Box<dyn AgentTool>> = vec![Box::new(MockTool::new("echo"))];
1728 let extensions: Vec<Box<dyn Extension>> = vec![Box::new(BlockingExt)];
1729
1730 let provider = MockProvider::new();
1731 provider.add_tool_call_response(
1732 "",
1733 vec![ToolCall {
1734 id: "tool-1".to_string(),
1735 name: "echo".to_string(),
1736 arguments: serde_json::json!({}),
1737 }],
1738 );
1739 provider.add_response("After blocked tool.");
1740
1741 let recorder = EventRecorder::new();
1742 let mut emit = |e: AgentEvent| recorder.record(e);
1743
1744 let config = LoopConfig {
1745 model: "test".to_string(),
1746 system_prompt: "".to_string(),
1747 tools: vec![],
1748 agent_tools: &agent_tools,
1749 extensions: &extensions,
1750 tool_execution: ToolExecutionMode::Sequential,
1751 steering_queue: None,
1752 follow_up_queue: None,
1753 transform_context: None,
1754 prepare_next_turn: None,
1755 should_stop_after_turn: None,
1756 };
1757
1758 let result = run_agent_loop(
1759 vec![AgentMessage::user("block test")],
1760 vec![],
1761 &config,
1762 &provider,
1763 &mut emit,
1764 )
1765 .await
1766 .unwrap();
1767
1768 assert!(
1770 result.len() >= 3,
1771 "blocked tool should still produce a result"
1772 );
1773
1774 let tool_results: Vec<&AgentMessage> = result
1776 .iter()
1777 .filter(|m| m.role == Role::ToolResult)
1778 .collect();
1779 assert!(!tool_results.is_empty());
1780 assert!(
1781 tool_results[0].is_error,
1782 "blocked tool result should be error"
1783 );
1784 assert!(
1785 tool_results[0].content.contains("blocked"),
1786 "blocked result should mention block reason"
1787 );
1788 }
1789
1790 #[tokio::test]
1792 async fn test_provider_error_aborts() {
1793 struct ErrorProvider;
1795 #[async_trait]
1796 impl Provider for ErrorProvider {
1797 async fn stream(
1798 &self,
1799 _model: &str,
1800 _system: &str,
1801 _messages: &[AgentMessage],
1802 _tools: &[ToolDef],
1803 ) -> anyhow::Result<Pin<Box<dyn Stream<Item = StreamEvent> + Send>>> {
1804 anyhow::bail!("provider error")
1805 }
1806 }
1807
1808 let recorder = EventRecorder::new();
1809 let mut emit = |e: AgentEvent| recorder.record(e);
1810
1811 let config = LoopConfig {
1812 model: "test".to_string(),
1813 system_prompt: "".to_string(),
1814 tools: vec![],
1815 agent_tools: &[],
1816 extensions: &[],
1817 tool_execution: ToolExecutionMode::Parallel,
1818 steering_queue: None,
1819 follow_up_queue: None,
1820 transform_context: None,
1821 prepare_next_turn: None,
1822 should_stop_after_turn: None,
1823 };
1824
1825 let result = run_agent_loop(
1826 vec![AgentMessage::user("hi")],
1827 vec![],
1828 &config,
1829 &ErrorProvider,
1830 &mut emit,
1831 )
1832 .await;
1833
1834 assert!(result.is_err(), "provider error should propagate");
1836 }
1837
1838 #[tokio::test]
1840 async fn test_tool_execution_error() {
1841 struct ErrorTool;
1842 #[async_trait]
1843 impl AgentTool for ErrorTool {
1844 fn name(&self) -> &str {
1845 "error_tool"
1846 }
1847 fn description(&self) -> &str {
1848 ""
1849 }
1850 fn parameters(&self) -> serde_json::Value {
1851 serde_json::json!({})
1852 }
1853 fn label(&self) -> &str {
1854 "error_tool"
1855 }
1856 async fn execute(
1857 &self,
1858 _tool_call_id: String,
1859 _args: serde_json::Value,
1860 _cancel: Cancel,
1861 _on_update: Option<tokio::sync::mpsc::UnboundedSender<ToolOutput>>,
1862 ) -> anyhow::Result<ToolOutput> {
1863 anyhow::bail!("tool crashed")
1864 }
1865 }
1866
1867 let agent_tools: Vec<Box<dyn AgentTool>> = vec![Box::new(ErrorTool)];
1868 let provider = MockProvider::new();
1869 provider.add_tool_call_response(
1870 "",
1871 vec![ToolCall {
1872 id: "tool-1".to_string(),
1873 name: "error_tool".to_string(),
1874 arguments: serde_json::json!({}),
1875 }],
1876 );
1877 provider.add_response("After error.");
1878
1879 let recorder = EventRecorder::new();
1880 let mut emit = |e: AgentEvent| recorder.record(e);
1881
1882 let config = LoopConfig {
1883 model: "test".to_string(),
1884 system_prompt: "".to_string(),
1885 tools: vec![],
1886 agent_tools: &agent_tools,
1887 extensions: &[],
1888 tool_execution: ToolExecutionMode::Sequential,
1889 steering_queue: None,
1890 follow_up_queue: None,
1891 transform_context: None,
1892 prepare_next_turn: None,
1893 should_stop_after_turn: None,
1894 };
1895
1896 let result = run_agent_loop(
1897 vec![AgentMessage::user("error test")],
1898 vec![],
1899 &config,
1900 &provider,
1901 &mut emit,
1902 )
1903 .await
1904 .unwrap();
1905
1906 let tool_results: Vec<&AgentMessage> = result
1908 .iter()
1909 .filter(|m| m.role == Role::ToolResult)
1910 .collect();
1911 assert!(!tool_results.is_empty());
1912 assert!(tool_results[0].is_error);
1913 }
1914
1915 #[tokio::test]
1917 async fn test_tool_not_found() {
1918 let provider = MockProvider::new();
1919 provider.add_tool_call_response(
1920 "",
1921 vec![ToolCall {
1922 id: "tool-1".to_string(),
1923 name: "nonexistent".to_string(),
1924 arguments: serde_json::json!({}),
1925 }],
1926 );
1927 provider.add_response("After missing tool.");
1928
1929 let agent_tools: Vec<Box<dyn AgentTool>> = vec![];
1931
1932 let recorder = EventRecorder::new();
1933 let mut emit = |e: AgentEvent| recorder.record(e);
1934
1935 let config = LoopConfig {
1936 model: "test".to_string(),
1937 system_prompt: "".to_string(),
1938 tools: vec![],
1939 agent_tools: &agent_tools,
1940 extensions: &[],
1941 tool_execution: ToolExecutionMode::Sequential,
1942 steering_queue: None,
1943 follow_up_queue: None,
1944 transform_context: None,
1945 prepare_next_turn: None,
1946 should_stop_after_turn: None,
1947 };
1948
1949 let result = run_agent_loop(
1950 vec![AgentMessage::user("test")],
1951 vec![],
1952 &config,
1953 &provider,
1954 &mut emit,
1955 )
1956 .await
1957 .unwrap();
1958
1959 let tool_results: Vec<&AgentMessage> = result
1960 .iter()
1961 .filter(|m| m.role == Role::ToolResult)
1962 .collect();
1963 assert!(!tool_results.is_empty());
1964 assert!(tool_results[0].is_error);
1965 assert!(tool_results[0].content.contains("not found"));
1966 }
1967}