1use std::collections::HashMap;
2use std::sync::Arc;
3
4use futures::future::{join_all, try_join_all};
5use motosan_agent_tool::{Tool, ToolContext, ToolDef, ToolResult};
6use tokio::sync::mpsc;
7
8use crate::context::ContextProvider;
9use crate::error::AgentError;
10use crate::llm::{LlmClient, TokenUsage, ToolCallItem};
11use crate::message::{Message, ToolCallRef};
12use crate::Result;
13
14#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
16pub enum BackpressurePolicy {
17 #[default]
19 Block,
20 DropOldest,
26 Reject,
28}
29
30#[derive(Debug, Clone)]
35pub struct ChannelConfig {
36 pub input_capacity: usize,
38 pub ops_capacity: usize,
40 pub ops_backpressure: BackpressurePolicy,
42}
43
44impl Default for ChannelConfig {
45 fn default() -> Self {
46 Self {
47 input_capacity: 64,
48 ops_capacity: 64,
49 ops_backpressure: BackpressurePolicy::Block,
50 }
51 }
52}
53
54struct MergedTools {
64 map: HashMap<String, Arc<dyn Tool>>,
65 defs: Vec<ToolDef>,
66}
67
68impl MergedTools {
69 fn new(
74 base_map: &HashMap<String, Arc<dyn Tool>>,
75 base_defs: &[ToolDef],
76 extra_tools: &[Arc<dyn Tool>],
77 ) -> Self {
78 if extra_tools.is_empty() {
79 return Self {
80 map: base_map.clone(),
81 defs: base_defs.to_vec(),
82 };
83 }
84 let mut map = base_map.clone();
85 let mut defs = base_defs.to_vec();
86 for t in extra_tools {
87 map.insert(t.def().name.clone(), Arc::clone(t));
88 defs.push(t.def());
89 }
90 Self { map, defs }
91 }
92
93 fn tool_map(&self) -> &HashMap<String, Arc<dyn Tool>> {
94 &self.map
95 }
96
97 fn tool_defs(&self) -> &[ToolDef] {
98 &self.defs
99 }
100}
101
102struct TurnState {
104 messages: Vec<Message>,
105 all_tool_calls: Vec<(String, serde_json::Value)>,
106 total_usage: TokenUsage,
107}
108
109impl TurnState {
110 fn new(messages: Vec<Message>) -> Self {
111 Self {
112 messages,
113 all_tool_calls: Vec::new(),
114 total_usage: TokenUsage::default(),
115 }
116 }
117
118 fn accumulate_usage(&mut self, usage: Option<TokenUsage>) {
120 if let Some(u) = usage {
121 self.total_usage.accumulate(u);
122 }
123 }
124
125 fn into_result(self, answer: String, iteration: usize) -> AgentResult {
127 AgentResult {
128 answer,
129 tool_calls: self.all_tool_calls,
130 iterations: iteration,
131 usage: self.total_usage,
132 messages: self.messages,
133 }
134 }
135}
136
137fn execute_and_record_tool_calls(
142 items: &[ToolCallItem],
143 results: Vec<ToolResult>,
144 state: &mut TurnState,
145 on_event: &(impl Fn(AgentEvent) + Send + Sync),
146) {
147 for (tc, result) in items.iter().zip(results.iter()) {
149 on_event(AgentEvent::ToolCompleted {
150 name: tc.name.clone(),
151 result: result.clone(),
152 });
153 state
154 .all_tool_calls
155 .push((tc.name.clone(), tc.args.clone()));
156 }
157
158 let tool_call_refs: Vec<ToolCallRef> = items
160 .iter()
161 .map(|tc| ToolCallRef {
162 id: tc.id.clone(),
163 name: tc.name.clone(),
164 args: tc.args.clone(),
165 })
166 .collect();
167 state
168 .messages
169 .push(Message::assistant_with_tool_calls("", tool_call_refs));
170
171 for (tc, result) in items.iter().zip(results.iter()) {
173 state
174 .messages
175 .push(Message::tool_result(&tc.id, &tool_result_to_string(result)));
176 }
177}
178
179fn emit_tool_started(items: &[ToolCallItem], on_event: &(impl Fn(AgentEvent) + Send + Sync)) {
181 for tc in items {
182 on_event(AgentEvent::ToolStarted {
183 name: tc.name.clone(),
184 });
185 }
186}
187
188#[derive(Debug, Clone)]
190pub enum AgentEvent {
191 ToolStarted { name: String },
193 ToolCompleted { name: String, result: ToolResult },
195 TextChunk(String),
197 TextDone(String),
199 IterationStarted(usize),
201 Interrupted,
203 AskUser {
205 call_id: String,
206 question: String,
207 options: Vec<String>,
208 },
209 AskUserTimeout { call_id: String, question: String },
211 OpsSaturated {
215 capacity: usize,
217 },
218 OpDropped {
220 reason: String,
222 },
223 OpRejected {
225 reason: String,
227 },
228}
229
230#[derive(Debug, Clone)]
232pub enum AgentOp {
233 Interrupt,
235 InjectUserMessage(String),
237 InjectHint(String),
239 AskUserAnswer {
241 call_id: Option<String>,
242 answer: String,
243 },
244}
245
246#[derive(Debug, Clone)]
248pub struct AgentResult {
249 pub answer: String,
251 pub tool_calls: Vec<(String, serde_json::Value)>,
253 pub iterations: usize,
255 pub usage: TokenUsage,
257 pub messages: Vec<Message>,
262}
263
264pub struct AgentLoopBuilder {
266 tools: Vec<Arc<dyn Tool>>,
267 context_providers: Vec<Box<dyn ContextProvider>>,
268 max_iterations: usize,
269 tool_timeout: Option<std::time::Duration>,
270 tool_context: Option<ToolContext>,
271 ask_user_enabled: bool,
272 ask_user_timeout: Option<std::time::Duration>,
273 channel_config: ChannelConfig,
274 #[cfg(feature = "mcp-client")]
275 mcp_servers: Vec<Arc<dyn crate::mcp::McpServer>>,
276}
277
278impl AgentLoopBuilder {
279 pub fn max_iterations(mut self, n: usize) -> Self {
281 self.max_iterations = n;
282 self
283 }
284
285 pub fn tool(mut self, tool: Arc<dyn Tool>) -> Self {
287 self.tools.push(tool);
288 self
289 }
290
291 pub fn tools(mut self, tools: impl IntoIterator<Item = Arc<dyn Tool>>) -> Self {
293 self.tools.extend(tools);
294 self
295 }
296
297 pub fn system_prompt(self, prompt: impl Into<String>) -> Self {
303 self.context(crate::context::StringContextProvider(prompt.into()))
304 }
305
306 pub fn context(mut self, provider: impl ContextProvider + 'static) -> Self {
312 self.context_providers.push(Box::new(provider));
313 self
314 }
315
316 pub fn contexts(
321 mut self,
322 providers: impl IntoIterator<Item = Box<dyn ContextProvider>>,
323 ) -> Self {
324 self.context_providers.extend(providers);
325 self
326 }
327
328 #[cfg(feature = "mcp-client")]
332 pub fn mcp_server(mut self, server: impl crate::mcp::McpServer + 'static) -> Self {
333 self.mcp_servers.push(std::sync::Arc::new(server));
334 self
335 }
336
337 #[cfg(feature = "mcp-client")]
345 pub fn mcp_server_arc(mut self, server: std::sync::Arc<dyn crate::mcp::McpServer>) -> Self {
346 self.mcp_servers.push(server);
347 self
348 }
349
350 pub fn tool_timeout(mut self, duration: std::time::Duration) -> Self {
355 self.tool_timeout = Some(duration);
356 self
357 }
358
359 pub fn tool_context(mut self, ctx: ToolContext) -> Self {
365 self.tool_context = Some(ctx);
366 self
367 }
368
369 pub fn channel_config(mut self, config: ChannelConfig) -> Self {
373 self.channel_config = config;
374 self
375 }
376
377 pub fn input_channel_capacity(mut self, capacity: usize) -> Self {
381 self.channel_config.input_capacity = capacity;
382 self
383 }
384
385 pub fn ops_channel_capacity(mut self, capacity: usize) -> Self {
389 self.channel_config.ops_capacity = capacity;
390 self
391 }
392
393 pub fn ops_backpressure(mut self, policy: BackpressurePolicy) -> Self {
397 self.channel_config.ops_backpressure = policy;
398 self
399 }
400
401 pub fn with_ask_user(mut self) -> Self {
403 self.ask_user_enabled = true;
404 if self.ask_user_timeout.is_none() {
405 self.ask_user_timeout = Some(std::time::Duration::from_secs(30));
406 }
407 self
408 }
409
410 pub fn with_ask_user_timeout(mut self, timeout: std::time::Duration) -> Self {
412 self.ask_user_enabled = true;
413 self.ask_user_timeout = Some(timeout);
414 self
415 }
416
417 pub fn build(self) -> AgentLoop {
424 assert!(
425 self.channel_config.input_capacity > 0,
426 "input_capacity must be >= 1"
427 );
428 assert!(
429 self.channel_config.ops_capacity > 0,
430 "ops_capacity must be >= 1"
431 );
432 let tool_map: HashMap<String, Arc<dyn Tool>> = self
433 .tools
434 .iter()
435 .map(|t| (t.def().name.clone(), Arc::clone(t)))
436 .collect();
437 let mut tool_defs: Vec<ToolDef> = self.tools.iter().map(|t| t.def()).collect();
438 if self.ask_user_enabled {
439 tool_defs.push(ask_user_tool_def());
440 }
441 AgentLoop {
442 tool_map,
443 tool_defs,
444 context_providers: self.context_providers,
445 max_iterations: self.max_iterations,
446 tool_timeout: self.tool_timeout,
447 tool_context: self.tool_context.unwrap_or_default(),
448 ask_user_timeout: if self.ask_user_enabled {
449 self.ask_user_timeout
450 } else {
451 None
452 },
453 channel_config: self.channel_config,
454 #[cfg(feature = "mcp-client")]
455 mcp_servers: self.mcp_servers,
456 }
457 }
458}
459
460pub struct AgentLoop {
468 tool_map: HashMap<String, Arc<dyn Tool>>,
470 tool_defs: Vec<ToolDef>,
472 context_providers: Vec<Box<dyn ContextProvider>>,
473 max_iterations: usize,
474 pub(crate) tool_timeout: Option<std::time::Duration>,
477 tool_context: ToolContext,
479 ask_user_timeout: Option<std::time::Duration>,
481 channel_config: ChannelConfig,
483 #[cfg(feature = "mcp-client")]
484 mcp_servers: Vec<Arc<dyn crate::mcp::McpServer>>,
485}
486
487impl AgentLoop {
488 pub fn builder() -> AgentLoopBuilder {
490 AgentLoopBuilder {
491 tools: Vec::new(),
492 context_providers: Vec::new(),
493 max_iterations: 10,
494 tool_timeout: None,
495 tool_context: None,
496 ask_user_enabled: false,
497 ask_user_timeout: None,
498 channel_config: ChannelConfig::default(),
499 #[cfg(feature = "mcp-client")]
500 mcp_servers: Vec::new(),
501 }
502 }
503
504 pub fn channel_config(&self) -> &ChannelConfig {
509 &self.channel_config
510 }
511
512 pub async fn run(
521 &self,
522 llm: &dyn LlmClient,
523 messages: Vec<Message>,
524 on_event: impl Fn(AgentEvent) + Send + Sync,
525 ) -> Result<AgentResult> {
526 #[cfg(feature = "mcp-client")]
528 let mcp_tools: Vec<Arc<dyn Tool>> = self.connect_mcp_servers().await?;
529 #[cfg(not(feature = "mcp-client"))]
530 let mcp_tools: Vec<Arc<dyn Tool>> = vec![];
531
532 let result = self.run_inner(llm, messages, &mcp_tools, &on_event).await;
533
534 #[cfg(feature = "mcp-client")]
536 for server in &self.mcp_servers {
537 let _ = server.disconnect().await;
538 }
539
540 result
541 }
542
543 pub async fn run_with_ops(
547 &self,
548 llm: &dyn LlmClient,
549 messages: Vec<Message>,
550 ops_rx: Option<mpsc::Receiver<AgentOp>>,
551 on_event: impl Fn(AgentEvent) + Send + Sync,
552 ) -> Result<AgentResult> {
553 #[cfg(feature = "mcp-client")]
554 let mcp_tools: Vec<Arc<dyn Tool>> = self.connect_mcp_servers().await?;
555 #[cfg(not(feature = "mcp-client"))]
556 let mcp_tools: Vec<Arc<dyn Tool>> = vec![];
557
558 let result = self
559 .run_inner_with_ops(llm, messages, &mcp_tools, &on_event, ops_rx)
560 .await;
561
562 #[cfg(feature = "mcp-client")]
563 for server in &self.mcp_servers {
564 let _ = server.disconnect().await;
565 }
566
567 result
568 }
569
570 #[cfg(feature = "mcp-client")]
575 async fn connect_mcp_servers(&self) -> Result<Vec<Arc<dyn Tool>>> {
576 use crate::mcp::adapter::McpToolAdapter;
577
578 let mut connected: Vec<&Arc<dyn crate::mcp::McpServer>> = Vec::new();
579 let mut tools: Vec<Arc<dyn Tool>> = Vec::new();
580
581 for server in &self.mcp_servers {
582 match server.connect().await {
583 Ok(()) => {
584 connected.push(server);
585 match McpToolAdapter::from_server(Arc::clone(server)).await {
586 Ok(adapter_tools) => tools.extend(adapter_tools),
587 Err(e) => {
588 for s in &connected {
589 let _ = s.disconnect().await;
590 }
591 return Err(e);
592 }
593 }
594 }
595 Err(e) => {
596 for s in &connected {
597 let _ = s.disconnect().await;
598 }
599 return Err(e);
600 }
601 }
602 }
603
604 Ok(tools)
605 }
606
607 #[cfg(feature = "cancellation")]
619 pub async fn run_with_cancel(
620 &self,
621 llm: &dyn LlmClient,
622 messages: Vec<Message>,
623 cancel: tokio_util::sync::CancellationToken,
624 on_event: impl Fn(AgentEvent) + Send + Sync,
625 ) -> Result<AgentResult> {
626 #[cfg(feature = "mcp-client")]
627 let mcp_tools: Vec<Arc<dyn Tool>> = {
628 use crate::mcp::adapter::McpToolAdapter;
629 let mut tools = Vec::new();
630 for server in &self.mcp_servers {
631 server.connect().await?;
632 tools.extend(McpToolAdapter::from_server(Arc::clone(server)).await?);
633 }
634 tools
635 };
636 #[cfg(not(feature = "mcp-client"))]
637 let mcp_tools: Vec<Arc<dyn Tool>> = vec![];
638
639 let result = self
640 .run_inner_cancel(llm, messages, &mcp_tools, &on_event, &cancel)
641 .await;
642
643 #[cfg(feature = "mcp-client")]
644 for server in &self.mcp_servers {
645 let _ = server.disconnect().await;
646 }
647
648 result
649 }
650
651 #[cfg(feature = "cancellation")]
658 pub async fn run_streaming_with_cancel(
659 &self,
660 llm: &dyn LlmClient,
661 messages: Vec<Message>,
662 cancel: tokio_util::sync::CancellationToken,
663 on_event: impl Fn(AgentEvent) + Send + Sync,
664 ) -> Result<AgentResult> {
665 use futures::StreamExt;
666
667 #[cfg(feature = "mcp-client")]
668 let mcp_tools: Vec<Arc<dyn Tool>> = {
669 use crate::mcp::adapter::McpToolAdapter;
670 let mut tools = Vec::new();
671 for server in &self.mcp_servers {
672 server.connect().await?;
673 tools.extend(McpToolAdapter::from_server(Arc::clone(server)).await?);
674 }
675 tools
676 };
677 #[cfg(not(feature = "mcp-client"))]
678 let mcp_tools: Vec<Arc<dyn Tool>> = vec![];
679
680 let tools = MergedTools::new(&self.tool_map, &self.tool_defs, &mcp_tools);
681 let mut state = TurnState::new(self.prepare_messages(messages).await?);
682
683 let result = async {
684 for iteration in 1..=self.max_iterations {
685 if cancel.is_cancelled() {
687 return Err(AgentError::Cancelled);
688 }
689 on_event(AgentEvent::IterationStarted(iteration));
690
691 let (accumulated_text, response) = {
693 let mut stream = llm.chat_stream(&state.messages, tools.tool_defs());
694 let mut accumulated = String::new();
695 let mut final_response: Option<crate::LlmResponse> = None;
696
697 loop {
698 tokio::select! {
699 chunk_opt = stream.next() => {
700 match chunk_opt {
701 Some(chunk_result) => {
702 let chunk = chunk_result?;
703 match chunk {
704 crate::llm::StreamChunk::TextDelta(delta) => {
705 accumulated.push_str(&delta);
706 on_event(AgentEvent::TextChunk(delta));
707 }
708 crate::llm::StreamChunk::Done(resp) => {
709 final_response = Some(resp);
710 }
711 crate::llm::StreamChunk::Usage(usage) => {
712 state.total_usage.accumulate(usage);
713 }
714 }
715 }
716 None => break,
717 }
718 }
719 _ = cancel.cancelled() => {
720 return Err(AgentError::Cancelled);
721 }
722 }
723 }
724
725 let resp = final_response
726 .unwrap_or_else(|| crate::LlmResponse::Message(accumulated.clone()));
727 (accumulated, resp)
728 };
729
730 match response {
731 crate::LlmResponse::Message(text) => {
732 if accumulated_text.is_empty() && !text.is_empty() {
733 on_event(AgentEvent::TextChunk(text.clone()));
734 }
735 on_event(AgentEvent::TextDone(text.clone()));
736 return Ok(state.into_result(text, iteration));
737 }
738 crate::LlmResponse::ToolCalls(items) => {
739 emit_tool_started(&items, &on_event);
740 let results = Self::execute_tools_parallel(
741 tools.tool_map(),
742 &items,
743 self.tool_timeout,
744 &self.tool_context,
745 )
746 .await;
747 execute_and_record_tool_calls(&items, results, &mut state, &on_event);
748 }
749 }
750 }
751
752 Err(AgentError::MaxIterations(self.max_iterations))
753 }
754 .await;
755
756 #[cfg(feature = "mcp-client")]
757 for server in &self.mcp_servers {
758 let _ = server.disconnect().await;
759 }
760
761 result
762 }
763
764 async fn consume_stream(
768 llm: &dyn LlmClient,
769 messages: &[Message],
770 tool_defs: &[ToolDef],
771 total_usage: &mut TokenUsage,
772 on_event: &(impl Fn(AgentEvent) + Send + Sync),
773 ) -> Result<(String, crate::LlmResponse)> {
774 use futures::StreamExt;
775
776 let mut stream = llm.chat_stream(messages, tool_defs);
777 let mut accumulated = String::new();
778 let mut final_response: Option<crate::LlmResponse> = None;
779
780 while let Some(chunk_result) = stream.next().await {
781 let chunk = chunk_result?;
782 match chunk {
783 crate::llm::StreamChunk::TextDelta(delta) => {
784 accumulated.push_str(&delta);
785 on_event(AgentEvent::TextChunk(delta));
786 }
787 crate::llm::StreamChunk::Done(response) => {
788 final_response = Some(response);
789 }
790 crate::llm::StreamChunk::Usage(usage) => {
791 total_usage.accumulate(usage);
792 }
793 }
794 }
795
796 let response =
797 final_response.unwrap_or_else(|| crate::LlmResponse::Message(accumulated.clone()));
798 Ok((accumulated, response))
799 }
800
801 async fn prepare_messages(&self, mut messages: Vec<Message>) -> Result<Vec<Message>> {
807 if self.context_providers.is_empty() {
808 return Ok(messages);
809 }
810 let query: String = messages
811 .iter()
812 .rev()
813 .find(|m| m.role == crate::message::Role::User)
814 .map(|m| m.content.clone())
815 .unwrap_or_default();
816
817 let contexts = try_join_all(self.context_providers.iter().map(|p| p.build(&query))).await?;
818
819 let mut insert_idx = 0;
820 for ctx in contexts {
821 if !ctx.is_empty() {
822 messages.insert(insert_idx, Message::system(&ctx));
823 insert_idx += 1;
824 }
825 }
826 Ok(messages)
827 }
828
829 #[cfg(feature = "cancellation")]
830 async fn run_inner_cancel(
831 &self,
832 llm: &dyn LlmClient,
833 messages: Vec<Message>,
834 extra_tools: &[Arc<dyn Tool>],
835 on_event: &(impl Fn(AgentEvent) + Send + Sync),
836 cancel: &tokio_util::sync::CancellationToken,
837 ) -> Result<AgentResult> {
838 let tools = MergedTools::new(&self.tool_map, &self.tool_defs, extra_tools);
839 let mut state = TurnState::new(self.prepare_messages(messages).await?);
840
841 for iteration in 1..=self.max_iterations {
842 if cancel.is_cancelled() {
844 return Err(AgentError::Cancelled);
845 }
846 on_event(AgentEvent::IterationStarted(iteration));
847
848 let output = tokio::select! {
850 output = llm.chat(&state.messages, tools.tool_defs()) => output?,
851 _ = cancel.cancelled() => return Err(AgentError::Cancelled),
852 };
853 state.accumulate_usage(output.usage);
854
855 match output.response {
856 crate::LlmResponse::Message(text) => {
857 on_event(AgentEvent::TextChunk(text.clone()));
858 return Ok(state.into_result(text, iteration));
859 }
860 crate::LlmResponse::ToolCalls(items) => {
861 emit_tool_started(&items, on_event);
862 let results = Self::execute_tools_parallel(
863 tools.tool_map(),
864 &items,
865 self.tool_timeout,
866 &self.tool_context,
867 )
868 .await;
869 execute_and_record_tool_calls(&items, results, &mut state, on_event);
870 }
871 }
872 }
873
874 Err(AgentError::MaxIterations(self.max_iterations))
875 }
876
877 async fn run_inner_with_ops(
878 &self,
879 llm: &dyn LlmClient,
880 messages: Vec<Message>,
881 extra_tools: &[Arc<dyn Tool>],
882 on_event: &(impl Fn(AgentEvent) + Send + Sync),
883 mut ops_rx: Option<mpsc::Receiver<AgentOp>>,
884 ) -> Result<AgentResult> {
885 let tools = MergedTools::new(&self.tool_map, &self.tool_defs, extra_tools);
886 let mut state = TurnState::new(self.prepare_messages(messages).await?);
887 let mut ops_state = OpsState::default();
888
889 for iteration in 1..=self.max_iterations {
890 Self::drain_ops(&mut state.messages, &mut ops_rx, &mut ops_state);
892 if ops_state.interrupted {
893 on_event(AgentEvent::Interrupted);
894 return Ok(
895 state.into_result("(interrupted)".to_string(), iteration.saturating_sub(1))
896 );
897 }
898
899 on_event(AgentEvent::IterationStarted(iteration));
900
901 let output = llm.chat(&state.messages, tools.tool_defs()).await?;
903 state.accumulate_usage(output.usage);
904
905 match output.response {
906 crate::LlmResponse::Message(text) => {
907 on_event(AgentEvent::TextChunk(text.clone()));
908 state.messages.push(Message::assistant(&text));
909 return Ok(state.into_result(text, iteration));
910 }
911 crate::LlmResponse::ToolCalls(items) => {
912 emit_tool_started(&items, on_event);
914 let results = self
915 .execute_tools_with_policy(
916 tools.tool_map(),
917 &items,
918 &mut state.messages,
919 &mut ops_rx,
920 &mut ops_state,
921 on_event,
922 )
923 .await;
924 execute_and_record_tool_calls(&items, results, &mut state, on_event);
925 }
926 }
927 }
928
929 Err(AgentError::MaxIterations(self.max_iterations))
930 }
931
932 async fn wait_for_ask_user_answer(
933 &self,
934 call_id: &str,
935 question: &str,
936 messages: &mut Vec<Message>,
937 ops_rx: &mut Option<mpsc::Receiver<AgentOp>>,
938 ops_state: &mut OpsState,
939 on_event: &(impl Fn(AgentEvent) + Send + Sync),
940 ) -> String {
941 if let Some(answer) = pop_matching_answer(&mut ops_state.pending_answers, call_id) {
942 return answer;
943 }
944
945 let timeout = self.ask_user_timeout;
946 let started = tokio::time::Instant::now();
947
948 loop {
949 let next_op = if let Some(rx) = ops_rx.as_mut() {
950 if let Some(limit) = timeout {
951 let elapsed = started.elapsed();
952 if elapsed >= limit {
953 on_event(AgentEvent::AskUserTimeout {
954 call_id: call_id.to_string(),
955 question: question.to_string(),
956 });
957 return String::new();
958 }
959 let remaining = limit - elapsed;
960 match tokio::time::timeout(remaining, rx.recv()).await {
961 Ok(op) => op,
962 Err(_) => {
963 on_event(AgentEvent::AskUserTimeout {
964 call_id: call_id.to_string(),
965 question: question.to_string(),
966 });
967 return String::new();
968 }
969 }
970 } else {
971 rx.recv().await
972 }
973 } else {
974 on_event(AgentEvent::AskUserTimeout {
975 call_id: call_id.to_string(),
976 question: question.to_string(),
977 });
978 return String::new();
979 };
980
981 let Some(op) = next_op else {
982 on_event(AgentEvent::AskUserTimeout {
983 call_id: call_id.to_string(),
984 question: question.to_string(),
985 });
986 return String::new();
987 };
988 Self::apply_op(op, messages, ops_state);
989 if ops_state.interrupted {
990 return String::new();
991 }
992
993 if let Some(answer) = pop_matching_answer(&mut ops_state.pending_answers, call_id) {
994 return answer;
995 }
996 }
997 }
998
999 async fn run_inner(
1000 &self,
1001 llm: &dyn LlmClient,
1002 messages: Vec<Message>,
1003 extra_tools: &[Arc<dyn Tool>],
1004 on_event: &(impl Fn(AgentEvent) + Send + Sync),
1005 ) -> Result<AgentResult> {
1006 let tools = MergedTools::new(&self.tool_map, &self.tool_defs, extra_tools);
1007 let mut state = TurnState::new(self.prepare_messages(messages).await?);
1008
1009 for iteration in 1..=self.max_iterations {
1010 on_event(AgentEvent::IterationStarted(iteration));
1011
1012 let output = llm.chat(&state.messages, tools.tool_defs()).await?;
1014 state.accumulate_usage(output.usage);
1015
1016 match output.response {
1017 crate::LlmResponse::Message(text) => {
1018 on_event(AgentEvent::TextChunk(text.clone()));
1019 state.messages.push(Message::assistant(&text));
1020 return Ok(state.into_result(text, iteration));
1021 }
1022 crate::LlmResponse::ToolCalls(items) => {
1023 emit_tool_started(&items, on_event);
1025 let results = Self::execute_tools_parallel(
1026 tools.tool_map(),
1027 &items,
1028 self.tool_timeout,
1029 &self.tool_context,
1030 )
1031 .await;
1032 execute_and_record_tool_calls(&items, results, &mut state, on_event);
1033 }
1034 }
1035 }
1036
1037 Err(AgentError::MaxIterations(self.max_iterations))
1038 }
1039
1040 pub async fn run_streaming(
1049 &self,
1050 llm: &dyn LlmClient,
1051 messages: Vec<Message>,
1052 on_event: impl Fn(AgentEvent) + Send + Sync,
1053 ) -> Result<AgentResult> {
1054 #[cfg(feature = "mcp-client")]
1056 let mcp_tools: Vec<Arc<dyn Tool>> = self.connect_mcp_servers().await?;
1057 #[cfg(not(feature = "mcp-client"))]
1058 let mcp_tools: Vec<Arc<dyn Tool>> = vec![];
1059
1060 let tools = MergedTools::new(&self.tool_map, &self.tool_defs, &mcp_tools);
1061 let mut state = TurnState::new(self.prepare_messages(messages).await?);
1062
1063 let result = async {
1064 for iteration in 1..=self.max_iterations {
1065 on_event(AgentEvent::IterationStarted(iteration));
1066
1067 let (accumulated_text, response) = Self::consume_stream(
1069 llm,
1070 &state.messages,
1071 tools.tool_defs(),
1072 &mut state.total_usage,
1073 &on_event,
1074 )
1075 .await?;
1076
1077 match response {
1078 crate::LlmResponse::Message(text) => {
1079 if accumulated_text.is_empty() && !text.is_empty() {
1080 on_event(AgentEvent::TextChunk(text.clone()));
1081 }
1082 on_event(AgentEvent::TextDone(text.clone()));
1083 state.messages.push(Message::assistant(&text));
1084 return Ok(state.into_result(text, iteration));
1085 }
1086 crate::LlmResponse::ToolCalls(items) => {
1087 emit_tool_started(&items, &on_event);
1088 let results = Self::execute_tools_parallel(
1089 tools.tool_map(),
1090 &items,
1091 self.tool_timeout,
1092 &self.tool_context,
1093 )
1094 .await;
1095 execute_and_record_tool_calls(&items, results, &mut state, &on_event);
1096 }
1097 }
1098 }
1099
1100 Err(AgentError::MaxIterations(self.max_iterations))
1101 }
1102 .await;
1103
1104 #[cfg(feature = "mcp-client")]
1106 for server in &self.mcp_servers {
1107 let _ = server.disconnect().await;
1108 }
1109
1110 result
1111 }
1112
1113 fn drain_ops(
1114 messages: &mut Vec<Message>,
1115 ops_rx: &mut Option<mpsc::Receiver<AgentOp>>,
1116 ops_state: &mut OpsState,
1117 ) {
1118 if let Some(rx) = ops_rx.as_mut() {
1119 while let Ok(op) = rx.try_recv() {
1120 Self::apply_op(op, messages, ops_state);
1121 }
1122 }
1123 }
1124
1125 fn apply_op(op: AgentOp, messages: &mut Vec<Message>, ops_state: &mut OpsState) {
1126 match op {
1127 AgentOp::Interrupt => {
1128 ops_state.interrupted = true;
1129 }
1130 AgentOp::InjectUserMessage(text) => {
1131 messages.push(Message::user(&text));
1132 }
1133 AgentOp::InjectHint(hint) => {
1134 messages.push(Message::user(&format!("[Note: {hint}]")));
1135 }
1136 AgentOp::AskUserAnswer { call_id, answer } => {
1137 ops_state
1138 .pending_answers
1139 .push(PendingAskUserAnswer { call_id, answer });
1140 }
1141 }
1142 }
1143
1144 async fn execute_tools_with_policy(
1145 &self,
1146 tool_map: &HashMap<String, Arc<dyn Tool>>,
1147 items: &[crate::llm::ToolCallItem],
1148 messages: &mut Vec<Message>,
1149 ops_rx: &mut Option<mpsc::Receiver<AgentOp>>,
1150 ops_state: &mut OpsState,
1151 on_event: &(impl Fn(AgentEvent) + Send + Sync),
1152 ) -> Vec<ToolResult> {
1153 let policy = ToolExecutionPolicy::from_items(items);
1154 match policy {
1155 ToolExecutionPolicy::ParallelOnly => {
1156 Self::execute_tools_parallel(tool_map, items, self.tool_timeout, &self.tool_context)
1157 .await
1158 }
1159 ToolExecutionPolicy::InteractiveAskUser => {
1160 let non_ask_future = join_all(
1161 items
1162 .iter()
1163 .enumerate()
1164 .filter(|(_, tc)| tc.name != "ask_user")
1165 .map(|(idx, tc)| async move {
1166 (
1167 idx,
1168 Self::execute_tool(
1169 tool_map,
1170 &tc.name,
1171 tc.args.clone(),
1172 self.tool_timeout,
1173 &self.tool_context,
1174 )
1175 .await,
1176 )
1177 }),
1178 );
1179
1180 let ask_future = async {
1181 let mut ask_results = Vec::new();
1182 for (idx, tc) in items
1183 .iter()
1184 .enumerate()
1185 .filter(|(_, tc)| tc.name == "ask_user")
1186 {
1187 let (question, options) = parse_ask_user_args(&tc.args);
1188 on_event(AgentEvent::AskUser {
1189 call_id: tc.id.clone(),
1190 question: question.clone(),
1191 options,
1192 });
1193
1194 let answer = self
1195 .wait_for_ask_user_answer(
1196 &tc.id, &question, messages, ops_rx, ops_state, on_event,
1197 )
1198 .await;
1199 ask_results.push((idx, ToolResult::text(answer)));
1200 }
1201 ask_results
1202 };
1203
1204 let (non_ask_results, ask_results) = futures::join!(non_ask_future, ask_future);
1205 let mut merged: Vec<Option<ToolResult>> = vec![None; items.len()];
1206 for (idx, result) in non_ask_results.into_iter().chain(ask_results) {
1207 merged[idx] = Some(result);
1208 }
1209 merged
1210 .into_iter()
1211 .map(|result| result.expect("tool result must be present"))
1212 .collect()
1213 }
1214 }
1215 }
1216
1217 async fn execute_tools_parallel(
1218 tool_map: &HashMap<String, Arc<dyn Tool>>,
1219 items: &[crate::llm::ToolCallItem],
1220 timeout: Option<std::time::Duration>,
1221 ctx: &ToolContext,
1222 ) -> Vec<ToolResult> {
1223 join_all(
1224 items
1225 .iter()
1226 .map(|tc| Self::execute_tool(tool_map, &tc.name, tc.args.clone(), timeout, ctx)),
1227 )
1228 .await
1229 }
1230
1231 async fn execute_tool(
1234 tool_map: &HashMap<String, Arc<dyn Tool>>,
1235 name: &str,
1236 args: serde_json::Value,
1237 timeout: Option<std::time::Duration>,
1238 ctx: &ToolContext,
1239 ) -> ToolResult {
1240 let fut = async {
1241 if let Some(tool) = tool_map.get(name) {
1242 tool.call(args, &ctx).await
1243 } else {
1244 ToolResult::error(format!("unknown tool: {name}"))
1245 }
1246 };
1247 if let Some(dur) = timeout {
1248 match tokio::time::timeout(dur, fut).await {
1249 Ok(result) => result,
1250 Err(_) => ToolResult::error(format!("tool '{name}' timed out after {dur:?}")),
1251 }
1252 } else {
1253 fut.await
1254 }
1255 }
1256}
1257
1258#[derive(Debug, Clone, Copy)]
1259enum ToolExecutionPolicy {
1260 ParallelOnly,
1261 InteractiveAskUser,
1262}
1263
1264impl ToolExecutionPolicy {
1265 fn from_items(items: &[crate::llm::ToolCallItem]) -> Self {
1266 if items.iter().any(|tc| tc.name == "ask_user") {
1267 Self::InteractiveAskUser
1268 } else {
1269 Self::ParallelOnly
1270 }
1271 }
1272}
1273
1274#[derive(Default)]
1275struct OpsState {
1276 interrupted: bool,
1277 pending_answers: Vec<PendingAskUserAnswer>,
1278}
1279
1280struct PendingAskUserAnswer {
1281 call_id: Option<String>,
1282 answer: String,
1283}
1284
1285fn pop_matching_answer(pending: &mut Vec<PendingAskUserAnswer>, call_id: &str) -> Option<String> {
1286 let pos = pending
1287 .iter()
1288 .position(|item| item.call_id.as_deref() == Some(call_id) || item.call_id.is_none())?;
1289 Some(pending.remove(pos).answer)
1290}
1291
1292fn parse_ask_user_args(args: &serde_json::Value) -> (String, Vec<String>) {
1293 let question = args
1294 .get("question")
1295 .and_then(|value| value.as_str())
1296 .unwrap_or("")
1297 .to_string();
1298 let options = args
1299 .get("options")
1300 .and_then(|value| value.as_array())
1301 .map(|list| {
1302 list.iter()
1303 .filter_map(|entry| entry.as_str().map(ToString::to_string))
1304 .collect::<Vec<_>>()
1305 })
1306 .unwrap_or_default();
1307 (question, options)
1308}
1309
1310fn ask_user_tool_def() -> ToolDef {
1311 ToolDef {
1312 name: "ask_user".to_string(),
1313 description: "Ask the user a question and wait for a reply.".to_string(),
1314 input_schema: serde_json::json!({
1315 "type": "object",
1316 "properties": {
1317 "question": { "type": "string" },
1318 "options": {
1319 "type": "array",
1320 "items": { "type": "string" }
1321 }
1322 },
1323 "required": ["question"]
1324 }),
1325 }
1326}
1327
1328fn tool_result_to_string(result: &ToolResult) -> String {
1330 match result.as_text() {
1331 Some(text) => text.to_string(),
1332 None => {
1333 serde_json::to_string(&result.content).unwrap_or_else(|_| "<no content>".to_string())
1335 }
1336 }
1337}
1338
1339#[cfg(test)]
1340mod tests {
1341 use super::*;
1342 use crate::context::ContextProvider;
1343 use crate::llm::{ChatOutput, LlmResponse, TokenUsage, ToolCallItem};
1344 use async_trait::async_trait;
1345 use std::sync::{Arc, Mutex};
1346
1347 struct MockLlm {
1353 responses: Mutex<Vec<LlmResponse>>,
1354 usage_per_call: Option<TokenUsage>,
1355 }
1356
1357 impl MockLlm {
1358 fn new(responses: Vec<LlmResponse>) -> Self {
1359 Self {
1360 responses: Mutex::new(responses),
1361 usage_per_call: None,
1362 }
1363 }
1364
1365 fn with_usage(responses: Vec<LlmResponse>, usage: TokenUsage) -> Self {
1366 Self {
1367 responses: Mutex::new(responses),
1368 usage_per_call: Some(usage),
1369 }
1370 }
1371 }
1372
1373 #[async_trait]
1374 impl LlmClient for MockLlm {
1375 async fn chat(
1376 &self,
1377 _messages: &[Message],
1378 _tools: &[ToolDef],
1379 ) -> crate::Result<ChatOutput> {
1380 let mut responses = self.responses.lock().unwrap();
1381 if responses.is_empty() {
1382 panic!("MockLlm: no more responses");
1383 }
1384 let response = responses.remove(0);
1385 Ok(ChatOutput {
1386 response,
1387 usage: self.usage_per_call,
1388 })
1389 }
1390 }
1391
1392 struct MockTool {
1397 name: String,
1398 result_text: String,
1399 }
1400
1401 impl MockTool {
1402 fn new(name: &str, result_text: &str) -> Self {
1403 Self {
1404 name: name.to_string(),
1405 result_text: result_text.to_string(),
1406 }
1407 }
1408 }
1409
1410 impl Tool for MockTool {
1411 fn def(&self) -> ToolDef {
1412 ToolDef {
1413 name: self.name.clone(),
1414 description: format!("Mock tool: {}", self.name),
1415 input_schema: serde_json::json!({
1416 "type": "object",
1417 "properties": {
1418 "input": { "type": "string" }
1419 },
1420 "required": []
1421 }),
1422 }
1423 }
1424
1425 fn call(
1426 &self,
1427 _args: serde_json::Value,
1428 _ctx: &ToolContext,
1429 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ToolResult> + Send + '_>> {
1430 let text = self.result_text.clone();
1431 Box::pin(async move { ToolResult::text(text) })
1432 }
1433 }
1434
1435 #[tokio::test]
1440 async fn direct_message_response() {
1441 let llm = MockLlm::new(vec![LlmResponse::Message("Hello!".to_string())]);
1442
1443 let agent = AgentLoop::builder().build();
1444 let result = agent
1445 .run(&llm, vec![Message::user("Hi")], |_| {})
1446 .await
1447 .unwrap();
1448
1449 assert_eq!(result.answer, "Hello!");
1450 assert!(result.tool_calls.is_empty());
1451 assert_eq!(result.iterations, 1);
1452 }
1453
1454 #[tokio::test]
1455 async fn single_tool_call_then_message() {
1456 let llm = MockLlm::new(vec![
1457 LlmResponse::single_tool_call(
1458 "call_1".to_string(),
1459 "search".to_string(),
1460 serde_json::json!({"input": "rust"}),
1461 ),
1462 LlmResponse::Message("Found results about Rust.".to_string()),
1463 ]);
1464
1465 let tool: Arc<dyn Tool> = Arc::new(MockTool::new("search", "Rust is a systems language"));
1466
1467 let agent = AgentLoop::builder().tool(tool).build();
1468 let result = agent
1469 .run(&llm, vec![Message::user("Search for rust")], |_| {})
1470 .await
1471 .unwrap();
1472
1473 assert_eq!(result.answer, "Found results about Rust.");
1474 assert_eq!(result.tool_calls.len(), 1);
1475 assert_eq!(result.tool_calls[0].0, "search");
1476 assert_eq!(result.iterations, 2);
1477 }
1478
1479 #[tokio::test]
1480 async fn parallel_tool_calls() {
1481 let llm = MockLlm::new(vec![
1482 LlmResponse::ToolCalls(vec![
1483 ToolCallItem {
1484 id: "call_1".to_string(),
1485 name: "search".to_string(),
1486 args: serde_json::json!({"input": "a"}),
1487 },
1488 ToolCallItem {
1489 id: "call_2".to_string(),
1490 name: "fetch".to_string(),
1491 args: serde_json::json!({"input": "b"}),
1492 },
1493 ]),
1494 LlmResponse::Message("Combined results.".to_string()),
1495 ]);
1496
1497 let search: Arc<dyn Tool> = Arc::new(MockTool::new("search", "result_a"));
1498 let fetch: Arc<dyn Tool> = Arc::new(MockTool::new("fetch", "result_b"));
1499
1500 let agent = AgentLoop::builder().tool(search).tool(fetch).build();
1501 let result = agent
1502 .run(&llm, vec![Message::user("Do both")], |_| {})
1503 .await
1504 .unwrap();
1505
1506 assert_eq!(result.answer, "Combined results.");
1507 assert_eq!(result.tool_calls.len(), 2);
1508 assert_eq!(result.tool_calls[0].0, "search");
1509 assert_eq!(result.tool_calls[1].0, "fetch");
1510 assert_eq!(result.iterations, 2);
1511 }
1512
1513 #[tokio::test]
1514 async fn max_iterations_exceeded() {
1515 let responses: Vec<LlmResponse> = (0..5)
1517 .map(|i| {
1518 LlmResponse::single_tool_call(
1519 format!("call_{i}"),
1520 "search".to_string(),
1521 serde_json::json!({}),
1522 )
1523 })
1524 .collect();
1525 let llm = MockLlm::new(responses);
1526
1527 let tool: Arc<dyn Tool> = Arc::new(MockTool::new("search", "result"));
1528
1529 let agent = AgentLoop::builder().tool(tool).max_iterations(3).build();
1530 let err = agent
1531 .run(&llm, vec![Message::user("loop forever")], |_| {})
1532 .await
1533 .unwrap_err();
1534
1535 assert!(matches!(err, AgentError::MaxIterations(3)));
1536 }
1537
1538 #[tokio::test]
1539 async fn events_emitted_correctly() {
1540 let llm = MockLlm::new(vec![
1541 LlmResponse::single_tool_call(
1542 "call_1".to_string(),
1543 "search".to_string(),
1544 serde_json::json!({}),
1545 ),
1546 LlmResponse::Message("Done.".to_string()),
1547 ]);
1548
1549 let tool: Arc<dyn Tool> = Arc::new(MockTool::new("search", "result"));
1550 let events: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
1551 let events_clone = events.clone();
1552
1553 let agent = AgentLoop::builder().tool(tool).build();
1554 let _ = agent
1555 .run(&llm, vec![Message::user("test")], move |event| {
1556 let label = match &event {
1557 AgentEvent::ToolStarted { name } => format!("started:{name}"),
1558 AgentEvent::ToolCompleted { name, .. } => format!("completed:{name}"),
1559 AgentEvent::TextChunk(t) => format!("text:{t}"),
1560 AgentEvent::TextDone(t) => format!("done:{t}"),
1561 AgentEvent::IterationStarted(n) => format!("iter:{n}"),
1562 AgentEvent::Interrupted => "interrupted".to_string(),
1563 AgentEvent::AskUser { question, .. } => format!("ask_user:{question}"),
1564 AgentEvent::AskUserTimeout { call_id, .. } => {
1565 format!("ask_user_timeout:{call_id}")
1566 }
1567 AgentEvent::OpsSaturated { capacity } => {
1568 format!("ops_saturated:{capacity}")
1569 }
1570 AgentEvent::OpDropped { reason } => format!("op_dropped:{reason}"),
1571 AgentEvent::OpRejected { reason } => format!("op_rejected:{reason}"),
1572 };
1573 events_clone.lock().unwrap().push(label);
1574 })
1575 .await
1576 .unwrap();
1577
1578 let events = events.lock().unwrap();
1579 assert_eq!(
1580 *events,
1581 vec![
1582 "iter:1",
1583 "started:search",
1584 "completed:search",
1585 "iter:2",
1586 "text:Done.",
1587 ]
1588 );
1589 }
1590
1591 #[tokio::test]
1592 async fn unknown_tool_produces_error_result() {
1593 let llm = MockLlm::new(vec![
1594 LlmResponse::single_tool_call(
1595 "call_1".to_string(),
1596 "nonexistent".to_string(),
1597 serde_json::json!({}),
1598 ),
1599 LlmResponse::Message("Handled error.".to_string()),
1600 ]);
1601
1602 let events: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
1603 let events_clone = events.clone();
1604
1605 let agent = AgentLoop::builder().build();
1606 let result = agent
1607 .run(&llm, vec![Message::user("call missing")], move |event| {
1608 if let AgentEvent::ToolCompleted { result, .. } = &event {
1609 if result.is_error {
1610 events_clone
1611 .lock()
1612 .unwrap()
1613 .push("error_tool_result".to_string());
1614 }
1615 }
1616 })
1617 .await
1618 .unwrap();
1619
1620 assert_eq!(result.answer, "Handled error.");
1621 let events = events.lock().unwrap();
1622 assert!(events.contains(&"error_tool_result".to_string()));
1623 }
1624
1625 #[tokio::test]
1626 async fn builder_tools_method() {
1627 let tools: Vec<Arc<dyn Tool>> = vec![
1628 Arc::new(MockTool::new("a", "ra")),
1629 Arc::new(MockTool::new("b", "rb")),
1630 ];
1631
1632 let llm = MockLlm::new(vec![LlmResponse::Message("ok".to_string())]);
1633
1634 let agent = AgentLoop::builder().tools(tools).build();
1635 let result = agent
1636 .run(&llm, vec![Message::user("hi")], |_| {})
1637 .await
1638 .unwrap();
1639
1640 assert_eq!(result.answer, "ok");
1641 }
1642
1643 #[tokio::test]
1644 async fn noop_event_callback() {
1645 let llm = MockLlm::new(vec![LlmResponse::Message("ok".to_string())]);
1646 let agent = AgentLoop::builder().build();
1647 let result = agent
1649 .run(&llm, vec![Message::user("hi")], |_| {})
1650 .await
1651 .unwrap();
1652 assert_eq!(result.answer, "ok");
1653 }
1654
1655 struct MockContextProvider {
1660 context: String,
1661 }
1662
1663 impl MockContextProvider {
1664 fn new(context: &str) -> Self {
1665 Self {
1666 context: context.to_string(),
1667 }
1668 }
1669 }
1670
1671 #[async_trait]
1672 impl ContextProvider for MockContextProvider {
1673 async fn build(&self, _query: &str) -> crate::Result<String> {
1674 Ok(self.context.clone())
1675 }
1676 }
1677
1678 struct EchoContextProvider;
1681
1682 #[async_trait]
1683 impl ContextProvider for EchoContextProvider {
1684 async fn build(&self, query: &str) -> crate::Result<String> {
1685 Ok(format!("echo: {query}"))
1686 }
1687 }
1688
1689 struct CapturingLlm {
1696 captured: Mutex<Vec<Vec<Message>>>,
1697 responses: Mutex<Vec<LlmResponse>>,
1698 }
1699
1700 impl CapturingLlm {
1701 fn new(responses: Vec<LlmResponse>) -> Self {
1702 Self {
1703 captured: Mutex::new(Vec::new()),
1704 responses: Mutex::new(responses),
1705 }
1706 }
1707
1708 fn captured_messages(&self) -> Vec<Vec<Message>> {
1709 self.captured.lock().unwrap().clone()
1710 }
1711 }
1712
1713 #[async_trait]
1714 impl LlmClient for CapturingLlm {
1715 async fn chat(
1716 &self,
1717 messages: &[Message],
1718 _tools: &[motosan_agent_tool::ToolDef],
1719 ) -> crate::Result<ChatOutput> {
1720 self.captured.lock().unwrap().push(messages.to_vec());
1721 let mut responses = self.responses.lock().unwrap();
1722 if responses.is_empty() {
1723 panic!("CapturingLlm: no more responses");
1724 }
1725 Ok(ChatOutput::new(responses.remove(0)))
1726 }
1727 }
1728
1729 #[tokio::test]
1730 async fn context_provider_injects_system_message() {
1731 let llm = CapturingLlm::new(vec![LlmResponse::Message("answer".to_string())]);
1732
1733 let agent = AgentLoop::builder()
1734 .context(MockContextProvider::new("You have access to RAG docs."))
1735 .build();
1736
1737 let result = agent
1738 .run(&llm, vec![Message::user("tell me about rust")], |_| {})
1739 .await
1740 .unwrap();
1741
1742 assert_eq!(result.answer, "answer");
1743
1744 let calls = llm.captured_messages();
1747 assert_eq!(calls.len(), 1);
1748 let msgs = &calls[0];
1749 assert_eq!(msgs.len(), 2);
1751 assert_eq!(msgs[0].role, crate::message::Role::System);
1752 assert_eq!(msgs[0].content, "You have access to RAG docs.");
1753 assert_eq!(msgs[1].role, crate::message::Role::User);
1754 assert_eq!(msgs[1].content, "tell me about rust");
1755 }
1756
1757 #[tokio::test]
1758 async fn empty_context_is_skipped() {
1759 let llm = CapturingLlm::new(vec![LlmResponse::Message("answer".to_string())]);
1760
1761 let agent = AgentLoop::builder()
1762 .context(MockContextProvider::new("")) .build();
1764
1765 let result = agent
1766 .run(&llm, vec![Message::user("hi")], |_| {})
1767 .await
1768 .unwrap();
1769
1770 assert_eq!(result.answer, "answer");
1771
1772 let calls = llm.captured_messages();
1773 assert_eq!(calls.len(), 1);
1774 let msgs = &calls[0];
1775 assert_eq!(msgs.len(), 1);
1777 assert_eq!(msgs[0].role, crate::message::Role::User);
1778 }
1779
1780 #[tokio::test]
1781 async fn multiple_context_providers() {
1782 let llm = CapturingLlm::new(vec![LlmResponse::Message("done".to_string())]);
1783
1784 let agent = AgentLoop::builder()
1785 .context(MockContextProvider::new("RAG context here"))
1786 .context(MockContextProvider::new("")) .context(MockContextProvider::new("User profile: premium"))
1788 .build();
1789
1790 let result = agent
1791 .run(&llm, vec![Message::user("query")], |_| {})
1792 .await
1793 .unwrap();
1794
1795 assert_eq!(result.answer, "done");
1796
1797 let calls = llm.captured_messages();
1798 assert_eq!(calls.len(), 1);
1799 let msgs = &calls[0];
1800 assert_eq!(msgs.len(), 3);
1802 assert_eq!(msgs[0].role, crate::message::Role::System);
1803 assert_eq!(msgs[0].content, "RAG context here");
1804 assert_eq!(msgs[1].role, crate::message::Role::System);
1805 assert_eq!(msgs[1].content, "User profile: premium");
1806 assert_eq!(msgs[2].role, crate::message::Role::User);
1807 }
1808
1809 #[tokio::test]
1810 async fn context_provider_receives_user_query() {
1811 let llm = CapturingLlm::new(vec![LlmResponse::Message("ok".to_string())]);
1812
1813 let agent = AgentLoop::builder().context(EchoContextProvider).build();
1814
1815 let result = agent
1816 .run(&llm, vec![Message::user("my question")], |_| {})
1817 .await
1818 .unwrap();
1819
1820 assert_eq!(result.answer, "ok");
1821
1822 let calls = llm.captured_messages();
1823 let msgs = &calls[0];
1824 assert_eq!(msgs.len(), 2);
1825 assert_eq!(msgs[0].role, crate::message::Role::System);
1826 assert_eq!(msgs[0].content, "echo: my question");
1827 }
1828
1829 #[tokio::test]
1830 async fn no_context_providers_leaves_messages_unchanged() {
1831 let llm = CapturingLlm::new(vec![LlmResponse::Message("ok".to_string())]);
1832
1833 let agent = AgentLoop::builder().build(); let _ = agent
1836 .run(&llm, vec![Message::user("hi")], |_| {})
1837 .await
1838 .unwrap();
1839
1840 let calls = llm.captured_messages();
1841 let msgs = &calls[0];
1842 assert_eq!(msgs.len(), 1);
1843 assert_eq!(msgs[0].content, "hi");
1844 }
1845
1846 #[tokio::test]
1847 async fn builder_contexts_batch_method() {
1848 let llm = CapturingLlm::new(vec![LlmResponse::Message("ok".to_string())]);
1849
1850 let providers: Vec<Box<dyn ContextProvider>> = vec![
1851 Box::new(MockContextProvider::new("ctx-a")),
1852 Box::new(MockContextProvider::new("ctx-b")),
1853 ];
1854
1855 let agent = AgentLoop::builder().contexts(providers).build();
1856 let result = agent
1857 .run(&llm, vec![Message::user("hi")], |_| {})
1858 .await
1859 .unwrap();
1860
1861 assert_eq!(result.answer, "ok");
1862
1863 let calls = llm.captured_messages();
1864 let msgs = &calls[0];
1865 assert_eq!(msgs.len(), 3);
1867 assert_eq!(msgs[0].role, crate::message::Role::System);
1868 assert_eq!(msgs[0].content, "ctx-a");
1869 assert_eq!(msgs[1].role, crate::message::Role::System);
1870 assert_eq!(msgs[1].content, "ctx-b");
1871 assert_eq!(msgs[2].role, crate::message::Role::User);
1872 assert_eq!(msgs[2].content, "hi");
1873 }
1874
1875 struct DelayContextProvider {
1878 context: String,
1879 delay: std::time::Duration,
1880 }
1881
1882 impl DelayContextProvider {
1883 fn new(context: &str, delay: std::time::Duration) -> Self {
1884 Self {
1885 context: context.to_string(),
1886 delay,
1887 }
1888 }
1889 }
1890
1891 #[async_trait]
1892 impl ContextProvider for DelayContextProvider {
1893 async fn build(&self, _query: &str) -> crate::Result<String> {
1894 tokio::time::sleep(self.delay).await;
1895 Ok(self.context.clone())
1896 }
1897 }
1898
1899 #[tokio::test]
1900 async fn context_providers_run_in_parallel() {
1901 let llm = CapturingLlm::new(vec![LlmResponse::Message("ok".to_string())]);
1902 let delay = std::time::Duration::from_millis(100);
1903
1904 let agent = AgentLoop::builder()
1905 .context(DelayContextProvider::new("ctx-a", delay))
1906 .context(DelayContextProvider::new("ctx-b", delay))
1907 .context(DelayContextProvider::new("ctx-c", delay))
1908 .build();
1909
1910 let start = std::time::Instant::now();
1911 let result = agent
1912 .run(&llm, vec![Message::user("hi")], |_| {})
1913 .await
1914 .unwrap();
1915 let elapsed = start.elapsed();
1916
1917 assert!(
1920 elapsed < std::time::Duration::from_millis(250),
1921 "Expected parallel execution (<250ms), but took {elapsed:?}",
1922 );
1923
1924 assert_eq!(result.answer, "ok");
1925
1926 let calls = llm.captured_messages();
1928 let msgs = &calls[0];
1929 assert_eq!(msgs.len(), 4); assert_eq!(msgs[0].content, "ctx-a");
1931 assert_eq!(msgs[1].content, "ctx-b");
1932 assert_eq!(msgs[2].content, "ctx-c");
1933 assert_eq!(msgs[3].content, "hi");
1934 }
1935
1936 #[tokio::test]
1941 async fn system_prompt_injects_system_message() {
1942 let llm = CapturingLlm::new(vec![LlmResponse::Message("ok".to_string())]);
1943
1944 let agent = AgentLoop::builder()
1945 .system_prompt("You are a helpful assistant.")
1946 .build();
1947
1948 let result = agent
1949 .run(&llm, vec![Message::user("hello")], |_| {})
1950 .await
1951 .unwrap();
1952
1953 assert_eq!(result.answer, "ok");
1954
1955 let calls = llm.captured_messages();
1956 assert_eq!(calls.len(), 1);
1957 let msgs = &calls[0];
1958 assert_eq!(msgs.len(), 2);
1959 assert_eq!(msgs[0].role, crate::message::Role::System);
1960 assert_eq!(msgs[0].content, "You are a helpful assistant.");
1961 assert_eq!(msgs[1].role, crate::message::Role::User);
1962 assert_eq!(msgs[1].content, "hello");
1963 }
1964
1965 struct StreamingMockLlm {
1972 responses: Mutex<Vec<Vec<crate::llm::StreamChunk>>>,
1974 }
1975
1976 impl StreamingMockLlm {
1977 fn new(responses: Vec<Vec<crate::llm::StreamChunk>>) -> Self {
1978 Self {
1979 responses: Mutex::new(responses),
1980 }
1981 }
1982 }
1983
1984 #[async_trait]
1985 impl LlmClient for StreamingMockLlm {
1986 async fn chat(
1987 &self,
1988 _messages: &[Message],
1989 _tools: &[ToolDef],
1990 ) -> crate::Result<ChatOutput> {
1991 panic!("StreamingMockLlm: chat() should not be called");
1992 }
1993
1994 fn chat_stream<'a>(
1995 &'a self,
1996 _messages: &'a [Message],
1997 _tools: &'a [ToolDef],
1998 ) -> std::pin::Pin<
1999 Box<dyn futures::Stream<Item = crate::Result<crate::llm::StreamChunk>> + Send + 'a>,
2000 > {
2001 let chunks = {
2002 let mut responses = self.responses.lock().unwrap();
2003 if responses.is_empty() {
2004 panic!("StreamingMockLlm: no more responses");
2005 }
2006 responses.remove(0)
2007 };
2008 Box::pin(futures::stream::iter(chunks.into_iter().map(Ok)))
2009 }
2010 }
2011
2012 #[tokio::test]
2013 async fn run_streaming_emits_text_chunks_and_done() {
2014 let llm = StreamingMockLlm::new(vec![vec![
2015 crate::llm::StreamChunk::TextDelta("Hel".into()),
2016 crate::llm::StreamChunk::TextDelta("lo!".into()),
2017 crate::llm::StreamChunk::Done(LlmResponse::Message("Hello!".into())),
2018 ]]);
2019
2020 let events: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
2021 let events_clone = events.clone();
2022
2023 let agent = AgentLoop::builder().build();
2024 let result = agent
2025 .run_streaming(&llm, vec![Message::user("Hi")], move |event| {
2026 let label = match &event {
2027 AgentEvent::TextChunk(t) => format!("chunk:{t}"),
2028 AgentEvent::TextDone(t) => format!("done:{t}"),
2029 AgentEvent::IterationStarted(n) => format!("iter:{n}"),
2030 AgentEvent::ToolStarted { name } => format!("started:{name}"),
2031 AgentEvent::ToolCompleted { name, .. } => format!("completed:{name}"),
2032 AgentEvent::Interrupted => "interrupted".to_string(),
2033 AgentEvent::AskUser { question, .. } => format!("ask_user:{question}"),
2034 AgentEvent::AskUserTimeout { call_id, .. } => {
2035 format!("ask_user_timeout:{call_id}")
2036 }
2037 _ => format!("{event:?}"),
2038 };
2039 events_clone.lock().unwrap().push(label);
2040 })
2041 .await
2042 .unwrap();
2043
2044 assert_eq!(result.answer, "Hello!");
2045 assert_eq!(result.iterations, 1);
2046
2047 let events = events.lock().unwrap();
2048 assert_eq!(
2049 *events,
2050 vec!["iter:1", "chunk:Hel", "chunk:lo!", "done:Hello!"]
2051 );
2052 }
2053
2054 #[tokio::test]
2055 async fn run_streaming_with_tool_calls() {
2056 let llm = StreamingMockLlm::new(vec![
2057 vec![crate::llm::StreamChunk::Done(
2059 LlmResponse::single_tool_call(
2060 "call_1".into(),
2061 "search".into(),
2062 serde_json::json!({"input": "rust"}),
2063 ),
2064 )],
2065 vec![
2067 crate::llm::StreamChunk::TextDelta("Found ".into()),
2068 crate::llm::StreamChunk::TextDelta("it.".into()),
2069 crate::llm::StreamChunk::Done(LlmResponse::Message("Found it.".into())),
2070 ],
2071 ]);
2072
2073 let tool: Arc<dyn Tool> = Arc::new(MockTool::new("search", "Rust is great"));
2074 let events: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
2075 let events_clone = events.clone();
2076
2077 let agent = AgentLoop::builder().tool(tool).build();
2078 let result = agent
2079 .run_streaming(&llm, vec![Message::user("search rust")], move |event| {
2080 let label = match &event {
2081 AgentEvent::TextChunk(t) => format!("chunk:{t}"),
2082 AgentEvent::TextDone(t) => format!("done:{t}"),
2083 AgentEvent::IterationStarted(n) => format!("iter:{n}"),
2084 AgentEvent::ToolStarted { name } => format!("started:{name}"),
2085 AgentEvent::ToolCompleted { name, .. } => format!("completed:{name}"),
2086 AgentEvent::Interrupted => "interrupted".to_string(),
2087 AgentEvent::AskUser { question, .. } => format!("ask_user:{question}"),
2088 AgentEvent::AskUserTimeout { call_id, .. } => {
2089 format!("ask_user_timeout:{call_id}")
2090 }
2091 _ => format!("{event:?}"),
2092 };
2093 events_clone.lock().unwrap().push(label);
2094 })
2095 .await
2096 .unwrap();
2097
2098 assert_eq!(result.answer, "Found it.");
2099 assert_eq!(result.tool_calls.len(), 1);
2100 assert_eq!(result.iterations, 2);
2101
2102 let events = events.lock().unwrap();
2103 assert_eq!(
2104 *events,
2105 vec![
2106 "iter:1",
2107 "started:search",
2108 "completed:search",
2109 "iter:2",
2110 "chunk:Found ",
2111 "chunk:it.",
2112 "done:Found it.",
2113 ]
2114 );
2115 }
2116
2117 #[tokio::test]
2118 async fn run_streaming_fallback_non_streaming_llm() {
2119 let llm = MockLlm::new(vec![LlmResponse::Message("Hello!".into())]);
2122
2123 let events: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
2124 let events_clone = events.clone();
2125
2126 let agent = AgentLoop::builder().build();
2127 let result = agent
2128 .run_streaming(&llm, vec![Message::user("Hi")], move |event| {
2129 let label = match &event {
2130 AgentEvent::TextChunk(t) => format!("chunk:{t}"),
2131 AgentEvent::TextDone(t) => format!("done:{t}"),
2132 AgentEvent::IterationStarted(n) => format!("iter:{n}"),
2133 _ => "other".into(),
2134 };
2135 events_clone.lock().unwrap().push(label);
2136 })
2137 .await
2138 .unwrap();
2139
2140 assert_eq!(result.answer, "Hello!");
2141
2142 let events = events.lock().unwrap();
2143 assert_eq!(*events, vec!["iter:1", "chunk:Hello!", "done:Hello!"]);
2146 }
2147
2148 #[tokio::test]
2153 async fn usage_is_nonzero_after_mocked_llm_call() {
2154 let usage = TokenUsage {
2155 input_tokens: 100,
2156 output_tokens: 50,
2157 };
2158 let llm = MockLlm::with_usage(vec![LlmResponse::Message("Hello!".to_string())], usage);
2159
2160 let agent = AgentLoop::builder().build();
2161 let result = agent
2162 .run(&llm, vec![Message::user("Hi")], |_| {})
2163 .await
2164 .unwrap();
2165
2166 assert_eq!(result.usage.input_tokens, 100);
2167 assert_eq!(result.usage.output_tokens, 50);
2168 }
2169
2170 #[tokio::test]
2171 async fn usage_accumulates_across_iterations() {
2172 let usage = TokenUsage {
2173 input_tokens: 10,
2174 output_tokens: 20,
2175 };
2176 let llm = MockLlm::with_usage(
2177 vec![
2178 LlmResponse::single_tool_call(
2179 "call_1".to_string(),
2180 "search".to_string(),
2181 serde_json::json!({}),
2182 ),
2183 LlmResponse::Message("Done.".to_string()),
2184 ],
2185 usage,
2186 );
2187
2188 let tool: Arc<dyn Tool> = Arc::new(MockTool::new("search", "result"));
2189 let agent = AgentLoop::builder().tool(tool).build();
2190 let result = agent
2191 .run(&llm, vec![Message::user("test")], |_| {})
2192 .await
2193 .unwrap();
2194
2195 assert_eq!(result.usage.input_tokens, 20);
2197 assert_eq!(result.usage.output_tokens, 40);
2198 assert_eq!(result.iterations, 2);
2199 }
2200
2201 #[tokio::test]
2202 async fn usage_zero_when_llm_reports_no_usage() {
2203 let llm = MockLlm::new(vec![LlmResponse::Message("ok".to_string())]);
2204
2205 let agent = AgentLoop::builder().build();
2206 let result = agent
2207 .run(&llm, vec![Message::user("hi")], |_| {})
2208 .await
2209 .unwrap();
2210
2211 assert_eq!(result.usage.input_tokens, 0);
2212 assert_eq!(result.usage.output_tokens, 0);
2213 }
2214
2215 #[tokio::test]
2216 async fn streaming_usage_accumulates() {
2217 let llm = StreamingMockLlm::new(vec![vec![
2218 crate::llm::StreamChunk::TextDelta("Hi".into()),
2219 crate::llm::StreamChunk::Done(LlmResponse::Message("Hi".into())),
2220 crate::llm::StreamChunk::Usage(TokenUsage {
2221 input_tokens: 50,
2222 output_tokens: 25,
2223 }),
2224 ]]);
2225
2226 let agent = AgentLoop::builder().build();
2227 let result = agent
2228 .run_streaming(&llm, vec![Message::user("Hi")], |_| {})
2229 .await
2230 .unwrap();
2231
2232 assert_eq!(result.usage.input_tokens, 50);
2233 assert_eq!(result.usage.output_tokens, 25);
2234 }
2235
2236 #[tokio::test]
2237 async fn run_streaming_max_iterations() {
2238 let responses: Vec<Vec<crate::llm::StreamChunk>> = (0..5)
2239 .map(|i| {
2240 vec![crate::llm::StreamChunk::Done(
2241 LlmResponse::single_tool_call(
2242 format!("call_{i}"),
2243 "search".into(),
2244 serde_json::json!({}),
2245 ),
2246 )]
2247 })
2248 .collect();
2249 let llm = StreamingMockLlm::new(responses);
2250
2251 let tool: Arc<dyn Tool> = Arc::new(MockTool::new("search", "result"));
2252 let agent = AgentLoop::builder().tool(tool).max_iterations(3).build();
2253 let err = agent
2254 .run_streaming(&llm, vec![Message::user("loop")], |_| {})
2255 .await
2256 .unwrap_err();
2257
2258 assert!(matches!(err, AgentError::MaxIterations(3)));
2259 }
2260
2261 #[tokio::test]
2266 async fn result_messages_contains_full_history() {
2267 let llm = MockLlm::new(vec![LlmResponse::Message("Hello!".to_string())]);
2268
2269 let agent = AgentLoop::builder().build();
2270 let result = agent
2271 .run(&llm, vec![Message::user("Hi")], |_| {})
2272 .await
2273 .unwrap();
2274
2275 assert_eq!(result.messages.len(), 2);
2277 assert_eq!(result.messages[0].role, crate::message::Role::User);
2278 assert_eq!(result.messages[0].content, "Hi");
2279 assert_eq!(result.messages[1].role, crate::message::Role::Assistant);
2280 assert_eq!(result.messages[1].content, "Hello!");
2281 }
2282
2283 #[tokio::test]
2284 async fn result_messages_includes_tool_call_pairs() {
2285 let llm = MockLlm::new(vec![
2286 LlmResponse::single_tool_call(
2287 "call_1".to_string(),
2288 "search".to_string(),
2289 serde_json::json!({"input": "rust"}),
2290 ),
2291 LlmResponse::Message("Found it.".to_string()),
2292 ]);
2293
2294 let tool: Arc<dyn Tool> = Arc::new(MockTool::new("search", "result text"));
2295
2296 let agent = AgentLoop::builder().tool(tool).build();
2297 let result = agent
2298 .run(&llm, vec![Message::user("Search")], |_| {})
2299 .await
2300 .unwrap();
2301
2302 assert_eq!(result.messages.len(), 4);
2304 assert_eq!(result.messages[0].role, crate::message::Role::User);
2305 assert_eq!(result.messages[1].role, crate::message::Role::Assistant);
2306 assert_eq!(result.messages[1].tool_calls.len(), 1);
2307 assert_eq!(result.messages[1].tool_calls[0].name, "search");
2308 assert_eq!(result.messages[2].role, crate::message::Role::Tool);
2309 assert_eq!(result.messages[3].role, crate::message::Role::Assistant);
2310 assert_eq!(result.messages[3].content, "Found it.");
2311 }
2312
2313 #[tokio::test]
2314 async fn multi_turn_continuation_via_messages() {
2315 let llm1 = MockLlm::new(vec![LlmResponse::Message("I'm fine!".to_string())]);
2317 let agent = AgentLoop::builder().build();
2318 let result1 = agent
2319 .run(&llm1, vec![Message::user("How are you?")], |_| {})
2320 .await
2321 .unwrap();
2322
2323 let llm2 = MockLlm::new(vec![LlmResponse::Message("Goodbye!".to_string())]);
2325 let mut next_messages = result1.messages;
2326 next_messages.push(Message::user("Bye!"));
2327
2328 let result2 = agent.run(&llm2, next_messages, |_| {}).await.unwrap();
2329
2330 assert_eq!(result2.answer, "Goodbye!");
2331 assert_eq!(result2.messages.len(), 4);
2333 assert_eq!(result2.messages[0].content, "How are you?");
2334 assert_eq!(result2.messages[1].content, "I'm fine!");
2335 assert_eq!(result2.messages[2].content, "Bye!");
2336 assert_eq!(result2.messages[3].content, "Goodbye!");
2337 }
2338
2339 #[tokio::test]
2340 async fn streaming_result_messages_contains_full_history() {
2341 let llm = StreamingMockLlm::new(vec![vec![
2342 crate::llm::StreamChunk::TextDelta("Hi".to_string()),
2343 crate::llm::StreamChunk::TextDelta(" there".to_string()),
2344 crate::llm::StreamChunk::Done(LlmResponse::Message("Hi there".to_string())),
2345 ]]);
2346
2347 let agent = AgentLoop::builder().build();
2348 let result = agent
2349 .run_streaming(&llm, vec![Message::user("Hello")], |_| {})
2350 .await
2351 .unwrap();
2352
2353 assert_eq!(result.messages.len(), 2);
2354 assert_eq!(result.messages[0].role, crate::message::Role::User);
2355 assert_eq!(result.messages[0].content, "Hello");
2356 assert_eq!(result.messages[1].role, crate::message::Role::Assistant);
2357 assert_eq!(result.messages[1].content, "Hi there");
2358 }
2359
2360 #[test]
2361 fn tool_execution_policy_detects_ask_user_presence() {
2362 let parallel_items = vec![ToolCallItem {
2363 id: "call_1".to_string(),
2364 name: "search".to_string(),
2365 args: serde_json::json!({}),
2366 }];
2367 let interactive_items = vec![
2368 ToolCallItem {
2369 id: "call_1".to_string(),
2370 name: "search".to_string(),
2371 args: serde_json::json!({}),
2372 },
2373 ToolCallItem {
2374 id: "call_2".to_string(),
2375 name: "ask_user".to_string(),
2376 args: serde_json::json!({"question": "continue?"}),
2377 },
2378 ];
2379
2380 assert!(matches!(
2381 ToolExecutionPolicy::from_items(¶llel_items),
2382 ToolExecutionPolicy::ParallelOnly
2383 ));
2384 assert!(matches!(
2385 ToolExecutionPolicy::from_items(&interactive_items),
2386 ToolExecutionPolicy::InteractiveAskUser
2387 ));
2388 }
2389
2390 #[tokio::test]
2398 async fn tool_result_ordering_matches_call_ordering() {
2399 let llm = MockLlm::new(vec![
2400 LlmResponse::ToolCalls(vec![
2401 ToolCallItem {
2402 id: "c1".to_string(),
2403 name: "alpha".to_string(),
2404 args: serde_json::json!({}),
2405 },
2406 ToolCallItem {
2407 id: "c2".to_string(),
2408 name: "beta".to_string(),
2409 args: serde_json::json!({}),
2410 },
2411 ToolCallItem {
2412 id: "c3".to_string(),
2413 name: "gamma".to_string(),
2414 args: serde_json::json!({}),
2415 },
2416 ]),
2417 LlmResponse::Message("done".to_string()),
2418 ]);
2419
2420 let alpha: Arc<dyn Tool> = Arc::new(MockTool::new("alpha", "res_alpha"));
2421 let beta: Arc<dyn Tool> = Arc::new(MockTool::new("beta", "res_beta"));
2422 let gamma: Arc<dyn Tool> = Arc::new(MockTool::new("gamma", "res_gamma"));
2423
2424 let agent = AgentLoop::builder()
2425 .tool(alpha)
2426 .tool(beta)
2427 .tool(gamma)
2428 .build();
2429
2430 let result = agent
2431 .run(&llm, vec![Message::user("go")], |_| {})
2432 .await
2433 .unwrap();
2434
2435 assert_eq!(result.tool_calls[0].0, "alpha");
2437 assert_eq!(result.tool_calls[1].0, "beta");
2438 assert_eq!(result.tool_calls[2].0, "gamma");
2439
2440 assert_eq!(result.messages.len(), 6);
2442 assert_eq!(result.messages[1].tool_calls.len(), 3);
2444 assert_eq!(result.messages[1].tool_calls[0].id, "c1");
2445 assert_eq!(result.messages[1].tool_calls[1].id, "c2");
2446 assert_eq!(result.messages[1].tool_calls[2].id, "c3");
2447 assert_eq!(result.messages[2].tool_call_id.as_deref(), Some("c1"));
2449 assert_eq!(result.messages[2].content, "res_alpha");
2450 assert_eq!(result.messages[3].tool_call_id.as_deref(), Some("c2"));
2451 assert_eq!(result.messages[3].content, "res_beta");
2452 assert_eq!(result.messages[4].tool_call_id.as_deref(), Some("c3"));
2453 assert_eq!(result.messages[4].content, "res_gamma");
2454 }
2455
2456 #[tokio::test]
2460 async fn stage_boundary_event_ordering() {
2461 let llm = MockLlm::new(vec![
2462 LlmResponse::ToolCalls(vec![
2463 ToolCallItem {
2464 id: "c1".to_string(),
2465 name: "x".to_string(),
2466 args: serde_json::json!({}),
2467 },
2468 ToolCallItem {
2469 id: "c2".to_string(),
2470 name: "y".to_string(),
2471 args: serde_json::json!({}),
2472 },
2473 ]),
2474 LlmResponse::Message("final".to_string()),
2475 ]);
2476
2477 let x: Arc<dyn Tool> = Arc::new(MockTool::new("x", "rx"));
2478 let y: Arc<dyn Tool> = Arc::new(MockTool::new("y", "ry"));
2479
2480 let events: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
2481 let events_clone = events.clone();
2482
2483 let agent = AgentLoop::builder().tool(x).tool(y).build();
2484 let _ = agent
2485 .run(&llm, vec![Message::user("go")], move |event| {
2486 let label = match &event {
2487 AgentEvent::IterationStarted(n) => format!("iter:{n}"),
2488 AgentEvent::ToolStarted { name } => format!("started:{name}"),
2489 AgentEvent::ToolCompleted { name, .. } => format!("completed:{name}"),
2490 AgentEvent::TextChunk(t) => format!("text:{t}"),
2491 AgentEvent::TextDone(_) => "done".to_string(),
2492 _ => "other".to_string(),
2493 };
2494 events_clone.lock().unwrap().push(label);
2495 })
2496 .await
2497 .unwrap();
2498
2499 let events = events.lock().unwrap();
2500 assert_eq!(events[0], "iter:1");
2507 assert_eq!(events[1], "started:x");
2508 assert_eq!(events[2], "started:y");
2509 assert_eq!(events[3], "completed:x");
2510 assert_eq!(events[4], "completed:y");
2511 assert_eq!(events[5], "iter:2");
2512 assert_eq!(events[6], "text:final");
2513 }
2514
2515 #[tokio::test]
2518 async fn turn_state_accumulation_across_iterations() {
2519 let usage = TokenUsage {
2520 input_tokens: 7,
2521 output_tokens: 3,
2522 };
2523 let llm = MockLlm::with_usage(
2524 vec![
2525 LlmResponse::single_tool_call(
2526 "c1".to_string(),
2527 "t1".to_string(),
2528 serde_json::json!({}),
2529 ),
2530 LlmResponse::single_tool_call(
2531 "c2".to_string(),
2532 "t2".to_string(),
2533 serde_json::json!({}),
2534 ),
2535 LlmResponse::Message("end".to_string()),
2536 ],
2537 usage,
2538 );
2539
2540 let t1: Arc<dyn Tool> = Arc::new(MockTool::new("t1", "r1"));
2541 let t2: Arc<dyn Tool> = Arc::new(MockTool::new("t2", "r2"));
2542
2543 let agent = AgentLoop::builder().tool(t1).tool(t2).build();
2544 let result = agent
2545 .run(&llm, vec![Message::user("go")], |_| {})
2546 .await
2547 .unwrap();
2548
2549 assert_eq!(result.iterations, 3);
2550 assert_eq!(result.tool_calls.len(), 2);
2551 assert_eq!(result.usage.input_tokens, 21);
2553 assert_eq!(result.usage.output_tokens, 9);
2554 assert_eq!(result.messages.len(), 6);
2557 }
2558
2559 #[tokio::test]
2562 async fn streaming_and_non_streaming_produce_same_messages() {
2563 let llm_sync = MockLlm::new(vec![
2565 LlmResponse::single_tool_call(
2566 "c1".to_string(),
2567 "search".to_string(),
2568 serde_json::json!({"q": "rust"}),
2569 ),
2570 LlmResponse::Message("Found it.".to_string()),
2571 ]);
2572 let tool: Arc<dyn Tool> = Arc::new(MockTool::new("search", "result text"));
2573 let agent = AgentLoop::builder().tool(tool).build();
2574 let result_sync = agent
2575 .run(&llm_sync, vec![Message::user("Search")], |_| {})
2576 .await
2577 .unwrap();
2578
2579 let llm_stream = MockLlm::new(vec![
2581 LlmResponse::single_tool_call(
2582 "c1".to_string(),
2583 "search".to_string(),
2584 serde_json::json!({"q": "rust"}),
2585 ),
2586 LlmResponse::Message("Found it.".to_string()),
2587 ]);
2588 let tool2: Arc<dyn Tool> = Arc::new(MockTool::new("search", "result text"));
2589 let agent2 = AgentLoop::builder().tool(tool2).build();
2590 let result_stream = agent2
2591 .run_streaming(&llm_stream, vec![Message::user("Search")], |_| {})
2592 .await
2593 .unwrap();
2594
2595 assert_eq!(result_sync.messages.len(), result_stream.messages.len());
2597 for (a, b) in result_sync
2598 .messages
2599 .iter()
2600 .zip(result_stream.messages.iter())
2601 {
2602 assert_eq!(a.role, b.role);
2603 assert_eq!(a.content, b.content);
2604 assert_eq!(a.tool_call_id, b.tool_call_id);
2605 assert_eq!(a.tool_calls.len(), b.tool_calls.len());
2606 }
2607 assert_eq!(result_sync.iterations, result_stream.iterations);
2608 assert_eq!(result_sync.tool_calls.len(), result_stream.tool_calls.len());
2609 }
2610
2611 #[cfg(feature = "cancellation")]
2616 mod cancellation_tests {
2617 use super::*;
2618 use tokio_util::sync::CancellationToken;
2619
2620 #[tokio::test]
2621 async fn cancel_before_run_returns_cancelled() {
2622 let llm = MockLlm::new(vec![LlmResponse::Message("Hello!".into())]);
2623 let agent = AgentLoop::builder().build();
2624
2625 let token = CancellationToken::new();
2626 token.cancel();
2627
2628 let err = agent
2629 .run_with_cancel(&llm, vec![Message::user("Hi")], token, |_| {})
2630 .await
2631 .unwrap_err();
2632
2633 assert!(matches!(err, AgentError::Cancelled));
2634 }
2635
2636 #[tokio::test]
2637 async fn cancel_mid_run_returns_cancelled() {
2638 let responses: Vec<LlmResponse> = (0..5)
2639 .map(|i| {
2640 LlmResponse::single_tool_call(
2641 format!("call_{i}"),
2642 "search".to_string(),
2643 serde_json::json!({}),
2644 )
2645 })
2646 .collect();
2647 let llm = MockLlm::new(responses);
2648 let tool: Arc<dyn Tool> = Arc::new(MockTool::new("search", "result"));
2649
2650 let agent = AgentLoop::builder().tool(tool).max_iterations(5).build();
2651
2652 let token = CancellationToken::new();
2653 let child = token.child_token();
2654 let iterations = Arc::new(Mutex::new(0usize));
2655 let iterations_clone = iterations.clone();
2656 let token_clone = token.clone();
2657
2658 let err = agent
2659 .run_with_cancel(&llm, vec![Message::user("loop")], child, move |event| {
2660 if let AgentEvent::IterationStarted(_) = &event {
2661 let mut count = iterations_clone.lock().unwrap();
2662 *count += 1;
2663 if *count >= 2 {
2664 token_clone.cancel();
2665 }
2666 }
2667 })
2668 .await
2669 .unwrap_err();
2670
2671 assert!(matches!(err, AgentError::Cancelled));
2672 let count = *iterations.lock().unwrap();
2673 assert!(
2674 count >= 2 && count <= 3,
2675 "unexpected iteration count: {count}"
2676 );
2677 }
2678
2679 #[tokio::test]
2680 async fn cancel_streaming_returns_cancelled() {
2681 let llm = MockLlm::new(vec![LlmResponse::Message("Hello!".into())]);
2682 let agent = AgentLoop::builder().build();
2683
2684 let token = CancellationToken::new();
2685 token.cancel();
2686
2687 let err = agent
2688 .run_streaming_with_cancel(&llm, vec![Message::user("Hi")], token, |_| {})
2689 .await
2690 .unwrap_err();
2691
2692 assert!(matches!(err, AgentError::Cancelled));
2693 }
2694
2695 #[tokio::test]
2696 async fn uncancelled_token_runs_normally() {
2697 let llm = MockLlm::new(vec![LlmResponse::Message("Hello!".into())]);
2698 let agent = AgentLoop::builder().build();
2699
2700 let token = CancellationToken::new();
2701
2702 let result = agent
2703 .run_with_cancel(&llm, vec![Message::user("Hi")], token, |_| {})
2704 .await
2705 .unwrap();
2706
2707 assert_eq!(result.answer, "Hello!");
2708 assert_eq!(result.iterations, 1);
2709 }
2710 }
2711}
2712
2713#[cfg(all(test, feature = "mcp-client"))]
2714mod mcp_integration_tests {
2715 use super::*;
2716 use crate::mcp::McpServer;
2717 use async_trait::async_trait;
2718 use motosan_agent_tool::ToolDef;
2719 use serde_json::json;
2720
2721 struct EchoMcpServer {
2722 name: String,
2723 }
2724
2725 #[async_trait]
2726 impl McpServer for EchoMcpServer {
2727 fn name(&self) -> &str {
2728 &self.name
2729 }
2730 async fn connect(&self) -> crate::Result<()> {
2731 Ok(())
2732 }
2733 async fn list_tools(&self) -> crate::Result<Vec<ToolDef>> {
2734 Ok(vec![ToolDef {
2735 name: "echo".to_string(),
2736 description: "Echo input".to_string(),
2737 input_schema: json!({"type": "object", "properties": {"msg": {"type": "string"}}}),
2738 }])
2739 }
2740 async fn call_tool(&self, _name: &str, args: serde_json::Value) -> crate::Result<String> {
2741 Ok(format!(
2742 "echo: {}",
2743 args.get("msg").and_then(|v| v.as_str()).unwrap_or("")
2744 ))
2745 }
2746 async fn disconnect(&self) -> crate::Result<()> {
2747 Ok(())
2748 }
2749 }
2750
2751 #[test]
2752 fn builder_accepts_mcp_server() {
2753 let agent = AgentLoop::builder()
2754 .mcp_server(EchoMcpServer {
2755 name: "test_mcp".to_string(),
2756 })
2757 .max_iterations(5)
2758 .build();
2759 assert_eq!(agent.mcp_servers.len(), 1);
2760 }
2761
2762 #[test]
2763 fn builder_accepts_shared_arc_mcp_server() {
2764 let shared: Arc<dyn McpServer> = Arc::new(EchoMcpServer {
2765 name: "shared_mcp".to_string(),
2766 });
2767
2768 let agent_a = AgentLoop::builder()
2769 .mcp_server_arc(Arc::clone(&shared))
2770 .max_iterations(5)
2771 .build();
2772
2773 let agent_b = AgentLoop::builder()
2774 .mcp_server_arc(Arc::clone(&shared))
2775 .max_iterations(5)
2776 .build();
2777
2778 assert_eq!(agent_a.mcp_servers.len(), 1);
2779 assert_eq!(agent_b.mcp_servers.len(), 1);
2780 assert!(Arc::ptr_eq(
2781 &agent_a.mcp_servers[0],
2782 &agent_b.mcp_servers[0]
2783 ));
2784 }
2785
2786 struct McpTestLlm {
2789 responses: std::sync::Mutex<Vec<crate::llm::LlmResponse>>,
2790 }
2791
2792 impl McpTestLlm {
2793 fn new(responses: Vec<crate::llm::LlmResponse>) -> Self {
2794 Self {
2795 responses: std::sync::Mutex::new(responses),
2796 }
2797 }
2798 }
2799
2800 #[async_trait]
2801 impl crate::llm::LlmClient for McpTestLlm {
2802 async fn chat(
2803 &self,
2804 _messages: &[Message],
2805 _tools: &[ToolDef],
2806 ) -> crate::Result<crate::llm::ChatOutput> {
2807 let mut responses = self.responses.lock().unwrap();
2808 assert!(!responses.is_empty(), "McpTestLlm: no more responses");
2809 let response = responses.remove(0);
2810 Ok(crate::llm::ChatOutput {
2811 response,
2812 usage: None,
2813 })
2814 }
2815 }
2816
2817 #[tokio::test]
2818 async fn run_streaming_with_mcp_server() {
2819 use crate::llm::{LlmResponse, ToolCallItem};
2820 use std::sync::Mutex;
2821
2822 let llm = McpTestLlm::new(vec![
2826 LlmResponse::ToolCalls(vec![ToolCallItem {
2827 id: "call_mcp".to_string(),
2828 name: "test_mcp__echo".to_string(),
2829 args: json!({"msg": "hello"}),
2830 }]),
2831 LlmResponse::Message("MCP says: echo hello".to_string()),
2832 ]);
2833
2834 let events: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
2835 let events_clone = events.clone();
2836
2837 let agent = AgentLoop::builder()
2838 .mcp_server(EchoMcpServer {
2839 name: "test_mcp".to_string(),
2840 })
2841 .max_iterations(5)
2842 .build();
2843
2844 let result = agent
2845 .run_streaming(&llm, vec![Message::user("call echo")], move |event| {
2846 let label = match &event {
2847 AgentEvent::ToolStarted { name } => format!("started:{name}"),
2848 AgentEvent::ToolCompleted { name, result } => {
2849 format!("completed:{name}:{}", result.as_text().unwrap_or(""))
2850 }
2851 AgentEvent::TextChunk(t) => format!("chunk:{t}"),
2852 AgentEvent::TextDone(t) => format!("done:{t}"),
2853 AgentEvent::IterationStarted(n) => format!("iter:{n}"),
2854 AgentEvent::Interrupted => "interrupted".to_string(),
2855 AgentEvent::AskUser { question, .. } => format!("ask_user:{question}"),
2856 AgentEvent::AskUserTimeout { call_id, .. } => {
2857 format!("ask_user_timeout:{call_id}")
2858 }
2859 };
2860 events_clone.lock().unwrap().push(label);
2861 })
2862 .await
2863 .unwrap();
2864
2865 assert_eq!(result.answer, "MCP says: echo hello");
2866 assert_eq!(result.tool_calls.len(), 1);
2867 assert_eq!(result.tool_calls[0].0, "test_mcp__echo");
2868
2869 let events = events.lock().unwrap();
2870 assert!(events
2872 .iter()
2873 .any(|e| e.starts_with("completed:test_mcp__echo:echo: hello")));
2874 }
2875
2876 struct TrackingMcpServer {
2878 name: String,
2879 fail_connect: bool,
2880 connected: Arc<std::sync::atomic::AtomicBool>,
2881 disconnect_count: Arc<std::sync::atomic::AtomicUsize>,
2882 }
2883
2884 #[async_trait]
2885 impl McpServer for TrackingMcpServer {
2886 fn name(&self) -> &str {
2887 &self.name
2888 }
2889 async fn connect(&self) -> crate::Result<()> {
2890 if self.fail_connect {
2891 Err(crate::AgentError::Mcp("connect failed".into()))
2892 } else {
2893 self.connected
2894 .store(true, std::sync::atomic::Ordering::SeqCst);
2895 Ok(())
2896 }
2897 }
2898 async fn list_tools(&self) -> crate::Result<Vec<ToolDef>> {
2899 Ok(vec![])
2900 }
2901 async fn call_tool(&self, _name: &str, _args: serde_json::Value) -> crate::Result<String> {
2902 Ok(String::new())
2903 }
2904 async fn disconnect(&self) -> crate::Result<()> {
2905 self.disconnect_count
2906 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
2907 Ok(())
2908 }
2909 }
2910
2911 #[tokio::test]
2912 async fn partial_connect_failure_disconnects_already_connected() {
2913 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
2914
2915 let server1_connected = Arc::new(AtomicBool::new(false));
2916 let server1_disconnects = Arc::new(AtomicUsize::new(0));
2917 let server2_disconnects = Arc::new(AtomicUsize::new(0));
2918
2919 let server1: Arc<dyn McpServer> = Arc::new(TrackingMcpServer {
2920 name: "ok_server".to_string(),
2921 fail_connect: false,
2922 connected: server1_connected.clone(),
2923 disconnect_count: server1_disconnects.clone(),
2924 });
2925 let server2: Arc<dyn McpServer> = Arc::new(TrackingMcpServer {
2926 name: "fail_server".to_string(),
2927 fail_connect: true,
2928 connected: Arc::new(AtomicBool::new(false)),
2929 disconnect_count: server2_disconnects.clone(),
2930 });
2931
2932 let agent = AgentLoop::builder()
2933 .mcp_server_arc(server1)
2934 .mcp_server_arc(server2)
2935 .max_iterations(1)
2936 .build();
2937
2938 let llm = McpTestLlm::new(vec![]);
2939 let err = agent
2940 .run(&llm, vec![Message::user("hi")], |_| {})
2941 .await
2942 .unwrap_err();
2943
2944 assert!(
2945 matches!(err, crate::AgentError::Mcp(_)),
2946 "expected connect error"
2947 );
2948 assert!(server1_connected.load(Ordering::SeqCst));
2950 assert_eq!(server1_disconnects.load(Ordering::SeqCst), 1);
2951 assert_eq!(server2_disconnects.load(Ordering::SeqCst), 0);
2953 }
2954}