agent_sdk/
agent_loop.rs

1//! Agent loop orchestration module.
2//!
3//! This module contains the core agent loop that orchestrates LLM calls,
4//! tool execution, and event handling. The agent loop is the main entry point
5//! for running an AI agent.
6//!
7//! # Architecture
8//!
9//! The agent loop works as follows:
10//! 1. Receives a user message
11//! 2. Sends the message to the LLM provider
12//! 3. Processes the LLM response (text or tool calls)
13//! 4. If tool calls are present, executes them and feeds results back to LLM
14//! 5. Repeats until the LLM responds with only text (no tool calls)
15//! 6. Emits events throughout for real-time UI updates
16//!
17//! # Building an Agent
18//!
19//! Use the builder pattern via [`builder()`] or [`AgentLoopBuilder`]:
20//!
21//! ```ignore
22//! use agent_sdk::{builder, providers::AnthropicProvider};
23//!
24//! let agent = builder()
25//!     .provider(AnthropicProvider::sonnet(api_key))
26//!     .tools(my_tools)
27//!     .build();
28//! ```
29
30use crate::context::{CompactionConfig, ContextCompactor, LlmContextCompactor};
31use crate::events::AgentEvent;
32use crate::hooks::{AgentHooks, DefaultHooks, ToolDecision};
33use crate::llm::{
34    ChatOutcome, ChatRequest, ChatResponse, Content, ContentBlock, LlmProvider, Message, Role,
35    StopReason,
36};
37use crate::skills::Skill;
38use crate::stores::{InMemoryStore, MessageStore, StateStore};
39use crate::tools::{ToolContext, ToolRegistry};
40use crate::types::{AgentConfig, AgentState, RetryConfig, ThreadId, TokenUsage, ToolResult};
41use anyhow::Result;
42use std::sync::Arc;
43use std::time::{Duration, Instant};
44use tokio::sync::mpsc;
45use tokio::time::sleep;
46use tracing::{debug, error, info, warn};
47
48/// Builder for constructing an `AgentLoop`.
49///
50/// # Example
51///
52/// ```ignore
53/// let agent = AgentLoop::builder()
54///     .provider(my_provider)
55///     .tools(my_tools)
56///     .config(AgentConfig::default())
57///     .build();
58/// ```
59pub struct AgentLoopBuilder<Ctx, P, H, M, S> {
60    provider: Option<P>,
61    tools: Option<ToolRegistry<Ctx>>,
62    hooks: Option<H>,
63    message_store: Option<M>,
64    state_store: Option<S>,
65    config: Option<AgentConfig>,
66    compaction_config: Option<CompactionConfig>,
67}
68
69impl<Ctx> AgentLoopBuilder<Ctx, (), (), (), ()> {
70    /// Create a new builder with no components set.
71    #[must_use]
72    pub const fn new() -> Self {
73        Self {
74            provider: None,
75            tools: None,
76            hooks: None,
77            message_store: None,
78            state_store: None,
79            config: None,
80            compaction_config: None,
81        }
82    }
83}
84
85impl<Ctx> Default for AgentLoopBuilder<Ctx, (), (), (), ()> {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91impl<Ctx, P, H, M, S> AgentLoopBuilder<Ctx, P, H, M, S> {
92    /// Set the LLM provider.
93    #[must_use]
94    pub fn provider<P2: LlmProvider>(self, provider: P2) -> AgentLoopBuilder<Ctx, P2, H, M, S> {
95        AgentLoopBuilder {
96            provider: Some(provider),
97            tools: self.tools,
98            hooks: self.hooks,
99            message_store: self.message_store,
100            state_store: self.state_store,
101            config: self.config,
102            compaction_config: self.compaction_config,
103        }
104    }
105
106    /// Set the tool registry.
107    #[must_use]
108    pub fn tools(mut self, tools: ToolRegistry<Ctx>) -> Self {
109        self.tools = Some(tools);
110        self
111    }
112
113    /// Set the agent hooks.
114    #[must_use]
115    pub fn hooks<H2: AgentHooks>(self, hooks: H2) -> AgentLoopBuilder<Ctx, P, H2, M, S> {
116        AgentLoopBuilder {
117            provider: self.provider,
118            tools: self.tools,
119            hooks: Some(hooks),
120            message_store: self.message_store,
121            state_store: self.state_store,
122            config: self.config,
123            compaction_config: self.compaction_config,
124        }
125    }
126
127    /// Set the message store.
128    #[must_use]
129    pub fn message_store<M2: MessageStore>(
130        self,
131        message_store: M2,
132    ) -> AgentLoopBuilder<Ctx, P, H, M2, S> {
133        AgentLoopBuilder {
134            provider: self.provider,
135            tools: self.tools,
136            hooks: self.hooks,
137            message_store: Some(message_store),
138            state_store: self.state_store,
139            config: self.config,
140            compaction_config: self.compaction_config,
141        }
142    }
143
144    /// Set the state store.
145    #[must_use]
146    pub fn state_store<S2: StateStore>(
147        self,
148        state_store: S2,
149    ) -> AgentLoopBuilder<Ctx, P, H, M, S2> {
150        AgentLoopBuilder {
151            provider: self.provider,
152            tools: self.tools,
153            hooks: self.hooks,
154            message_store: self.message_store,
155            state_store: Some(state_store),
156            config: self.config,
157            compaction_config: self.compaction_config,
158        }
159    }
160
161    /// Set the agent configuration.
162    #[must_use]
163    pub fn config(mut self, config: AgentConfig) -> Self {
164        self.config = Some(config);
165        self
166    }
167
168    /// Enable context compaction with the given configuration.
169    ///
170    /// When enabled, the agent will automatically compact conversation history
171    /// when it exceeds the configured token threshold.
172    ///
173    /// # Example
174    ///
175    /// ```ignore
176    /// use agent_sdk::{builder, context::CompactionConfig};
177    ///
178    /// let agent = builder()
179    ///     .provider(my_provider)
180    ///     .with_compaction(CompactionConfig::default())
181    ///     .build();
182    /// ```
183    #[must_use]
184    pub const fn with_compaction(mut self, config: CompactionConfig) -> Self {
185        self.compaction_config = Some(config);
186        self
187    }
188
189    /// Enable context compaction with default settings.
190    ///
191    /// This is a convenience method equivalent to:
192    /// ```ignore
193    /// builder.with_compaction(CompactionConfig::default())
194    /// ```
195    #[must_use]
196    pub fn with_auto_compaction(self) -> Self {
197        self.with_compaction(CompactionConfig::default())
198    }
199
200    /// Apply a skill configuration.
201    ///
202    /// This merges the skill's system prompt with the existing configuration
203    /// and filters tools based on the skill's allowed/denied lists.
204    ///
205    /// # Example
206    ///
207    /// ```ignore
208    /// let skill = Skill::new("code-review", "You are a code reviewer...")
209    ///     .with_denied_tools(vec!["bash".into()]);
210    ///
211    /// let agent = builder()
212    ///     .provider(provider)
213    ///     .tools(tools)
214    ///     .with_skill(skill)
215    ///     .build();
216    /// ```
217    #[must_use]
218    pub fn with_skill(mut self, skill: Skill) -> Self
219    where
220        Ctx: Send + Sync + 'static,
221    {
222        // Filter tools based on skill configuration first (before moving skill)
223        if let Some(ref mut tools) = self.tools {
224            tools.filter(|name| skill.is_tool_allowed(name));
225        }
226
227        // Merge system prompt
228        let mut config = self.config.take().unwrap_or_default();
229        if config.system_prompt.is_empty() {
230            config.system_prompt = skill.system_prompt;
231        } else {
232            config.system_prompt = format!("{}\n\n{}", config.system_prompt, skill.system_prompt);
233        }
234        self.config = Some(config);
235
236        self
237    }
238}
239
240impl<Ctx, P> AgentLoopBuilder<Ctx, P, (), (), ()>
241where
242    Ctx: Send + Sync + 'static,
243    P: LlmProvider + 'static,
244{
245    /// Build the agent loop with default hooks and in-memory stores.
246    ///
247    /// This is a convenience method that uses:
248    /// - `DefaultHooks` for hooks
249    /// - `InMemoryStore` for message store
250    /// - `InMemoryStore` for state store
251    /// - `AgentConfig::default()` if no config is set
252    ///
253    /// # Panics
254    ///
255    /// Panics if a provider has not been set.
256    #[must_use]
257    pub fn build(self) -> AgentLoop<Ctx, P, DefaultHooks, InMemoryStore, InMemoryStore> {
258        let provider = self.provider.expect("provider is required");
259        let tools = self.tools.unwrap_or_default();
260        let config = self.config.unwrap_or_default();
261
262        AgentLoop {
263            provider: Arc::new(provider),
264            tools: Arc::new(tools),
265            hooks: Arc::new(DefaultHooks),
266            message_store: Arc::new(InMemoryStore::new()),
267            state_store: Arc::new(InMemoryStore::new()),
268            config,
269            compaction_config: self.compaction_config,
270        }
271    }
272}
273
274impl<Ctx, P, H, M, S> AgentLoopBuilder<Ctx, P, H, M, S>
275where
276    Ctx: Send + Sync + 'static,
277    P: LlmProvider + 'static,
278    H: AgentHooks + 'static,
279    M: MessageStore + 'static,
280    S: StateStore + 'static,
281{
282    /// Build the agent loop with all custom components.
283    ///
284    /// # Panics
285    ///
286    /// Panics if any of the following have not been set:
287    /// - `provider`
288    /// - `hooks`
289    /// - `message_store`
290    /// - `state_store`
291    #[must_use]
292    pub fn build_with_stores(self) -> AgentLoop<Ctx, P, H, M, S> {
293        let provider = self.provider.expect("provider is required");
294        let tools = self.tools.unwrap_or_default();
295        let hooks = self
296            .hooks
297            .expect("hooks is required when using build_with_stores");
298        let message_store = self
299            .message_store
300            .expect("message_store is required when using build_with_stores");
301        let state_store = self
302            .state_store
303            .expect("state_store is required when using build_with_stores");
304        let config = self.config.unwrap_or_default();
305
306        AgentLoop {
307            provider: Arc::new(provider),
308            tools: Arc::new(tools),
309            hooks: Arc::new(hooks),
310            message_store: Arc::new(message_store),
311            state_store: Arc::new(state_store),
312            config,
313            compaction_config: self.compaction_config,
314        }
315    }
316}
317
318/// The main agent loop that orchestrates LLM calls and tool execution.
319///
320/// `AgentLoop` is the core component that:
321/// - Manages conversation state via message and state stores
322/// - Calls the LLM provider and processes responses
323/// - Executes tools through the tool registry
324/// - Emits events for real-time updates via an async channel
325/// - Enforces hooks for tool permissions and lifecycle events
326///
327/// # Type Parameters
328///
329/// - `Ctx`: Application-specific context passed to tools (e.g., user ID, database)
330/// - `P`: The LLM provider implementation
331/// - `H`: The hooks implementation for lifecycle customization
332/// - `M`: The message store implementation
333/// - `S`: The state store implementation
334///
335/// # Running the Agent
336///
337/// ```ignore
338/// let mut events = agent.run(thread_id, "Hello!".to_string(), tool_ctx);
339/// while let Some(event) = events.recv().await {
340///     match event {
341///         AgentEvent::Text { text } => println!("{}", text),
342///         AgentEvent::Done { .. } => break,
343///         _ => {}
344///     }
345/// }
346/// ```
347pub struct AgentLoop<Ctx, P, H, M, S>
348where
349    P: LlmProvider,
350    H: AgentHooks,
351    M: MessageStore,
352    S: StateStore,
353{
354    provider: Arc<P>,
355    tools: Arc<ToolRegistry<Ctx>>,
356    hooks: Arc<H>,
357    message_store: Arc<M>,
358    state_store: Arc<S>,
359    config: AgentConfig,
360    compaction_config: Option<CompactionConfig>,
361}
362
363/// Create a new builder for constructing an `AgentLoop`.
364#[must_use]
365pub const fn builder<Ctx>() -> AgentLoopBuilder<Ctx, (), (), (), ()> {
366    AgentLoopBuilder::new()
367}
368
369impl<Ctx, P, H, M, S> AgentLoop<Ctx, P, H, M, S>
370where
371    Ctx: Send + Sync + 'static,
372    P: LlmProvider + 'static,
373    H: AgentHooks + 'static,
374    M: MessageStore + 'static,
375    S: StateStore + 'static,
376{
377    /// Create a new agent loop with all components specified directly.
378    #[must_use]
379    pub fn new(
380        provider: P,
381        tools: ToolRegistry<Ctx>,
382        hooks: H,
383        message_store: M,
384        state_store: S,
385        config: AgentConfig,
386    ) -> Self {
387        Self {
388            provider: Arc::new(provider),
389            tools: Arc::new(tools),
390            hooks: Arc::new(hooks),
391            message_store: Arc::new(message_store),
392            state_store: Arc::new(state_store),
393            config,
394            compaction_config: None,
395        }
396    }
397
398    /// Create a new agent loop with compaction enabled.
399    #[must_use]
400    pub fn with_compaction(
401        provider: P,
402        tools: ToolRegistry<Ctx>,
403        hooks: H,
404        message_store: M,
405        state_store: S,
406        config: AgentConfig,
407        compaction_config: CompactionConfig,
408    ) -> Self {
409        Self {
410            provider: Arc::new(provider),
411            tools: Arc::new(tools),
412            hooks: Arc::new(hooks),
413            message_store: Arc::new(message_store),
414            state_store: Arc::new(state_store),
415            config,
416            compaction_config: Some(compaction_config),
417        }
418    }
419
420    /// Run the agent loop for a single user message.
421    /// Returns a channel receiver that yields `AgentEvents`.
422    pub fn run(
423        &self,
424        thread_id: ThreadId,
425        user_message: String,
426        tool_context: ToolContext<Ctx>,
427    ) -> mpsc::Receiver<AgentEvent>
428    where
429        Ctx: Clone,
430    {
431        let (tx, rx) = mpsc::channel(100);
432
433        let provider = Arc::clone(&self.provider);
434        let tools = Arc::clone(&self.tools);
435        let hooks = Arc::clone(&self.hooks);
436        let message_store = Arc::clone(&self.message_store);
437        let state_store = Arc::clone(&self.state_store);
438        let config = self.config.clone();
439        let compaction_config = self.compaction_config.clone();
440
441        tokio::spawn(async move {
442            let result = run_loop(
443                tx.clone(),
444                thread_id,
445                user_message,
446                tool_context,
447                provider,
448                tools,
449                hooks,
450                message_store,
451                state_store,
452                config,
453                compaction_config,
454            )
455            .await;
456
457            if let Err(e) = result {
458                let _ = tx.send(AgentEvent::error(e.to_string(), false)).await;
459            }
460        });
461
462        rx
463    }
464}
465
466#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
467async fn run_loop<Ctx, P, H, M, S>(
468    tx: mpsc::Sender<AgentEvent>,
469    thread_id: ThreadId,
470    user_message: String,
471    tool_context: ToolContext<Ctx>,
472    provider: Arc<P>,
473    tools: Arc<ToolRegistry<Ctx>>,
474    hooks: Arc<H>,
475    message_store: Arc<M>,
476    state_store: Arc<S>,
477    config: AgentConfig,
478    compaction_config: Option<CompactionConfig>,
479) -> Result<()>
480where
481    Ctx: Send + Sync + Clone + 'static,
482    P: LlmProvider,
483    H: AgentHooks,
484    M: MessageStore,
485    S: StateStore,
486{
487    // Add event channel to tool context so tools can emit events
488    let tool_context = tool_context.with_event_tx(tx.clone());
489
490    let start_time = Instant::now();
491    let mut turn = 0;
492    let mut total_usage = TokenUsage::default();
493
494    // Load or create state
495    let mut state = state_store
496        .load(&thread_id)
497        .await?
498        .unwrap_or_else(|| AgentState::new(thread_id.clone()));
499
500    // Add user message to history
501    let user_msg = Message::user(&user_message);
502    message_store.append(&thread_id, user_msg).await?;
503
504    // Main agent loop
505    loop {
506        turn += 1;
507        state.turn_count = turn;
508
509        if turn > config.max_turns {
510            warn!(turn, max = config.max_turns, "Max turns reached");
511            tx.send(AgentEvent::error(
512                format!("Maximum turns ({}) reached", config.max_turns),
513                true,
514            ))
515            .await?;
516            break;
517        }
518
519        // Emit start event
520        tx.send(AgentEvent::start(thread_id.clone(), turn)).await?;
521        hooks
522            .on_event(&AgentEvent::start(thread_id.clone(), turn))
523            .await;
524
525        // Get message history
526        let mut messages = message_store.get_history(&thread_id).await?;
527
528        // Check if compaction is needed
529        if let Some(ref compact_config) = compaction_config {
530            let compactor = LlmContextCompactor::new(Arc::clone(&provider), compact_config.clone());
531            if compactor.needs_compaction(&messages) {
532                debug!(
533                    turn,
534                    message_count = messages.len(),
535                    "Context compaction triggered"
536                );
537
538                match compactor.compact_history(messages).await {
539                    Ok(result) => {
540                        // Replace history in store
541                        message_store
542                            .replace_history(&thread_id, result.messages.clone())
543                            .await?;
544
545                        // Emit compaction event
546                        tx.send(AgentEvent::context_compacted(
547                            result.original_count,
548                            result.new_count,
549                            result.original_tokens,
550                            result.new_tokens,
551                        ))
552                        .await?;
553
554                        info!(
555                            original_count = result.original_count,
556                            new_count = result.new_count,
557                            original_tokens = result.original_tokens,
558                            new_tokens = result.new_tokens,
559                            "Context compacted successfully"
560                        );
561
562                        // Use the compacted messages
563                        messages = result.messages;
564                    }
565                    Err(e) => {
566                        warn!(error = %e, "Context compaction failed, continuing with full history");
567                        // Continue with original messages on failure
568                        messages = message_store.get_history(&thread_id).await?;
569                    }
570                }
571            }
572        }
573
574        // Build chat request
575        let llm_tools = if tools.is_empty() {
576            None
577        } else {
578            Some(tools.to_llm_tools())
579        };
580
581        let request = ChatRequest {
582            system: config.system_prompt.clone(),
583            messages,
584            tools: llm_tools,
585            max_tokens: config.max_tokens,
586        };
587
588        // Call LLM with retry logic for transient errors
589        debug!(turn, "Calling LLM");
590        let max_retries = config.retry.max_retries;
591        let response = {
592            let mut attempt = 0u32;
593            loop {
594                let outcome = provider.chat(request.clone()).await?;
595                match outcome {
596                    ChatOutcome::Success(response) => break Some(response),
597                    ChatOutcome::RateLimited => {
598                        attempt += 1;
599                        if attempt > max_retries {
600                            error!("Rate limited by LLM provider after {max_retries} retries");
601                            tx.send(AgentEvent::error(
602                                format!("Rate limited after {max_retries} retries"),
603                                true,
604                            ))
605                            .await?;
606                            break None;
607                        }
608                        let delay = calculate_backoff_delay(attempt, &config.retry);
609                        warn!(
610                            attempt,
611                            delay_ms = delay.as_millis(),
612                            "Rate limited, retrying after backoff"
613                        );
614                        tx.send(AgentEvent::text(format!(
615                            "\n[Rate limited, retrying in {:.1}s... (attempt {attempt}/{max_retries})]\n",
616                            delay.as_secs_f64()
617                        )))
618                        .await?;
619                        sleep(delay).await;
620                    }
621                    ChatOutcome::InvalidRequest(msg) => {
622                        error!(msg, "Invalid request to LLM");
623                        tx.send(AgentEvent::error(format!("Invalid request: {msg}"), false))
624                            .await?;
625                        break None;
626                    }
627                    ChatOutcome::ServerError(msg) => {
628                        attempt += 1;
629                        if attempt > max_retries {
630                            error!(msg, "LLM server error after {max_retries} retries");
631                            tx.send(AgentEvent::error(
632                                format!("Server error after {max_retries} retries: {msg}"),
633                                true,
634                            ))
635                            .await?;
636                            break None;
637                        }
638                        let delay = calculate_backoff_delay(attempt, &config.retry);
639                        warn!(
640                            attempt,
641                            delay_ms = delay.as_millis(),
642                            error = msg,
643                            "Server error, retrying after backoff"
644                        );
645                        tx.send(AgentEvent::text(format!(
646                            "\n[Server error: {msg}, retrying in {:.1}s... (attempt {attempt}/{max_retries})]\n",
647                            delay.as_secs_f64()
648                        )))
649                        .await?;
650                        sleep(delay).await;
651                    }
652                }
653            }
654        };
655
656        // If we failed to get a response after retries, exit the loop
657        let Some(response) = response else {
658            break;
659        };
660
661        // Track usage
662        let turn_usage = TokenUsage {
663            input_tokens: response.usage.input_tokens,
664            output_tokens: response.usage.output_tokens,
665        };
666        total_usage.add(&turn_usage);
667        state.total_usage = total_usage.clone();
668
669        // Process response content
670        let (text_content, tool_uses) = extract_content(&response);
671
672        // Emit text if present
673        if let Some(text) = &text_content {
674            tx.send(AgentEvent::text(text.clone())).await?;
675            hooks.on_event(&AgentEvent::text(text.clone())).await;
676        }
677
678        // If no tool uses, we're done
679        if tool_uses.is_empty() {
680            info!(turn, "Agent completed (no tool use)");
681            break;
682        }
683
684        // Store assistant message with tool uses
685        let assistant_msg = build_assistant_message(&response);
686        message_store.append(&thread_id, assistant_msg).await?;
687
688        // Execute tools
689        let mut tool_results = Vec::new();
690        for (tool_id, tool_name, tool_input) in &tool_uses {
691            let Some(tool) = tools.get(tool_name) else {
692                let result = ToolResult::error(format!("Unknown tool: {tool_name}"));
693                tool_results.push((tool_id.clone(), result));
694                continue;
695            };
696
697            let tier = tool.tier();
698
699            // Emit tool call start
700            tx.send(AgentEvent::tool_call_start(
701                tool_id,
702                tool_name,
703                tool_input.clone(),
704                tier,
705            ))
706            .await?;
707
708            // Check hooks for permission
709            let decision = hooks.pre_tool_use(tool_name, tool_input, tier).await;
710
711            match decision {
712                ToolDecision::Allow => {
713                    // Execute tool
714                    let tool_start = Instant::now();
715                    let result = match tool.execute(&tool_context, tool_input.clone()).await {
716                        Ok(mut r) => {
717                            r.duration_ms = Some(millis_to_u64(tool_start.elapsed().as_millis()));
718                            r
719                        }
720                        Err(e) => ToolResult::error(format!("Tool error: {e}"))
721                            .with_duration(millis_to_u64(tool_start.elapsed().as_millis())),
722                    };
723
724                    // Post-tool hook
725                    hooks.post_tool_use(tool_name, &result).await;
726
727                    // Emit tool call end
728                    tx.send(AgentEvent::tool_call_end(
729                        tool_id,
730                        tool_name,
731                        result.clone(),
732                    ))
733                    .await?;
734
735                    tool_results.push((tool_id.clone(), result));
736                }
737                ToolDecision::Block(reason) => {
738                    let result = ToolResult::error(format!("Blocked: {reason}"));
739                    tx.send(AgentEvent::tool_call_end(
740                        tool_id,
741                        tool_name,
742                        result.clone(),
743                    ))
744                    .await?;
745                    tool_results.push((tool_id.clone(), result));
746                }
747                ToolDecision::RequiresConfirmation(description) => {
748                    tx.send(AgentEvent::ToolRequiresConfirmation {
749                        id: tool_id.clone(),
750                        name: tool_name.clone(),
751                        input: tool_input.clone(),
752                        description,
753                    })
754                    .await?;
755                    // For now, treat as blocked - caller should handle confirmation flow
756                    let result = ToolResult::error("Awaiting user confirmation");
757                    tool_results.push((tool_id.clone(), result));
758                }
759                ToolDecision::RequiresPin(description) => {
760                    tx.send(AgentEvent::ToolRequiresPin {
761                        id: tool_id.clone(),
762                        name: tool_name.clone(),
763                        input: tool_input.clone(),
764                        description,
765                    })
766                    .await?;
767                    // For now, treat as blocked - caller should handle PIN flow
768                    let result = ToolResult::error("Awaiting PIN verification");
769                    tool_results.push((tool_id.clone(), result));
770                }
771            }
772        }
773
774        // Add tool results to message history
775        for (tool_id, result) in &tool_results {
776            let tool_result_msg = Message::tool_result(tool_id, &result.output, !result.success);
777            message_store.append(&thread_id, tool_result_msg).await?;
778        }
779
780        // Emit turn complete
781        tx.send(AgentEvent::TurnComplete {
782            turn,
783            usage: turn_usage,
784        })
785        .await?;
786
787        // Check stop reason
788        if response.stop_reason == Some(StopReason::EndTurn) {
789            info!(turn, "Agent completed (end_turn)");
790            break;
791        }
792
793        // Save state checkpoint
794        state_store.save(&state).await?;
795    }
796
797    // Final state save
798    state_store.save(&state).await?;
799
800    // Emit done
801    let duration = start_time.elapsed();
802    tx.send(AgentEvent::done(thread_id, turn, total_usage, duration))
803        .await?;
804
805    Ok(())
806}
807
808/// Convert u128 milliseconds to u64, capping at `u64::MAX`
809#[allow(clippy::cast_possible_truncation)]
810const fn millis_to_u64(millis: u128) -> u64 {
811    if millis > u64::MAX as u128 {
812        u64::MAX
813    } else {
814        millis as u64
815    }
816}
817
818/// Calculate exponential backoff delay with jitter.
819///
820/// Uses exponential backoff with the formula: `base * 2^(attempt-1) + jitter`,
821/// capped at the maximum delay. Jitter (0-1000ms) helps avoid thundering herd.
822fn calculate_backoff_delay(attempt: u32, config: &RetryConfig) -> Duration {
823    // Exponential backoff: base, base*2, base*4, base*8, ...
824    let base_delay = config
825        .base_delay_ms
826        .saturating_mul(1u64 << (attempt.saturating_sub(1)));
827
828    // Add jitter (0-1000ms or 10% of base, whichever is smaller) to avoid thundering herd
829    let max_jitter = config.base_delay_ms.min(1000);
830    let jitter = if max_jitter > 0 {
831        u64::from(
832            std::time::SystemTime::now()
833                .duration_since(std::time::UNIX_EPOCH)
834                .unwrap_or_default()
835                .subsec_nanos(),
836        ) % max_jitter
837    } else {
838        0
839    };
840
841    let delay_ms = base_delay.saturating_add(jitter).min(config.max_delay_ms);
842    Duration::from_millis(delay_ms)
843}
844
845fn extract_content(
846    response: &ChatResponse,
847) -> (Option<String>, Vec<(String, String, serde_json::Value)>) {
848    let mut text_parts = Vec::new();
849    let mut tool_uses = Vec::new();
850
851    for block in &response.content {
852        match block {
853            ContentBlock::Text { text } => {
854                text_parts.push(text.clone());
855            }
856            ContentBlock::ToolUse { id, name, input } => {
857                tool_uses.push((id.clone(), name.clone(), input.clone()));
858            }
859            ContentBlock::ToolResult { .. } => {
860                // Shouldn't appear in response, but ignore if it does
861            }
862        }
863    }
864
865    let text = if text_parts.is_empty() {
866        None
867    } else {
868        Some(text_parts.join("\n"))
869    };
870
871    (text, tool_uses)
872}
873
874fn build_assistant_message(response: &ChatResponse) -> Message {
875    let mut blocks = Vec::new();
876
877    for block in &response.content {
878        match block {
879            ContentBlock::Text { text } => {
880                blocks.push(ContentBlock::Text { text: text.clone() });
881            }
882            ContentBlock::ToolUse { id, name, input } => {
883                blocks.push(ContentBlock::ToolUse {
884                    id: id.clone(),
885                    name: name.clone(),
886                    input: input.clone(),
887                });
888            }
889            ContentBlock::ToolResult { .. } => {}
890        }
891    }
892
893    Message {
894        role: Role::Assistant,
895        content: Content::Blocks(blocks),
896    }
897}
898
899#[cfg(test)]
900mod tests {
901    use super::*;
902    use crate::hooks::AllowAllHooks;
903    use crate::llm::{ChatOutcome, ChatRequest, ChatResponse, ContentBlock, StopReason, Usage};
904    use crate::stores::InMemoryStore;
905    use crate::tools::{Tool, ToolContext, ToolRegistry};
906    use crate::types::{AgentConfig, ToolResult, ToolTier};
907    use async_trait::async_trait;
908    use serde_json::json;
909    use std::sync::RwLock;
910    use std::sync::atomic::{AtomicUsize, Ordering};
911
912    // ===================
913    // Mock LLM Provider
914    // ===================
915
916    struct MockProvider {
917        responses: RwLock<Vec<ChatOutcome>>,
918        call_count: AtomicUsize,
919    }
920
921    impl MockProvider {
922        fn new(responses: Vec<ChatOutcome>) -> Self {
923            Self {
924                responses: RwLock::new(responses),
925                call_count: AtomicUsize::new(0),
926            }
927        }
928
929        fn text_response(text: &str) -> ChatOutcome {
930            ChatOutcome::Success(ChatResponse {
931                id: "msg_1".to_string(),
932                content: vec![ContentBlock::Text {
933                    text: text.to_string(),
934                }],
935                model: "mock-model".to_string(),
936                stop_reason: Some(StopReason::EndTurn),
937                usage: Usage {
938                    input_tokens: 10,
939                    output_tokens: 20,
940                },
941            })
942        }
943
944        fn tool_use_response(
945            tool_id: &str,
946            tool_name: &str,
947            input: serde_json::Value,
948        ) -> ChatOutcome {
949            ChatOutcome::Success(ChatResponse {
950                id: "msg_1".to_string(),
951                content: vec![ContentBlock::ToolUse {
952                    id: tool_id.to_string(),
953                    name: tool_name.to_string(),
954                    input,
955                }],
956                model: "mock-model".to_string(),
957                stop_reason: Some(StopReason::ToolUse),
958                usage: Usage {
959                    input_tokens: 10,
960                    output_tokens: 20,
961                },
962            })
963        }
964    }
965
966    #[async_trait]
967    impl LlmProvider for MockProvider {
968        async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
969            let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
970            let responses = self.responses.read().unwrap();
971            if idx < responses.len() {
972                Ok(responses[idx].clone())
973            } else {
974                // Default: end conversation
975                Ok(Self::text_response("Done"))
976            }
977        }
978
979        fn model(&self) -> &'static str {
980            "mock-model"
981        }
982
983        fn provider(&self) -> &'static str {
984            "mock"
985        }
986    }
987
988    // Make ChatOutcome clonable for tests
989    impl Clone for ChatOutcome {
990        fn clone(&self) -> Self {
991            match self {
992                Self::Success(r) => Self::Success(r.clone()),
993                Self::RateLimited => Self::RateLimited,
994                Self::InvalidRequest(s) => Self::InvalidRequest(s.clone()),
995                Self::ServerError(s) => Self::ServerError(s.clone()),
996            }
997        }
998    }
999
1000    // ===================
1001    // Mock Tool
1002    // ===================
1003
1004    struct EchoTool;
1005
1006    #[async_trait]
1007    impl Tool<()> for EchoTool {
1008        fn name(&self) -> &'static str {
1009            "echo"
1010        }
1011
1012        fn description(&self) -> &'static str {
1013            "Echo the input message"
1014        }
1015
1016        fn input_schema(&self) -> serde_json::Value {
1017            json!({
1018                "type": "object",
1019                "properties": {
1020                    "message": { "type": "string" }
1021                },
1022                "required": ["message"]
1023            })
1024        }
1025
1026        fn tier(&self) -> ToolTier {
1027            ToolTier::Observe
1028        }
1029
1030        async fn execute(
1031            &self,
1032            _ctx: &ToolContext<()>,
1033            input: serde_json::Value,
1034        ) -> Result<ToolResult> {
1035            let message = input
1036                .get("message")
1037                .and_then(|v| v.as_str())
1038                .unwrap_or("no message");
1039            Ok(ToolResult::success(format!("Echo: {message}")))
1040        }
1041    }
1042
1043    // ===================
1044    // Builder Tests
1045    // ===================
1046
1047    #[test]
1048    fn test_builder_creates_agent_loop() {
1049        let provider = MockProvider::new(vec![]);
1050        let agent = builder::<()>().provider(provider).build();
1051
1052        assert_eq!(agent.config.max_turns, 10);
1053        assert_eq!(agent.config.max_tokens, 4096);
1054    }
1055
1056    #[test]
1057    fn test_builder_with_custom_config() {
1058        let provider = MockProvider::new(vec![]);
1059        let config = AgentConfig {
1060            max_turns: 5,
1061            max_tokens: 2048,
1062            system_prompt: "Custom prompt".to_string(),
1063            model: "custom-model".to_string(),
1064            ..Default::default()
1065        };
1066
1067        let agent = builder::<()>().provider(provider).config(config).build();
1068
1069        assert_eq!(agent.config.max_turns, 5);
1070        assert_eq!(agent.config.max_tokens, 2048);
1071        assert_eq!(agent.config.system_prompt, "Custom prompt");
1072    }
1073
1074    #[test]
1075    fn test_builder_with_tools() {
1076        let provider = MockProvider::new(vec![]);
1077        let mut tools = ToolRegistry::new();
1078        tools.register(EchoTool);
1079
1080        let agent = builder::<()>().provider(provider).tools(tools).build();
1081
1082        assert_eq!(agent.tools.len(), 1);
1083    }
1084
1085    #[test]
1086    fn test_builder_with_custom_stores() {
1087        let provider = MockProvider::new(vec![]);
1088        let message_store = InMemoryStore::new();
1089        let state_store = InMemoryStore::new();
1090
1091        let agent = builder::<()>()
1092            .provider(provider)
1093            .hooks(AllowAllHooks)
1094            .message_store(message_store)
1095            .state_store(state_store)
1096            .build_with_stores();
1097
1098        // Just verify it builds without panicking
1099        assert_eq!(agent.config.max_turns, 10);
1100    }
1101
1102    // ===================
1103    // Run Loop Tests
1104    // ===================
1105
1106    #[tokio::test]
1107    async fn test_simple_text_response() -> anyhow::Result<()> {
1108        let provider = MockProvider::new(vec![MockProvider::text_response("Hello, user!")]);
1109
1110        let agent = builder::<()>().provider(provider).build();
1111
1112        let thread_id = ThreadId::new();
1113        let tool_ctx = ToolContext::new(());
1114        let mut rx = agent.run(thread_id, "Hi".to_string(), tool_ctx);
1115
1116        let mut events = Vec::new();
1117        while let Some(event) = rx.recv().await {
1118            events.push(event);
1119        }
1120
1121        // Should have: Start, Text, Done
1122        assert!(events.iter().any(|e| matches!(e, AgentEvent::Text { .. })));
1123        assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
1124
1125        Ok(())
1126    }
1127
1128    #[tokio::test]
1129    async fn test_tool_execution() -> anyhow::Result<()> {
1130        let provider = MockProvider::new(vec![
1131            // First call: request tool use
1132            MockProvider::tool_use_response("tool_1", "echo", json!({"message": "test"})),
1133            // Second call: respond with text
1134            MockProvider::text_response("Tool executed successfully"),
1135        ]);
1136
1137        let mut tools = ToolRegistry::new();
1138        tools.register(EchoTool);
1139
1140        let agent = builder::<()>().provider(provider).tools(tools).build();
1141
1142        let thread_id = ThreadId::new();
1143        let tool_ctx = ToolContext::new(());
1144        let mut rx = agent.run(thread_id, "Run echo".to_string(), tool_ctx);
1145
1146        let mut events = Vec::new();
1147        while let Some(event) = rx.recv().await {
1148            events.push(event);
1149        }
1150
1151        // Should have tool call events
1152        assert!(
1153            events
1154                .iter()
1155                .any(|e| matches!(e, AgentEvent::ToolCallStart { .. }))
1156        );
1157        assert!(
1158            events
1159                .iter()
1160                .any(|e| matches!(e, AgentEvent::ToolCallEnd { .. }))
1161        );
1162
1163        Ok(())
1164    }
1165
1166    #[tokio::test]
1167    async fn test_max_turns_limit() -> anyhow::Result<()> {
1168        // Provider that always requests a tool
1169        let provider = MockProvider::new(vec![
1170            MockProvider::tool_use_response("tool_1", "echo", json!({"message": "1"})),
1171            MockProvider::tool_use_response("tool_2", "echo", json!({"message": "2"})),
1172            MockProvider::tool_use_response("tool_3", "echo", json!({"message": "3"})),
1173            MockProvider::tool_use_response("tool_4", "echo", json!({"message": "4"})),
1174        ]);
1175
1176        let mut tools = ToolRegistry::new();
1177        tools.register(EchoTool);
1178
1179        let config = AgentConfig {
1180            max_turns: 2,
1181            ..Default::default()
1182        };
1183
1184        let agent = builder::<()>()
1185            .provider(provider)
1186            .tools(tools)
1187            .config(config)
1188            .build();
1189
1190        let thread_id = ThreadId::new();
1191        let tool_ctx = ToolContext::new(());
1192        let mut rx = agent.run(thread_id, "Loop".to_string(), tool_ctx);
1193
1194        let mut events = Vec::new();
1195        while let Some(event) = rx.recv().await {
1196            events.push(event);
1197        }
1198
1199        // Should have an error about max turns
1200        assert!(events.iter().any(|e| {
1201            matches!(e, AgentEvent::Error { message, .. } if message.contains("Maximum turns"))
1202        }));
1203
1204        Ok(())
1205    }
1206
1207    #[tokio::test]
1208    async fn test_unknown_tool_handling() -> anyhow::Result<()> {
1209        let provider = MockProvider::new(vec![
1210            // Request unknown tool
1211            MockProvider::tool_use_response("tool_1", "nonexistent_tool", json!({})),
1212            // LLM gets tool error and ends conversation
1213            MockProvider::text_response("I couldn't find that tool."),
1214        ]);
1215
1216        // Empty tool registry
1217        let tools = ToolRegistry::new();
1218
1219        let agent = builder::<()>().provider(provider).tools(tools).build();
1220
1221        let thread_id = ThreadId::new();
1222        let tool_ctx = ToolContext::new(());
1223        let mut rx = agent.run(thread_id, "Call unknown".to_string(), tool_ctx);
1224
1225        let mut events = Vec::new();
1226        while let Some(event) = rx.recv().await {
1227            events.push(event);
1228        }
1229
1230        // Unknown tool errors are returned to the LLM (not emitted as ToolCallEnd)
1231        // The conversation should complete successfully with a Done event
1232        assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
1233
1234        // The LLM's response about the missing tool should be in the events
1235        assert!(
1236            events.iter().any(|e| {
1237                matches!(e, AgentEvent::Text { text } if text.contains("couldn't find"))
1238            })
1239        );
1240
1241        Ok(())
1242    }
1243
1244    #[tokio::test]
1245    async fn test_rate_limit_handling() -> anyhow::Result<()> {
1246        // Provide enough RateLimited responses to exhaust all retries (max_retries + 1)
1247        let provider = MockProvider::new(vec![
1248            ChatOutcome::RateLimited,
1249            ChatOutcome::RateLimited,
1250            ChatOutcome::RateLimited,
1251            ChatOutcome::RateLimited,
1252            ChatOutcome::RateLimited,
1253            ChatOutcome::RateLimited, // 6th attempt exceeds max_retries (5)
1254        ]);
1255
1256        // Use fast retry config for faster tests
1257        let config = AgentConfig {
1258            retry: crate::types::RetryConfig::fast(),
1259            ..Default::default()
1260        };
1261
1262        let agent = builder::<()>().provider(provider).config(config).build();
1263
1264        let thread_id = ThreadId::new();
1265        let tool_ctx = ToolContext::new(());
1266        let mut rx = agent.run(thread_id, "Hi".to_string(), tool_ctx);
1267
1268        let mut events = Vec::new();
1269        while let Some(event) = rx.recv().await {
1270            events.push(event);
1271        }
1272
1273        // Should have rate limit error after exhausting retries
1274        assert!(events.iter().any(|e| {
1275            matches!(e, AgentEvent::Error { message, recoverable: true } if message.contains("Rate limited"))
1276        }));
1277
1278        // Should have retry text events
1279        assert!(
1280            events
1281                .iter()
1282                .any(|e| { matches!(e, AgentEvent::Text { text } if text.contains("retrying")) })
1283        );
1284
1285        Ok(())
1286    }
1287
1288    #[tokio::test]
1289    async fn test_rate_limit_recovery() -> anyhow::Result<()> {
1290        // Rate limited once, then succeeds
1291        let provider = MockProvider::new(vec![
1292            ChatOutcome::RateLimited,
1293            MockProvider::text_response("Recovered after rate limit"),
1294        ]);
1295
1296        // Use fast retry config for faster tests
1297        let config = AgentConfig {
1298            retry: crate::types::RetryConfig::fast(),
1299            ..Default::default()
1300        };
1301
1302        let agent = builder::<()>().provider(provider).config(config).build();
1303
1304        let thread_id = ThreadId::new();
1305        let tool_ctx = ToolContext::new(());
1306        let mut rx = agent.run(thread_id, "Hi".to_string(), tool_ctx);
1307
1308        let mut events = Vec::new();
1309        while let Some(event) = rx.recv().await {
1310            events.push(event);
1311        }
1312
1313        // Should have successful completion after retry
1314        assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
1315
1316        // Should have retry text event
1317        assert!(
1318            events
1319                .iter()
1320                .any(|e| { matches!(e, AgentEvent::Text { text } if text.contains("retrying")) })
1321        );
1322
1323        Ok(())
1324    }
1325
1326    #[tokio::test]
1327    async fn test_server_error_handling() -> anyhow::Result<()> {
1328        // Provide enough ServerError responses to exhaust all retries (max_retries + 1)
1329        let provider = MockProvider::new(vec![
1330            ChatOutcome::ServerError("Internal error".to_string()),
1331            ChatOutcome::ServerError("Internal error".to_string()),
1332            ChatOutcome::ServerError("Internal error".to_string()),
1333            ChatOutcome::ServerError("Internal error".to_string()),
1334            ChatOutcome::ServerError("Internal error".to_string()),
1335            ChatOutcome::ServerError("Internal error".to_string()), // 6th attempt exceeds max_retries
1336        ]);
1337
1338        // Use fast retry config for faster tests
1339        let config = AgentConfig {
1340            retry: crate::types::RetryConfig::fast(),
1341            ..Default::default()
1342        };
1343
1344        let agent = builder::<()>().provider(provider).config(config).build();
1345
1346        let thread_id = ThreadId::new();
1347        let tool_ctx = ToolContext::new(());
1348        let mut rx = agent.run(thread_id, "Hi".to_string(), tool_ctx);
1349
1350        let mut events = Vec::new();
1351        while let Some(event) = rx.recv().await {
1352            events.push(event);
1353        }
1354
1355        // Should have server error after exhausting retries
1356        assert!(events.iter().any(|e| {
1357            matches!(e, AgentEvent::Error { message, recoverable: true } if message.contains("Server error"))
1358        }));
1359
1360        // Should have retry text events
1361        assert!(
1362            events
1363                .iter()
1364                .any(|e| { matches!(e, AgentEvent::Text { text } if text.contains("retrying")) })
1365        );
1366
1367        Ok(())
1368    }
1369
1370    #[tokio::test]
1371    async fn test_server_error_recovery() -> anyhow::Result<()> {
1372        // Server error once, then succeeds
1373        let provider = MockProvider::new(vec![
1374            ChatOutcome::ServerError("Temporary error".to_string()),
1375            MockProvider::text_response("Recovered after server error"),
1376        ]);
1377
1378        // Use fast retry config for faster tests
1379        let config = AgentConfig {
1380            retry: crate::types::RetryConfig::fast(),
1381            ..Default::default()
1382        };
1383
1384        let agent = builder::<()>().provider(provider).config(config).build();
1385
1386        let thread_id = ThreadId::new();
1387        let tool_ctx = ToolContext::new(());
1388        let mut rx = agent.run(thread_id, "Hi".to_string(), tool_ctx);
1389
1390        let mut events = Vec::new();
1391        while let Some(event) = rx.recv().await {
1392            events.push(event);
1393        }
1394
1395        // Should have successful completion after retry
1396        assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
1397
1398        // Should have retry text event
1399        assert!(
1400            events
1401                .iter()
1402                .any(|e| { matches!(e, AgentEvent::Text { text } if text.contains("retrying")) })
1403        );
1404
1405        Ok(())
1406    }
1407
1408    // ===================
1409    // Helper Function Tests
1410    // ===================
1411
1412    #[test]
1413    fn test_extract_content_text_only() {
1414        let response = ChatResponse {
1415            id: "msg_1".to_string(),
1416            content: vec![ContentBlock::Text {
1417                text: "Hello".to_string(),
1418            }],
1419            model: "test".to_string(),
1420            stop_reason: None,
1421            usage: Usage {
1422                input_tokens: 0,
1423                output_tokens: 0,
1424            },
1425        };
1426
1427        let (text, tool_uses) = extract_content(&response);
1428        assert_eq!(text, Some("Hello".to_string()));
1429        assert!(tool_uses.is_empty());
1430    }
1431
1432    #[test]
1433    fn test_extract_content_tool_use() {
1434        let response = ChatResponse {
1435            id: "msg_1".to_string(),
1436            content: vec![ContentBlock::ToolUse {
1437                id: "tool_1".to_string(),
1438                name: "test_tool".to_string(),
1439                input: json!({"key": "value"}),
1440            }],
1441            model: "test".to_string(),
1442            stop_reason: None,
1443            usage: Usage {
1444                input_tokens: 0,
1445                output_tokens: 0,
1446            },
1447        };
1448
1449        let (text, tool_uses) = extract_content(&response);
1450        assert!(text.is_none());
1451        assert_eq!(tool_uses.len(), 1);
1452        assert_eq!(tool_uses[0].1, "test_tool");
1453    }
1454
1455    #[test]
1456    fn test_extract_content_mixed() {
1457        let response = ChatResponse {
1458            id: "msg_1".to_string(),
1459            content: vec![
1460                ContentBlock::Text {
1461                    text: "Let me help".to_string(),
1462                },
1463                ContentBlock::ToolUse {
1464                    id: "tool_1".to_string(),
1465                    name: "helper".to_string(),
1466                    input: json!({}),
1467                },
1468            ],
1469            model: "test".to_string(),
1470            stop_reason: None,
1471            usage: Usage {
1472                input_tokens: 0,
1473                output_tokens: 0,
1474            },
1475        };
1476
1477        let (text, tool_uses) = extract_content(&response);
1478        assert_eq!(text, Some("Let me help".to_string()));
1479        assert_eq!(tool_uses.len(), 1);
1480    }
1481
1482    #[test]
1483    fn test_millis_to_u64() {
1484        assert_eq!(millis_to_u64(0), 0);
1485        assert_eq!(millis_to_u64(1000), 1000);
1486        assert_eq!(millis_to_u64(u128::from(u64::MAX)), u64::MAX);
1487        assert_eq!(millis_to_u64(u128::from(u64::MAX) + 1), u64::MAX);
1488    }
1489
1490    #[test]
1491    fn test_build_assistant_message() {
1492        let response = ChatResponse {
1493            id: "msg_1".to_string(),
1494            content: vec![
1495                ContentBlock::Text {
1496                    text: "Response text".to_string(),
1497                },
1498                ContentBlock::ToolUse {
1499                    id: "tool_1".to_string(),
1500                    name: "echo".to_string(),
1501                    input: json!({"message": "test"}),
1502                },
1503            ],
1504            model: "test".to_string(),
1505            stop_reason: None,
1506            usage: Usage {
1507                input_tokens: 0,
1508                output_tokens: 0,
1509            },
1510        };
1511
1512        let msg = build_assistant_message(&response);
1513        assert_eq!(msg.role, Role::Assistant);
1514
1515        if let Content::Blocks(blocks) = msg.content {
1516            assert_eq!(blocks.len(), 2);
1517        } else {
1518            panic!("Expected Content::Blocks");
1519        }
1520    }
1521}