1#![allow(dead_code)]
2use crate::{
3 default_context::DefaultContext,
4 errors::AgentError,
5 hooks::{
6 AfterCompletionFn, AfterEachFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn,
7 Hook, HookTypes, MessageHookFn, OnStartFn, OnStopFn, OnStreamFn,
8 },
9 invoke_hooks,
10 state::{self, StopReason},
11 system_prompt::SystemPrompt,
12 tools::{arg_preprocessor::ArgPreprocessor, control::Stop},
13};
14use std::{
15 collections::{HashMap, HashSet},
16 hash::{DefaultHasher, Hash as _, Hasher as _},
17 sync::Arc,
18};
19
20use derive_builder::Builder;
21use futures_util::stream::StreamExt;
22use swiftide_core::{
23 AgentContext, ToolBox,
24 chat_completion::{
25 ChatCompletion, ChatCompletionRequest, ChatMessage, Tool, ToolCall, ToolOutput,
26 },
27 prompt::Prompt,
28};
29use tracing::{Instrument, debug};
30
31#[derive(Builder)]
46pub struct Agent {
47 #[builder(default, setter(into))]
49 pub(crate) hooks: Vec<Hook>,
50 #[builder(
52 setter(custom),
53 default = Arc::new(DefaultContext::default()) as Arc<dyn AgentContext>
54 )]
55 pub(crate) context: Arc<dyn AgentContext>,
56 #[builder(default = Agent::default_tools(), setter(custom))]
58 pub(crate) tools: HashSet<Box<dyn Tool>>,
59
60 #[builder(default)]
64 pub(crate) toolboxes: Vec<Box<dyn ToolBox>>,
65
66 #[builder(setter(custom))]
68 pub(crate) llm: Box<dyn ChatCompletion>,
69
70 #[builder(setter(into, strip_option), default = Some(SystemPrompt::default()))]
90 pub(crate) system_prompt: Option<SystemPrompt>,
91
92 #[builder(private, default = state::State::default())]
94 pub(crate) state: state::State,
95
96 #[builder(default, setter(strip_option))]
99 pub(crate) limit: Option<usize>,
100
101 #[builder(default = 3)]
113 pub(crate) tool_retry_limit: usize,
114
115 #[builder(default)]
117 pub(crate) streaming: bool,
118
119 #[builder(private, default)]
122 pub(crate) clear_default_tools: bool,
123
124 #[builder(private, default)]
127 pub(crate) tool_retries_counter: HashMap<u64, usize>,
128
129 #[builder(default = "unnamed_agent".into(), setter(into))]
131 pub(crate) name: String,
132}
133
134impl Clone for Agent {
135 fn clone(&self) -> Self {
136 Agent {
137 hooks: self.hooks.clone(),
138 context: Arc::new(self.context.clone()),
139 tools: self.tools.clone(),
140 toolboxes: self.toolboxes.clone(),
141 llm: self.llm.clone(),
142 system_prompt: self.system_prompt.clone(),
143 state: self.state.clone(),
144 limit: self.limit,
145 tool_retry_limit: self.tool_retry_limit,
146 tool_retries_counter: HashMap::new(),
147 streaming: self.streaming,
148 name: self.name.clone(),
149 clear_default_tools: self.clear_default_tools,
150 }
151 }
152}
153
154impl std::fmt::Debug for Agent {
155 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156 f.debug_struct("Agent")
157 .field("name", &self.name)
158 .field(
160 "hooks",
161 &self
162 .hooks
163 .iter()
164 .map(std::string::ToString::to_string)
165 .collect::<Vec<_>>(),
166 )
167 .field(
168 "tools",
169 &self
170 .tools
171 .iter()
172 .map(swiftide_core::Tool::name)
173 .collect::<Vec<_>>(),
174 )
175 .field("llm", &"Box<dyn ChatCompletion>")
176 .field("state", &self.state)
177 .finish()
178 }
179}
180
181impl AgentBuilder {
182 pub fn context(&mut self, context: impl AgentContext + 'static) -> &mut AgentBuilder
184 where
185 Self: Clone,
186 {
187 self.context = Some(Arc::new(context) as Arc<dyn AgentContext>);
188 self
189 }
190
191 pub fn system_prompt_mut(&mut self) -> Option<&mut SystemPrompt> {
193 self.system_prompt.as_mut().and_then(Option::as_mut)
194 }
195
196 pub fn no_system_prompt(&mut self) -> &mut Self {
198 self.system_prompt = Some(None);
199
200 self
201 }
202
203 pub fn add_hook(&mut self, hook: Hook) -> &mut Self {
205 let hooks = self.hooks.get_or_insert_with(Vec::new);
206 hooks.push(hook);
207
208 self
209 }
210
211 pub fn add_tool(&mut self, tool: impl Tool + 'static) -> &mut Self {
213 let tools = self.tools.get_or_insert_with(HashSet::new);
214 if let Some(tool) = tools.replace(tool.boxed()) {
215 tracing::debug!("Tool {} already exists, replacing", tool.name());
216 }
217
218 self
219 }
220
221 pub fn before_all(&mut self, hook: impl BeforeAllFn + 'static) -> &mut Self {
224 self.add_hook(Hook::BeforeAll(Box::new(hook)))
225 }
226
227 pub fn on_start(&mut self, hook: impl OnStartFn + 'static) -> &mut Self {
231 self.add_hook(Hook::OnStart(Box::new(hook)))
232 }
233
234 pub fn on_stream(&mut self, hook: impl OnStreamFn + 'static) -> &mut Self {
241 self.streaming = Some(true);
242 self.add_hook(Hook::OnStream(Box::new(hook)))
243 }
244
245 pub fn before_completion(&mut self, hook: impl BeforeCompletionFn + 'static) -> &mut Self {
247 self.add_hook(Hook::BeforeCompletion(Box::new(hook)))
248 }
249
250 pub fn after_tool(&mut self, hook: impl AfterToolFn + 'static) -> &mut Self {
256 self.add_hook(Hook::AfterTool(Box::new(hook)))
257 }
258
259 pub fn before_tool(&mut self, hook: impl BeforeToolFn + 'static) -> &mut Self {
261 self.add_hook(Hook::BeforeTool(Box::new(hook)))
262 }
263
264 pub fn after_completion(&mut self, hook: impl AfterCompletionFn + 'static) -> &mut Self {
266 self.add_hook(Hook::AfterCompletion(Box::new(hook)))
267 }
268
269 pub fn after_each(&mut self, hook: impl AfterEachFn + 'static) -> &mut Self {
272 self.add_hook(Hook::AfterEach(Box::new(hook)))
273 }
274
275 pub fn on_new_message(&mut self, hook: impl MessageHookFn + 'static) -> &mut Self {
278 self.add_hook(Hook::OnNewMessage(Box::new(hook)))
279 }
280
281 pub fn on_stop(&mut self, hook: impl OnStopFn + 'static) -> &mut Self {
282 self.add_hook(Hook::OnStop(Box::new(hook)))
283 }
284
285 pub fn llm<LLM: ChatCompletion + Clone + 'static>(&mut self, llm: &LLM) -> &mut Self {
287 let boxed: Box<dyn ChatCompletion> = Box::new(llm.clone()) as Box<dyn ChatCompletion>;
288
289 self.llm = Some(boxed);
290 self
291 }
292
293 pub fn without_default_stop_tool(&mut self) -> &mut Self {
298 self.clear_default_tools = Some(true);
299 self
300 }
301
302 fn builder_default_tools(&self) -> HashSet<Box<dyn Tool>> {
303 if self.clear_default_tools.is_some_and(|b| b) {
304 HashSet::new()
305 } else {
306 Agent::default_tools()
307 }
308 }
309
310 pub fn tools<TOOL, I: IntoIterator<Item = TOOL>>(&mut self, tools: I) -> &mut Self
315 where
316 TOOL: Into<Box<dyn Tool>>,
317 {
318 self.tools = Some(
319 self.builder_default_tools()
320 .into_iter()
321 .chain(tools.into_iter().map(Into::into))
322 .collect(),
323 );
324 self
325 }
326
327 pub fn add_toolbox(&mut self, toolbox: impl ToolBox + 'static) -> &mut Self {
333 let toolboxes = self.toolboxes.get_or_insert_with(Vec::new);
334 toolboxes.push(Box::new(toolbox));
335
336 self
337 }
338}
339
340impl Agent {
341 pub fn builder() -> AgentBuilder {
343 AgentBuilder::default()
344 .tools(Agent::default_tools())
345 .to_owned()
346 }
347
348 pub fn name(&self) -> &str {
350 &self.name
351 }
352
353 pub fn default_tools() -> HashSet<Box<dyn Tool>> {
356 HashSet::from([Stop::default().boxed()])
357 }
358
359 #[tracing::instrument(skip_all, name = "agent.query", err)]
366 pub async fn query(&mut self, query: impl Into<Prompt>) -> Result<(), AgentError> {
367 let query = query
368 .into()
369 .render()
370 .map_err(AgentError::FailedToRenderPrompt)?;
371 self.run_agent(Some(query), false).await
372 }
373
374 pub fn add_tool(&mut self, tool: Box<dyn Tool>) {
376 if let Some(tool) = self.tools.replace(tool) {
377 tracing::debug!("Tool {} already exists, replacing", tool.name());
378 }
379 }
380
381 pub fn tools_mut(&mut self) -> &mut HashSet<Box<dyn Tool>> {
386 &mut self.tools
387 }
388
389 #[tracing::instrument(skip_all, name = "agent.query_once", err)]
395 pub async fn query_once(&mut self, query: impl Into<Prompt>) -> Result<(), AgentError> {
396 self.run_agent(Some(query), true).await
397 }
398
399 #[tracing::instrument(skip_all, name = "agent.run", err)]
406 pub async fn run(&mut self) -> Result<(), AgentError> {
407 self.run_agent(None::<Prompt>, false).await
408 }
409
410 #[tracing::instrument(skip_all, name = "agent.run_once", err)]
417 pub async fn run_once(&mut self) -> Result<(), AgentError> {
418 self.run_agent(None::<Prompt>, true).await
419 }
420
421 pub async fn history(&self) -> Result<Vec<ChatMessage>, AgentError> {
428 self.context
429 .history()
430 .await
431 .map_err(AgentError::MessageHistoryError)
432 }
433
434 pub(crate) async fn run_agent(
435 &mut self,
436 maybe_query: Option<impl Into<Prompt>>,
437 just_once: bool,
438 ) -> Result<(), AgentError> {
439 let maybe_query = maybe_query
440 .map(|q| q.into().render())
441 .transpose()
442 .map_err(AgentError::FailedToRenderPrompt)?;
443 if self.state.is_running() {
444 return Err(AgentError::AlreadyRunning);
445 }
446
447 if self.state.is_pending() {
448 if let Some(system_prompt) = &self.system_prompt {
449 self.context
450 .add_messages(vec![ChatMessage::System(
451 system_prompt
452 .to_prompt()
453 .render()
454 .map_err(AgentError::FailedToRenderSystemPrompt)?,
455 )])
456 .await
457 .map_err(AgentError::MessageHistoryError)?;
458 }
459
460 invoke_hooks!(BeforeAll, self);
461
462 self.load_toolboxes().await?;
463 }
464
465 if let Some(query) = maybe_query {
466 if cfg!(feature = "langfuse") {
467 debug!(langfuse.input = query);
468 }
469 self.context
470 .add_message(ChatMessage::User(query))
471 .await
472 .map_err(AgentError::MessageHistoryError)?;
473 }
474
475 invoke_hooks!(OnStart, self);
476
477 self.state = state::State::Running;
478
479 let mut loop_counter = 0;
480
481 while let Some(messages) = self
482 .context
483 .next_completion()
484 .await
485 .map_err(AgentError::MessageHistoryError)?
486 {
487 if let Some(limit) = self.limit
488 && loop_counter >= limit
489 {
490 tracing::warn!("Agent loop limit reached");
491 break;
492 }
493
494 if let Some(&ChatMessage::Assistant(.., Some(ref tool_calls))) =
497 maybe_tool_call_without_output(&messages)
498 {
499 tracing::debug!("Uncompleted tool calls found; invoking tools");
500 self.invoke_tools(tool_calls).await?;
501 continue;
503 }
504
505 let result = self.step(&messages, loop_counter).await;
506
507 if let Err(err) = result {
508 self.stop_with_error(&err).await;
509 tracing::error!(error = ?err, "Agent stopped with error {err}");
510 return Err(err);
511 }
512
513 if just_once || self.state.is_stopped() {
514 break;
515 }
516 loop_counter += 1;
517 }
518
519 self.stop(StopReason::NoNewMessages).await;
521
522 Ok(())
523 }
524
525 #[tracing::instrument(skip(self, messages), err, fields(otel.name))]
526 async fn step(
527 &mut self,
528 messages: &[ChatMessage],
529 step_count: usize,
530 ) -> Result<(), AgentError> {
531 tracing::Span::current().record("otel.name", format!("step-{step_count}"));
532
533 debug!(
534 tools = ?self
535 .tools
536 .iter()
537 .map(|t| t.name())
538 .collect::<Vec<_>>()
539 ,
540 "Running completion for agent with {} new messages",
541 messages.len()
542 );
543
544 let mut chat_completion_request = ChatCompletionRequest::builder()
545 .messages(messages)
546 .tools_spec(
547 self.tools
548 .iter()
549 .map(swiftide_core::Tool::tool_spec)
550 .collect::<HashSet<_>>(),
551 )
552 .build()
553 .map_err(AgentError::FailedToBuildRequest)?;
554
555 invoke_hooks!(BeforeCompletion, self, &mut chat_completion_request);
556
557 debug!(
558 "Calling LLM with the following new messages:\n {}",
559 self.context
560 .current_new_messages()
561 .await
562 .map_err(AgentError::MessageHistoryError)?
563 .iter()
564 .map(ToString::to_string)
565 .collect::<Vec<_>>()
566 .join(",\n")
567 );
568
569 let mut response = if self.streaming {
570 let mut last_response = None;
571 let mut stream = self.llm.complete_stream(&chat_completion_request).await;
572
573 while let Some(response) = stream.next().await {
574 let response = response.map_err(AgentError::CompletionsFailed)?;
575 invoke_hooks!(OnStream, self, &response);
576 last_response = Some(response);
577 }
578 tracing::trace!(?last_response, "Streaming completed");
579 last_response.ok_or(AgentError::EmptyStream)
580 } else {
581 self.llm
582 .complete(&chat_completion_request)
583 .await
584 .map_err(AgentError::CompletionsFailed)
585 }?;
586
587 response
590 .tool_calls
591 .as_deref_mut()
592 .map(ArgPreprocessor::preprocess_tool_calls);
593
594 invoke_hooks!(AfterCompletion, self, &mut response);
595
596 self.add_message(ChatMessage::Assistant(
597 response.message,
598 response.tool_calls.clone(),
599 ))
600 .await?;
601
602 if let Some(tool_calls) = response.tool_calls {
603 self.invoke_tools(&tool_calls).await?;
604 }
605
606 invoke_hooks!(AfterEach, self);
607
608 Ok(())
609 }
610
611 async fn invoke_tools(&mut self, tool_calls: &[ToolCall]) -> Result<(), AgentError> {
612 tracing::debug!("LLM returned tool calls: {:?}", tool_calls);
613
614 let mut handles = vec![];
615 for tool_call in tool_calls {
616 let Some(tool) = self.find_tool_by_name(tool_call.name()) else {
617 tracing::warn!("Tool {} not found", tool_call.name());
618 continue;
619 };
620 tracing::info!("Calling tool `{}`", tool_call.name());
621
622 let context: Arc<dyn AgentContext> = Arc::clone(&self.context);
624
625 invoke_hooks!(BeforeTool, self, &tool_call);
626
627 let tool_span = tracing::info_span!(
628 "tool",
629 "otel.name" = format!("tool.{}", tool.name().as_ref()),
630 );
631
632 let handle_tool_call = tool_call.clone();
633 let handle = tokio::spawn(async move {
634 let handle_tool_call = handle_tool_call;
635 let output = tool.invoke(&*context, &handle_tool_call)
636 .await
637 .map_err(|e| { tracing::error!(error = %e, "Failed tool call"); e })?;
638
639 if cfg!(feature = "langfuse") {
640 tracing::debug!(
641 langfuse.output = %output,
642 langfuse.input = handle_tool_call.args(),
643 tool_name = tool.name().as_ref(),
644 );
645 } else {
646 tracing::debug!(output = output.to_string(), args = ?handle_tool_call.args(), tool_name = tool.name().as_ref(), "Completed tool call");
647 }
648
649 Ok(output)
650 }.instrument(tool_span.or_current()));
651
652 handles.push((handle, tool_call));
653 }
654
655 for (handle, tool_call) in handles {
656 let mut output = handle
657 .await
658 .map_err(|err| AgentError::ToolFailedToJoin(tool_call.name().to_string(), err))?;
659
660 invoke_hooks!(AfterTool, self, &tool_call, &mut output);
661
662 if let Err(error) = output {
663 let stop = self.tool_calls_over_limit(tool_call);
664 if stop {
665 tracing::error!(
666 ?error,
667 "Tool call failed, retry limit reached, stopping agent: {error}",
668 );
669 } else {
670 tracing::warn!(
671 ?error,
672 tool_call = ?tool_call,
673 "Tool call failed, retrying",
674 );
675 }
676 self.add_message(ChatMessage::ToolOutput(
677 tool_call.clone(),
678 ToolOutput::fail(error.to_string()),
679 ))
680 .await?;
681 if stop {
682 self.stop(StopReason::ToolCallsOverLimit(tool_call.to_owned()))
683 .await;
684 return Err(error.into());
685 }
686 continue;
687 }
688
689 let output = output?;
690 self.handle_control_tools(tool_call, &output).await;
691
692 if !output.is_feedback_required() {
696 self.add_message(ChatMessage::ToolOutput(tool_call.to_owned(), output))
697 .await?;
698 }
699 }
700
701 Ok(())
702 }
703
704 fn hooks_by_type(&self, hook_type: HookTypes) -> Vec<&Hook> {
705 self.hooks
706 .iter()
707 .filter(|h| hook_type == (*h).into())
708 .collect()
709 }
710
711 fn find_tool_by_name(&self, tool_name: &str) -> Option<Box<dyn Tool>> {
712 self.tools
713 .iter()
714 .find(|tool| tool.name() == tool_name)
715 .cloned()
716 }
717
718 async fn handle_control_tools(&mut self, tool_call: &ToolCall, output: &ToolOutput) {
720 match output {
721 ToolOutput::Stop(maybe_message) => {
722 tracing::warn!("Stop tool called, stopping agent");
723 self.stop(StopReason::RequestedByTool(
724 tool_call.clone(),
725 maybe_message.clone(),
726 ))
727 .await;
728 }
729 ToolOutput::FeedbackRequired(maybe_payload) => {
730 tracing::warn!("Feedback required, stopping agent");
731 self.stop(StopReason::FeedbackRequired {
732 tool_call: tool_call.clone(),
733 payload: maybe_payload.clone(),
734 })
735 .await;
736 }
737 ToolOutput::AgentFailed(output) => {
738 tracing::warn!("Agent failed, stopping agent");
739 self.stop(StopReason::AgentFailed(output.clone())).await;
740 }
741 _ => (),
742 }
743 }
744
745 fn tool_calls_over_limit(&mut self, tool_call: &ToolCall) -> bool {
746 let mut s = DefaultHasher::new();
747 tool_call.hash(&mut s);
748 let hash = s.finish();
749
750 if let Some(retries) = self.tool_retries_counter.get_mut(&hash) {
751 let val = *retries >= self.tool_retry_limit;
752 *retries += 1;
753 val
754 } else {
755 self.tool_retries_counter.insert(hash, 1);
756 false
757 }
758 }
759
760 #[tracing::instrument(skip_all, fields(message = message.to_string()))]
771 pub async fn add_message(&self, mut message: ChatMessage) -> Result<(), AgentError> {
772 invoke_hooks!(OnNewMessage, self, &mut message);
773
774 self.context
775 .add_message(message)
776 .await
777 .map_err(AgentError::MessageHistoryError)?;
778 Ok(())
779 }
780
781 pub async fn stop(&mut self, reason: impl Into<StopReason>) {
783 if self.state.is_stopped() {
784 return;
785 }
786
787 let reason = reason.into();
788 invoke_hooks!(OnStop, self, reason.clone(), None);
789
790 if cfg!(feature = "langfuse") {
791 debug!(langfuse.output = serde_json::to_string_pretty(&reason).ok());
792 }
793
794 self.state = state::State::Stopped(reason);
795 }
796
797 pub async fn stop_with_error(&mut self, error: &AgentError) {
798 if self.state.is_stopped() {
799 return;
800 }
801 invoke_hooks!(OnStop, self, StopReason::Error, Some(error));
802
803 self.state = state::State::Stopped(StopReason::Error);
804 }
805
806 pub fn context(&self) -> &dyn AgentContext {
808 &self.context
809 }
810
811 pub fn is_running(&self) -> bool {
813 self.state.is_running()
814 }
815
816 pub fn is_stopped(&self) -> bool {
818 self.state.is_stopped()
819 }
820
821 pub fn is_pending(&self) -> bool {
823 self.state.is_pending()
824 }
825
826 pub fn tools(&self) -> &HashSet<Box<dyn Tool>> {
828 &self.tools
829 }
830
831 pub fn state(&self) -> &state::State {
832 &self.state
833 }
834
835 pub fn stop_reason(&self) -> Option<&StopReason> {
836 self.state.stop_reason()
837 }
838
839 async fn load_toolboxes(&mut self) -> Result<(), AgentError> {
840 for toolbox in &self.toolboxes {
841 let tools = toolbox
842 .available_tools()
843 .await
844 .map_err(AgentError::ToolBoxFailedToLoad)?;
845 self.tools.extend(tools);
846 }
847
848 Ok(())
849 }
850}
851
852fn maybe_tool_call_without_output(messages: &[ChatMessage]) -> Option<&ChatMessage> {
855 for message in messages.iter().rev() {
856 if let ChatMessage::ToolOutput(..) = message {
857 return None;
858 }
859
860 if let ChatMessage::Assistant(.., Some(tool_calls)) = message
861 && !tool_calls.is_empty()
862 {
863 return Some(message);
864 }
865 }
866
867 None
868}
869
870#[cfg(test)]
871mod tests {
872
873 use serde::ser::Error;
874 use swiftide_core::ToolFeedback;
875 use swiftide_core::chat_completion::errors::ToolError;
876 use swiftide_core::chat_completion::{ChatCompletionResponse, ToolCall};
877 use swiftide_core::test_utils::MockChatCompletion;
878
879 use super::*;
880 use crate::{
881 State, assistant, chat_request, chat_response, summary, system, tool_failed, tool_output,
882 user,
883 };
884
885 use crate::test_utils::{MockHook, MockTool};
886
887 #[test_log::test(tokio::test)]
888 async fn test_agent_builder_defaults() {
889 let mock_llm = MockChatCompletion::new();
891
892 let agent = Agent::builder().llm(&mock_llm).build().unwrap();
894
895 assert!(agent.find_tool_by_name("stop").is_some());
899
900 let agent = Agent::builder()
902 .tools([Stop::default(), Stop::default()])
903 .llm(&mock_llm)
904 .build()
905 .unwrap();
906
907 assert_eq!(agent.tools.len(), 1);
908
909 let agent = Agent::builder()
911 .tools([MockTool::new("mock_tool")])
912 .llm(&mock_llm)
913 .build()
914 .unwrap();
915
916 assert_eq!(agent.tools.len(), 2);
917 assert!(agent.find_tool_by_name("mock_tool").is_some());
918 assert!(agent.find_tool_by_name("stop").is_some());
919
920 assert!(agent.context().history().await.unwrap().is_empty());
921 }
922
923 #[test_log::test(tokio::test)]
924 async fn test_agent_tool_calling_loop() {
925 let prompt = "Write a poem";
926 let mock_llm = MockChatCompletion::new();
927 let mock_tool = MockTool::new("mock_tool");
928
929 let chat_request = chat_request! {
930 user!("Write a poem");
931
932 tools = [mock_tool.clone()]
933 };
934
935 let mock_tool_response = chat_response! {
936 "Roses are red";
937 tool_calls = ["mock_tool"]
938
939 };
940
941 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
942
943 let chat_request = chat_request! {
944 user!("Write a poem"),
945 assistant!("Roses are red", ["mock_tool"]),
946 tool_output!("mock_tool", "Great!");
947
948 tools = [mock_tool.clone()]
949 };
950
951 let stop_response = chat_response! {
952 "Roses are red";
953 tool_calls = ["stop"]
954 };
955
956 mock_llm.expect_complete(chat_request, Ok(stop_response));
957 mock_tool.expect_invoke_ok("Great!".into(), None);
958
959 let mut agent = Agent::builder()
960 .tools([mock_tool])
961 .llm(&mock_llm)
962 .no_system_prompt()
963 .build()
964 .unwrap();
965
966 agent.query(prompt).await.unwrap();
967 }
968
969 #[test_log::test(tokio::test)]
970 async fn test_agent_tool_run_once() {
971 let prompt = "Write a poem";
972 let mock_llm = MockChatCompletion::new();
973 let mock_tool = MockTool::default();
974
975 let chat_request = chat_request! {
976 system!("My system prompt"),
977 user!("Write a poem");
978
979 tools = [mock_tool.clone()]
980 };
981
982 let mock_tool_response = chat_response! {
983 "Roses are red";
984 tool_calls = ["mock_tool"]
985
986 };
987
988 mock_tool.expect_invoke_ok("Great!".into(), None);
989 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
990
991 let mut agent = Agent::builder()
992 .tools([mock_tool])
993 .system_prompt("My system prompt")
994 .llm(&mock_llm)
995 .build()
996 .unwrap();
997
998 agent.query_once(prompt).await.unwrap();
999 }
1000
1001 #[test_log::test(tokio::test)]
1002 async fn test_agent_tool_via_toolbox_run_once() {
1003 let prompt = "Write a poem";
1004 let mock_llm = MockChatCompletion::new();
1005 let mock_tool = MockTool::default();
1006
1007 let chat_request = chat_request! {
1008 system!("My system prompt"),
1009 user!("Write a poem");
1010
1011 tools = [mock_tool.clone()]
1012 };
1013
1014 let mock_tool_response = chat_response! {
1015 "Roses are red";
1016 tool_calls = ["mock_tool"]
1017
1018 };
1019
1020 mock_tool.expect_invoke_ok("Great!".into(), None);
1021 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1022
1023 let mut agent = Agent::builder()
1024 .add_toolbox(vec![mock_tool.boxed()])
1025 .system_prompt("My system prompt")
1026 .llm(&mock_llm)
1027 .build()
1028 .unwrap();
1029
1030 agent.query_once(prompt).await.unwrap();
1031 }
1032
1033 #[test_log::test(tokio::test(flavor = "multi_thread"))]
1034 async fn test_multiple_tool_calls() {
1035 let prompt = "Write a poem";
1036 let mock_llm = MockChatCompletion::new();
1037 let mock_tool = MockTool::new("mock_tool1");
1038 let mock_tool2 = MockTool::new("mock_tool2");
1039
1040 let chat_request = chat_request! {
1041 system!("My system prompt"),
1042 user!("Write a poem");
1043
1044
1045
1046 tools = [mock_tool.clone(), mock_tool2.clone()]
1047 };
1048
1049 let mock_tool_response = chat_response! {
1050 "Roses are red";
1051
1052 tool_calls = ["mock_tool1", "mock_tool2"]
1053
1054 };
1055
1056 dbg!(&chat_request);
1057 mock_tool.expect_invoke_ok("Great!".into(), None);
1058 mock_tool2.expect_invoke_ok("Great!".into(), None);
1059 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1060
1061 let chat_request = chat_request! {
1062 system!("My system prompt"),
1063 user!("Write a poem"),
1064 assistant!("Roses are red", ["mock_tool1", "mock_tool2"]),
1065 tool_output!("mock_tool1", "Great!"),
1066 tool_output!("mock_tool2", "Great!");
1067
1068 tools = [mock_tool.clone(), mock_tool2.clone()]
1069 };
1070
1071 let mock_tool_response = chat_response! {
1072 "Ok!";
1073
1074 tool_calls = ["stop"]
1075
1076 };
1077
1078 mock_llm.expect_complete(chat_request, Ok(mock_tool_response));
1079
1080 let mut agent = Agent::builder()
1081 .tools([mock_tool, mock_tool2])
1082 .system_prompt("My system prompt")
1083 .llm(&mock_llm)
1084 .build()
1085 .unwrap();
1086
1087 agent.query(prompt).await.unwrap();
1088 }
1089
1090 #[test_log::test(tokio::test)]
1091 async fn test_agent_state_machine() {
1092 let prompt = "Write a poem";
1093 let mock_llm = MockChatCompletion::new();
1094
1095 let chat_request = chat_request! {
1096 user!("Write a poem");
1097 tools = []
1098 };
1099 let mock_tool_response = chat_response! {
1100 "Roses are red";
1101 tool_calls = []
1102 };
1103
1104 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1105 let mut agent = Agent::builder()
1106 .llm(&mock_llm)
1107 .no_system_prompt()
1108 .build()
1109 .unwrap();
1110
1111 assert!(agent.state.is_pending());
1113 agent.query_once(prompt).await.unwrap();
1114
1115 assert!(agent.state.is_stopped());
1117 }
1118
1119 #[test_log::test(tokio::test)]
1120 async fn test_summary() {
1121 let prompt = "Write a poem";
1122 let mock_llm = MockChatCompletion::new();
1123
1124 let mock_tool_response = chat_response! {
1125 "Roses are red";
1126 tool_calls = []
1127
1128 };
1129
1130 let expected_chat_request = chat_request! {
1131 system!("My system prompt"),
1132 user!("Write a poem");
1133
1134 tools = []
1135 };
1136
1137 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
1138
1139 let mut agent = Agent::builder()
1140 .system_prompt("My system prompt")
1141 .llm(&mock_llm)
1142 .build()
1143 .unwrap();
1144
1145 agent.query_once(prompt).await.unwrap();
1146
1147 agent
1148 .context
1149 .add_message(ChatMessage::new_summary("Summary"))
1150 .await
1151 .unwrap();
1152
1153 let expected_chat_request = chat_request! {
1154 system!("My system prompt"),
1155 summary!("Summary"),
1156 user!("Write another poem");
1157 tools = []
1158 };
1159 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
1160
1161 agent.query_once("Write another poem").await.unwrap();
1162
1163 agent
1164 .context
1165 .add_message(ChatMessage::new_summary("Summary 2"))
1166 .await
1167 .unwrap();
1168
1169 let expected_chat_request = chat_request! {
1170 system!("My system prompt"),
1171 summary!("Summary 2"),
1172 user!("Write a third poem");
1173 tools = []
1174 };
1175 mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response));
1176
1177 agent.query_once("Write a third poem").await.unwrap();
1178 }
1179
1180 #[test_log::test(tokio::test)]
1181 async fn test_agent_hooks() {
1182 let mock_before_all = MockHook::new("before_all").expect_calls(1).to_owned();
1183 let mock_on_start_fn = MockHook::new("on_start").expect_calls(1).to_owned();
1184 let mock_before_completion = MockHook::new("before_completion")
1185 .expect_calls(2)
1186 .to_owned();
1187 let mock_after_completion = MockHook::new("after_completion").expect_calls(2).to_owned();
1188 let mock_after_each = MockHook::new("after_each").expect_calls(2).to_owned();
1189 let mock_on_message = MockHook::new("on_message").expect_calls(4).to_owned();
1190 let mock_on_stop = MockHook::new("on_stop").expect_calls(1).to_owned();
1191
1192 let mock_before_tool = MockHook::new("before_tool").expect_calls(2).to_owned();
1194 let mock_after_tool = MockHook::new("after_tool").expect_calls(2).to_owned();
1195
1196 let prompt = "Write a poem";
1197 let mock_llm = MockChatCompletion::new();
1198 let mock_tool = MockTool::default();
1199
1200 let chat_request = chat_request! {
1201 user!("Write a poem");
1202
1203 tools = [mock_tool.clone()]
1204 };
1205
1206 let mock_tool_response = chat_response! {
1207 "Roses are red";
1208 tool_calls = ["mock_tool"]
1209
1210 };
1211
1212 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1213
1214 let chat_request = chat_request! {
1215 user!("Write a poem"),
1216 assistant!("Roses are red", ["mock_tool"]),
1217 tool_output!("mock_tool", "Great!");
1218
1219 tools = [mock_tool.clone()]
1220 };
1221
1222 let stop_response = chat_response! {
1223 "Roses are red";
1224 tool_calls = ["stop"]
1225 };
1226
1227 mock_llm.expect_complete(chat_request, Ok(stop_response));
1228 mock_tool.expect_invoke_ok("Great!".into(), None);
1229
1230 let mut agent = Agent::builder()
1231 .tools([mock_tool])
1232 .llm(&mock_llm)
1233 .no_system_prompt()
1234 .before_all(mock_before_all.hook_fn())
1235 .on_start(mock_on_start_fn.on_start_fn())
1236 .before_completion(mock_before_completion.before_completion_fn())
1237 .before_tool(mock_before_tool.before_tool_fn())
1238 .after_completion(mock_after_completion.after_completion_fn())
1239 .after_tool(mock_after_tool.after_tool_fn())
1240 .after_each(mock_after_each.hook_fn())
1241 .on_new_message(mock_on_message.message_hook_fn())
1242 .on_stop(mock_on_stop.stop_hook_fn())
1243 .build()
1244 .unwrap();
1245
1246 agent.query(prompt).await.unwrap();
1247 }
1248
1249 #[test_log::test(tokio::test)]
1250 async fn test_agent_loop_limit() {
1251 let prompt = "Generate content"; let mock_llm = MockChatCompletion::new();
1253 let mock_tool = MockTool::new("mock_tool");
1254
1255 let chat_request = chat_request! {
1256 user!(prompt);
1257 tools = [mock_tool.clone()]
1258 };
1259 mock_tool.expect_invoke_ok("Great!".into(), None);
1260
1261 let mock_tool_response = chat_response! {
1262 "Some response";
1263 tool_calls = ["mock_tool"]
1264 };
1265
1266 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response.clone()));
1268
1269 let stop_response = chat_response! {
1271 "Final response";
1272 tool_calls = ["stop"]
1273 };
1274
1275 mock_llm.expect_complete(chat_request, Ok(stop_response));
1276
1277 let mut agent = Agent::builder()
1278 .tools([mock_tool])
1279 .llm(&mock_llm)
1280 .no_system_prompt()
1281 .limit(1) .build()
1283 .unwrap();
1284
1285 agent.query(prompt).await.unwrap();
1287
1288 let remaining = mock_llm.expectations.lock().unwrap().pop();
1290 assert!(remaining.is_some());
1291
1292 assert!(agent.is_stopped());
1294 }
1295
1296 #[test_log::test(tokio::test)]
1297 async fn test_tool_retry_mechanism() {
1298 let prompt = "Execute my tool";
1299 let mock_llm = MockChatCompletion::new();
1300 let mock_tool = MockTool::new("retry_tool");
1301
1302 mock_tool.expect_invoke(
1305 Err(ToolError::WrongArguments(serde_json::Error::custom(
1306 "missing `query`",
1307 ))),
1308 None,
1309 );
1310 mock_tool.expect_invoke(
1311 Err(ToolError::WrongArguments(serde_json::Error::custom(
1312 "missing `query`",
1313 ))),
1314 None,
1315 );
1316
1317 let chat_request = chat_request! {
1318 user!(prompt);
1319 tools = [mock_tool.clone()]
1320 };
1321 let retry_response = chat_response! {
1322 "First failing attempt";
1323 tool_calls = ["retry_tool"]
1324 };
1325 mock_llm.expect_complete(chat_request.clone(), Ok(retry_response));
1326
1327 let chat_request = chat_request! {
1328 user!(prompt),
1329 assistant!("First failing attempt", ["retry_tool"]),
1330 tool_failed!("retry_tool", "arguments for tool failed to parse: missing `query`");
1331
1332 tools = [mock_tool.clone()]
1333 };
1334 let will_fail_response = chat_response! {
1335 "Finished execution";
1336 tool_calls = ["retry_tool"]
1337 };
1338 mock_llm.expect_complete(chat_request.clone(), Ok(will_fail_response));
1339
1340 let mut agent = Agent::builder()
1341 .tools([mock_tool])
1342 .llm(&mock_llm)
1343 .no_system_prompt()
1344 .tool_retry_limit(1) .build()
1346 .unwrap();
1347
1348 let result = agent.query(prompt).await;
1350
1351 assert!(result.is_err());
1352 assert!(result.unwrap_err().to_string().contains("missing `query`"));
1353 assert!(agent.is_stopped());
1354 }
1355
1356 #[test_log::test(tokio::test(flavor = "multi_thread"))]
1357 async fn test_streaming() {
1358 let prompt = "Generate content"; let mock_llm = MockChatCompletion::new();
1360 let on_stream_fn = MockHook::new("on_stream").expect_calls(3).to_owned();
1361
1362 let chat_request = chat_request! {
1363 user!(prompt);
1364
1365 tools = []
1366 };
1367
1368 let response = chat_response! {
1369 "one two three";
1370 tool_calls = ["stop"]
1371 };
1372
1373 mock_llm.expect_complete(chat_request, Ok(response));
1375
1376 let mut agent = Agent::builder()
1377 .llm(&mock_llm)
1378 .on_stream(on_stream_fn.on_stream_fn())
1379 .no_system_prompt()
1380 .build()
1381 .unwrap();
1382
1383 agent.query(prompt).await.unwrap();
1385
1386 tracing::debug!("Agent finished running");
1387
1388 assert!(agent.is_stopped());
1390 }
1391
1392 #[test_log::test(tokio::test)]
1393 async fn test_recovering_agent_existing_history() {
1394 let prompt = "Write a poem";
1396 let mock_llm = MockChatCompletion::new();
1397 let mock_tool = MockTool::new("mock_tool");
1398
1399 let chat_request = chat_request! {
1400 user!("Write a poem");
1401
1402 tools = [mock_tool.clone()]
1403 };
1404
1405 let mock_tool_response = chat_response! {
1406 "Roses are red";
1407 tool_calls = ["mock_tool"]
1408
1409 };
1410
1411 mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1412
1413 let chat_request = chat_request! {
1414 user!("Write a poem"),
1415 assistant!("Roses are red", ["mock_tool"]),
1416 tool_output!("mock_tool", "Great!");
1417
1418 tools = [mock_tool.clone()]
1419 };
1420
1421 let stop_response = chat_response! {
1422 "Roses are red";
1423 tool_calls = ["stop"]
1424 };
1425
1426 mock_llm.expect_complete(chat_request, Ok(stop_response));
1427 mock_tool.expect_invoke_ok("Great!".into(), None);
1428
1429 let mut agent = Agent::builder()
1430 .tools([mock_tool.clone()])
1431 .llm(&mock_llm)
1432 .no_system_prompt()
1433 .build()
1434 .unwrap();
1435
1436 agent.query(prompt).await.unwrap();
1437
1438 let history = agent.history().await.unwrap();
1440
1441 let serialized = serde_json::to_string(&history).unwrap();
1443
1444 let history: Vec<ChatMessage> = serde_json::from_str(&serialized).unwrap();
1446
1447 let context = DefaultContext::default()
1449 .with_existing_messages(history)
1450 .await
1451 .unwrap()
1452 .to_owned();
1453
1454 let stop_output = ToolOutput::stop();
1455 let expected_chat_request = chat_request! {
1456 user!("Write a poem"),
1457 assistant!("Roses are red", ["mock_tool"]),
1458 tool_output!("mock_tool", "Great!"),
1459 assistant!("Roses are red", ["stop"]),
1460 tool_output!("stop", stop_output),
1461 user!("Try again!");
1462
1463 tools = [mock_tool.clone()]
1464 };
1465
1466 let stop_response = chat_response! {
1467 "Really stopping now";
1468 tool_calls = ["stop"]
1469 };
1470
1471 mock_llm.expect_complete(expected_chat_request, Ok(stop_response));
1472
1473 let mut agent = Agent::builder()
1474 .context(context)
1475 .tools([mock_tool])
1476 .llm(&mock_llm)
1477 .no_system_prompt()
1478 .build()
1479 .unwrap();
1480
1481 agent.query_once("Try again!").await.unwrap();
1482 }
1483
1484 #[test_log::test(tokio::test)]
1485 async fn test_agent_with_approval_required_tool() {
1486 use super::*;
1487 use crate::tools::control::ApprovalRequired;
1488 use crate::{assistant, chat_request, chat_response, user};
1489 use swiftide_core::chat_completion::ToolCall;
1490
1491 let mock_tool = MockTool::default();
1493 mock_tool.expect_invoke_ok("Great!".into(), None);
1494
1495 let approval_tool = ApprovalRequired(mock_tool.boxed());
1496
1497 let mock_llm = MockChatCompletion::new();
1499
1500 let chat_req1 = chat_request! {
1501 user!("Request with approval");
1502 tools = [approval_tool.clone()]
1503 };
1504 let chat_resp1 = chat_response! {
1505 "Completion message";
1506 tool_calls = ["mock_tool"]
1507 };
1508 mock_llm.expect_complete(chat_req1.clone(), Ok(chat_resp1));
1509
1510 let chat_req2 = chat_request! {
1513 user!("Request with approval"),
1514 assistant!("Completion message", ["mock_tool"]),
1515 tool_output!("mock_tool", "Great!");
1516 tools = [approval_tool.clone()]
1518 };
1519 let chat_resp2 = chat_response! {
1520 "Post-feedback message";
1521 tool_calls = ["stop"]
1522 };
1523 mock_llm.expect_complete(chat_req2.clone(), Ok(chat_resp2));
1524
1525 let mut agent = Agent::builder()
1527 .tools([approval_tool])
1528 .llm(&mock_llm)
1529 .no_system_prompt()
1530 .build()
1531 .unwrap();
1532
1533 agent.query_once("Request with approval").await.unwrap();
1535
1536 assert!(matches!(
1537 agent.state,
1538 crate::state::State::Stopped(crate::state::StopReason::FeedbackRequired { .. })
1539 ));
1540
1541 let State::Stopped(StopReason::FeedbackRequired { tool_call, .. }) = agent.state.clone()
1542 else {
1543 panic!("Expected feedback required");
1544 };
1545
1546 agent
1548 .context
1549 .feedback_received(&tool_call, &ToolFeedback::approved())
1550 .await
1551 .unwrap();
1552
1553 tracing::debug!("running after approval");
1554 agent.run_once().await.unwrap();
1555 assert!(agent.is_stopped());
1556 }
1557
1558 #[test_log::test(tokio::test)]
1559 async fn test_agent_with_approval_required_tool_denied() {
1560 use super::*;
1561 use crate::tools::control::ApprovalRequired;
1562 use crate::{assistant, chat_request, chat_response, user};
1563 use swiftide_core::chat_completion::ToolCall;
1564
1565 let mock_tool = MockTool::default();
1567
1568 let approval_tool = ApprovalRequired(mock_tool.boxed());
1569
1570 let mock_llm = MockChatCompletion::new();
1572
1573 let chat_req1 = chat_request! {
1574 user!("Request with approval");
1575 tools = [approval_tool.clone()]
1576 };
1577 let chat_resp1 = chat_response! {
1578 "Completion message";
1579 tool_calls = ["mock_tool"]
1580 };
1581 mock_llm.expect_complete(chat_req1.clone(), Ok(chat_resp1));
1582
1583 let chat_req2 = chat_request! {
1586 user!("Request with approval"),
1587 assistant!("Completion message", ["mock_tool"]),
1588 tool_output!("mock_tool", "This tool call was refused");
1589 tools = [approval_tool.clone()]
1591 };
1592 let chat_resp2 = chat_response! {
1593 "Post-feedback message";
1594 tool_calls = ["stop"]
1595 };
1596 mock_llm.expect_complete(chat_req2.clone(), Ok(chat_resp2));
1597
1598 let mut agent = Agent::builder()
1600 .tools([approval_tool])
1601 .llm(&mock_llm)
1602 .no_system_prompt()
1603 .build()
1604 .unwrap();
1605
1606 agent.query_once("Request with approval").await.unwrap();
1608
1609 assert!(matches!(
1610 agent.state,
1611 crate::state::State::Stopped(crate::state::StopReason::FeedbackRequired { .. })
1612 ));
1613
1614 let State::Stopped(StopReason::FeedbackRequired { tool_call, .. }) = agent.state.clone()
1615 else {
1616 panic!("Expected feedback required");
1617 };
1618
1619 agent
1621 .context
1622 .feedback_received(&tool_call, &ToolFeedback::refused())
1623 .await
1624 .unwrap();
1625
1626 tracing::debug!("running after approval");
1627 agent.run_once().await.unwrap();
1628
1629 let history = agent.context().history().await.unwrap();
1630 history
1631 .iter()
1632 .rfind(|m| {
1633 let ChatMessage::ToolOutput(.., ToolOutput::Text(msg)) = m else {
1634 return false;
1635 };
1636 msg.contains("refused")
1637 })
1638 .expect("Could not find refusal message");
1639
1640 assert!(agent.is_stopped());
1641 }
1642
1643 #[test_log::test(tokio::test)]
1644 async fn test_removing_default_stop_tool() {
1645 let mock_llm = MockChatCompletion::new();
1646 let mock_tool = MockTool::new("mock_tool");
1647
1648 let agent = Agent::builder()
1650 .without_default_stop_tool()
1651 .tools([mock_tool.clone()])
1652 .llm(&mock_llm)
1653 .no_system_prompt()
1654 .build()
1655 .unwrap();
1656
1657 assert!(agent.find_tool_by_name("stop").is_none());
1659 assert!(agent.find_tool_by_name("mock_tool").is_some());
1661 }
1662}