Skip to main content

agentrs_agents/
runner.rs

1use std::{collections::HashMap, sync::Arc};
2
3use async_stream::try_stream;
4use async_trait::async_trait;
5use futures::{stream::BoxStream, StreamExt};
6
7use agentrs_core::{
8    Agent as AgentTrait, AgentError, AgentEvent, AgentOutput, CompletionRequest,
9    CompletionResponse, LlmProvider, Memory, Message, Result, ToolOutput,
10};
11use agentrs_tools::ToolRegistry;
12
13/// Runtime agent configuration.
14#[derive(Debug, Clone)]
15pub struct AgentConfig {
16    /// Default model name.
17    pub model: String,
18    /// Sampling temperature.
19    pub temperature: Option<f32>,
20    /// Maximum output tokens.
21    pub max_tokens: Option<u32>,
22    /// Loop strategy.
23    pub loop_strategy: LoopStrategy,
24    /// Maximum ReAct iterations.
25    pub max_steps: usize,
26}
27
28impl Default for AgentConfig {
29    fn default() -> Self {
30        Self {
31            model: String::new(),
32            temperature: Some(0.2),
33            max_tokens: Some(4096),
34            loop_strategy: LoopStrategy::ReAct { max_steps: 8 },
35            max_steps: 8,
36        }
37    }
38}
39
40/// Built-in loop strategies.
41#[derive(Debug, Clone)]
42pub enum LoopStrategy {
43    /// Standard reason-and-act loop.
44    ReAct {
45        /// Maximum number of reasoning/tool iterations.
46        max_steps: usize,
47    },
48    /// Single-pass answer generation.
49    CoT,
50    /// Planner + executor loop using the same LLM.
51    PlanAndExecute {
52        /// Maximum number of execution steps to budget for planning.
53        max_steps: usize,
54    },
55    /// Custom instruction prepended before execution.
56    Custom(String),
57}
58
59impl LoopStrategy {
60    pub(crate) fn max_steps_hint(&self, fallback: usize) -> usize {
61        match self {
62            Self::ReAct { max_steps } | Self::PlanAndExecute { max_steps } => *max_steps,
63            Self::CoT | Self::Custom(_) => fallback,
64        }
65    }
66}
67
68/// Runnable agent implementation.
69pub struct AgentRunner<M> {
70    llm: Arc<dyn LlmProvider>,
71    memory: M,
72    tools: ToolRegistry,
73    system_prompt: Option<String>,
74    config: AgentConfig,
75}
76
77impl<M> AgentRunner<M>
78where
79    M: Memory,
80{
81    /// Creates a new agent runner.
82    pub fn new(
83        llm: Arc<dyn LlmProvider>,
84        memory: M,
85        tools: ToolRegistry,
86        system_prompt: Option<String>,
87        config: AgentConfig,
88    ) -> Self {
89        Self {
90            llm,
91            memory,
92            tools,
93            system_prompt,
94            config,
95        }
96    }
97
98    /// Runs the agent to completion.
99    pub async fn run(&mut self, input: &str) -> Result<AgentOutput> {
100        AgentTrait::run(self, input).await
101    }
102
103    /// Runs the agent as a stream of events.
104    pub async fn stream_run(&mut self, input: &str) -> Result<BoxStream<'_, Result<AgentEvent>>> {
105        AgentTrait::stream_run(self, input).await
106    }
107
108    async fn run_react(&mut self, input: &str) -> Result<AgentOutput> {
109        self.memory.store("user", Message::user(input)).await?;
110
111        let max_steps = match self.config.loop_strategy {
112            LoopStrategy::ReAct { max_steps } => max_steps,
113            _ => self.config.max_steps,
114        };
115
116        for step in 1..=max_steps {
117            let history = self.memory.history().await?;
118            let request = self.build_request(history, !self.tools.is_empty());
119            let response = self.llm.complete(request).await?;
120            let assistant_message = response.message.clone();
121            self.memory
122                .store("assistant", assistant_message.clone())
123                .await?;
124
125            if let Some(tool_calls) = assistant_message
126                .tool_calls
127                .clone()
128                .filter(|calls| !calls.is_empty())
129            {
130                for message in self.execute_tool_calls(tool_calls).await? {
131                    self.memory.store("tool", message).await?;
132                }
133                continue;
134            }
135
136            return self.finish_output(response, step).await;
137        }
138
139        Err(AgentError::MaxStepsReached { steps: max_steps })
140    }
141
142    async fn run_cot(&mut self, input: &str) -> Result<AgentOutput> {
143        self.memory.store("user", Message::user(input)).await?;
144        let history = self.memory.history().await?;
145        let response = self
146            .llm
147            .complete(self.build_request(history, false))
148            .await?;
149        self.memory
150            .store("assistant", response.message.clone())
151            .await?;
152        self.finish_output(response, 1).await
153    }
154
155    async fn run_plan_execute(&mut self, input: &str, max_steps: usize) -> Result<AgentOutput> {
156        let planner_prompt =
157            format!("Create a concise numbered plan to solve the user task. Task: {input}");
158        let plan_response = self
159            .llm
160            .complete(CompletionRequest {
161                messages: vec![Message::user(planner_prompt)],
162                tools: None,
163                model: self.config.model.clone(),
164                temperature: Some(0.1),
165                max_tokens: self.config.max_tokens,
166                stream: false,
167                system: self.system_prompt.clone(),
168                extra: HashMap::new(),
169            })
170            .await?;
171
172        self.memory
173            .store(
174                "plan",
175                Message::assistant(plan_response.message.text_content()),
176            )
177            .await?;
178        let execution_prompt = format!(
179            "Use this plan to solve the task.\nPlan:\n{}\n\nTask: {input}",
180            plan_response.message.text_content()
181        );
182        self.memory
183            .store("user", Message::user(execution_prompt))
184            .await?;
185
186        let mut output = self.run_react(input).await?;
187        output.steps = output.steps.max(max_steps.min(output.steps.max(1)));
188        Ok(output)
189    }
190
191    fn build_request(&self, history: Vec<Message>, include_tools: bool) -> CompletionRequest {
192        CompletionRequest {
193            messages: history,
194            tools: include_tools.then(|| self.tools.to_definitions()),
195            model: self.config.model.clone(),
196            temperature: self.config.temperature,
197            max_tokens: self.config.max_tokens,
198            stream: false,
199            system: self.system_prompt.clone(),
200            extra: HashMap::new(),
201        }
202    }
203
204    async fn execute_tool_calls(
205        &self,
206        tool_calls: Vec<agentrs_core::ToolCall>,
207    ) -> Result<Vec<Message>> {
208        let futures = tool_calls.into_iter().map(|tool_call| {
209            let tools = self.tools.clone();
210            async move {
211                let output = match tools
212                    .call(&tool_call.name, tool_call.arguments.clone())
213                    .await
214                {
215                    Ok(output) => output,
216                    Err(error) => ToolOutput::error(error.to_string()),
217                };
218                Ok::<_, AgentError>(Message::tool_result(tool_call.id, tool_call.name, output))
219            }
220        });
221        futures::future::try_join_all(futures).await
222    }
223
224    async fn finish_output(
225        &self,
226        response: CompletionResponse,
227        steps: usize,
228    ) -> Result<AgentOutput> {
229        let history = self.memory.history().await?;
230        Ok(AgentOutput {
231            text: response.message.text_content(),
232            steps,
233            usage: response.usage,
234            messages: history,
235            metadata: HashMap::new(),
236        })
237    }
238}
239
240#[async_trait]
241impl<M> AgentTrait for AgentRunner<M>
242where
243    M: Memory,
244{
245    async fn run(&mut self, input: &str) -> Result<AgentOutput> {
246        match self.config.loop_strategy.clone() {
247            LoopStrategy::ReAct { .. } => self.run_react(input).await,
248            LoopStrategy::CoT => self.run_cot(input).await,
249            LoopStrategy::PlanAndExecute { max_steps } => {
250                self.run_plan_execute(input, max_steps).await
251            }
252            LoopStrategy::Custom(instruction) => {
253                let input = format!("{instruction}\n\nUser task: {input}");
254                self.run_cot(&input).await
255            }
256        }
257    }
258
259    async fn stream_run(&mut self, input: &str) -> Result<BoxStream<'_, Result<AgentEvent>>> {
260        let output = self.run(input).await?;
261        Ok(try_stream! {
262            yield AgentEvent::Thinking("completed".to_string());
263            for token in output.text.split_whitespace() {
264                yield AgentEvent::Token(format!("{token} "));
265            }
266            yield AgentEvent::Done(output);
267        }
268        .boxed())
269    }
270}