1#![allow(dead_code)]
2use crate::{
3 default_context::DefaultContext,
4 errors::AgentError,
5 hooks::{
6 AfterCompletionFn, AfterEachFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn,
7 Hook, HookTypes, MessageHookFn, OnStartFn, OnStopFn,
8 },
9 invoke_hooks,
10 state::{self, StopReason},
11 system_prompt::SystemPrompt,
12 tools::{arg_preprocessor::ArgPreprocessor, control::Stop},
13};
14use std::{
15 collections::{HashMap, HashSet},
16 hash::{DefaultHasher, Hash as _, Hasher as _},
17 sync::Arc,
18};
19
20use derive_builder::Builder;
21use swiftide_core::{
22 chat_completion::{
23 ChatCompletion, ChatCompletionRequest, ChatMessage, Tool, ToolCall, ToolOutput,
24 },
25 prompt::Prompt,
26 AgentContext, ToolBox,
27};
28use tracing::{debug, Instrument};
29
30#[derive(Clone, Builder)]
42pub struct Agent {
43 #[builder(default, setter(into))]
45 pub(crate) hooks: Vec<Hook>,
46 #[builder(
48 setter(custom),
49 default = Arc::new(DefaultContext::default()) as Arc<dyn AgentContext>
50 )]
51 pub(crate) context: Arc<dyn AgentContext>,
52 #[builder(default = Agent::default_tools(), setter(custom))]
54 pub(crate) tools: HashSet<Box<dyn Tool>>,
55
56 #[builder(default)]
60 pub(crate) toolboxes: Vec<Box<dyn ToolBox>>,
61
62 #[builder(setter(custom))]
64 pub(crate) llm: Box<dyn ChatCompletion>,
65
66 #[builder(setter(into, strip_option), default = Some(SystemPrompt::default().into()))]
86 pub(crate) system_prompt: Option<Prompt>,
87
88 #[builder(private, default = state::State::default())]
90 pub(crate) state: state::State,
91
92 #[builder(default, setter(strip_option))]
95 pub(crate) limit: Option<usize>,
96
97 #[builder(default = 3)]
109 pub(crate) tool_retry_limit: usize,
110
111 #[builder(private, default)]
114 pub(crate) tool_retries_counter: HashMap<u64, usize>,
115
116 #[builder(private, default)]
118 pub(crate) toolbox_tools: HashSet<Box<dyn Tool>>,
119}
120
121impl std::fmt::Debug for Agent {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 f.debug_struct("Agent")
124 .field(
126 "hooks",
127 &self
128 .hooks
129 .iter()
130 .map(std::string::ToString::to_string)
131 .collect::<Vec<_>>(),
132 )
133 .field(
134 "tools",
135 &self
136 .tools
137 .iter()
138 .map(swiftide_core::Tool::name)
139 .collect::<Vec<_>>(),
140 )
141 .field("llm", &"Box<dyn ChatCompletion>")
142 .field("state", &self.state)
143 .finish()
144 }
145}
146
147impl AgentBuilder {
148 pub fn context(&mut self, context: impl AgentContext + 'static) -> &mut AgentBuilder
150 where
151 Self: Clone,
152 {
153 self.context = Some(Arc::new(context) as Arc<dyn AgentContext>);
154 self
155 }
156
157 pub fn no_system_prompt(&mut self) -> &mut Self {
159 self.system_prompt = Some(None);
160
161 self
162 }
163
164 pub fn add_hook(&mut self, hook: Hook) -> &mut Self {
166 let hooks = self.hooks.get_or_insert_with(Vec::new);
167 hooks.push(hook);
168
169 self
170 }
171
172 pub fn before_all(&mut self, hook: impl BeforeAllFn + 'static) -> &mut Self {
175 self.add_hook(Hook::BeforeAll(Box::new(hook)))
176 }
177
178 pub fn on_start(&mut self, hook: impl OnStartFn + 'static) -> &mut Self {
182 self.add_hook(Hook::OnStart(Box::new(hook)))
183 }
184
185 pub fn before_completion(&mut self, hook: impl BeforeCompletionFn + 'static) -> &mut Self {
187 self.add_hook(Hook::BeforeCompletion(Box::new(hook)))
188 }
189
190 pub fn after_tool(&mut self, hook: impl AfterToolFn + 'static) -> &mut Self {
196 self.add_hook(Hook::AfterTool(Box::new(hook)))
197 }
198
199 pub fn before_tool(&mut self, hook: impl BeforeToolFn + 'static) -> &mut Self {
201 self.add_hook(Hook::BeforeTool(Box::new(hook)))
202 }
203
204 pub fn after_completion(&mut self, hook: impl AfterCompletionFn + 'static) -> &mut Self {
206 self.add_hook(Hook::AfterCompletion(Box::new(hook)))
207 }
208
209 pub fn after_each(&mut self, hook: impl AfterEachFn + 'static) -> &mut Self {
212 self.add_hook(Hook::AfterEach(Box::new(hook)))
213 }
214
215 pub fn on_new_message(&mut self, hook: impl MessageHookFn + 'static) -> &mut Self {
218 self.add_hook(Hook::OnNewMessage(Box::new(hook)))
219 }
220
221 pub fn on_stop(&mut self, hook: impl OnStopFn + 'static) -> &mut Self {
222 self.add_hook(Hook::OnStop(Box::new(hook)))
223 }
224
225 pub fn llm<LLM: ChatCompletion + Clone + 'static>(&mut self, llm: &LLM) -> &mut Self {
227 let boxed: Box<dyn ChatCompletion> = Box::new(llm.clone()) as Box<dyn ChatCompletion>;
228
229 self.llm = Some(boxed);
230 self
231 }
232
233 pub fn tools<TOOL, I: IntoIterator<Item = TOOL>>(&mut self, tools: I) -> &mut Self
238 where
239 TOOL: Into<Box<dyn Tool>>,
240 {
241 self.tools = Some(
242 tools
243 .into_iter()
244 .map(Into::into)
245 .chain(Agent::default_tools())
246 .collect(),
247 );
248 self
249 }
250
251 pub fn add_toolbox(&mut self, toolbox: impl ToolBox + 'static) -> &mut Self {
257 let toolboxes = self.toolboxes.get_or_insert_with(Vec::new);
258 toolboxes.push(Box::new(toolbox));
259
260 self
261 }
262}
263
264impl Agent {
265 pub fn builder() -> AgentBuilder {
267 AgentBuilder::default()
268 }
269}
270
271impl Agent {
272 fn default_tools() -> HashSet<Box<dyn Tool>> {
274 HashSet::from([Box::new(Stop::default()) as Box<dyn Tool>])
275 }
276
277 #[tracing::instrument(skip_all, name = "agent.query")]
280 pub async fn query(&mut self, query: impl Into<Prompt>) -> Result<(), AgentError> {
281 let query = query
282 .into()
283 .render()
284 .map_err(AgentError::FailedToRenderPrompt)?;
285 self.run_agent(Some(query), false).await
286 }
287
288 #[tracing::instrument(skip_all, name = "agent.query_once")]
290 pub async fn query_once(&mut self, query: impl Into<Prompt>) -> Result<(), AgentError> {
291 let query = query
292 .into()
293 .render()
294 .map_err(AgentError::FailedToRenderPrompt)?;
295 self.run_agent(Some(query), true).await
296 }
297
298 #[tracing::instrument(skip_all, name = "agent.run")]
301 pub async fn run(&mut self) -> Result<(), AgentError> {
302 self.run_agent(None, false).await
303 }
304
305 #[tracing::instrument(skip_all, name = "agent.run_once")]
308 pub async fn run_once(&mut self) -> Result<(), AgentError> {
309 self.run_agent(None, true).await
310 }
311
312 pub async fn history(&self) -> Vec<ChatMessage> {
314 self.context.history().await
315 }
316
317 async fn run_agent(
318 &mut self,
319 maybe_query: Option<String>,
320 just_once: bool,
321 ) -> Result<(), AgentError> {
322 if self.state.is_running() {
323 return Err(AgentError::AlreadyRunning);
324 }
325
326 if self.state.is_pending() {
327 if let Some(system_prompt) = &self.system_prompt {
328 self.context
329 .add_messages(vec![ChatMessage::System(
330 system_prompt
331 .render()
332 .map_err(AgentError::FailedToRenderSystemPrompt)?,
333 )])
334 .await;
335 }
336
337 invoke_hooks!(BeforeAll, self);
338
339 self.load_toolboxes().await?;
340 }
341
342 invoke_hooks!(OnStart, self);
343
344 self.state = state::State::Running;
345
346 if let Some(query) = maybe_query {
347 self.context.add_message(ChatMessage::User(query)).await;
348 }
349
350 let mut loop_counter = 0;
351
352 while let Some(messages) = self.context.next_completion().await {
353 if let Some(limit) = self.limit {
354 if loop_counter >= limit {
355 tracing::warn!("Agent loop limit reached");
356 break;
357 }
358 }
359 let result = self.run_completions(&messages).await;
360
361 if let Err(err) = result {
362 self.stop_with_error(&err).await;
363 tracing::error!(error = ?err, "Agent stopped with error {err}");
364 return Err(err);
365 }
366
367 if just_once || self.state.is_stopped() {
368 break;
369 }
370 loop_counter += 1;
371 }
372
373 self.stop(StopReason::NoNewMessages).await;
375
376 Ok(())
377 }
378
379 #[tracing::instrument(skip_all, err)]
380 async fn run_completions(&mut self, messages: &[ChatMessage]) -> Result<(), AgentError> {
381 debug!(
382 "Running completion for agent with {} messages",
383 messages.len()
384 );
385
386 let mut chat_completion_request = ChatCompletionRequest::builder()
387 .messages(messages)
388 .tools_spec(
389 self.tools
390 .iter()
391 .map(swiftide_core::Tool::tool_spec)
392 .collect::<HashSet<_>>(),
393 )
394 .build()
395 .map_err(AgentError::FailedToBuildRequest)?;
396
397 invoke_hooks!(BeforeCompletion, self, &mut chat_completion_request);
398
399 debug!(
400 "Calling LLM with the following new messages:\n {}",
401 self.context
402 .current_new_messages()
403 .await
404 .iter()
405 .map(ToString::to_string)
406 .collect::<Vec<_>>()
407 .join(",\n")
408 );
409
410 let mut response = self
411 .llm
412 .complete(&chat_completion_request)
413 .await
414 .map_err(AgentError::CompletionsFailed)?;
415
416 invoke_hooks!(AfterCompletion, self, &mut response);
417
418 self.add_message(ChatMessage::Assistant(
419 response.message,
420 response.tool_calls.clone(),
421 ))
422 .await?;
423
424 if let Some(tool_calls) = response.tool_calls {
425 self.invoke_tools(tool_calls).await?;
426 }
427
428 invoke_hooks!(AfterEach, self);
429
430 Ok(())
431 }
432
433 async fn invoke_tools(&mut self, tool_calls: Vec<ToolCall>) -> Result<(), AgentError> {
434 debug!("LLM returned tool calls: {:?}", tool_calls);
435
436 let mut handles = vec![];
437 for tool_call in tool_calls {
438 let Some(tool) = self.find_tool_by_name(tool_call.name()) else {
439 tracing::warn!("Tool {} not found", tool_call.name());
440 continue;
441 };
442 tracing::info!("Calling tool `{}`", tool_call.name());
443
444 let tool_args = tool_call.args().map(String::from);
445 let context: Arc<dyn AgentContext> = Arc::clone(&self.context);
446
447 invoke_hooks!(BeforeTool, self, &tool_call);
448
449 let tool_span = tracing::info_span!(
450 "tool",
451 "otel.name" = format!("tool.{}", tool.name().as_ref())
452 );
453
454 let handle = tokio::spawn(async move {
455 let tool_args = ArgPreprocessor::preprocess(tool_args.as_deref());
456 let output = tool.invoke(&*context, tool_args.as_deref()).await.map_err(|e| { tracing::error!(error = %e, "Failed tool call"); e })?;
457
458 tracing::debug!(output = output.to_string(), args = ?tool_args, tool_name = tool.name().as_ref(), "Completed tool call");
459
460 Ok(output)
461 }.instrument(tool_span.or_current()));
462
463 handles.push((handle, tool_call));
464 }
465
466 for (handle, tool_call) in handles {
467 let mut output = handle.await.map_err(AgentError::ToolFailedToJoin)?;
468
469 invoke_hooks!(AfterTool, self, &tool_call, &mut output);
470
471 if let Err(error) = output {
472 let stop = self.tool_calls_over_limit(&tool_call);
473 if stop {
474 tracing::error!(
475 ?error,
476 "Tool call failed, retry limit reached, stopping agent: {error}",
477 );
478 } else {
479 tracing::warn!(
480 ?error,
481 tool_call = ?tool_call,
482 "Tool call failed, retrying",
483 );
484 }
485 self.add_message(ChatMessage::ToolOutput(
486 tool_call.clone(),
487 ToolOutput::Fail(error.to_string()),
488 ))
489 .await?;
490 if stop {
491 self.stop(StopReason::ToolCallsOverLimit(tool_call)).await;
492 return Err(error.into());
493 }
494 continue;
495 }
496
497 let output = output?;
498 self.handle_control_tools(&tool_call, &output).await;
499 self.add_message(ChatMessage::ToolOutput(tool_call, output))
500 .await?;
501 }
502
503 Ok(())
504 }
505
506 fn hooks_by_type(&self, hook_type: HookTypes) -> Vec<&Hook> {
507 self.hooks
508 .iter()
509 .filter(|h| hook_type == (*h).into())
510 .collect()
511 }
512
513 fn find_tool_by_name(&self, tool_name: &str) -> Option<Box<dyn Tool>> {
514 self.tools
515 .iter()
516 .find(|tool| tool.name() == tool_name)
517 .cloned()
518 }
519
520 async fn handle_control_tools(&mut self, tool_call: &ToolCall, output: &ToolOutput) {
522 if let ToolOutput::Stop = output {
523 tracing::warn!("Stop tool called, stopping agent");
524 self.stop(StopReason::RequestedByTool(tool_call.clone()))
525 .await;
526 }
527 }
528
529 fn tool_calls_over_limit(&mut self, tool_call: &ToolCall) -> bool {
530 let mut s = DefaultHasher::new();
531 tool_call.hash(&mut s);
532 let hash = s.finish();
533
534 if let Some(retries) = self.tool_retries_counter.get_mut(&hash) {
535 let val = *retries >= self.tool_retry_limit;
536 *retries += 1;
537 val
538 } else {
539 self.tool_retries_counter.insert(hash, 1);
540 false
541 }
542 }
543
544 #[tracing::instrument(skip_all, fields(message = message.to_string()))]
550 pub async fn add_message(&self, mut message: ChatMessage) -> Result<(), AgentError> {
551 invoke_hooks!(OnNewMessage, self, &mut message);
552
553 self.context.add_message(message).await;
554 Ok(())
555 }
556
557 pub async fn stop(&mut self, reason: impl Into<StopReason>) {
559 if self.state.is_stopped() {
560 return;
561 }
562 let reason = reason.into();
563 invoke_hooks!(OnStop, self, reason.clone(), None);
564
565 self.state = state::State::Stopped(reason);
566 }
567
568 pub async fn stop_with_error(&mut self, error: &AgentError) {
569 if self.state.is_stopped() {
570 return;
571 }
572 invoke_hooks!(OnStop, self, StopReason::Error, Some(error));
573
574 self.state = state::State::Stopped(StopReason::Error);
575 }
576
577 pub fn context(&self) -> &dyn AgentContext {
579 &self.context
580 }
581
582 pub fn is_running(&self) -> bool {
584 self.state.is_running()
585 }
586
587 pub fn is_stopped(&self) -> bool {
589 self.state.is_stopped()
590 }
591
592 pub fn is_pending(&self) -> bool {
594 self.state.is_pending()
595 }
596
597 fn tools(&self) -> &HashSet<Box<dyn Tool>> {
599 &self.tools
600 }
601
602 async fn load_toolboxes(&mut self) -> Result<(), AgentError> {
603 for toolbox in &self.toolboxes {
604 let tools = toolbox
605 .available_tools()
606 .await
607 .map_err(AgentError::ToolBoxFailedToLoad)?;
608 self.toolbox_tools.extend(tools);
609 }
610
611 self.tools.extend(self.toolbox_tools.clone());
612
613 Ok(())
614 }
615}
616
617#[cfg(test)]
618mod tests {
619
620 use serde::ser::Error;
621 use swiftide_core::chat_completion::errors::ToolError;
622 use swiftide_core::chat_completion::{ChatCompletionResponse, ToolCall};
623 use swiftide_core::test_utils::MockChatCompletion;
624
625 use super::*;
626 use crate::{
627 assistant, chat_request, chat_response, summary, system, tool_failed, tool_output, user,
628 };
629
630 use crate::test_utils::{MockHook, MockTool};
631
632 #[test_log::test(tokio::test)]
633 async fn test_agent_builder_defaults() {
634 let mock_llm = MockChatCompletion::new();
636
637 let agent = Agent::builder().llm(&mock_llm).build().unwrap();
639
640 assert!(agent.find_tool_by_name("stop").is_some());
644
645 let agent = Agent::builder()
647 .tools([Stop::default(), Stop::default()])
648 .llm(&mock_llm)
649 .build()
650 .unwrap();
651
652 assert_eq!(agent.tools.len(), 1);
653
654 let agent = Agent::builder()
656 .tools([MockTool::new("mock_tool")])
657 .llm(&mock_llm)
658 .build()
659 .unwrap();
660
661 assert_eq!(agent.tools.len(), 2);
662 assert!(agent.find_tool_by_name("mock_tool").is_some());
663 assert!(agent.find_tool_by_name("stop").is_some());
664
665 assert!(agent.context().history().await.is_empty());
666 }
667
668 #[test_log::test(tokio::test)]
669 async fn test_agent_tool_calling_loop() {
670 let prompt = "Write a poem";
671 let mock_llm = MockChatCompletion::new();
672 let mock_tool = MockTool::new("mock_tool");
673
674 let chat_request = chat_request! {
675 user!("Write a poem");
676
677 tools = [mock_tool.clone()]
678 };
679
680 let mock_tool_response = chat_response! {
681 "Roses are red";
682 tool_calls = ["mock_tool"]
683
684 };
685
686 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
687
688 let chat_request = chat_request! {
689 user!("Write a poem"),
690 assistant!("Roses are red", ["mock_tool"]),
691 tool_output!("mock_tool", "Great!");
692
693 tools = [mock_tool.clone()]
694 };
695
696 let stop_response = chat_response! {
697 "Roses are red";
698 tool_calls = ["stop"]
699 };
700
701 mock_llm.expect_complete(chat_request, Ok(stop_response));
702 mock_tool.expect_invoke_ok("Great!".into(), None);
703
704 let mut agent = Agent::builder()
705 .tools([mock_tool])
706 .llm(&mock_llm)
707 .no_system_prompt()
708 .build()
709 .unwrap();
710
711 agent.query(prompt).await.unwrap();
712 }
713
714 #[test_log::test(tokio::test)]
715 async fn test_agent_tool_run_once() {
716 let prompt = "Write a poem";
717 let mock_llm = MockChatCompletion::new();
718 let mock_tool = MockTool::default();
719
720 let chat_request = chat_request! {
721 system!("My system prompt"),
722 user!("Write a poem");
723
724 tools = [mock_tool.clone()]
725 };
726
727 let mock_tool_response = chat_response! {
728 "Roses are red";
729 tool_calls = ["mock_tool"]
730
731 };
732
733 mock_tool.expect_invoke_ok("Great!".into(), None);
734 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
735
736 let mut agent = Agent::builder()
737 .tools([mock_tool])
738 .system_prompt("My system prompt")
739 .llm(&mock_llm)
740 .build()
741 .unwrap();
742
743 agent.query_once(prompt).await.unwrap();
744 }
745
746 #[test_log::test(tokio::test)]
747 async fn test_agent_tool_via_toolbox_run_once() {
748 let prompt = "Write a poem";
749 let mock_llm = MockChatCompletion::new();
750 let mock_tool = MockTool::default();
751
752 let chat_request = chat_request! {
753 system!("My system prompt"),
754 user!("Write a poem");
755
756 tools = [mock_tool.clone()]
757 };
758
759 let mock_tool_response = chat_response! {
760 "Roses are red";
761 tool_calls = ["mock_tool"]
762
763 };
764
765 mock_tool.expect_invoke_ok("Great!".into(), None);
766 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
767
768 let mut agent = Agent::builder()
769 .add_toolbox(vec![mock_tool.boxed()])
770 .system_prompt("My system prompt")
771 .llm(&mock_llm)
772 .build()
773 .unwrap();
774
775 agent.query_once(prompt).await.unwrap();
776 }
777
778 #[test_log::test(tokio::test(flavor = "multi_thread"))]
779 async fn test_multiple_tool_calls() {
780 let prompt = "Write a poem";
781 let mock_llm = MockChatCompletion::new();
782 let mock_tool = MockTool::new("mock_tool1");
783 let mock_tool2 = MockTool::new("mock_tool2");
784
785 let chat_request = chat_request! {
786 system!("My system prompt"),
787 user!("Write a poem");
788
789
790
791 tools = [mock_tool.clone(), mock_tool2.clone()]
792 };
793
794 let mock_tool_response = chat_response! {
795 "Roses are red";
796
797 tool_calls = ["mock_tool1", "mock_tool2"]
798
799 };
800
801 dbg!(&chat_request);
802 mock_tool.expect_invoke_ok("Great!".into(), None);
803 mock_tool2.expect_invoke_ok("Great!".into(), None);
804 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
805
806 let chat_request = chat_request! {
807 system!("My system prompt"),
808 user!("Write a poem"),
809 assistant!("Roses are red", ["mock_tool1", "mock_tool2"]),
810 tool_output!("mock_tool1", "Great!"),
811 tool_output!("mock_tool2", "Great!");
812
813 tools = [mock_tool.clone(), mock_tool2.clone()]
814 };
815
816 let mock_tool_response = chat_response! {
817 "Ok!";
818
819 tool_calls = ["stop"]
820
821 };
822
823 mock_llm.expect_complete(chat_request, Ok(mock_tool_response));
824
825 let mut agent = Agent::builder()
826 .tools([mock_tool, mock_tool2])
827 .system_prompt("My system prompt")
828 .llm(&mock_llm)
829 .build()
830 .unwrap();
831
832 agent.query(prompt).await.unwrap();
833 }
834
835 #[test_log::test(tokio::test)]
836 async fn test_agent_state_machine() {
837 let prompt = "Write a poem";
838 let mock_llm = MockChatCompletion::new();
839
840 let chat_request = chat_request! {
841 user!("Write a poem");
842 tools = []
843 };
844 let mock_tool_response = chat_response! {
845 "Roses are red";
846 tool_calls = []
847 };
848
849 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
850 let mut agent = Agent::builder()
851 .llm(&mock_llm)
852 .no_system_prompt()
853 .build()
854 .unwrap();
855
856 assert!(agent.state.is_pending());
858 agent.query_once(prompt).await.unwrap();
859
860 assert!(agent.state.is_stopped());
862 }
863
864 #[test_log::test(tokio::test)]
865 async fn test_summary() {
866 let prompt = "Write a poem";
867 let mock_llm = MockChatCompletion::new();
868
869 let mock_tool_response = chat_response! {
870 "Roses are red";
871 tool_calls = []
872
873 };
874
875 let expected_chat_request = chat_request! {
876 system!("My system prompt"),
877 user!("Write a poem");
878
879 tools = []
880 };
881
882 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
883
884 let mut agent = Agent::builder()
885 .system_prompt("My system prompt")
886 .llm(&mock_llm)
887 .build()
888 .unwrap();
889
890 agent.query_once(prompt).await.unwrap();
891
892 agent
893 .context
894 .add_message(ChatMessage::new_summary("Summary"))
895 .await;
896
897 let expected_chat_request = chat_request! {
898 system!("My system prompt"),
899 summary!("Summary"),
900 user!("Write another poem");
901 tools = []
902 };
903 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
904
905 agent.query_once("Write another poem").await.unwrap();
906
907 agent
908 .context
909 .add_message(ChatMessage::new_summary("Summary 2"))
910 .await;
911
912 let expected_chat_request = chat_request! {
913 system!("My system prompt"),
914 summary!("Summary 2"),
915 user!("Write a third poem");
916 tools = []
917 };
918 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response));
919
920 agent.query_once("Write a third poem").await.unwrap();
921 }
922
923 #[test_log::test(tokio::test)]
924 async fn test_agent_hooks() {
925 let mock_before_all = MockHook::new("before_all").expect_calls(1).to_owned();
926 let mock_on_start_fn = MockHook::new("on_start").expect_calls(1).to_owned();
927 let mock_before_completion = MockHook::new("before_completion")
928 .expect_calls(2)
929 .to_owned();
930 let mock_after_completion = MockHook::new("after_completion").expect_calls(2).to_owned();
931 let mock_after_each = MockHook::new("after_each").expect_calls(2).to_owned();
932 let mock_on_message = MockHook::new("on_message").expect_calls(4).to_owned();
933 let mock_on_stop = MockHook::new("on_stop").expect_calls(1).to_owned();
934
935 let mock_before_tool = MockHook::new("before_tool").expect_calls(2).to_owned();
937 let mock_after_tool = MockHook::new("after_tool").expect_calls(2).to_owned();
938
939 let prompt = "Write a poem";
940 let mock_llm = MockChatCompletion::new();
941 let mock_tool = MockTool::default();
942
943 let chat_request = chat_request! {
944 user!("Write a poem");
945
946 tools = [mock_tool.clone()]
947 };
948
949 let mock_tool_response = chat_response! {
950 "Roses are red";
951 tool_calls = ["mock_tool"]
952
953 };
954
955 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
956
957 let chat_request = chat_request! {
958 user!("Write a poem"),
959 assistant!("Roses are red", ["mock_tool"]),
960 tool_output!("mock_tool", "Great!");
961
962 tools = [mock_tool.clone()]
963 };
964
965 let stop_response = chat_response! {
966 "Roses are red";
967 tool_calls = ["stop"]
968 };
969
970 mock_llm.expect_complete(chat_request, Ok(stop_response));
971 mock_tool.expect_invoke_ok("Great!".into(), None);
972
973 let mut agent = Agent::builder()
974 .tools([mock_tool])
975 .llm(&mock_llm)
976 .no_system_prompt()
977 .before_all(mock_before_all.hook_fn())
978 .on_start(mock_on_start_fn.on_start_fn())
979 .before_completion(mock_before_completion.before_completion_fn())
980 .before_tool(mock_before_tool.before_tool_fn())
981 .after_completion(mock_after_completion.after_completion_fn())
982 .after_tool(mock_after_tool.after_tool_fn())
983 .after_each(mock_after_each.hook_fn())
984 .on_new_message(mock_on_message.message_hook_fn())
985 .on_stop(mock_on_stop.stop_hook_fn())
986 .build()
987 .unwrap();
988
989 agent.query(prompt).await.unwrap();
990 }
991
992 #[test_log::test(tokio::test)]
993 async fn test_agent_loop_limit() {
994 let prompt = "Generate content"; let mock_llm = MockChatCompletion::new();
996 let mock_tool = MockTool::new("mock_tool");
997
998 let chat_request = chat_request! {
999 user!(prompt);
1000 tools = [mock_tool.clone()]
1001 };
1002 mock_tool.expect_invoke_ok("Great!".into(), None);
1003
1004 let mock_tool_response = chat_response! {
1005 "Some response";
1006 tool_calls = ["mock_tool"]
1007 };
1008
1009 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response.clone()));
1011
1012 let stop_response = chat_response! {
1014 "Final response";
1015 tool_calls = ["stop"]
1016 };
1017
1018 mock_llm.expect_complete(chat_request, Ok(stop_response));
1019
1020 let mut agent = Agent::builder()
1021 .tools([mock_tool])
1022 .llm(&mock_llm)
1023 .no_system_prompt()
1024 .limit(1) .build()
1026 .unwrap();
1027
1028 agent.query(prompt).await.unwrap();
1030
1031 let remaining = mock_llm.expectations.lock().unwrap().pop();
1033 assert!(remaining.is_some());
1034
1035 assert!(agent.is_stopped());
1037 }
1038
1039 #[test_log::test(tokio::test)]
1040 async fn test_tool_retry_mechanism() {
1041 let prompt = "Execute my tool";
1042 let mock_llm = MockChatCompletion::new();
1043 let mock_tool = MockTool::new("retry_tool");
1044
1045 mock_tool.expect_invoke(
1048 Err(ToolError::WrongArguments(serde_json::Error::custom(
1049 "missing `query`",
1050 ))),
1051 None,
1052 );
1053 mock_tool.expect_invoke(
1054 Err(ToolError::WrongArguments(serde_json::Error::custom(
1055 "missing `query`",
1056 ))),
1057 None,
1058 );
1059
1060 let chat_request = chat_request! {
1061 user!(prompt);
1062 tools = [mock_tool.clone()]
1063 };
1064 let retry_response = chat_response! {
1065 "First failing attempt";
1066 tool_calls = ["retry_tool"]
1067 };
1068 mock_llm.expect_complete(chat_request.clone(), Ok(retry_response));
1069
1070 let chat_request = chat_request! {
1071 user!(prompt),
1072 assistant!("First failing attempt", ["retry_tool"]),
1073 tool_failed!("retry_tool", "arguments for tool failed to parse: missing `query`");
1074
1075 tools = [mock_tool.clone()]
1076 };
1077 let will_fail_response = chat_response! {
1078 "Finished execution";
1079 tool_calls = ["retry_tool"]
1080 };
1081 mock_llm.expect_complete(chat_request.clone(), Ok(will_fail_response));
1082
1083 let mut agent = Agent::builder()
1084 .tools([mock_tool])
1085 .llm(&mock_llm)
1086 .no_system_prompt()
1087 .tool_retry_limit(1) .build()
1089 .unwrap();
1090
1091 let result = agent.query(prompt).await;
1093
1094 assert!(result.is_err());
1095 assert!(result.unwrap_err().to_string().contains("missing `query`"));
1096 assert!(agent.is_stopped());
1097 }
1098
1099 #[test_log::test(tokio::test)]
1100 async fn test_recovering_agent_existing_history() {
1101 let prompt = "Write a poem";
1103 let mock_llm = MockChatCompletion::new();
1104 let mock_tool = MockTool::new("mock_tool");
1105
1106 let chat_request = chat_request! {
1107 user!("Write a poem");
1108
1109 tools = [mock_tool.clone()]
1110 };
1111
1112 let mock_tool_response = chat_response! {
1113 "Roses are red";
1114 tool_calls = ["mock_tool"]
1115
1116 };
1117
1118 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1119
1120 let chat_request = chat_request! {
1121 user!("Write a poem"),
1122 assistant!("Roses are red", ["mock_tool"]),
1123 tool_output!("mock_tool", "Great!");
1124
1125 tools = [mock_tool.clone()]
1126 };
1127
1128 let stop_response = chat_response! {
1129 "Roses are red";
1130 tool_calls = ["stop"]
1131 };
1132
1133 mock_llm.expect_complete(chat_request, Ok(stop_response));
1134 mock_tool.expect_invoke_ok("Great!".into(), None);
1135
1136 let mut agent = Agent::builder()
1137 .tools([mock_tool.clone()])
1138 .llm(&mock_llm)
1139 .no_system_prompt()
1140 .build()
1141 .unwrap();
1142
1143 agent.query(prompt).await.unwrap();
1144
1145 let history = agent.history().await;
1147
1148 let serialized = serde_json::to_string(&history).unwrap();
1150
1151 let history: Vec<ChatMessage> = serde_json::from_str(&serialized).unwrap();
1153
1154 let context = DefaultContext::default()
1156 .with_message_history(history)
1157 .to_owned();
1158
1159 let expected_chat_request = chat_request! {
1160 user!("Write a poem"),
1161 assistant!("Roses are red", ["mock_tool"]),
1162 tool_output!("mock_tool", "Great!"),
1163 assistant!("Roses are red", ["stop"]),
1164 tool_output!("stop", ToolOutput::Stop),
1165 user!("Try again!");
1166
1167 tools = [mock_tool.clone()]
1168 };
1169
1170 let stop_response = chat_response! {
1171 "Really stopping now";
1172 tool_calls = ["stop"]
1173 };
1174
1175 mock_llm.expect_complete(expected_chat_request, Ok(stop_response));
1176
1177 let mut agent = Agent::builder()
1178 .context(context)
1179 .tools([mock_tool])
1180 .llm(&mock_llm)
1181 .no_system_prompt()
1182 .build()
1183 .unwrap();
1184
1185 agent.query_once("Try again!").await.unwrap();
1186 }
1187}