1#![allow(dead_code)]
2use crate::{
3 default_context::DefaultContext,
4 errors::AgentError,
5 hooks::{
6 AfterCompletionFn, AfterEachFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn,
7 Hook, HookTypes, MessageHookFn, OnStartFn, OnStopFn, OnStreamFn,
8 },
9 invoke_hooks,
10 state::{self, StopReason},
11 system_prompt::SystemPrompt,
12 tools::{arg_preprocessor::ArgPreprocessor, control::Stop},
13};
14use std::{
15 collections::{HashMap, HashSet},
16 hash::{DefaultHasher, Hash as _, Hasher as _},
17 sync::Arc,
18};
19
20use derive_builder::Builder;
21use futures_util::stream::StreamExt;
22use swiftide_core::{
23 AgentContext, ToolBox,
24 chat_completion::{
25 ChatCompletion, ChatCompletionRequest, ChatMessage, Tool, ToolCall, ToolOutput,
26 },
27 prompt::Prompt,
28};
29use tracing::{Instrument, debug};
30
31#[derive(Builder)]
46pub struct Agent {
47 #[builder(default, setter(into))]
49 pub(crate) hooks: Vec<Hook>,
50 #[builder(
52 setter(custom),
53 default = Arc::new(DefaultContext::default()) as Arc<dyn AgentContext>
54 )]
55 pub(crate) context: Arc<dyn AgentContext>,
56 #[builder(default = Agent::default_tools(), setter(custom))]
58 pub(crate) tools: HashSet<Box<dyn Tool>>,
59
60 #[builder(default)]
64 pub(crate) toolboxes: Vec<Box<dyn ToolBox>>,
65
66 #[builder(setter(custom))]
68 pub(crate) llm: Box<dyn ChatCompletion>,
69
70 #[builder(setter(into, strip_option), default = Some(SystemPrompt::default()))]
92 pub(crate) system_prompt: Option<SystemPrompt>,
93
94 #[builder(private, default = state::State::default())]
96 pub(crate) state: state::State,
97
98 #[builder(default, setter(strip_option))]
101 pub(crate) limit: Option<usize>,
102
103 #[builder(default = 3)]
115 pub(crate) tool_retry_limit: usize,
116
117 #[builder(default)]
119 pub(crate) streaming: bool,
120
121 #[builder(private, default)]
124 pub(crate) clear_default_tools: bool,
125
126 #[builder(private, default)]
129 pub(crate) tool_retries_counter: HashMap<u64, usize>,
130
131 #[builder(default = "unnamed_agent".into(), setter(into))]
133 pub(crate) name: String,
134}
135
136impl Clone for Agent {
137 fn clone(&self) -> Self {
138 Agent {
139 hooks: self.hooks.clone(),
140 context: Arc::new(self.context.clone()),
141 tools: self.tools.clone(),
142 toolboxes: self.toolboxes.clone(),
143 llm: self.llm.clone(),
144 system_prompt: self.system_prompt.clone(),
145 state: self.state.clone(),
146 limit: self.limit,
147 tool_retry_limit: self.tool_retry_limit,
148 tool_retries_counter: HashMap::new(),
149 streaming: self.streaming,
150 name: self.name.clone(),
151 clear_default_tools: self.clear_default_tools,
152 }
153 }
154}
155
156impl std::fmt::Debug for Agent {
157 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158 f.debug_struct("Agent")
159 .field("name", &self.name)
160 .field(
162 "hooks",
163 &self
164 .hooks
165 .iter()
166 .map(std::string::ToString::to_string)
167 .collect::<Vec<_>>(),
168 )
169 .field(
170 "tools",
171 &self
172 .tools
173 .iter()
174 .map(swiftide_core::Tool::name)
175 .collect::<Vec<_>>(),
176 )
177 .field("llm", &"Box<dyn ChatCompletion>")
178 .field("state", &self.state)
179 .finish()
180 }
181}
182
183impl AgentBuilder {
184 pub fn context(&mut self, context: impl AgentContext + 'static) -> &mut AgentBuilder
186 where
187 Self: Clone,
188 {
189 self.context = Some(Arc::new(context) as Arc<dyn AgentContext>);
190 self
191 }
192
193 pub fn system_prompt_mut(&mut self) -> Option<&mut SystemPrompt> {
195 self.system_prompt.as_mut().and_then(Option::as_mut)
196 }
197
198 pub fn no_system_prompt(&mut self) -> &mut Self {
200 self.system_prompt = Some(None);
201
202 self
203 }
204
205 pub fn add_hook(&mut self, hook: Hook) -> &mut Self {
207 let hooks = self.hooks.get_or_insert_with(Vec::new);
208 hooks.push(hook);
209
210 self
211 }
212
213 pub fn add_tool(&mut self, tool: impl Tool + 'static) -> &mut Self {
215 let tools = self.tools.get_or_insert_with(HashSet::new);
216 if let Some(tool) = tools.replace(tool.boxed()) {
217 tracing::debug!("Tool {} already exists, replacing", tool.name());
218 }
219
220 self
221 }
222
223 pub fn before_all(&mut self, hook: impl BeforeAllFn + 'static) -> &mut Self {
226 self.add_hook(Hook::BeforeAll(Box::new(hook)))
227 }
228
229 pub fn on_start(&mut self, hook: impl OnStartFn + 'static) -> &mut Self {
233 self.add_hook(Hook::OnStart(Box::new(hook)))
234 }
235
236 pub fn on_stream(&mut self, hook: impl OnStreamFn + 'static) -> &mut Self {
243 self.streaming = Some(true);
244 self.add_hook(Hook::OnStream(Box::new(hook)))
245 }
246
247 pub fn before_completion(&mut self, hook: impl BeforeCompletionFn + 'static) -> &mut Self {
249 self.add_hook(Hook::BeforeCompletion(Box::new(hook)))
250 }
251
252 pub fn after_tool(&mut self, hook: impl AfterToolFn + 'static) -> &mut Self {
258 self.add_hook(Hook::AfterTool(Box::new(hook)))
259 }
260
261 pub fn before_tool(&mut self, hook: impl BeforeToolFn + 'static) -> &mut Self {
263 self.add_hook(Hook::BeforeTool(Box::new(hook)))
264 }
265
266 pub fn after_completion(&mut self, hook: impl AfterCompletionFn + 'static) -> &mut Self {
268 self.add_hook(Hook::AfterCompletion(Box::new(hook)))
269 }
270
271 pub fn after_each(&mut self, hook: impl AfterEachFn + 'static) -> &mut Self {
274 self.add_hook(Hook::AfterEach(Box::new(hook)))
275 }
276
277 pub fn on_new_message(&mut self, hook: impl MessageHookFn + 'static) -> &mut Self {
280 self.add_hook(Hook::OnNewMessage(Box::new(hook)))
281 }
282
283 pub fn on_stop(&mut self, hook: impl OnStopFn + 'static) -> &mut Self {
284 self.add_hook(Hook::OnStop(Box::new(hook)))
285 }
286
287 pub fn llm<LLM: ChatCompletion + Clone + 'static>(&mut self, llm: &LLM) -> &mut Self {
289 let boxed: Box<dyn ChatCompletion> = Box::new(llm.clone()) as Box<dyn ChatCompletion>;
290
291 self.llm = Some(boxed);
292 self
293 }
294
295 pub fn without_default_stop_tool(&mut self) -> &mut Self {
300 self.clear_default_tools = Some(true);
301 self
302 }
303
304 fn builder_default_tools(&self) -> HashSet<Box<dyn Tool>> {
305 if self.clear_default_tools.is_some_and(|b| b) {
306 HashSet::new()
307 } else {
308 Agent::default_tools()
309 }
310 }
311
312 pub fn tools<TOOL, I: IntoIterator<Item = TOOL>>(&mut self, tools: I) -> &mut Self
317 where
318 TOOL: Into<Box<dyn Tool>>,
319 {
320 self.tools = Some(
321 self.builder_default_tools()
322 .into_iter()
323 .chain(tools.into_iter().map(Into::into))
324 .collect(),
325 );
326 self
327 }
328
329 pub fn add_toolbox(&mut self, toolbox: impl ToolBox + 'static) -> &mut Self {
335 let toolboxes = self.toolboxes.get_or_insert_with(Vec::new);
336 toolboxes.push(Box::new(toolbox));
337
338 self
339 }
340}
341
342impl Agent {
343 pub fn builder() -> AgentBuilder {
345 AgentBuilder::default()
346 .tools(Agent::default_tools())
347 .to_owned()
348 }
349
350 pub fn name(&self) -> &str {
352 &self.name
353 }
354
355 pub fn default_tools() -> HashSet<Box<dyn Tool>> {
358 HashSet::from([Stop::default().boxed()])
359 }
360
361 #[tracing::instrument(skip_all, name = "agent.query", err)]
368 pub async fn query(&mut self, query: impl Into<Prompt>) -> Result<(), AgentError> {
369 let query = query
370 .into()
371 .render()
372 .map_err(AgentError::FailedToRenderPrompt)?;
373 self.run_agent(Some(query), false).await
374 }
375
376 pub fn add_tool(&mut self, tool: Box<dyn Tool>) {
378 if let Some(tool) = self.tools.replace(tool) {
379 tracing::debug!("Tool {} already exists, replacing", tool.name());
380 }
381 }
382
383 pub fn tools_mut(&mut self) -> &mut HashSet<Box<dyn Tool>> {
388 &mut self.tools
389 }
390
391 #[tracing::instrument(skip_all, name = "agent.query_once", err)]
397 pub async fn query_once(&mut self, query: impl Into<Prompt>) -> Result<(), AgentError> {
398 self.run_agent(Some(query), true).await
399 }
400
401 #[tracing::instrument(skip_all, name = "agent.run", err)]
408 pub async fn run(&mut self) -> Result<(), AgentError> {
409 self.run_agent(None::<Prompt>, false).await
410 }
411
412 #[tracing::instrument(skip_all, name = "agent.run_once", err)]
419 pub async fn run_once(&mut self) -> Result<(), AgentError> {
420 self.run_agent(None::<Prompt>, true).await
421 }
422
423 pub async fn history(&self) -> Result<Vec<ChatMessage>, AgentError> {
430 self.context
431 .history()
432 .await
433 .map_err(AgentError::MessageHistoryError)
434 }
435
436 pub(crate) async fn run_agent(
437 &mut self,
438 maybe_query: Option<impl Into<Prompt>>,
439 just_once: bool,
440 ) -> Result<(), AgentError> {
441 let maybe_query = maybe_query
442 .map(|q| q.into().render())
443 .transpose()
444 .map_err(AgentError::FailedToRenderPrompt)?;
445 if self.state.is_running() {
446 return Err(AgentError::AlreadyRunning);
447 }
448
449 if self.state.is_pending() {
450 if let Some(system_prompt) = &self.system_prompt {
451 self.context
452 .add_messages(vec![ChatMessage::System(
453 system_prompt
454 .to_prompt()
455 .render()
456 .map_err(AgentError::FailedToRenderSystemPrompt)?,
457 )])
458 .await
459 .map_err(AgentError::MessageHistoryError)?;
460 }
461
462 invoke_hooks!(BeforeAll, self);
463
464 self.load_toolboxes().await?;
465 }
466
467 if let Some(query) = maybe_query {
468 if cfg!(feature = "langfuse") {
469 debug!(langfuse.input = query);
470 }
471 self.context
472 .add_message(ChatMessage::User(query))
473 .await
474 .map_err(AgentError::MessageHistoryError)?;
475 }
476
477 invoke_hooks!(OnStart, self);
478
479 self.state = state::State::Running;
480
481 let mut loop_counter = 0;
482
483 while let Some(messages) = self
484 .context
485 .next_completion()
486 .await
487 .map_err(AgentError::MessageHistoryError)?
488 {
489 if let Some(limit) = self.limit
490 && loop_counter >= limit
491 {
492 tracing::warn!("Agent loop limit reached");
493 break;
494 }
495
496 if let Some(&ChatMessage::Assistant(.., Some(ref tool_calls))) =
499 maybe_tool_call_without_output(&messages)
500 {
501 tracing::debug!("Uncompleted tool calls found; invoking tools");
502 self.invoke_tools(tool_calls).await?;
503 continue;
505 }
506
507 let result = self.step(&messages, loop_counter).await;
508
509 if let Err(err) = result {
510 self.stop_with_error(&err).await;
511 tracing::error!(error = ?err, "Agent stopped with error {err}");
512 return Err(err);
513 }
514
515 if just_once || self.state.is_stopped() {
516 break;
517 }
518 loop_counter += 1;
519 }
520
521 self.stop(StopReason::NoNewMessages).await;
523
524 Ok(())
525 }
526
527 #[tracing::instrument(skip(self, messages), err, fields(otel.name))]
528 async fn step(
529 &mut self,
530 messages: &[ChatMessage],
531 step_count: usize,
532 ) -> Result<(), AgentError> {
533 tracing::Span::current().record("otel.name", format!("step-{step_count}"));
534
535 debug!(
536 tools = ?self
537 .tools
538 .iter()
539 .map(|t| t.name())
540 .collect::<Vec<_>>()
541 ,
542 "Running completion for agent with {} new messages",
543 messages.len()
544 );
545
546 let mut chat_completion_request = ChatCompletionRequest::builder()
547 .messages(messages.to_vec())
548 .tool_specs(self.tools.iter().map(swiftide_core::Tool::tool_spec))
549 .build()
550 .map_err(AgentError::FailedToBuildRequest)?;
551
552 invoke_hooks!(BeforeCompletion, self, &mut chat_completion_request);
553
554 debug!(
555 "Calling LLM with the following new messages:\n {}",
556 self.context
557 .current_new_messages()
558 .await
559 .map_err(AgentError::MessageHistoryError)?
560 .iter()
561 .map(ToString::to_string)
562 .collect::<Vec<_>>()
563 .join(",\n")
564 );
565
566 let mut response = if self.streaming {
567 let mut last_response = None;
568 let mut stream = self.llm.complete_stream(&chat_completion_request).await;
569
570 while let Some(response) = stream.next().await {
571 let response = response.map_err(AgentError::CompletionsFailed)?;
572 invoke_hooks!(OnStream, self, &response);
573 last_response = Some(response);
574 }
575 tracing::trace!(?last_response, "Streaming completed");
576 last_response.ok_or(AgentError::EmptyStream)
577 } else {
578 self.llm
579 .complete(&chat_completion_request)
580 .await
581 .map_err(AgentError::CompletionsFailed)
582 }?;
583
584 response
587 .tool_calls
588 .as_deref_mut()
589 .map(ArgPreprocessor::preprocess_tool_calls);
590
591 invoke_hooks!(AfterCompletion, self, &mut response);
592
593 self.add_message(ChatMessage::Assistant(
594 response.message,
595 response.tool_calls.clone(),
596 ))
597 .await?;
598
599 if let Some(tool_calls) = response.tool_calls {
600 self.invoke_tools(&tool_calls).await?;
601 }
602
603 invoke_hooks!(AfterEach, self);
604
605 Ok(())
606 }
607
608 async fn invoke_tools(&mut self, tool_calls: &[ToolCall]) -> Result<(), AgentError> {
609 tracing::debug!("LLM returned tool calls: {:?}", tool_calls);
610
611 let mut handles = vec![];
612 for tool_call in tool_calls {
613 let Some(tool) = self.find_tool_by_name(tool_call.name()) else {
614 tracing::warn!("Tool {} not found", tool_call.name());
615 continue;
616 };
617 tracing::info!("Calling tool `{}`", tool_call.name());
618
619 let context: Arc<dyn AgentContext> = Arc::clone(&self.context);
621
622 invoke_hooks!(BeforeTool, self, &tool_call);
623
624 let tool_span = tracing::info_span!(
625 "tool",
626 "otel.name" = format!("tool.{}", tool.name().as_ref()),
627 );
628
629 let handle_tool_call = tool_call.clone();
630 let handle = tokio::spawn(async move {
631 let handle_tool_call = handle_tool_call;
632 let output = tool.invoke(&*context, &handle_tool_call)
633 .await?;
634
635 if cfg!(feature = "langfuse") {
636 tracing::debug!(
637 langfuse.output = %output,
638 langfuse.input = handle_tool_call.args(),
639 tool_name = tool.name().as_ref(),
640 );
641 } else {
642 tracing::debug!(output = output.to_string(), args = ?handle_tool_call.args(), tool_name = tool.name().as_ref(), "Completed tool call");
643 }
644
645 Ok(output)
646 }.instrument(tool_span.or_current()));
647
648 handles.push((handle, tool_call));
649 }
650
651 for (handle, tool_call) in handles {
652 let mut output = handle
653 .await
654 .map_err(|err| AgentError::ToolFailedToJoin(tool_call.name().to_string(), err))?;
655
656 invoke_hooks!(AfterTool, self, &tool_call, &mut output);
657
658 if let Err(error) = output {
659 let stop = self.tool_calls_over_limit(tool_call);
660 if stop {
661 tracing::error!(
662 ?error,
663 "Tool call failed, retry limit reached, stopping agent: {error}",
664 );
665 } else {
666 tracing::warn!(
667 ?error,
668 tool_call = ?tool_call,
669 "Tool call failed, retrying",
670 );
671 }
672 self.add_message(ChatMessage::ToolOutput(
673 tool_call.clone(),
674 ToolOutput::fail(error.to_string()),
675 ))
676 .await?;
677 if stop {
678 self.stop(StopReason::ToolCallsOverLimit(tool_call.to_owned()))
679 .await;
680 return Err(error.into());
681 }
682 continue;
683 }
684
685 let output = output?;
686 self.handle_control_tools(tool_call, &output).await;
687
688 if !output.is_feedback_required() {
692 self.add_message(ChatMessage::ToolOutput(tool_call.to_owned(), output))
693 .await?;
694 }
695 }
696
697 Ok(())
698 }
699
700 fn hooks_by_type(&self, hook_type: HookTypes) -> Vec<&Hook> {
701 self.hooks
702 .iter()
703 .filter(|h| hook_type == (*h).into())
704 .collect()
705 }
706
707 fn find_tool_by_name(&self, tool_name: &str) -> Option<Box<dyn Tool>> {
708 self.tools
709 .iter()
710 .find(|tool| tool.name() == tool_name)
711 .cloned()
712 }
713
714 async fn handle_control_tools(&mut self, tool_call: &ToolCall, output: &ToolOutput) {
716 match output {
717 ToolOutput::Stop(maybe_message) => {
718 tracing::warn!("Stop tool called, stopping agent");
719 self.stop(StopReason::RequestedByTool(
720 tool_call.clone(),
721 maybe_message.clone(),
722 ))
723 .await;
724 }
725 ToolOutput::FeedbackRequired(maybe_payload) => {
726 tracing::warn!("Feedback required, stopping agent");
727 self.stop(StopReason::FeedbackRequired {
728 tool_call: tool_call.clone(),
729 payload: maybe_payload.clone(),
730 })
731 .await;
732 }
733 ToolOutput::AgentFailed(output) => {
734 tracing::warn!("Agent failed, stopping agent");
735 self.stop(StopReason::AgentFailed(output.clone())).await;
736 }
737 _ => (),
738 }
739 }
740
741 pub fn system_prompt(&self) -> Option<&SystemPrompt> {
743 self.system_prompt.as_ref()
744 }
745
746 pub fn system_prompt_mut(&mut self) -> Option<&mut SystemPrompt> {
750 self.system_prompt.as_mut()
751 }
752
753 fn tool_calls_over_limit(&mut self, tool_call: &ToolCall) -> bool {
754 let mut s = DefaultHasher::new();
755 tool_call.hash(&mut s);
756 let hash = s.finish();
757
758 if let Some(retries) = self.tool_retries_counter.get_mut(&hash) {
759 let val = *retries >= self.tool_retry_limit;
760 *retries += 1;
761 val
762 } else {
763 self.tool_retries_counter.insert(hash, 1);
764 false
765 }
766 }
767
768 #[tracing::instrument(skip_all, fields(message = message.to_string()))]
779 pub async fn add_message(&self, mut message: ChatMessage) -> Result<(), AgentError> {
780 invoke_hooks!(OnNewMessage, self, &mut message);
781
782 self.context
783 .add_message(message)
784 .await
785 .map_err(AgentError::MessageHistoryError)?;
786 Ok(())
787 }
788
789 pub async fn stop(&mut self, reason: impl Into<StopReason>) {
791 if self.state.is_stopped() {
792 return;
793 }
794
795 let reason = reason.into();
796 invoke_hooks!(OnStop, self, reason.clone(), None);
797
798 if cfg!(feature = "langfuse") {
799 debug!(langfuse.output = serde_json::to_string_pretty(&reason).ok());
800 }
801
802 self.state = state::State::Stopped(reason);
803 }
804
805 pub async fn stop_with_error(&mut self, error: &AgentError) {
806 if self.state.is_stopped() {
807 return;
808 }
809 invoke_hooks!(OnStop, self, StopReason::Error, Some(error));
810
811 self.state = state::State::Stopped(StopReason::Error);
812 }
813
814 pub fn context(&self) -> &dyn AgentContext {
816 &self.context
817 }
818
819 pub fn is_running(&self) -> bool {
821 self.state.is_running()
822 }
823
824 pub fn is_stopped(&self) -> bool {
826 self.state.is_stopped()
827 }
828
829 pub fn is_pending(&self) -> bool {
831 self.state.is_pending()
832 }
833
834 pub fn tools(&self) -> &HashSet<Box<dyn Tool>> {
836 &self.tools
837 }
838
839 pub fn state(&self) -> &state::State {
840 &self.state
841 }
842
843 pub fn stop_reason(&self) -> Option<&StopReason> {
844 self.state.stop_reason()
845 }
846
847 async fn load_toolboxes(&mut self) -> Result<(), AgentError> {
848 for toolbox in &self.toolboxes {
849 let tools = toolbox
850 .available_tools()
851 .await
852 .map_err(AgentError::ToolBoxFailedToLoad)?;
853 self.tools.extend(tools);
854 }
855
856 Ok(())
857 }
858}
859
860fn maybe_tool_call_without_output(messages: &[ChatMessage]) -> Option<&ChatMessage> {
863 for message in messages.iter().rev() {
864 if let ChatMessage::ToolOutput(..) = message {
865 return None;
866 }
867
868 if let ChatMessage::Assistant(.., Some(tool_calls)) = message
869 && !tool_calls.is_empty()
870 {
871 return Some(message);
872 }
873 }
874
875 None
876}
877
878#[cfg(test)]
879mod tests {
880
881 use serde::ser::Error;
882 use swiftide_core::ToolFeedback;
883 use swiftide_core::chat_completion::errors::ToolError;
884 use swiftide_core::chat_completion::{ChatCompletionResponse, ToolCall};
885 use swiftide_core::test_utils::MockChatCompletion;
886
887 use super::*;
888 use crate::{
889 State, assistant, chat_request, chat_response, summary, system, tool_failed, tool_output,
890 user,
891 };
892
893 use crate::test_utils::{MockHook, MockTool};
894
895 #[test_log::test(tokio::test)]
896 async fn test_agent_builder_defaults() {
897 let mock_llm = MockChatCompletion::new();
899
900 let agent = Agent::builder().llm(&mock_llm).build().unwrap();
902
903 assert!(agent.find_tool_by_name("stop").is_some());
907
908 let agent = Agent::builder()
910 .tools([Stop::default(), Stop::default()])
911 .llm(&mock_llm)
912 .build()
913 .unwrap();
914
915 assert_eq!(agent.tools.len(), 1);
916
917 let agent = Agent::builder()
919 .tools([MockTool::new("mock_tool")])
920 .llm(&mock_llm)
921 .build()
922 .unwrap();
923
924 assert_eq!(agent.tools.len(), 2);
925 assert!(agent.find_tool_by_name("mock_tool").is_some());
926 assert!(agent.find_tool_by_name("stop").is_some());
927
928 assert!(agent.context().history().await.unwrap().is_empty());
929 }
930
931 #[test_log::test(tokio::test)]
932 async fn test_agent_tool_calling_loop() {
933 let prompt = "Write a poem";
934 let mock_llm = MockChatCompletion::new();
935 let mock_tool = MockTool::new("mock_tool");
936
937 let chat_request = chat_request! {
938 user!("Write a poem");
939
940 tools = [mock_tool.clone()]
941 };
942
943 let mock_tool_response = chat_response! {
944 "Roses are red";
945 tool_calls = ["mock_tool"]
946
947 };
948
949 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
950
951 let chat_request = chat_request! {
952 user!("Write a poem"),
953 assistant!("Roses are red", ["mock_tool"]),
954 tool_output!("mock_tool", "Great!");
955
956 tools = [mock_tool.clone()]
957 };
958
959 let stop_response = chat_response! {
960 "Roses are red";
961 tool_calls = ["stop"]
962 };
963
964 mock_llm.expect_complete(chat_request, Ok(stop_response));
965 mock_tool.expect_invoke_ok("Great!".into(), None);
966
967 let mut agent = Agent::builder()
968 .tools([mock_tool])
969 .llm(&mock_llm)
970 .no_system_prompt()
971 .build()
972 .unwrap();
973
974 agent.query(prompt).await.unwrap();
975 }
976
977 #[test_log::test(tokio::test)]
978 async fn test_agent_tool_run_once() {
979 let prompt = "Write a poem";
980 let mock_llm = MockChatCompletion::new();
981 let mock_tool = MockTool::default();
982
983 let chat_request = chat_request! {
984 system!("My system prompt"),
985 user!("Write a poem");
986
987 tools = [mock_tool.clone()]
988 };
989
990 let mock_tool_response = chat_response! {
991 "Roses are red";
992 tool_calls = ["mock_tool"]
993
994 };
995
996 mock_tool.expect_invoke_ok("Great!".into(), None);
997 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
998
999 let mut agent = Agent::builder()
1000 .tools([mock_tool])
1001 .system_prompt("My system prompt")
1002 .llm(&mock_llm)
1003 .build()
1004 .unwrap();
1005
1006 agent.query_once(prompt).await.unwrap();
1007 }
1008
1009 #[test_log::test(tokio::test)]
1010 async fn test_agent_tool_via_toolbox_run_once() {
1011 let prompt = "Write a poem";
1012 let mock_llm = MockChatCompletion::new();
1013 let mock_tool = MockTool::default();
1014
1015 let chat_request = chat_request! {
1016 system!("My system prompt"),
1017 user!("Write a poem");
1018
1019 tools = [mock_tool.clone()]
1020 };
1021
1022 let mock_tool_response = chat_response! {
1023 "Roses are red";
1024 tool_calls = ["mock_tool"]
1025
1026 };
1027
1028 mock_tool.expect_invoke_ok("Great!".into(), None);
1029 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1030
1031 let mut agent = Agent::builder()
1032 .add_toolbox(vec![mock_tool.boxed()])
1033 .system_prompt("My system prompt")
1034 .llm(&mock_llm)
1035 .build()
1036 .unwrap();
1037
1038 agent.query_once(prompt).await.unwrap();
1039 }
1040
1041 #[test_log::test(tokio::test(flavor = "multi_thread"))]
1042 async fn test_multiple_tool_calls() {
1043 let prompt = "Write a poem";
1044 let mock_llm = MockChatCompletion::new();
1045 let mock_tool = MockTool::new("mock_tool1");
1046 let mock_tool2 = MockTool::new("mock_tool2");
1047
1048 let chat_request = chat_request! {
1049 system!("My system prompt"),
1050 user!("Write a poem");
1051
1052
1053
1054 tools = [mock_tool.clone(), mock_tool2.clone()]
1055 };
1056
1057 let mock_tool_response = chat_response! {
1058 "Roses are red";
1059
1060 tool_calls = ["mock_tool1", "mock_tool2"]
1061
1062 };
1063
1064 dbg!(&chat_request);
1065 mock_tool.expect_invoke_ok("Great!".into(), None);
1066 mock_tool2.expect_invoke_ok("Great!".into(), None);
1067 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1068
1069 let chat_request = chat_request! {
1070 system!("My system prompt"),
1071 user!("Write a poem"),
1072 assistant!("Roses are red", ["mock_tool1", "mock_tool2"]),
1073 tool_output!("mock_tool1", "Great!"),
1074 tool_output!("mock_tool2", "Great!");
1075
1076 tools = [mock_tool.clone(), mock_tool2.clone()]
1077 };
1078
1079 let mock_tool_response = chat_response! {
1080 "Ok!";
1081
1082 tool_calls = ["stop"]
1083
1084 };
1085
1086 mock_llm.expect_complete(chat_request, Ok(mock_tool_response));
1087
1088 let mut agent = Agent::builder()
1089 .tools([mock_tool, mock_tool2])
1090 .system_prompt("My system prompt")
1091 .llm(&mock_llm)
1092 .build()
1093 .unwrap();
1094
1095 agent.query(prompt).await.unwrap();
1096 }
1097
1098 #[test_log::test(tokio::test)]
1099 async fn test_agent_state_machine() {
1100 let prompt = "Write a poem";
1101 let mock_llm = MockChatCompletion::new();
1102
1103 let chat_request = chat_request! {
1104 user!("Write a poem");
1105 tools = []
1106 };
1107 let mock_tool_response = chat_response! {
1108 "Roses are red";
1109 tool_calls = []
1110 };
1111
1112 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1113 let mut agent = Agent::builder()
1114 .llm(&mock_llm)
1115 .no_system_prompt()
1116 .build()
1117 .unwrap();
1118
1119 assert!(agent.state.is_pending());
1121 agent.query_once(prompt).await.unwrap();
1122
1123 assert!(agent.state.is_stopped());
1125 }
1126
1127 #[test_log::test(tokio::test)]
1128 async fn test_summary() {
1129 let prompt = "Write a poem";
1130 let mock_llm = MockChatCompletion::new();
1131
1132 let mock_tool_response = chat_response! {
1133 "Roses are red";
1134 tool_calls = []
1135
1136 };
1137
1138 let expected_chat_request = chat_request! {
1139 system!("My system prompt"),
1140 user!("Write a poem");
1141
1142 tools = []
1143 };
1144
1145 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
1146
1147 let mut agent = Agent::builder()
1148 .system_prompt("My system prompt")
1149 .llm(&mock_llm)
1150 .build()
1151 .unwrap();
1152
1153 agent.query_once(prompt).await.unwrap();
1154
1155 agent
1156 .context
1157 .add_message(ChatMessage::new_summary("Summary"))
1158 .await
1159 .unwrap();
1160
1161 let expected_chat_request = chat_request! {
1162 system!("My system prompt"),
1163 summary!("Summary"),
1164 user!("Write another poem");
1165 tools = []
1166 };
1167 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
1168
1169 agent.query_once("Write another poem").await.unwrap();
1170
1171 agent
1172 .context
1173 .add_message(ChatMessage::new_summary("Summary 2"))
1174 .await
1175 .unwrap();
1176
1177 let expected_chat_request = chat_request! {
1178 system!("My system prompt"),
1179 summary!("Summary 2"),
1180 user!("Write a third poem");
1181 tools = []
1182 };
1183 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response));
1184
1185 agent.query_once("Write a third poem").await.unwrap();
1186 }
1187
1188 #[test_log::test(tokio::test)]
1189 async fn test_agent_hooks() {
1190 let mock_before_all = MockHook::new("before_all").expect_calls(1).to_owned();
1191 let mock_on_start_fn = MockHook::new("on_start").expect_calls(1).to_owned();
1192 let mock_before_completion = MockHook::new("before_completion")
1193 .expect_calls(2)
1194 .to_owned();
1195 let mock_after_completion = MockHook::new("after_completion").expect_calls(2).to_owned();
1196 let mock_after_each = MockHook::new("after_each").expect_calls(2).to_owned();
1197 let mock_on_message = MockHook::new("on_message").expect_calls(4).to_owned();
1198 let mock_on_stop = MockHook::new("on_stop").expect_calls(1).to_owned();
1199
1200 let mock_before_tool = MockHook::new("before_tool").expect_calls(2).to_owned();
1202 let mock_after_tool = MockHook::new("after_tool").expect_calls(2).to_owned();
1203
1204 let prompt = "Write a poem";
1205 let mock_llm = MockChatCompletion::new();
1206 let mock_tool = MockTool::default();
1207
1208 let chat_request = chat_request! {
1209 user!("Write a poem");
1210
1211 tools = [mock_tool.clone()]
1212 };
1213
1214 let mock_tool_response = chat_response! {
1215 "Roses are red";
1216 tool_calls = ["mock_tool"]
1217
1218 };
1219
1220 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1221
1222 let chat_request = chat_request! {
1223 user!("Write a poem"),
1224 assistant!("Roses are red", ["mock_tool"]),
1225 tool_output!("mock_tool", "Great!");
1226
1227 tools = [mock_tool.clone()]
1228 };
1229
1230 let stop_response = chat_response! {
1231 "Roses are red";
1232 tool_calls = ["stop"]
1233 };
1234
1235 mock_llm.expect_complete(chat_request, Ok(stop_response));
1236 mock_tool.expect_invoke_ok("Great!".into(), None);
1237
1238 let mut agent = Agent::builder()
1239 .tools([mock_tool])
1240 .llm(&mock_llm)
1241 .no_system_prompt()
1242 .before_all(mock_before_all.hook_fn())
1243 .on_start(mock_on_start_fn.on_start_fn())
1244 .before_completion(mock_before_completion.before_completion_fn())
1245 .before_tool(mock_before_tool.before_tool_fn())
1246 .after_completion(mock_after_completion.after_completion_fn())
1247 .after_tool(mock_after_tool.after_tool_fn())
1248 .after_each(mock_after_each.hook_fn())
1249 .on_new_message(mock_on_message.message_hook_fn())
1250 .on_stop(mock_on_stop.stop_hook_fn())
1251 .build()
1252 .unwrap();
1253
1254 agent.query(prompt).await.unwrap();
1255 }
1256
1257 #[test_log::test(tokio::test)]
1258 async fn test_agent_loop_limit() {
1259 let prompt = "Generate content"; let mock_llm = MockChatCompletion::new();
1261 let mock_tool = MockTool::new("mock_tool");
1262
1263 let chat_request = chat_request! {
1264 user!(prompt);
1265 tools = [mock_tool.clone()]
1266 };
1267 mock_tool.expect_invoke_ok("Great!".into(), None);
1268
1269 let mock_tool_response = chat_response! {
1270 "Some response";
1271 tool_calls = ["mock_tool"]
1272 };
1273
1274 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response.clone()));
1276
1277 let stop_response = chat_response! {
1279 "Final response";
1280 tool_calls = ["stop"]
1281 };
1282
1283 mock_llm.expect_complete(chat_request, Ok(stop_response));
1284
1285 let mut agent = Agent::builder()
1286 .tools([mock_tool])
1287 .llm(&mock_llm)
1288 .no_system_prompt()
1289 .limit(1) .build()
1291 .unwrap();
1292
1293 agent.query(prompt).await.unwrap();
1295
1296 let remaining = mock_llm.expectations.lock().unwrap().pop();
1298 assert!(remaining.is_some());
1299
1300 assert!(agent.is_stopped());
1302 }
1303
1304 #[test_log::test(tokio::test)]
1305 async fn test_tool_retry_mechanism() {
1306 let prompt = "Execute my tool";
1307 let mock_llm = MockChatCompletion::new();
1308 let mock_tool = MockTool::new("retry_tool");
1309
1310 mock_tool.expect_invoke(
1313 Err(ToolError::WrongArguments(serde_json::Error::custom(
1314 "missing `query`",
1315 ))),
1316 None,
1317 );
1318 mock_tool.expect_invoke(
1319 Err(ToolError::WrongArguments(serde_json::Error::custom(
1320 "missing `query`",
1321 ))),
1322 None,
1323 );
1324
1325 let chat_request = chat_request! {
1326 user!(prompt);
1327 tools = [mock_tool.clone()]
1328 };
1329 let retry_response = chat_response! {
1330 "First failing attempt";
1331 tool_calls = ["retry_tool"]
1332 };
1333 mock_llm.expect_complete(chat_request.clone(), Ok(retry_response));
1334
1335 let chat_request = chat_request! {
1336 user!(prompt),
1337 assistant!("First failing attempt", ["retry_tool"]),
1338 tool_failed!("retry_tool", "arguments for tool failed to parse: missing `query`");
1339
1340 tools = [mock_tool.clone()]
1341 };
1342 let will_fail_response = chat_response! {
1343 "Finished execution";
1344 tool_calls = ["retry_tool"]
1345 };
1346 mock_llm.expect_complete(chat_request.clone(), Ok(will_fail_response));
1347
1348 let mut agent = Agent::builder()
1349 .tools([mock_tool])
1350 .llm(&mock_llm)
1351 .no_system_prompt()
1352 .tool_retry_limit(1) .build()
1354 .unwrap();
1355
1356 let result = agent.query(prompt).await;
1358
1359 assert!(result.is_err());
1360 assert!(result.unwrap_err().to_string().contains("missing `query`"));
1361 assert!(agent.is_stopped());
1362 }
1363
1364 #[test_log::test(tokio::test(flavor = "multi_thread"))]
1365 async fn test_streaming() {
1366 let prompt = "Generate content"; let mock_llm = MockChatCompletion::new();
1368 let on_stream_fn = MockHook::new("on_stream").expect_calls(3).to_owned();
1369
1370 let chat_request = chat_request! {
1371 user!(prompt);
1372
1373 tools = []
1374 };
1375
1376 let response = chat_response! {
1377 "one two three";
1378 tool_calls = ["stop"]
1379 };
1380
1381 mock_llm.expect_complete(chat_request, Ok(response));
1383
1384 let mut agent = Agent::builder()
1385 .llm(&mock_llm)
1386 .on_stream(on_stream_fn.on_stream_fn())
1387 .no_system_prompt()
1388 .build()
1389 .unwrap();
1390
1391 agent.query(prompt).await.unwrap();
1393
1394 tracing::debug!("Agent finished running");
1395
1396 assert!(agent.is_stopped());
1398 }
1399
1400 #[test_log::test(tokio::test)]
1401 async fn test_recovering_agent_existing_history() {
1402 let prompt = "Write a poem";
1404 let mock_llm = MockChatCompletion::new();
1405 let mock_tool = MockTool::new("mock_tool");
1406
1407 let chat_request = chat_request! {
1408 user!("Write a poem");
1409
1410 tools = [mock_tool.clone()]
1411 };
1412
1413 let mock_tool_response = chat_response! {
1414 "Roses are red";
1415 tool_calls = ["mock_tool"]
1416
1417 };
1418
1419 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1420
1421 let chat_request = chat_request! {
1422 user!("Write a poem"),
1423 assistant!("Roses are red", ["mock_tool"]),
1424 tool_output!("mock_tool", "Great!");
1425
1426 tools = [mock_tool.clone()]
1427 };
1428
1429 let stop_response = chat_response! {
1430 "Roses are red";
1431 tool_calls = ["stop"]
1432 };
1433
1434 mock_llm.expect_complete(chat_request, Ok(stop_response));
1435 mock_tool.expect_invoke_ok("Great!".into(), None);
1436
1437 let mut agent = Agent::builder()
1438 .tools([mock_tool.clone()])
1439 .llm(&mock_llm)
1440 .no_system_prompt()
1441 .build()
1442 .unwrap();
1443
1444 agent.query(prompt).await.unwrap();
1445
1446 let history = agent.history().await.unwrap();
1448
1449 let serialized = serde_json::to_string(&history).unwrap();
1451
1452 let history: Vec<ChatMessage> = serde_json::from_str(&serialized).unwrap();
1454
1455 let context = DefaultContext::default()
1457 .with_existing_messages(history)
1458 .await
1459 .unwrap()
1460 .to_owned();
1461
1462 let stop_output = ToolOutput::stop();
1463 let expected_chat_request = chat_request! {
1464 user!("Write a poem"),
1465 assistant!("Roses are red", ["mock_tool"]),
1466 tool_output!("mock_tool", "Great!"),
1467 assistant!("Roses are red", ["stop"]),
1468 tool_output!("stop", stop_output),
1469 user!("Try again!");
1470
1471 tools = [mock_tool.clone()]
1472 };
1473
1474 let stop_response = chat_response! {
1475 "Really stopping now";
1476 tool_calls = ["stop"]
1477 };
1478
1479 mock_llm.expect_complete(expected_chat_request, Ok(stop_response));
1480
1481 let mut agent = Agent::builder()
1482 .context(context)
1483 .tools([mock_tool])
1484 .llm(&mock_llm)
1485 .no_system_prompt()
1486 .build()
1487 .unwrap();
1488
1489 agent.query_once("Try again!").await.unwrap();
1490 }
1491
1492 #[test_log::test(tokio::test)]
1493 async fn test_agent_with_approval_required_tool() {
1494 use super::*;
1495 use crate::tools::control::ApprovalRequired;
1496 use crate::{assistant, chat_request, chat_response, user};
1497 use swiftide_core::chat_completion::ToolCall;
1498
1499 let mock_tool = MockTool::default();
1501 mock_tool.expect_invoke_ok("Great!".into(), None);
1502
1503 let approval_tool = ApprovalRequired(mock_tool.boxed());
1504
1505 let mock_llm = MockChatCompletion::new();
1507
1508 let chat_req1 = chat_request! {
1509 user!("Request with approval");
1510 tools = [approval_tool.clone()]
1511 };
1512 let chat_resp1 = chat_response! {
1513 "Completion message";
1514 tool_calls = ["mock_tool"]
1515 };
1516 mock_llm.expect_complete(chat_req1.clone(), Ok(chat_resp1));
1517
1518 let chat_req2 = chat_request! {
1521 user!("Request with approval"),
1522 assistant!("Completion message", ["mock_tool"]),
1523 tool_output!("mock_tool", "Great!");
1524 tools = [approval_tool.clone()]
1526 };
1527 let chat_resp2 = chat_response! {
1528 "Post-feedback message";
1529 tool_calls = ["stop"]
1530 };
1531 mock_llm.expect_complete(chat_req2.clone(), Ok(chat_resp2));
1532
1533 let mut agent = Agent::builder()
1535 .tools([approval_tool])
1536 .llm(&mock_llm)
1537 .no_system_prompt()
1538 .build()
1539 .unwrap();
1540
1541 agent.query_once("Request with approval").await.unwrap();
1543
1544 assert!(matches!(
1545 agent.state,
1546 crate::state::State::Stopped(crate::state::StopReason::FeedbackRequired { .. })
1547 ));
1548
1549 let State::Stopped(StopReason::FeedbackRequired { tool_call, .. }) = agent.state.clone()
1550 else {
1551 panic!("Expected feedback required");
1552 };
1553
1554 agent
1556 .context
1557 .feedback_received(&tool_call, &ToolFeedback::approved())
1558 .await
1559 .unwrap();
1560
1561 tracing::debug!("running after approval");
1562 agent.run_once().await.unwrap();
1563 assert!(agent.is_stopped());
1564 }
1565
1566 #[test_log::test(tokio::test)]
1567 async fn test_agent_with_approval_required_tool_denied() {
1568 use super::*;
1569 use crate::tools::control::ApprovalRequired;
1570 use crate::{assistant, chat_request, chat_response, user};
1571 use swiftide_core::chat_completion::ToolCall;
1572
1573 let mock_tool = MockTool::default();
1575
1576 let approval_tool = ApprovalRequired(mock_tool.boxed());
1577
1578 let mock_llm = MockChatCompletion::new();
1580
1581 let chat_req1 = chat_request! {
1582 user!("Request with approval");
1583 tools = [approval_tool.clone()]
1584 };
1585 let chat_resp1 = chat_response! {
1586 "Completion message";
1587 tool_calls = ["mock_tool"]
1588 };
1589 mock_llm.expect_complete(chat_req1.clone(), Ok(chat_resp1));
1590
1591 let chat_req2 = chat_request! {
1594 user!("Request with approval"),
1595 assistant!("Completion message", ["mock_tool"]),
1596 tool_output!("mock_tool", "This tool call was refused");
1597 tools = [approval_tool.clone()]
1599 };
1600 let chat_resp2 = chat_response! {
1601 "Post-feedback message";
1602 tool_calls = ["stop"]
1603 };
1604 mock_llm.expect_complete(chat_req2.clone(), Ok(chat_resp2));
1605
1606 let mut agent = Agent::builder()
1608 .tools([approval_tool])
1609 .llm(&mock_llm)
1610 .no_system_prompt()
1611 .build()
1612 .unwrap();
1613
1614 agent.query_once("Request with approval").await.unwrap();
1616
1617 assert!(matches!(
1618 agent.state,
1619 crate::state::State::Stopped(crate::state::StopReason::FeedbackRequired { .. })
1620 ));
1621
1622 let State::Stopped(StopReason::FeedbackRequired { tool_call, .. }) = agent.state.clone()
1623 else {
1624 panic!("Expected feedback required");
1625 };
1626
1627 agent
1629 .context
1630 .feedback_received(&tool_call, &ToolFeedback::refused())
1631 .await
1632 .unwrap();
1633
1634 tracing::debug!("running after approval");
1635 agent.run_once().await.unwrap();
1636
1637 let history = agent.context().history().await.unwrap();
1638 history
1639 .iter()
1640 .rfind(|m| {
1641 let ChatMessage::ToolOutput(.., ToolOutput::Text(msg)) = m else {
1642 return false;
1643 };
1644 msg.contains("refused")
1645 })
1646 .expect("Could not find refusal message");
1647
1648 assert!(agent.is_stopped());
1649 }
1650
1651 #[test_log::test(tokio::test)]
1652 async fn test_removing_default_stop_tool() {
1653 let mock_llm = MockChatCompletion::new();
1654 let mock_tool = MockTool::new("mock_tool");
1655
1656 let agent = Agent::builder()
1658 .without_default_stop_tool()
1659 .tools([mock_tool.clone()])
1660 .llm(&mock_llm)
1661 .no_system_prompt()
1662 .build()
1663 .unwrap();
1664
1665 assert!(agent.find_tool_by_name("stop").is_none());
1667 assert!(agent.find_tool_by_name("mock_tool").is_some());
1669 }
1670}