swiftide_agents/
agent.rs

1#![allow(dead_code)]
2use crate::{
3    default_context::DefaultContext,
4    errors::AgentError,
5    hooks::{
6        AfterCompletionFn, AfterEachFn, AfterToolFn, BeforeAllFn, BeforeCompletionFn, BeforeToolFn,
7        Hook, HookTypes, MessageHookFn, OnStartFn, OnStopFn,
8    },
9    invoke_hooks,
10    state::{self, StopReason},
11    system_prompt::SystemPrompt,
12    tools::{arg_preprocessor::ArgPreprocessor, control::Stop},
13};
14use std::{
15    collections::{HashMap, HashSet},
16    hash::{DefaultHasher, Hash as _, Hasher as _},
17    sync::Arc,
18};
19
20use derive_builder::Builder;
21use swiftide_core::{
22    chat_completion::{
23        ChatCompletion, ChatCompletionRequest, ChatMessage, Tool, ToolCall, ToolOutput,
24    },
25    prompt::Prompt,
26    AgentContext, ToolBox,
27};
28use tracing::{debug, Instrument};
29
30/// Agents are the main interface for building agentic systems.
31///
32/// Construct agents by calling the builder, setting an llm, configure hooks, tools and other
33/// customizations.
34///
35/// # Important defaults
36///
37/// - The default context is the `DefaultContext`, executing tools locally with the `LocalExecutor`.
38/// - A default `stop` tool is provided for agents to explicitly stop if needed
39/// - The default `SystemPrompt` instructs the agent with chain of thought and some common
40///   safeguards, but is otherwise quite bare. In a lot of cases this can be sufficient.
41#[derive(Clone, Builder)]
42pub struct Agent {
43    /// Hooks are functions that are called at specific points in the agent's lifecycle.
44    #[builder(default, setter(into))]
45    pub(crate) hooks: Vec<Hook>,
46    /// The context in which the agent operates, by default this is the `DefaultContext`.
47    #[builder(
48        setter(custom),
49        default = Arc::new(DefaultContext::default()) as Arc<dyn AgentContext>
50    )]
51    pub(crate) context: Arc<dyn AgentContext>,
52    /// Tools the agent can use
53    #[builder(default = Agent::default_tools(), setter(custom))]
54    pub(crate) tools: HashSet<Box<dyn Tool>>,
55
56    /// Toolboxes are collections of tools that can be added to the agent.
57    ///
58    /// Toolboxes make their tools available to the agent at runtime.
59    #[builder(default)]
60    pub(crate) toolboxes: Vec<Box<dyn ToolBox>>,
61
62    /// The language model that the agent uses for completion.
63    #[builder(setter(custom))]
64    pub(crate) llm: Box<dyn ChatCompletion>,
65
66    /// System prompt for the agent when it starts
67    ///
68    /// Some agents profit significantly from a tailored prompt. But it is not always needed.
69    ///
70    /// See [`SystemPrompt`] for an opiniated, customizable system prompt.
71    ///
72    /// Swiftide provides a default system prompt for all agents.
73    ///
74    /// # Example
75    ///
76    /// ```no_run
77    /// # use swiftide_agents::system_prompt::SystemPrompt;
78    /// # use swiftide_agents::Agent;
79    /// Agent::builder()
80    ///     .system_prompt(
81    ///         SystemPrompt::builder().role("You are an expert engineer")
82    ///         .build().unwrap())
83    ///     .build().unwrap();
84    /// ```
85    #[builder(setter(into, strip_option), default = Some(SystemPrompt::default().into()))]
86    pub(crate) system_prompt: Option<Prompt>,
87
88    /// Initial state of the agent
89    #[builder(private, default = state::State::default())]
90    pub(crate) state: state::State,
91
92    /// Optional limit on the amount of loops the agent can run.
93    /// The counter is reset when the agent is stopped.
94    #[builder(default, setter(strip_option))]
95    pub(crate) limit: Option<usize>,
96
97    /// The maximum amount of times the failed output of a tool will be send
98    /// to an LLM before the agent stops. Defaults to 3.
99    ///
100    /// LLMs sometimes send missing arguments, or a tool might actually fail, but retrying could be
101    /// worth while. If the limit is not reached, the agent will send the formatted error back to
102    /// the LLM.
103    ///
104    /// The limit is hashed based on the tool call name and arguments, so the limit is per tool
105    /// call.
106    ///
107    /// This limit is _not_ reset when the agent is stopped.
108    #[builder(default = 3)]
109    pub(crate) tool_retry_limit: usize,
110
111    /// Internally tracks the amount of times a tool has been retried. The key is a hash based on
112    /// the name and args of the tool.
113    #[builder(private, default)]
114    pub(crate) tool_retries_counter: HashMap<u64, usize>,
115
116    /// Tools loaded from toolboxes
117    #[builder(private, default)]
118    pub(crate) toolbox_tools: HashSet<Box<dyn Tool>>,
119}
120
121impl std::fmt::Debug for Agent {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        f.debug_struct("Agent")
124            // display hooks as a list of type: number of hooks
125            .field(
126                "hooks",
127                &self
128                    .hooks
129                    .iter()
130                    .map(std::string::ToString::to_string)
131                    .collect::<Vec<_>>(),
132            )
133            .field(
134                "tools",
135                &self
136                    .tools
137                    .iter()
138                    .map(swiftide_core::Tool::name)
139                    .collect::<Vec<_>>(),
140            )
141            .field("llm", &"Box<dyn ChatCompletion>")
142            .field("state", &self.state)
143            .finish()
144    }
145}
146
147impl AgentBuilder {
148    /// The context in which the agent operates, by default this is the `DefaultContext`.
149    pub fn context(&mut self, context: impl AgentContext + 'static) -> &mut AgentBuilder
150    where
151        Self: Clone,
152    {
153        self.context = Some(Arc::new(context) as Arc<dyn AgentContext>);
154        self
155    }
156
157    /// Disable the system prompt.
158    pub fn no_system_prompt(&mut self) -> &mut Self {
159        self.system_prompt = Some(None);
160
161        self
162    }
163
164    /// Add a hook to the agent.
165    pub fn add_hook(&mut self, hook: Hook) -> &mut Self {
166        let hooks = self.hooks.get_or_insert_with(Vec::new);
167        hooks.push(hook);
168
169        self
170    }
171
172    /// Add a hook that runs once, before all completions. Even if the agent is paused and resumed,
173    /// before all will not trigger again.
174    pub fn before_all(&mut self, hook: impl BeforeAllFn + 'static) -> &mut Self {
175        self.add_hook(Hook::BeforeAll(Box::new(hook)))
176    }
177
178    /// Add a hook that runs once, when the agent starts. This hook also runs if the agent stopped
179    /// and then starts again. The hook runs after any `before_all` hooks and before the
180    /// `before_completion` hooks.
181    pub fn on_start(&mut self, hook: impl OnStartFn + 'static) -> &mut Self {
182        self.add_hook(Hook::OnStart(Box::new(hook)))
183    }
184
185    /// Add a hook that runs before each completion.
186    pub fn before_completion(&mut self, hook: impl BeforeCompletionFn + 'static) -> &mut Self {
187        self.add_hook(Hook::BeforeCompletion(Box::new(hook)))
188    }
189
190    /// Add a hook that runs after each tool. The `Result<ToolOutput, ToolError>` is provided
191    /// as mut, so the tool output can be fully modified.
192    ///
193    /// The `ToolOutput` also references the original `ToolCall`, allowing you to match at runtime
194    /// what tool to interact with.
195    pub fn after_tool(&mut self, hook: impl AfterToolFn + 'static) -> &mut Self {
196        self.add_hook(Hook::AfterTool(Box::new(hook)))
197    }
198
199    /// Add a hook that runs before each tool. Yields an immutable reference to the `ToolCall`.
200    pub fn before_tool(&mut self, hook: impl BeforeToolFn + 'static) -> &mut Self {
201        self.add_hook(Hook::BeforeTool(Box::new(hook)))
202    }
203
204    /// Add a hook that runs after each completion, before tool invocation and/or new messages.
205    pub fn after_completion(&mut self, hook: impl AfterCompletionFn + 'static) -> &mut Self {
206        self.add_hook(Hook::AfterCompletion(Box::new(hook)))
207    }
208
209    /// Add a hook that runs after each completion, after tool invocations, right before a new loop
210    /// might start
211    pub fn after_each(&mut self, hook: impl AfterEachFn + 'static) -> &mut Self {
212        self.add_hook(Hook::AfterEach(Box::new(hook)))
213    }
214
215    /// Add a hook that runs when a new message is added to the context. Note that each tool adds a
216    /// separate message.
217    pub fn on_new_message(&mut self, hook: impl MessageHookFn + 'static) -> &mut Self {
218        self.add_hook(Hook::OnNewMessage(Box::new(hook)))
219    }
220
221    pub fn on_stop(&mut self, hook: impl OnStopFn + 'static) -> &mut Self {
222        self.add_hook(Hook::OnStop(Box::new(hook)))
223    }
224
225    /// Set the LLM for the agent. An LLM must implement the `ChatCompletion` trait.
226    pub fn llm<LLM: ChatCompletion + Clone + 'static>(&mut self, llm: &LLM) -> &mut Self {
227        let boxed: Box<dyn ChatCompletion> = Box::new(llm.clone()) as Box<dyn ChatCompletion>;
228
229        self.llm = Some(boxed);
230        self
231    }
232
233    /// Define the available tools for the agent. Tools must implement the `Tool` trait.
234    ///
235    /// See the [tool attribute macro](`swiftide_macros::tool`) and the [tool derive
236    /// macro](`swiftide_macros::Tool`) for easy ways to create (many) tools.
237    pub fn tools<TOOL, I: IntoIterator<Item = TOOL>>(&mut self, tools: I) -> &mut Self
238    where
239        TOOL: Into<Box<dyn Tool>>,
240    {
241        self.tools = Some(
242            tools
243                .into_iter()
244                .map(Into::into)
245                .chain(Agent::default_tools())
246                .collect(),
247        );
248        self
249    }
250
251    /// Add a toolbox to the agent. Toolboxes are collections of tools that can be added to the
252    /// to the agent. Available tools are evaluated at runtime, when the agent starts for the first
253    /// time.
254    ///
255    /// Agents can have many toolboxes.
256    pub fn add_toolbox(&mut self, toolbox: impl ToolBox + 'static) -> &mut Self {
257        let toolboxes = self.toolboxes.get_or_insert_with(Vec::new);
258        toolboxes.push(Box::new(toolbox));
259
260        self
261    }
262}
263
264impl Agent {
265    /// Build a new agent
266    pub fn builder() -> AgentBuilder {
267        AgentBuilder::default()
268    }
269}
270
271impl Agent {
272    /// Default tools for the agent that it always includes
273    fn default_tools() -> HashSet<Box<dyn Tool>> {
274        HashSet::from([Box::new(Stop::default()) as Box<dyn Tool>])
275    }
276
277    /// Run the agent with a user message. The agent will loop completions, make tool calls, until
278    /// no new messages are available.
279    #[tracing::instrument(skip_all, name = "agent.query")]
280    pub async fn query(&mut self, query: impl Into<Prompt>) -> Result<(), AgentError> {
281        let query = query
282            .into()
283            .render()
284            .map_err(AgentError::FailedToRenderPrompt)?;
285        self.run_agent(Some(query), false).await
286    }
287
288    /// Run the agent with a user message once.
289    #[tracing::instrument(skip_all, name = "agent.query_once")]
290    pub async fn query_once(&mut self, query: impl Into<Prompt>) -> Result<(), AgentError> {
291        let query = query
292            .into()
293            .render()
294            .map_err(AgentError::FailedToRenderPrompt)?;
295        self.run_agent(Some(query), true).await
296    }
297
298    /// Run the agent with without user message. The agent will loop completions, make tool calls,
299    /// until no new messages are available.
300    #[tracing::instrument(skip_all, name = "agent.run")]
301    pub async fn run(&mut self) -> Result<(), AgentError> {
302        self.run_agent(None, false).await
303    }
304
305    /// Run the agent with without user message. The agent will loop completions, make tool calls,
306    /// until
307    #[tracing::instrument(skip_all, name = "agent.run_once")]
308    pub async fn run_once(&mut self) -> Result<(), AgentError> {
309        self.run_agent(None, true).await
310    }
311
312    /// Retrieve the message history of the agent
313    pub async fn history(&self) -> Vec<ChatMessage> {
314        self.context.history().await
315    }
316
317    async fn run_agent(
318        &mut self,
319        maybe_query: Option<String>,
320        just_once: bool,
321    ) -> Result<(), AgentError> {
322        if self.state.is_running() {
323            return Err(AgentError::AlreadyRunning);
324        }
325
326        if self.state.is_pending() {
327            if let Some(system_prompt) = &self.system_prompt {
328                self.context
329                    .add_messages(vec![ChatMessage::System(
330                        system_prompt
331                            .render()
332                            .map_err(AgentError::FailedToRenderSystemPrompt)?,
333                    )])
334                    .await;
335            }
336
337            invoke_hooks!(BeforeAll, self);
338
339            self.load_toolboxes().await?;
340        }
341
342        invoke_hooks!(OnStart, self);
343
344        self.state = state::State::Running;
345
346        if let Some(query) = maybe_query {
347            self.context.add_message(ChatMessage::User(query)).await;
348        }
349
350        let mut loop_counter = 0;
351
352        while let Some(messages) = self.context.next_completion().await {
353            if let Some(limit) = self.limit {
354                if loop_counter >= limit {
355                    tracing::warn!("Agent loop limit reached");
356                    break;
357                }
358            }
359            let result = self.run_completions(&messages).await;
360
361            if let Err(err) = result {
362                self.stop_with_error(&err).await;
363                tracing::error!(error = ?err, "Agent stopped with error {err}");
364                return Err(err);
365            }
366
367            if just_once || self.state.is_stopped() {
368                break;
369            }
370            loop_counter += 1;
371        }
372
373        // If there are no new messages, ensure we update our state
374        self.stop(StopReason::NoNewMessages).await;
375
376        Ok(())
377    }
378
379    #[tracing::instrument(skip_all, err)]
380    async fn run_completions(&mut self, messages: &[ChatMessage]) -> Result<(), AgentError> {
381        debug!(
382            "Running completion for agent with {} messages",
383            messages.len()
384        );
385
386        let mut chat_completion_request = ChatCompletionRequest::builder()
387            .messages(messages)
388            .tools_spec(
389                self.tools
390                    .iter()
391                    .map(swiftide_core::Tool::tool_spec)
392                    .collect::<HashSet<_>>(),
393            )
394            .build()
395            .map_err(AgentError::FailedToBuildRequest)?;
396
397        invoke_hooks!(BeforeCompletion, self, &mut chat_completion_request);
398
399        debug!(
400            "Calling LLM with the following new messages:\n {}",
401            self.context
402                .current_new_messages()
403                .await
404                .iter()
405                .map(ToString::to_string)
406                .collect::<Vec<_>>()
407                .join(",\n")
408        );
409
410        let mut response = self
411            .llm
412            .complete(&chat_completion_request)
413            .await
414            .map_err(AgentError::CompletionsFailed)?;
415
416        invoke_hooks!(AfterCompletion, self, &mut response);
417
418        self.add_message(ChatMessage::Assistant(
419            response.message,
420            response.tool_calls.clone(),
421        ))
422        .await?;
423
424        if let Some(tool_calls) = response.tool_calls {
425            self.invoke_tools(tool_calls).await?;
426        }
427
428        invoke_hooks!(AfterEach, self);
429
430        Ok(())
431    }
432
433    async fn invoke_tools(&mut self, tool_calls: Vec<ToolCall>) -> Result<(), AgentError> {
434        debug!("LLM returned tool calls: {:?}", tool_calls);
435
436        let mut handles = vec![];
437        for tool_call in tool_calls {
438            let Some(tool) = self.find_tool_by_name(tool_call.name()) else {
439                tracing::warn!("Tool {} not found", tool_call.name());
440                continue;
441            };
442            tracing::info!("Calling tool `{}`", tool_call.name());
443
444            let tool_args = tool_call.args().map(String::from);
445            let context: Arc<dyn AgentContext> = Arc::clone(&self.context);
446
447            invoke_hooks!(BeforeTool, self, &tool_call);
448
449            let tool_span = tracing::info_span!(
450                "tool",
451                "otel.name" = format!("tool.{}", tool.name().as_ref())
452            );
453
454            let handle = tokio::spawn(async move {
455                    let tool_args = ArgPreprocessor::preprocess(tool_args.as_deref());
456                    let output = tool.invoke(&*context, tool_args.as_deref()).await.map_err(|e| { tracing::error!(error = %e, "Failed tool call"); e })?;
457
458                    tracing::debug!(output = output.to_string(), args = ?tool_args, tool_name = tool.name().as_ref(), "Completed tool call");
459
460                    Ok(output)
461                }.instrument(tool_span.or_current()));
462
463            handles.push((handle, tool_call));
464        }
465
466        for (handle, tool_call) in handles {
467            let mut output = handle.await.map_err(AgentError::ToolFailedToJoin)?;
468
469            invoke_hooks!(AfterTool, self, &tool_call, &mut output);
470
471            if let Err(error) = output {
472                let stop = self.tool_calls_over_limit(&tool_call);
473                if stop {
474                    tracing::error!(
475                        ?error,
476                        "Tool call failed, retry limit reached, stopping agent: {error}",
477                    );
478                } else {
479                    tracing::warn!(
480                        ?error,
481                        tool_call = ?tool_call,
482                        "Tool call failed, retrying",
483                    );
484                }
485                self.add_message(ChatMessage::ToolOutput(
486                    tool_call.clone(),
487                    ToolOutput::Fail(error.to_string()),
488                ))
489                .await?;
490                if stop {
491                    self.stop(StopReason::ToolCallsOverLimit(tool_call)).await;
492                    return Err(error.into());
493                }
494                continue;
495            }
496
497            let output = output?;
498            self.handle_control_tools(&tool_call, &output).await;
499            self.add_message(ChatMessage::ToolOutput(tool_call, output))
500                .await?;
501        }
502
503        Ok(())
504    }
505
506    fn hooks_by_type(&self, hook_type: HookTypes) -> Vec<&Hook> {
507        self.hooks
508            .iter()
509            .filter(|h| hook_type == (*h).into())
510            .collect()
511    }
512
513    fn find_tool_by_name(&self, tool_name: &str) -> Option<Box<dyn Tool>> {
514        self.tools
515            .iter()
516            .find(|tool| tool.name() == tool_name)
517            .cloned()
518    }
519
520    // Handle any tool specific output (e.g. stop)
521    async fn handle_control_tools(&mut self, tool_call: &ToolCall, output: &ToolOutput) {
522        if let ToolOutput::Stop = output {
523            tracing::warn!("Stop tool called, stopping agent");
524            self.stop(StopReason::RequestedByTool(tool_call.clone()))
525                .await;
526        }
527    }
528
529    fn tool_calls_over_limit(&mut self, tool_call: &ToolCall) -> bool {
530        let mut s = DefaultHasher::new();
531        tool_call.hash(&mut s);
532        let hash = s.finish();
533
534        if let Some(retries) = self.tool_retries_counter.get_mut(&hash) {
535            let val = *retries >= self.tool_retry_limit;
536            *retries += 1;
537            val
538        } else {
539            self.tool_retries_counter.insert(hash, 1);
540            false
541        }
542    }
543
544    /// Add a message to the agent's context
545    ///
546    /// This will trigger a `OnNewMessage` hook if its present.
547    ///
548    /// If you want to add a message without triggering the hook, use the context directly.
549    #[tracing::instrument(skip_all, fields(message = message.to_string()))]
550    pub async fn add_message(&self, mut message: ChatMessage) -> Result<(), AgentError> {
551        invoke_hooks!(OnNewMessage, self, &mut message);
552
553        self.context.add_message(message).await;
554        Ok(())
555    }
556
557    /// Tell the agent to stop. It will finish it's current loop and then stop.
558    pub async fn stop(&mut self, reason: impl Into<StopReason>) {
559        if self.state.is_stopped() {
560            return;
561        }
562        let reason = reason.into();
563        invoke_hooks!(OnStop, self, reason.clone(), None);
564
565        self.state = state::State::Stopped(reason);
566    }
567
568    pub async fn stop_with_error(&mut self, error: &AgentError) {
569        if self.state.is_stopped() {
570            return;
571        }
572        invoke_hooks!(OnStop, self, StopReason::Error, Some(error));
573
574        self.state = state::State::Stopped(StopReason::Error);
575    }
576
577    /// Access the agent's context
578    pub fn context(&self) -> &dyn AgentContext {
579        &self.context
580    }
581
582    /// The agent is still running
583    pub fn is_running(&self) -> bool {
584        self.state.is_running()
585    }
586
587    /// The agent stopped
588    pub fn is_stopped(&self) -> bool {
589        self.state.is_stopped()
590    }
591
592    /// The agent has not (ever) started
593    pub fn is_pending(&self) -> bool {
594        self.state.is_pending()
595    }
596
597    /// Get a list of tools available to the agent
598    fn tools(&self) -> &HashSet<Box<dyn Tool>> {
599        &self.tools
600    }
601
602    async fn load_toolboxes(&mut self) -> Result<(), AgentError> {
603        for toolbox in &self.toolboxes {
604            let tools = toolbox
605                .available_tools()
606                .await
607                .map_err(AgentError::ToolBoxFailedToLoad)?;
608            self.toolbox_tools.extend(tools);
609        }
610
611        self.tools.extend(self.toolbox_tools.clone());
612
613        Ok(())
614    }
615}
616
617#[cfg(test)]
618mod tests {
619
620    use serde::ser::Error;
621    use swiftide_core::chat_completion::errors::ToolError;
622    use swiftide_core::chat_completion::{ChatCompletionResponse, ToolCall};
623    use swiftide_core::test_utils::MockChatCompletion;
624
625    use super::*;
626    use crate::{
627        assistant, chat_request, chat_response, summary, system, tool_failed, tool_output, user,
628    };
629
630    use crate::test_utils::{MockHook, MockTool};
631
632    #[test_log::test(tokio::test)]
633    async fn test_agent_builder_defaults() {
634        // Create a prompt
635        let mock_llm = MockChatCompletion::new();
636
637        // Build the agent
638        let agent = Agent::builder().llm(&mock_llm).build().unwrap();
639
640        // Check that the context is the default context
641
642        // Check that the default tools are added
643        assert!(agent.find_tool_by_name("stop").is_some());
644
645        // Check it does not allow duplicates
646        let agent = Agent::builder()
647            .tools([Stop::default(), Stop::default()])
648            .llm(&mock_llm)
649            .build()
650            .unwrap();
651
652        assert_eq!(agent.tools.len(), 1);
653
654        // It should include the default tool if a different tool is provided
655        let agent = Agent::builder()
656            .tools([MockTool::new("mock_tool")])
657            .llm(&mock_llm)
658            .build()
659            .unwrap();
660
661        assert_eq!(agent.tools.len(), 2);
662        assert!(agent.find_tool_by_name("mock_tool").is_some());
663        assert!(agent.find_tool_by_name("stop").is_some());
664
665        assert!(agent.context().history().await.is_empty());
666    }
667
668    #[test_log::test(tokio::test)]
669    async fn test_agent_tool_calling_loop() {
670        let prompt = "Write a poem";
671        let mock_llm = MockChatCompletion::new();
672        let mock_tool = MockTool::new("mock_tool");
673
674        let chat_request = chat_request! {
675            user!("Write a poem");
676
677            tools = [mock_tool.clone()]
678        };
679
680        let mock_tool_response = chat_response! {
681            "Roses are red";
682            tool_calls = ["mock_tool"]
683
684        };
685
686        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
687
688        let chat_request = chat_request! {
689            user!("Write a poem"),
690            assistant!("Roses are red", ["mock_tool"]),
691            tool_output!("mock_tool", "Great!");
692
693            tools = [mock_tool.clone()]
694        };
695
696        let stop_response = chat_response! {
697            "Roses are red";
698            tool_calls = ["stop"]
699        };
700
701        mock_llm.expect_complete(chat_request, Ok(stop_response));
702        mock_tool.expect_invoke_ok("Great!".into(), None);
703
704        let mut agent = Agent::builder()
705            .tools([mock_tool])
706            .llm(&mock_llm)
707            .no_system_prompt()
708            .build()
709            .unwrap();
710
711        agent.query(prompt).await.unwrap();
712    }
713
714    #[test_log::test(tokio::test)]
715    async fn test_agent_tool_run_once() {
716        let prompt = "Write a poem";
717        let mock_llm = MockChatCompletion::new();
718        let mock_tool = MockTool::default();
719
720        let chat_request = chat_request! {
721            system!("My system prompt"),
722            user!("Write a poem");
723
724            tools = [mock_tool.clone()]
725        };
726
727        let mock_tool_response = chat_response! {
728            "Roses are red";
729            tool_calls = ["mock_tool"]
730
731        };
732
733        mock_tool.expect_invoke_ok("Great!".into(), None);
734        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
735
736        let mut agent = Agent::builder()
737            .tools([mock_tool])
738            .system_prompt("My system prompt")
739            .llm(&mock_llm)
740            .build()
741            .unwrap();
742
743        agent.query_once(prompt).await.unwrap();
744    }
745
746    #[test_log::test(tokio::test)]
747    async fn test_agent_tool_via_toolbox_run_once() {
748        let prompt = "Write a poem";
749        let mock_llm = MockChatCompletion::new();
750        let mock_tool = MockTool::default();
751
752        let chat_request = chat_request! {
753            system!("My system prompt"),
754            user!("Write a poem");
755
756            tools = [mock_tool.clone()]
757        };
758
759        let mock_tool_response = chat_response! {
760            "Roses are red";
761            tool_calls = ["mock_tool"]
762
763        };
764
765        mock_tool.expect_invoke_ok("Great!".into(), None);
766        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
767
768        let mut agent = Agent::builder()
769            .add_toolbox(vec![mock_tool.boxed()])
770            .system_prompt("My system prompt")
771            .llm(&mock_llm)
772            .build()
773            .unwrap();
774
775        agent.query_once(prompt).await.unwrap();
776    }
777
778    #[test_log::test(tokio::test(flavor = "multi_thread"))]
779    async fn test_multiple_tool_calls() {
780        let prompt = "Write a poem";
781        let mock_llm = MockChatCompletion::new();
782        let mock_tool = MockTool::new("mock_tool1");
783        let mock_tool2 = MockTool::new("mock_tool2");
784
785        let chat_request = chat_request! {
786            system!("My system prompt"),
787            user!("Write a poem");
788
789
790
791            tools = [mock_tool.clone(), mock_tool2.clone()]
792        };
793
794        let mock_tool_response = chat_response! {
795            "Roses are red";
796
797            tool_calls = ["mock_tool1", "mock_tool2"]
798
799        };
800
801        dbg!(&chat_request);
802        mock_tool.expect_invoke_ok("Great!".into(), None);
803        mock_tool2.expect_invoke_ok("Great!".into(), None);
804        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
805
806        let chat_request = chat_request! {
807            system!("My system prompt"),
808            user!("Write a poem"),
809            assistant!("Roses are red", ["mock_tool1", "mock_tool2"]),
810            tool_output!("mock_tool1", "Great!"),
811            tool_output!("mock_tool2", "Great!");
812
813            tools = [mock_tool.clone(), mock_tool2.clone()]
814        };
815
816        let mock_tool_response = chat_response! {
817            "Ok!";
818
819            tool_calls = ["stop"]
820
821        };
822
823        mock_llm.expect_complete(chat_request, Ok(mock_tool_response));
824
825        let mut agent = Agent::builder()
826            .tools([mock_tool, mock_tool2])
827            .system_prompt("My system prompt")
828            .llm(&mock_llm)
829            .build()
830            .unwrap();
831
832        agent.query(prompt).await.unwrap();
833    }
834
835    #[test_log::test(tokio::test)]
836    async fn test_agent_state_machine() {
837        let prompt = "Write a poem";
838        let mock_llm = MockChatCompletion::new();
839
840        let chat_request = chat_request! {
841            user!("Write a poem");
842            tools = []
843        };
844        let mock_tool_response = chat_response! {
845            "Roses are red";
846            tool_calls = []
847        };
848
849        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
850        let mut agent = Agent::builder()
851            .llm(&mock_llm)
852            .no_system_prompt()
853            .build()
854            .unwrap();
855
856        // Agent has never run and is pending
857        assert!(agent.state.is_pending());
858        agent.query_once(prompt).await.unwrap();
859
860        // Agent is stopped, there might be more messages
861        assert!(agent.state.is_stopped());
862    }
863
864    #[test_log::test(tokio::test)]
865    async fn test_summary() {
866        let prompt = "Write a poem";
867        let mock_llm = MockChatCompletion::new();
868
869        let mock_tool_response = chat_response! {
870            "Roses are red";
871            tool_calls = []
872
873        };
874
875        let expected_chat_request = chat_request! {
876            system!("My system prompt"),
877            user!("Write a poem");
878
879            tools = []
880        };
881
882        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
883
884        let mut agent = Agent::builder()
885            .system_prompt("My system prompt")
886            .llm(&mock_llm)
887            .build()
888            .unwrap();
889
890        agent.query_once(prompt).await.unwrap();
891
892        agent
893            .context
894            .add_message(ChatMessage::new_summary("Summary"))
895            .await;
896
897        let expected_chat_request = chat_request! {
898            system!("My system prompt"),
899            summary!("Summary"),
900            user!("Write another poem");
901            tools = []
902        };
903        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response.clone()));
904
905        agent.query_once("Write another poem").await.unwrap();
906
907        agent
908            .context
909            .add_message(ChatMessage::new_summary("Summary 2"))
910            .await;
911
912        let expected_chat_request = chat_request! {
913            system!("My system prompt"),
914            summary!("Summary 2"),
915            user!("Write a third poem");
916            tools = []
917        };
918        mock_llm.expect_complete(expected_chat_request, Ok(mock_tool_response));
919
920        agent.query_once("Write a third poem").await.unwrap();
921    }
922
923    #[test_log::test(tokio::test)]
924    async fn test_agent_hooks() {
925        let mock_before_all = MockHook::new("before_all").expect_calls(1).to_owned();
926        let mock_on_start_fn = MockHook::new("on_start").expect_calls(1).to_owned();
927        let mock_before_completion = MockHook::new("before_completion")
928            .expect_calls(2)
929            .to_owned();
930        let mock_after_completion = MockHook::new("after_completion").expect_calls(2).to_owned();
931        let mock_after_each = MockHook::new("after_each").expect_calls(2).to_owned();
932        let mock_on_message = MockHook::new("on_message").expect_calls(4).to_owned();
933        let mock_on_stop = MockHook::new("on_stop").expect_calls(1).to_owned();
934
935        // Once for mock tool and once for stop
936        let mock_before_tool = MockHook::new("before_tool").expect_calls(2).to_owned();
937        let mock_after_tool = MockHook::new("after_tool").expect_calls(2).to_owned();
938
939        let prompt = "Write a poem";
940        let mock_llm = MockChatCompletion::new();
941        let mock_tool = MockTool::default();
942
943        let chat_request = chat_request! {
944            user!("Write a poem");
945
946            tools = [mock_tool.clone()]
947        };
948
949        let mock_tool_response = chat_response! {
950            "Roses are red";
951            tool_calls = ["mock_tool"]
952
953        };
954
955        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
956
957        let chat_request = chat_request! {
958            user!("Write a poem"),
959            assistant!("Roses are red", ["mock_tool"]),
960            tool_output!("mock_tool", "Great!");
961
962            tools = [mock_tool.clone()]
963        };
964
965        let stop_response = chat_response! {
966            "Roses are red";
967            tool_calls = ["stop"]
968        };
969
970        mock_llm.expect_complete(chat_request, Ok(stop_response));
971        mock_tool.expect_invoke_ok("Great!".into(), None);
972
973        let mut agent = Agent::builder()
974            .tools([mock_tool])
975            .llm(&mock_llm)
976            .no_system_prompt()
977            .before_all(mock_before_all.hook_fn())
978            .on_start(mock_on_start_fn.on_start_fn())
979            .before_completion(mock_before_completion.before_completion_fn())
980            .before_tool(mock_before_tool.before_tool_fn())
981            .after_completion(mock_after_completion.after_completion_fn())
982            .after_tool(mock_after_tool.after_tool_fn())
983            .after_each(mock_after_each.hook_fn())
984            .on_new_message(mock_on_message.message_hook_fn())
985            .on_stop(mock_on_stop.stop_hook_fn())
986            .build()
987            .unwrap();
988
989        agent.query(prompt).await.unwrap();
990    }
991
992    #[test_log::test(tokio::test)]
993    async fn test_agent_loop_limit() {
994        let prompt = "Generate content"; // Example prompt
995        let mock_llm = MockChatCompletion::new();
996        let mock_tool = MockTool::new("mock_tool");
997
998        let chat_request = chat_request! {
999            user!(prompt);
1000            tools = [mock_tool.clone()]
1001        };
1002        mock_tool.expect_invoke_ok("Great!".into(), None);
1003
1004        let mock_tool_response = chat_response! {
1005            "Some response";
1006            tool_calls = ["mock_tool"]
1007        };
1008
1009        // Set expectations for the mock LLM responses
1010        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response.clone()));
1011
1012        // // Response for terminating the loop
1013        let stop_response = chat_response! {
1014            "Final response";
1015            tool_calls = ["stop"]
1016        };
1017
1018        mock_llm.expect_complete(chat_request, Ok(stop_response));
1019
1020        let mut agent = Agent::builder()
1021            .tools([mock_tool])
1022            .llm(&mock_llm)
1023            .no_system_prompt()
1024            .limit(1) // Setting the loop limit to 1
1025            .build()
1026            .unwrap();
1027
1028        // Run the agent
1029        agent.query(prompt).await.unwrap();
1030
1031        // Assert that the remaining message is still in the queue
1032        let remaining = mock_llm.expectations.lock().unwrap().pop();
1033        assert!(remaining.is_some());
1034
1035        // Assert that the agent is stopped after reaching the loop limit
1036        assert!(agent.is_stopped());
1037    }
1038
1039    #[test_log::test(tokio::test)]
1040    async fn test_tool_retry_mechanism() {
1041        let prompt = "Execute my tool";
1042        let mock_llm = MockChatCompletion::new();
1043        let mock_tool = MockTool::new("retry_tool");
1044
1045        // Configure mock tool to fail twice. First time is fed back to the LLM, second time is an
1046        // error
1047        mock_tool.expect_invoke(
1048            Err(ToolError::WrongArguments(serde_json::Error::custom(
1049                "missing `query`",
1050            ))),
1051            None,
1052        );
1053        mock_tool.expect_invoke(
1054            Err(ToolError::WrongArguments(serde_json::Error::custom(
1055                "missing `query`",
1056            ))),
1057            None,
1058        );
1059
1060        let chat_request = chat_request! {
1061            user!(prompt);
1062            tools = [mock_tool.clone()]
1063        };
1064        let retry_response = chat_response! {
1065            "First failing attempt";
1066            tool_calls = ["retry_tool"]
1067        };
1068        mock_llm.expect_complete(chat_request.clone(), Ok(retry_response));
1069
1070        let chat_request = chat_request! {
1071            user!(prompt),
1072            assistant!("First failing attempt", ["retry_tool"]),
1073            tool_failed!("retry_tool", "arguments for tool failed to parse: missing `query`");
1074
1075            tools = [mock_tool.clone()]
1076        };
1077        let will_fail_response = chat_response! {
1078            "Finished execution";
1079            tool_calls = ["retry_tool"]
1080        };
1081        mock_llm.expect_complete(chat_request.clone(), Ok(will_fail_response));
1082
1083        let mut agent = Agent::builder()
1084            .tools([mock_tool])
1085            .llm(&mock_llm)
1086            .no_system_prompt()
1087            .tool_retry_limit(1) // The test relies on a limit of 2 retries.
1088            .build()
1089            .unwrap();
1090
1091        // Run the agent
1092        let result = agent.query(prompt).await;
1093
1094        assert!(result.is_err());
1095        assert!(result.unwrap_err().to_string().contains("missing `query`"));
1096        assert!(agent.is_stopped());
1097    }
1098
1099    #[test_log::test(tokio::test)]
1100    async fn test_recovering_agent_existing_history() {
1101        // First, let's run an agent
1102        let prompt = "Write a poem";
1103        let mock_llm = MockChatCompletion::new();
1104        let mock_tool = MockTool::new("mock_tool");
1105
1106        let chat_request = chat_request! {
1107            user!("Write a poem");
1108
1109            tools = [mock_tool.clone()]
1110        };
1111
1112        let mock_tool_response = chat_response! {
1113            "Roses are red";
1114            tool_calls = ["mock_tool"]
1115
1116        };
1117
1118        mock_llm.expect_complete(chat_request.clone(), Ok(mock_tool_response));
1119
1120        let chat_request = chat_request! {
1121            user!("Write a poem"),
1122            assistant!("Roses are red", ["mock_tool"]),
1123            tool_output!("mock_tool", "Great!");
1124
1125            tools = [mock_tool.clone()]
1126        };
1127
1128        let stop_response = chat_response! {
1129            "Roses are red";
1130            tool_calls = ["stop"]
1131        };
1132
1133        mock_llm.expect_complete(chat_request, Ok(stop_response));
1134        mock_tool.expect_invoke_ok("Great!".into(), None);
1135
1136        let mut agent = Agent::builder()
1137            .tools([mock_tool.clone()])
1138            .llm(&mock_llm)
1139            .no_system_prompt()
1140            .build()
1141            .unwrap();
1142
1143        agent.query(prompt).await.unwrap();
1144
1145        // Let's retrieve the history of the agent
1146        let history = agent.history().await;
1147
1148        // Store it as a string somewhere
1149        let serialized = serde_json::to_string(&history).unwrap();
1150
1151        // Retrieve it
1152        let history: Vec<ChatMessage> = serde_json::from_str(&serialized).unwrap();
1153
1154        // Build a context from the history
1155        let context = DefaultContext::default()
1156            .with_message_history(history)
1157            .to_owned();
1158
1159        let expected_chat_request = chat_request! {
1160            user!("Write a poem"),
1161            assistant!("Roses are red", ["mock_tool"]),
1162            tool_output!("mock_tool", "Great!"),
1163            assistant!("Roses are red", ["stop"]),
1164            tool_output!("stop", ToolOutput::Stop),
1165            user!("Try again!");
1166
1167            tools = [mock_tool.clone()]
1168        };
1169
1170        let stop_response = chat_response! {
1171            "Really stopping now";
1172            tool_calls = ["stop"]
1173        };
1174
1175        mock_llm.expect_complete(expected_chat_request, Ok(stop_response));
1176
1177        let mut agent = Agent::builder()
1178            .context(context)
1179            .tools([mock_tool])
1180            .llm(&mock_llm)
1181            .no_system_prompt()
1182            .build()
1183            .unwrap();
1184
1185        agent.query_once("Try again!").await.unwrap();
1186    }
1187}