1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Instant;
4
5use async_stream::try_stream;
6use futures_util::{Stream, StreamExt};
7use tokio::time::{Duration, sleep};
8
9use crate::error::{AgentError, ProviderError, ToolError};
10use crate::llm::{
11 ChatModel, ModelCompletion, ModelMessage, ModelToolCall, ModelToolChoice, ModelToolDefinition,
12};
13use crate::tools::{DependencyMap, ToolOutcome, ToolSpec};
14
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum AgentToolChoice {
17 Auto,
18 Required,
19 None,
20 Tool(String),
21}
22
23impl Default for AgentToolChoice {
24 fn default() -> Self {
25 Self::Auto
26 }
27}
28
29#[derive(Debug, Clone)]
30pub struct AgentConfig {
31 pub require_done_tool: bool,
32 pub max_iterations: u32,
33 pub system_prompt: Option<String>,
34 pub tool_choice: AgentToolChoice,
35 pub llm_max_retries: u32,
36 pub llm_retry_base_delay_ms: u64,
37 pub llm_retry_max_delay_ms: u64,
38 pub hidden_user_message_prompt: Option<String>,
39}
40
41impl Default for AgentConfig {
42 fn default() -> Self {
43 Self {
44 require_done_tool: false,
45 max_iterations: 24,
46 system_prompt: None,
47 tool_choice: AgentToolChoice::Auto,
48 llm_max_retries: 5,
49 llm_retry_base_delay_ms: 1_000,
50 llm_retry_max_delay_ms: 60_000,
51 hidden_user_message_prompt: None,
52 }
53 }
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum AgentRole {
58 User,
59 Assistant,
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63pub enum StepStatus {
64 Completed,
65 Error,
66}
67
68#[derive(Debug, Clone, PartialEq)]
69pub enum AgentEvent {
70 MessageStart {
71 message_id: String,
72 role: AgentRole,
73 },
74 MessageComplete {
75 message_id: String,
76 content: String,
77 },
78 HiddenUserMessage {
79 content: String,
80 },
81 StepStart {
82 step_id: String,
83 title: String,
84 step_number: u32,
85 },
86 StepComplete {
87 step_id: String,
88 status: StepStatus,
89 duration_ms: u128,
90 },
91 Thinking {
92 content: String,
93 },
94 Text {
95 content: String,
96 },
97 ToolCall {
98 tool: String,
99 args_json: serde_json::Value,
100 tool_call_id: String,
101 },
102 ToolResult {
103 tool: String,
104 result_text: String,
105 tool_call_id: String,
106 is_error: bool,
107 },
108 FinalResponse {
109 content: String,
110 },
111}
112
113pub struct AgentBuilder {
114 model: Option<Arc<dyn ChatModel>>,
115 tools: Vec<ToolSpec>,
116 config: AgentConfig,
117 dependencies: DependencyMap,
118 dependency_overrides: DependencyMap,
119}
120
121impl Default for AgentBuilder {
122 fn default() -> Self {
123 Self {
124 model: None,
125 tools: Vec::new(),
126 config: AgentConfig::default(),
127 dependencies: DependencyMap::new(),
128 dependency_overrides: DependencyMap::new(),
129 }
130 }
131}
132
133impl AgentBuilder {
134 pub fn model<M>(mut self, model: M) -> Self
135 where
136 M: ChatModel + 'static,
137 {
138 self.model = Some(Arc::new(model));
139 self
140 }
141
142 pub fn tool(mut self, tool: ToolSpec) -> Self {
143 self.tools.push(tool);
144 self
145 }
146
147 pub fn tools(mut self, tools: Vec<ToolSpec>) -> Self {
148 self.tools.extend(tools);
149 self
150 }
151
152 pub fn config(mut self, config: AgentConfig) -> Self {
153 self.config = config;
154 self
155 }
156
157 pub fn system_prompt(mut self, system_prompt: impl Into<String>) -> Self {
158 self.config.system_prompt = Some(system_prompt.into());
159 self
160 }
161
162 pub fn require_done_tool(mut self, require_done_tool: bool) -> Self {
163 self.config.require_done_tool = require_done_tool;
164 self
165 }
166
167 pub fn max_iterations(mut self, max_iterations: u32) -> Self {
168 self.config.max_iterations = max_iterations;
169 self
170 }
171
172 pub fn tool_choice(mut self, tool_choice: AgentToolChoice) -> Self {
173 self.config.tool_choice = tool_choice;
174 self
175 }
176
177 pub fn llm_retry_config(
178 mut self,
179 max_retries: u32,
180 base_delay_ms: u64,
181 max_delay_ms: u64,
182 ) -> Self {
183 self.config.llm_max_retries = max_retries;
184 self.config.llm_retry_base_delay_ms = base_delay_ms;
185 self.config.llm_retry_max_delay_ms = max_delay_ms;
186 self
187 }
188
189 pub fn hidden_user_message_prompt(mut self, prompt: impl Into<String>) -> Self {
190 self.config.hidden_user_message_prompt = Some(prompt.into());
191 self
192 }
193
194 pub fn dependency<T>(self, value: T) -> Self
195 where
196 T: Send + Sync + 'static,
197 {
198 self.dependencies.insert(value);
199 self
200 }
201
202 pub fn dependency_named<T>(self, key: impl Into<String>, value: T) -> Self
203 where
204 T: Send + Sync + 'static,
205 {
206 self.dependencies.insert_named(key, value);
207 self
208 }
209
210 pub fn dependency_override<T>(self, value: T) -> Self
211 where
212 T: Send + Sync + 'static,
213 {
214 self.dependency_overrides.insert(value);
215 self
216 }
217
218 pub fn dependency_override_named<T>(self, key: impl Into<String>, value: T) -> Self
219 where
220 T: Send + Sync + 'static,
221 {
222 self.dependency_overrides.insert_named(key, value);
223 self
224 }
225
226 pub fn build(self) -> Result<Agent, AgentError> {
227 let Some(model) = self.model else {
228 return Err(AgentError::Config(
229 "agent model must be configured via AgentBuilder::model(...)".to_string(),
230 ));
231 };
232
233 let mut tool_map = HashMap::new();
234 for tool in &self.tools {
235 if tool_map
236 .insert(tool.name().to_string(), tool.clone())
237 .is_some()
238 {
239 return Err(AgentError::Config(format!(
240 "duplicate tool registered: {}",
241 tool.name()
242 )));
243 }
244 }
245
246 Ok(Agent {
247 model,
248 tools: self.tools,
249 tool_map,
250 config: self.config,
251 dependencies: self.dependencies,
252 dependency_overrides: self.dependency_overrides,
253 history: Vec::new(),
254 next_message_id: 0,
255 })
256 }
257}
258
259pub struct Agent {
260 model: Arc<dyn ChatModel>,
261 tools: Vec<ToolSpec>,
262 tool_map: HashMap<String, ToolSpec>,
263 config: AgentConfig,
264 dependencies: DependencyMap,
265 dependency_overrides: DependencyMap,
266 history: Vec<ModelMessage>,
267 next_message_id: u64,
268}
269
270impl Agent {
271 pub fn builder() -> AgentBuilder {
272 AgentBuilder::default()
273 }
274
275 pub fn clear_history(&mut self) {
276 self.history.clear();
277 self.next_message_id = 0;
278 }
279
280 pub fn load_history(&mut self, messages: Vec<ModelMessage>) {
281 self.next_message_id = messages.len() as u64;
282 self.history = messages;
283 }
284
285 pub fn messages_len(&self) -> usize {
286 self.history.len()
287 }
288
289 pub fn messages(&self) -> &[ModelMessage] {
290 &self.history
291 }
292
293 pub async fn query(&mut self, user_message: impl Into<String>) -> Result<String, AgentError> {
294 let stream = self.query_stream(user_message);
295 futures_util::pin_mut!(stream);
296
297 let mut final_response: Option<String> = None;
298
299 while let Some(event) = stream.next().await {
300 match event? {
301 AgentEvent::FinalResponse { content } => final_response = Some(content),
302 AgentEvent::MessageStart { .. }
303 | AgentEvent::MessageComplete { .. }
304 | AgentEvent::HiddenUserMessage { .. }
305 | AgentEvent::StepStart { .. }
306 | AgentEvent::StepComplete { .. }
307 | AgentEvent::Thinking { .. }
308 | AgentEvent::Text { .. }
309 | AgentEvent::ToolCall { .. }
310 | AgentEvent::ToolResult { .. } => {}
311 }
312 }
313
314 final_response.ok_or(AgentError::MissingFinalResponse)
315 }
316
317 pub fn query_stream(
318 &mut self,
319 user_message: impl Into<String>,
320 ) -> impl Stream<Item = Result<AgentEvent, AgentError>> + '_ {
321 let user_message = user_message.into();
322
323 try_stream! {
324 if self.history.is_empty() {
325 if let Some(system_prompt) = &self.config.system_prompt {
326 self.history.push(ModelMessage::System(system_prompt.clone()));
327 }
328 }
329
330 let user_message_id = self.next_message_id(AgentRole::User);
331 yield AgentEvent::MessageStart {
332 message_id: user_message_id.clone(),
333 role: AgentRole::User,
334 };
335 self.history.push(ModelMessage::User(user_message.clone()));
336 yield AgentEvent::MessageComplete {
337 message_id: user_message_id,
338 content: user_message,
339 };
340
341 let tool_definitions = self
342 .tools
343 .iter()
344 .map(|tool| ModelToolDefinition {
345 name: tool.name().to_string(),
346 description: tool.description().to_string(),
347 parameters: tool.json_schema().clone(),
348 })
349 .collect::<Vec<_>>();
350
351 let tool_choice = self.resolve_tool_choice(!tool_definitions.is_empty());
352 let mut hidden_prompt_injected = false;
353
354 for _ in 0..self.config.max_iterations {
355 let completion = self
356 .invoke_with_retry(&tool_definitions, tool_choice.clone())
357 .await?;
358
359 let assistant_message_id = self.next_message_id(AgentRole::Assistant);
360 yield AgentEvent::MessageStart {
361 message_id: assistant_message_id.clone(),
362 role: AgentRole::Assistant,
363 };
364
365 if let Some(thinking) = completion.thinking.clone() {
366 yield AgentEvent::Thinking { content: thinking };
367 }
368
369 self.append_assistant_message(&completion);
370
371 if let Some(text) = completion.text.clone() {
372 if !text.is_empty() {
373 yield AgentEvent::Text {
374 content: text.clone(),
375 };
376 }
377 }
378
379 let assistant_content = completion.text.clone().unwrap_or_default();
380 yield AgentEvent::MessageComplete {
381 message_id: assistant_message_id,
382 content: assistant_content.clone(),
383 };
384
385 if completion.tool_calls.is_empty() {
386 if !self.config.require_done_tool {
387 if !hidden_prompt_injected {
388 if let Some(hidden_prompt) = self.config.hidden_user_message_prompt.clone() {
389 hidden_prompt_injected = true;
390 self.history.push(ModelMessage::User(hidden_prompt.clone()));
391 yield AgentEvent::HiddenUserMessage {
392 content: hidden_prompt,
393 };
394 continue;
395 }
396 }
397
398 yield AgentEvent::FinalResponse {
399 content: completion.text.unwrap_or_default(),
400 };
401 return;
402 }
403 continue;
404 }
405
406 let mut step_number = 0_u32;
407 for tool_call in completion.tool_calls {
408 step_number += 1;
409 yield AgentEvent::StepStart {
410 step_id: tool_call.id.clone(),
411 title: tool_call.name.clone(),
412 step_number,
413 };
414
415 yield AgentEvent::ToolCall {
416 tool: tool_call.name.clone(),
417 args_json: tool_call.arguments.clone(),
418 tool_call_id: tool_call.id.clone(),
419 };
420
421 let step_start = Instant::now();
422 let execution = self.execute_tool_call(&tool_call).await;
423 self.history.push(ModelMessage::ToolResult {
424 tool_call_id: tool_call.id.clone(),
425 tool_name: tool_call.name.clone(),
426 content: execution.result_text.clone(),
427 is_error: execution.is_error,
428 });
429
430 yield AgentEvent::ToolResult {
431 tool: tool_call.name.clone(),
432 result_text: execution.result_text.clone(),
433 tool_call_id: tool_call.id.clone(),
434 is_error: execution.is_error,
435 };
436
437 yield AgentEvent::StepComplete {
438 step_id: tool_call.id.clone(),
439 status: if execution.is_error {
440 StepStatus::Error
441 } else {
442 StepStatus::Completed
443 },
444 duration_ms: step_start.elapsed().as_millis(),
445 };
446
447 if let Some(done_message) = execution.done_message {
448 yield AgentEvent::FinalResponse {
449 content: done_message,
450 };
451 return;
452 }
453 }
454 }
455
456 Err::<(), AgentError>(AgentError::MaxIterationsReached {
457 max_iterations: self.config.max_iterations,
458 })?;
459 }
460 }
461
462 fn next_message_id(&mut self, role: AgentRole) -> String {
463 self.next_message_id += 1;
464 let role_label = match role {
465 AgentRole::User => "user",
466 AgentRole::Assistant => "assistant",
467 };
468 format!("msg_{}_{}", self.next_message_id, role_label)
469 }
470
471 fn resolve_tool_choice(&self, has_tools: bool) -> ModelToolChoice {
472 if !has_tools {
473 return ModelToolChoice::None;
474 }
475
476 match &self.config.tool_choice {
477 AgentToolChoice::Auto => ModelToolChoice::Auto,
478 AgentToolChoice::Required => ModelToolChoice::Required,
479 AgentToolChoice::None => ModelToolChoice::None,
480 AgentToolChoice::Tool(name) => ModelToolChoice::Tool(name.clone()),
481 }
482 }
483
484 async fn invoke_with_retry(
485 &self,
486 tool_definitions: &[ModelToolDefinition],
487 tool_choice: ModelToolChoice,
488 ) -> Result<ModelCompletion, AgentError> {
489 let max_retries = self.config.llm_max_retries.max(1);
490 for attempt in 0..max_retries {
491 match self
492 .model
493 .invoke(&self.history, tool_definitions, tool_choice.clone())
494 .await
495 {
496 Ok(completion) => return Ok(completion),
497 Err(err) => {
498 let should_retry =
499 is_retryable_provider_error(&err) && (attempt + 1) < max_retries;
500 if !should_retry {
501 return Err(AgentError::Provider(err));
502 }
503
504 let delay_ms = retry_delay_ms(
505 attempt,
506 self.config.llm_retry_base_delay_ms,
507 self.config.llm_retry_max_delay_ms,
508 );
509 sleep(Duration::from_millis(delay_ms)).await;
510 }
511 }
512 }
513
514 Err(AgentError::Config(
515 "retry loop failed unexpectedly".to_string(),
516 ))
517 }
518
519 fn append_assistant_message(&mut self, completion: &ModelCompletion) {
520 self.history.push(ModelMessage::Assistant {
521 content: completion.text.clone(),
522 tool_calls: completion.tool_calls.clone(),
523 });
524 }
525
526 async fn execute_tool_call(&self, tool_call: &ModelToolCall) -> ToolExecutionResult {
527 let Some(tool) = self.tool_map.get(&tool_call.name) else {
528 return ToolExecutionResult {
529 result_text: format!("Unknown tool '{}'.", tool_call.name),
530 is_error: true,
531 done_message: None,
532 };
533 };
534
535 let runtime_dependencies = self.dependencies.merged_with(&self.dependency_overrides);
536
537 match tool
538 .execute(tool_call.arguments.clone(), &runtime_dependencies)
539 .await
540 {
541 Ok(ToolOutcome::Text(text)) => ToolExecutionResult {
542 result_text: text,
543 is_error: false,
544 done_message: None,
545 },
546 Ok(ToolOutcome::Done(message)) => ToolExecutionResult {
547 result_text: format!("Task completed: {message}"),
548 is_error: false,
549 done_message: Some(message),
550 },
551 Err(err) => ToolExecutionResult {
552 result_text: format_tool_error(err),
553 is_error: true,
554 done_message: None,
555 },
556 }
557 }
558}
559
560fn is_retryable_provider_error(err: &ProviderError) -> bool {
561 match err {
562 ProviderError::Request(_) => true,
563 ProviderError::Response(_) => false,
564 }
565}
566
567fn retry_delay_ms(attempt: u32, base_delay_ms: u64, max_delay_ms: u64) -> u64 {
568 let mut delay = base_delay_ms;
569 for _ in 0..attempt {
570 delay = delay.saturating_mul(2);
571 }
572 delay.min(max_delay_ms)
573}
574
575fn format_tool_error(err: ToolError) -> String {
576 err.to_string()
577}
578
579struct ToolExecutionResult {
580 result_text: String,
581 is_error: bool,
582 done_message: Option<String>,
583}
584
585pub async fn query(
586 agent: &mut Agent,
587 user_message: impl Into<String>,
588) -> Result<String, AgentError> {
589 agent.query(user_message).await
590}
591
592pub fn query_stream(
593 agent: &mut Agent,
594 user_message: impl Into<String>,
595) -> impl Stream<Item = Result<AgentEvent, AgentError>> + '_ {
596 agent.query_stream(user_message)
597}
598
599#[cfg(test)]
600mod tests;