1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Instant;
4
5use async_stream::try_stream;
6use futures_util::{Stream, StreamExt};
7use tokio::time::{Duration, sleep};
8
9use crate::error::{AgentError, ProviderError, ToolError};
10use crate::llm::{
11 ChatModel, ModelCompletion, ModelMessage, ModelToolCall, ModelToolChoice, ModelToolDefinition,
12};
13use crate::tools::{DependencyMap, ToolOutcome, ToolSpec};
14
15#[derive(Debug, Clone, PartialEq, Eq)]
16pub enum AgentToolChoice {
18 Auto,
20 Required,
22 None,
24 Tool(String),
26}
27
28impl Default for AgentToolChoice {
29 fn default() -> Self {
30 Self::Auto
31 }
32}
33
34#[derive(Debug, Clone)]
35pub struct AgentConfig {
37 pub require_done_tool: bool,
39 pub max_iterations: u32,
41 pub system_prompt: Option<String>,
43 pub tool_choice: AgentToolChoice,
45 pub llm_max_retries: u32,
47 pub llm_retry_base_delay_ms: u64,
49 pub llm_retry_max_delay_ms: u64,
51 pub hidden_user_message_prompt: Option<String>,
53}
54
55impl Default for AgentConfig {
56 fn default() -> Self {
57 Self {
58 require_done_tool: false,
59 max_iterations: 24,
60 system_prompt: None,
61 tool_choice: AgentToolChoice::Auto,
62 llm_max_retries: 5,
63 llm_retry_base_delay_ms: 1_000,
64 llm_retry_max_delay_ms: 60_000,
65 hidden_user_message_prompt: None,
66 }
67 }
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum AgentRole {
73 User,
75 Assistant,
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
80pub enum StepStatus {
82 Completed,
84 Error,
86}
87
88#[derive(Debug, Clone, PartialEq)]
89pub enum AgentEvent {
91 MessageStart {
93 message_id: String,
95 role: AgentRole,
97 },
98 MessageComplete {
100 message_id: String,
102 content: String,
104 },
105 HiddenUserMessage {
107 content: String,
109 },
110 StepStart {
112 step_id: String,
114 title: String,
116 step_number: u32,
118 },
119 StepComplete {
121 step_id: String,
123 status: StepStatus,
125 duration_ms: u128,
127 },
128 Thinking {
130 content: String,
132 },
133 Text {
135 content: String,
137 },
138 ToolCall {
140 tool: String,
142 args_json: serde_json::Value,
144 tool_call_id: String,
146 },
147 ToolResult {
149 tool: String,
151 result_text: String,
153 tool_call_id: String,
155 is_error: bool,
157 },
158 FinalResponse {
160 content: String,
162 },
163}
164
165pub struct AgentBuilder {
167 model: Option<Arc<dyn ChatModel>>,
168 tools: Vec<ToolSpec>,
169 config: AgentConfig,
170 dependencies: DependencyMap,
171 dependency_overrides: DependencyMap,
172}
173
174impl Default for AgentBuilder {
175 fn default() -> Self {
176 Self {
177 model: None,
178 tools: Vec::new(),
179 config: AgentConfig::default(),
180 dependencies: DependencyMap::new(),
181 dependency_overrides: DependencyMap::new(),
182 }
183 }
184}
185
186impl AgentBuilder {
187 pub fn model<M>(mut self, model: M) -> Self
189 where
190 M: ChatModel + 'static,
191 {
192 self.model = Some(Arc::new(model));
193 self
194 }
195
196 pub fn tool(mut self, tool: ToolSpec) -> Self {
198 self.tools.push(tool);
199 self
200 }
201
202 pub fn tools(mut self, tools: Vec<ToolSpec>) -> Self {
204 self.tools.extend(tools);
205 self
206 }
207
208 pub fn config(mut self, config: AgentConfig) -> Self {
210 self.config = config;
211 self
212 }
213
214 pub fn system_prompt(mut self, system_prompt: impl Into<String>) -> Self {
216 self.config.system_prompt = Some(system_prompt.into());
217 self
218 }
219
220 pub fn require_done_tool(mut self, require_done_tool: bool) -> Self {
222 self.config.require_done_tool = require_done_tool;
223 self
224 }
225
226 pub fn max_iterations(mut self, max_iterations: u32) -> Self {
228 self.config.max_iterations = max_iterations;
229 self
230 }
231
232 pub fn tool_choice(mut self, tool_choice: AgentToolChoice) -> Self {
234 self.config.tool_choice = tool_choice;
235 self
236 }
237
238 pub fn llm_retry_config(
240 mut self,
241 max_retries: u32,
242 base_delay_ms: u64,
243 max_delay_ms: u64,
244 ) -> Self {
245 self.config.llm_max_retries = max_retries;
246 self.config.llm_retry_base_delay_ms = base_delay_ms;
247 self.config.llm_retry_max_delay_ms = max_delay_ms;
248 self
249 }
250
251 pub fn hidden_user_message_prompt(mut self, prompt: impl Into<String>) -> Self {
253 self.config.hidden_user_message_prompt = Some(prompt.into());
254 self
255 }
256
257 pub fn dependency<T>(self, value: T) -> Self
259 where
260 T: Send + Sync + 'static,
261 {
262 self.dependencies.insert(value);
263 self
264 }
265
266 pub fn dependency_named<T>(self, key: impl Into<String>, value: T) -> Self
268 where
269 T: Send + Sync + 'static,
270 {
271 self.dependencies.insert_named(key, value);
272 self
273 }
274
275 pub fn dependency_override<T>(self, value: T) -> Self
277 where
278 T: Send + Sync + 'static,
279 {
280 self.dependency_overrides.insert(value);
281 self
282 }
283
284 pub fn dependency_override_named<T>(self, key: impl Into<String>, value: T) -> Self
286 where
287 T: Send + Sync + 'static,
288 {
289 self.dependency_overrides.insert_named(key, value);
290 self
291 }
292
293 pub fn build(self) -> Result<Agent, AgentError> {
295 let Some(model) = self.model else {
296 return Err(AgentError::Config(
297 "agent model must be configured via AgentBuilder::model(...)".to_string(),
298 ));
299 };
300
301 let mut tool_map = HashMap::new();
302 for tool in &self.tools {
303 if tool_map
304 .insert(tool.name().to_string(), tool.clone())
305 .is_some()
306 {
307 return Err(AgentError::Config(format!(
308 "duplicate tool registered: {}",
309 tool.name()
310 )));
311 }
312 }
313
314 Ok(Agent {
315 model,
316 tools: self.tools,
317 tool_map,
318 config: self.config,
319 dependencies: self.dependencies,
320 dependency_overrides: self.dependency_overrides,
321 history: Vec::new(),
322 next_message_id: 0,
323 })
324 }
325}
326
327pub struct Agent {
329 model: Arc<dyn ChatModel>,
330 tools: Vec<ToolSpec>,
331 tool_map: HashMap<String, ToolSpec>,
332 config: AgentConfig,
333 dependencies: DependencyMap,
334 dependency_overrides: DependencyMap,
335 history: Vec<ModelMessage>,
336 next_message_id: u64,
337}
338
339impl Agent {
340 pub fn builder() -> AgentBuilder {
342 AgentBuilder::default()
343 }
344
345 pub fn clear_history(&mut self) {
347 self.history.clear();
348 self.next_message_id = 0;
349 }
350
351 pub fn load_history(&mut self, messages: Vec<ModelMessage>) {
353 self.next_message_id = messages.len() as u64;
354 self.history = messages;
355 }
356
357 pub fn messages_len(&self) -> usize {
359 self.history.len()
360 }
361
362 pub fn messages(&self) -> &[ModelMessage] {
364 &self.history
365 }
366
367 pub async fn query(&mut self, user_message: impl Into<String>) -> Result<String, AgentError> {
369 let stream = self.query_stream(user_message);
370 futures_util::pin_mut!(stream);
371
372 let mut final_response: Option<String> = None;
373
374 while let Some(event) = stream.next().await {
375 match event? {
376 AgentEvent::FinalResponse { content } => final_response = Some(content),
377 AgentEvent::MessageStart { .. }
378 | AgentEvent::MessageComplete { .. }
379 | AgentEvent::HiddenUserMessage { .. }
380 | AgentEvent::StepStart { .. }
381 | AgentEvent::StepComplete { .. }
382 | AgentEvent::Thinking { .. }
383 | AgentEvent::Text { .. }
384 | AgentEvent::ToolCall { .. }
385 | AgentEvent::ToolResult { .. } => {}
386 }
387 }
388
389 final_response.ok_or(AgentError::MissingFinalResponse)
390 }
391
392 pub fn query_stream(
394 &mut self,
395 user_message: impl Into<String>,
396 ) -> impl Stream<Item = Result<AgentEvent, AgentError>> + '_ {
397 let user_message = user_message.into();
398
399 try_stream! {
400 if self.history.is_empty() {
401 if let Some(system_prompt) = &self.config.system_prompt {
402 self.history.push(ModelMessage::System(system_prompt.clone()));
403 }
404 }
405
406 let user_message_id = self.next_message_id(AgentRole::User);
407 yield AgentEvent::MessageStart {
408 message_id: user_message_id.clone(),
409 role: AgentRole::User,
410 };
411 self.history.push(ModelMessage::User(user_message.clone()));
412 yield AgentEvent::MessageComplete {
413 message_id: user_message_id,
414 content: user_message,
415 };
416
417 let tool_definitions = self
418 .tools
419 .iter()
420 .map(|tool| ModelToolDefinition {
421 name: tool.name().to_string(),
422 description: tool.description().to_string(),
423 parameters: tool.json_schema().clone(),
424 })
425 .collect::<Vec<_>>();
426
427 let tool_choice = self.resolve_tool_choice(!tool_definitions.is_empty());
428 let mut hidden_prompt_injected = false;
429
430 for _ in 0..self.config.max_iterations {
431 let completion = self
432 .invoke_with_retry(&tool_definitions, tool_choice.clone())
433 .await?;
434
435 let assistant_message_id = self.next_message_id(AgentRole::Assistant);
436 yield AgentEvent::MessageStart {
437 message_id: assistant_message_id.clone(),
438 role: AgentRole::Assistant,
439 };
440
441 if let Some(thinking) = completion.thinking.clone() {
442 yield AgentEvent::Thinking { content: thinking };
443 }
444
445 self.append_assistant_message(&completion);
446
447 if let Some(text) = completion.text.clone() {
448 if !text.is_empty() {
449 yield AgentEvent::Text {
450 content: text.clone(),
451 };
452 }
453 }
454
455 let assistant_content = completion.text.clone().unwrap_or_default();
456 yield AgentEvent::MessageComplete {
457 message_id: assistant_message_id,
458 content: assistant_content.clone(),
459 };
460
461 if completion.tool_calls.is_empty() {
462 if !self.config.require_done_tool {
463 if !hidden_prompt_injected {
464 if let Some(hidden_prompt) = self.config.hidden_user_message_prompt.clone() {
465 hidden_prompt_injected = true;
466 self.history.push(ModelMessage::User(hidden_prompt.clone()));
467 yield AgentEvent::HiddenUserMessage {
468 content: hidden_prompt,
469 };
470 continue;
471 }
472 }
473
474 yield AgentEvent::FinalResponse {
475 content: completion.text.unwrap_or_default(),
476 };
477 return;
478 }
479 continue;
480 }
481
482 let mut step_number = 0_u32;
483 for tool_call in completion.tool_calls {
484 step_number += 1;
485 yield AgentEvent::StepStart {
486 step_id: tool_call.id.clone(),
487 title: tool_call.name.clone(),
488 step_number,
489 };
490
491 yield AgentEvent::ToolCall {
492 tool: tool_call.name.clone(),
493 args_json: tool_call.arguments.clone(),
494 tool_call_id: tool_call.id.clone(),
495 };
496
497 let step_start = Instant::now();
498 let execution = self.execute_tool_call(&tool_call).await;
499 self.history.push(ModelMessage::ToolResult {
500 tool_call_id: tool_call.id.clone(),
501 tool_name: tool_call.name.clone(),
502 content: execution.result_text.clone(),
503 is_error: execution.is_error,
504 });
505
506 yield AgentEvent::ToolResult {
507 tool: tool_call.name.clone(),
508 result_text: execution.result_text.clone(),
509 tool_call_id: tool_call.id.clone(),
510 is_error: execution.is_error,
511 };
512
513 yield AgentEvent::StepComplete {
514 step_id: tool_call.id.clone(),
515 status: if execution.is_error {
516 StepStatus::Error
517 } else {
518 StepStatus::Completed
519 },
520 duration_ms: step_start.elapsed().as_millis(),
521 };
522
523 if let Some(done_message) = execution.done_message {
524 yield AgentEvent::FinalResponse {
525 content: done_message,
526 };
527 return;
528 }
529 }
530 }
531
532 Err::<(), AgentError>(AgentError::MaxIterationsReached {
533 max_iterations: self.config.max_iterations,
534 })?;
535 }
536 }
537
538 fn next_message_id(&mut self, role: AgentRole) -> String {
539 self.next_message_id += 1;
540 let role_label = match role {
541 AgentRole::User => "user",
542 AgentRole::Assistant => "assistant",
543 };
544 format!("msg_{}_{}", self.next_message_id, role_label)
545 }
546
547 fn resolve_tool_choice(&self, has_tools: bool) -> ModelToolChoice {
548 if !has_tools {
549 return ModelToolChoice::None;
550 }
551
552 match &self.config.tool_choice {
553 AgentToolChoice::Auto => ModelToolChoice::Auto,
554 AgentToolChoice::Required => ModelToolChoice::Required,
555 AgentToolChoice::None => ModelToolChoice::None,
556 AgentToolChoice::Tool(name) => ModelToolChoice::Tool(name.clone()),
557 }
558 }
559
560 async fn invoke_with_retry(
561 &self,
562 tool_definitions: &[ModelToolDefinition],
563 tool_choice: ModelToolChoice,
564 ) -> Result<ModelCompletion, AgentError> {
565 let max_retries = self.config.llm_max_retries.max(1);
566 for attempt in 0..max_retries {
567 match self
568 .model
569 .invoke(&self.history, tool_definitions, tool_choice.clone())
570 .await
571 {
572 Ok(completion) => return Ok(completion),
573 Err(err) => {
574 let should_retry =
575 is_retryable_provider_error(&err) && (attempt + 1) < max_retries;
576 if !should_retry {
577 return Err(AgentError::Provider(err));
578 }
579
580 let delay_ms = retry_delay_ms(
581 attempt,
582 self.config.llm_retry_base_delay_ms,
583 self.config.llm_retry_max_delay_ms,
584 );
585 sleep(Duration::from_millis(delay_ms)).await;
586 }
587 }
588 }
589
590 Err(AgentError::Config(
591 "retry loop failed unexpectedly".to_string(),
592 ))
593 }
594
595 fn append_assistant_message(&mut self, completion: &ModelCompletion) {
596 self.history.push(ModelMessage::Assistant {
597 content: completion.text.clone(),
598 tool_calls: completion.tool_calls.clone(),
599 });
600 }
601
602 async fn execute_tool_call(&self, tool_call: &ModelToolCall) -> ToolExecutionResult {
603 let Some(tool) = self.tool_map.get(&tool_call.name) else {
604 return ToolExecutionResult {
605 result_text: format!("Unknown tool '{}'.", tool_call.name),
606 is_error: true,
607 done_message: None,
608 };
609 };
610
611 let runtime_dependencies = self.dependencies.merged_with(&self.dependency_overrides);
612
613 match tool
614 .execute(tool_call.arguments.clone(), &runtime_dependencies)
615 .await
616 {
617 Ok(ToolOutcome::Text(text)) => ToolExecutionResult {
618 result_text: text,
619 is_error: false,
620 done_message: None,
621 },
622 Ok(ToolOutcome::Done(message)) => ToolExecutionResult {
623 result_text: format!("Task completed: {message}"),
624 is_error: false,
625 done_message: Some(message),
626 },
627 Err(err) => ToolExecutionResult {
628 result_text: format_tool_error(err),
629 is_error: true,
630 done_message: None,
631 },
632 }
633 }
634}
635
636fn is_retryable_provider_error(err: &ProviderError) -> bool {
637 match err {
638 ProviderError::Request(_) => true,
639 ProviderError::Response(_) => false,
640 }
641}
642
643fn retry_delay_ms(attempt: u32, base_delay_ms: u64, max_delay_ms: u64) -> u64 {
644 let mut delay = base_delay_ms;
645 for _ in 0..attempt {
646 delay = delay.saturating_mul(2);
647 }
648 delay.min(max_delay_ms)
649}
650
651fn format_tool_error(err: ToolError) -> String {
652 err.to_string()
653}
654
655struct ToolExecutionResult {
656 result_text: String,
657 is_error: bool,
658 done_message: Option<String>,
659}
660
661pub async fn query(
663 agent: &mut Agent,
664 user_message: impl Into<String>,
665) -> Result<String, AgentError> {
666 agent.query(user_message).await
667}
668
669pub fn query_stream(
671 agent: &mut Agent,
672 user_message: impl Into<String>,
673) -> impl Stream<Item = Result<AgentEvent, AgentError>> + '_ {
674 agent.query_stream(user_message)
675}
676
677#[cfg(test)]
678mod tests;