1use std::{collections::HashMap, sync::Arc};
2
3use async_stream::try_stream;
4use async_trait::async_trait;
5use futures::{stream::BoxStream, StreamExt};
6
7use agentrs_core::{
8 Agent as AgentTrait, AgentError, AgentEvent, AgentOutput, CompletionRequest,
9 CompletionResponse, LlmProvider, Memory, Message, Result, ToolOutput,
10};
11use agentrs_tools::ToolRegistry;
12
13#[derive(Debug, Clone)]
15pub struct AgentConfig {
16 pub model: String,
18 pub temperature: Option<f32>,
20 pub max_tokens: Option<u32>,
22 pub loop_strategy: LoopStrategy,
24 pub max_steps: usize,
26}
27
28impl Default for AgentConfig {
29 fn default() -> Self {
30 Self {
31 model: String::new(),
32 temperature: Some(0.2),
33 max_tokens: Some(4096),
34 loop_strategy: LoopStrategy::ReAct { max_steps: 8 },
35 max_steps: 8,
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
42pub enum LoopStrategy {
43 ReAct {
45 max_steps: usize,
47 },
48 CoT,
50 PlanAndExecute {
52 max_steps: usize,
54 },
55 Custom(String),
57}
58
59impl LoopStrategy {
60 pub(crate) fn max_steps_hint(&self, fallback: usize) -> usize {
61 match self {
62 Self::ReAct { max_steps } | Self::PlanAndExecute { max_steps } => *max_steps,
63 Self::CoT | Self::Custom(_) => fallback,
64 }
65 }
66}
67
68pub struct AgentRunner<M> {
70 llm: Arc<dyn LlmProvider>,
71 memory: M,
72 tools: ToolRegistry,
73 system_prompt: Option<String>,
74 config: AgentConfig,
75}
76
77impl<M> AgentRunner<M>
78where
79 M: Memory,
80{
81 pub fn new(
83 llm: Arc<dyn LlmProvider>,
84 memory: M,
85 tools: ToolRegistry,
86 system_prompt: Option<String>,
87 config: AgentConfig,
88 ) -> Self {
89 Self {
90 llm,
91 memory,
92 tools,
93 system_prompt,
94 config,
95 }
96 }
97
98 pub async fn run(&mut self, input: &str) -> Result<AgentOutput> {
100 AgentTrait::run(self, input).await
101 }
102
103 pub async fn stream_run(&mut self, input: &str) -> Result<BoxStream<'_, Result<AgentEvent>>> {
105 AgentTrait::stream_run(self, input).await
106 }
107
108 async fn run_react(&mut self, input: &str) -> Result<AgentOutput> {
109 self.memory.store("user", Message::user(input)).await?;
110
111 let max_steps = match self.config.loop_strategy {
112 LoopStrategy::ReAct { max_steps } => max_steps,
113 _ => self.config.max_steps,
114 };
115
116 for step in 1..=max_steps {
117 let history = self.memory.history().await?;
118 let request = self.build_request(history, !self.tools.is_empty());
119 let response = self.llm.complete(request).await?;
120 let assistant_message = response.message.clone();
121 self.memory
122 .store("assistant", assistant_message.clone())
123 .await?;
124
125 if let Some(tool_calls) = assistant_message
126 .tool_calls
127 .clone()
128 .filter(|calls| !calls.is_empty())
129 {
130 for message in self.execute_tool_calls(tool_calls).await? {
131 self.memory.store("tool", message).await?;
132 }
133 continue;
134 }
135
136 return self.finish_output(response, step).await;
137 }
138
139 Err(AgentError::MaxStepsReached { steps: max_steps })
140 }
141
142 async fn run_cot(&mut self, input: &str) -> Result<AgentOutput> {
143 self.memory.store("user", Message::user(input)).await?;
144 let history = self.memory.history().await?;
145 let response = self
146 .llm
147 .complete(self.build_request(history, false))
148 .await?;
149 self.memory
150 .store("assistant", response.message.clone())
151 .await?;
152 self.finish_output(response, 1).await
153 }
154
155 async fn run_plan_execute(&mut self, input: &str, max_steps: usize) -> Result<AgentOutput> {
156 let planner_prompt =
157 format!("Create a concise numbered plan to solve the user task. Task: {input}");
158 let plan_response = self
159 .llm
160 .complete(CompletionRequest {
161 messages: vec![Message::user(planner_prompt)],
162 tools: None,
163 model: self.config.model.clone(),
164 temperature: Some(0.1),
165 max_tokens: self.config.max_tokens,
166 stream: false,
167 system: self.system_prompt.clone(),
168 extra: HashMap::new(),
169 })
170 .await?;
171
172 self.memory
173 .store(
174 "plan",
175 Message::assistant(plan_response.message.text_content()),
176 )
177 .await?;
178 let execution_prompt = format!(
179 "Use this plan to solve the task.\nPlan:\n{}\n\nTask: {input}",
180 plan_response.message.text_content()
181 );
182 self.memory
183 .store("user", Message::user(execution_prompt))
184 .await?;
185
186 let mut output = self.run_react(input).await?;
187 output.steps = output.steps.max(max_steps.min(output.steps.max(1)));
188 Ok(output)
189 }
190
191 fn build_request(&self, history: Vec<Message>, include_tools: bool) -> CompletionRequest {
192 CompletionRequest {
193 messages: history,
194 tools: include_tools.then(|| self.tools.to_definitions()),
195 model: self.config.model.clone(),
196 temperature: self.config.temperature,
197 max_tokens: self.config.max_tokens,
198 stream: false,
199 system: self.system_prompt.clone(),
200 extra: HashMap::new(),
201 }
202 }
203
204 async fn execute_tool_calls(
205 &self,
206 tool_calls: Vec<agentrs_core::ToolCall>,
207 ) -> Result<Vec<Message>> {
208 let futures = tool_calls.into_iter().map(|tool_call| {
209 let tools = self.tools.clone();
210 async move {
211 let output = match tools
212 .call(&tool_call.name, tool_call.arguments.clone())
213 .await
214 {
215 Ok(output) => output,
216 Err(error) => ToolOutput::error(error.to_string()),
217 };
218 Ok::<_, AgentError>(Message::tool_result(tool_call.id, tool_call.name, output))
219 }
220 });
221 futures::future::try_join_all(futures).await
222 }
223
224 async fn finish_output(
225 &self,
226 response: CompletionResponse,
227 steps: usize,
228 ) -> Result<AgentOutput> {
229 let history = self.memory.history().await?;
230 Ok(AgentOutput {
231 text: response.message.text_content(),
232 steps,
233 usage: response.usage,
234 messages: history,
235 metadata: HashMap::new(),
236 })
237 }
238}
239
240#[async_trait]
241impl<M> AgentTrait for AgentRunner<M>
242where
243 M: Memory,
244{
245 async fn run(&mut self, input: &str) -> Result<AgentOutput> {
246 match self.config.loop_strategy.clone() {
247 LoopStrategy::ReAct { .. } => self.run_react(input).await,
248 LoopStrategy::CoT => self.run_cot(input).await,
249 LoopStrategy::PlanAndExecute { max_steps } => {
250 self.run_plan_execute(input, max_steps).await
251 }
252 LoopStrategy::Custom(instruction) => {
253 let input = format!("{instruction}\n\nUser task: {input}");
254 self.run_cot(&input).await
255 }
256 }
257 }
258
259 async fn stream_run(&mut self, input: &str) -> Result<BoxStream<'_, Result<AgentEvent>>> {
260 let output = self.run(input).await?;
261 Ok(try_stream! {
262 yield AgentEvent::Thinking("completed".to_string());
263 for token in output.text.split_whitespace() {
264 yield AgentEvent::Token(format!("{token} "));
265 }
266 yield AgentEvent::Done(output);
267 }
268 .boxed())
269 }
270}