agent_sdk/
agent_loop.rs

1//! Agent loop orchestration module.
2//!
3//! This module contains the core agent loop that orchestrates LLM calls,
4//! tool execution, and event handling. The agent loop is the main entry point
5//! for running an AI agent.
6//!
7//! # Architecture
8//!
9//! The agent loop works as follows:
10//! 1. Receives a user message
11//! 2. Sends the message to the LLM provider
12//! 3. Processes the LLM response (text or tool calls)
13//! 4. If tool calls are present, executes them and feeds results back to LLM
14//! 5. Repeats until the LLM responds with only text (no tool calls)
15//! 6. Emits events throughout for real-time UI updates
16//!
17//! # Building an Agent
18//!
19//! Use the builder pattern via [`builder()`] or [`AgentLoopBuilder`]:
20//!
21//! ```ignore
22//! use agent_sdk::{builder, providers::AnthropicProvider};
23//!
24//! let agent = builder()
25//!     .provider(AnthropicProvider::sonnet(api_key))
26//!     .tools(my_tools)
27//!     .build();
28//! ```
29
30use crate::context::{CompactionConfig, ContextCompactor, LlmContextCompactor};
31use crate::events::AgentEvent;
32use crate::hooks::{AgentHooks, DefaultHooks, ToolDecision};
33use crate::llm::{
34    ChatOutcome, ChatRequest, ChatResponse, Content, ContentBlock, LlmProvider, Message, Role,
35    StopReason,
36};
37use crate::skills::Skill;
38use crate::stores::{InMemoryStore, MessageStore, StateStore};
39use crate::tools::{ToolContext, ToolRegistry};
40use crate::types::{AgentConfig, AgentState, RetryConfig, ThreadId, TokenUsage, ToolResult};
41use anyhow::Result;
42use std::sync::Arc;
43use std::time::{Duration, Instant};
44use tokio::sync::mpsc;
45use tokio::time::sleep;
46use tracing::{debug, error, info, warn};
47
48/// Builder for constructing an `AgentLoop`.
49///
50/// # Example
51///
52/// ```ignore
53/// let agent = AgentLoop::builder()
54///     .provider(my_provider)
55///     .tools(my_tools)
56///     .config(AgentConfig::default())
57///     .build();
58/// ```
59pub struct AgentLoopBuilder<Ctx, P, H, M, S> {
60    provider: Option<P>,
61    tools: Option<ToolRegistry<Ctx>>,
62    hooks: Option<H>,
63    message_store: Option<M>,
64    state_store: Option<S>,
65    config: Option<AgentConfig>,
66    compaction_config: Option<CompactionConfig>,
67}
68
69impl<Ctx> AgentLoopBuilder<Ctx, (), (), (), ()> {
70    /// Create a new builder with no components set.
71    #[must_use]
72    pub const fn new() -> Self {
73        Self {
74            provider: None,
75            tools: None,
76            hooks: None,
77            message_store: None,
78            state_store: None,
79            config: None,
80            compaction_config: None,
81        }
82    }
83}
84
85impl<Ctx> Default for AgentLoopBuilder<Ctx, (), (), (), ()> {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91impl<Ctx, P, H, M, S> AgentLoopBuilder<Ctx, P, H, M, S> {
92    /// Set the LLM provider.
93    #[must_use]
94    pub fn provider<P2: LlmProvider>(self, provider: P2) -> AgentLoopBuilder<Ctx, P2, H, M, S> {
95        AgentLoopBuilder {
96            provider: Some(provider),
97            tools: self.tools,
98            hooks: self.hooks,
99            message_store: self.message_store,
100            state_store: self.state_store,
101            config: self.config,
102            compaction_config: self.compaction_config,
103        }
104    }
105
106    /// Set the tool registry.
107    #[must_use]
108    pub fn tools(mut self, tools: ToolRegistry<Ctx>) -> Self {
109        self.tools = Some(tools);
110        self
111    }
112
113    /// Set the agent hooks.
114    #[must_use]
115    pub fn hooks<H2: AgentHooks>(self, hooks: H2) -> AgentLoopBuilder<Ctx, P, H2, M, S> {
116        AgentLoopBuilder {
117            provider: self.provider,
118            tools: self.tools,
119            hooks: Some(hooks),
120            message_store: self.message_store,
121            state_store: self.state_store,
122            config: self.config,
123            compaction_config: self.compaction_config,
124        }
125    }
126
127    /// Set the message store.
128    #[must_use]
129    pub fn message_store<M2: MessageStore>(
130        self,
131        message_store: M2,
132    ) -> AgentLoopBuilder<Ctx, P, H, M2, S> {
133        AgentLoopBuilder {
134            provider: self.provider,
135            tools: self.tools,
136            hooks: self.hooks,
137            message_store: Some(message_store),
138            state_store: self.state_store,
139            config: self.config,
140            compaction_config: self.compaction_config,
141        }
142    }
143
144    /// Set the state store.
145    #[must_use]
146    pub fn state_store<S2: StateStore>(
147        self,
148        state_store: S2,
149    ) -> AgentLoopBuilder<Ctx, P, H, M, S2> {
150        AgentLoopBuilder {
151            provider: self.provider,
152            tools: self.tools,
153            hooks: self.hooks,
154            message_store: self.message_store,
155            state_store: Some(state_store),
156            config: self.config,
157            compaction_config: self.compaction_config,
158        }
159    }
160
161    /// Set the agent configuration.
162    #[must_use]
163    pub fn config(mut self, config: AgentConfig) -> Self {
164        self.config = Some(config);
165        self
166    }
167
168    /// Enable context compaction with the given configuration.
169    ///
170    /// When enabled, the agent will automatically compact conversation history
171    /// when it exceeds the configured token threshold.
172    ///
173    /// # Example
174    ///
175    /// ```ignore
176    /// use agent_sdk::{builder, context::CompactionConfig};
177    ///
178    /// let agent = builder()
179    ///     .provider(my_provider)
180    ///     .with_compaction(CompactionConfig::default())
181    ///     .build();
182    /// ```
183    #[must_use]
184    pub const fn with_compaction(mut self, config: CompactionConfig) -> Self {
185        self.compaction_config = Some(config);
186        self
187    }
188
189    /// Enable context compaction with default settings.
190    ///
191    /// This is a convenience method equivalent to:
192    /// ```ignore
193    /// builder.with_compaction(CompactionConfig::default())
194    /// ```
195    #[must_use]
196    pub fn with_auto_compaction(self) -> Self {
197        self.with_compaction(CompactionConfig::default())
198    }
199
200    /// Apply a skill configuration.
201    ///
202    /// This merges the skill's system prompt with the existing configuration
203    /// and filters tools based on the skill's allowed/denied lists.
204    ///
205    /// # Example
206    ///
207    /// ```ignore
208    /// let skill = Skill::new("code-review", "You are a code reviewer...")
209    ///     .with_denied_tools(vec!["bash".into()]);
210    ///
211    /// let agent = builder()
212    ///     .provider(provider)
213    ///     .tools(tools)
214    ///     .with_skill(skill)
215    ///     .build();
216    /// ```
217    #[must_use]
218    pub fn with_skill(mut self, skill: Skill) -> Self
219    where
220        Ctx: Send + Sync + 'static,
221    {
222        // Filter tools based on skill configuration first (before moving skill)
223        if let Some(ref mut tools) = self.tools {
224            tools.filter(|name| skill.is_tool_allowed(name));
225        }
226
227        // Merge system prompt
228        let mut config = self.config.take().unwrap_or_default();
229        if config.system_prompt.is_empty() {
230            config.system_prompt = skill.system_prompt;
231        } else {
232            config.system_prompt = format!("{}\n\n{}", config.system_prompt, skill.system_prompt);
233        }
234        self.config = Some(config);
235
236        self
237    }
238}
239
240impl<Ctx, P> AgentLoopBuilder<Ctx, P, (), (), ()>
241where
242    Ctx: Send + Sync + 'static,
243    P: LlmProvider + 'static,
244{
245    /// Build the agent loop with default hooks and in-memory stores.
246    ///
247    /// This is a convenience method that uses:
248    /// - `DefaultHooks` for hooks
249    /// - `InMemoryStore` for message store
250    /// - `InMemoryStore` for state store
251    /// - `AgentConfig::default()` if no config is set
252    ///
253    /// # Panics
254    ///
255    /// Panics if a provider has not been set.
256    #[must_use]
257    pub fn build(self) -> AgentLoop<Ctx, P, DefaultHooks, InMemoryStore, InMemoryStore> {
258        let provider = self.provider.expect("provider is required");
259        let tools = self.tools.unwrap_or_default();
260        let config = self.config.unwrap_or_default();
261
262        AgentLoop {
263            provider: Arc::new(provider),
264            tools: Arc::new(tools),
265            hooks: Arc::new(DefaultHooks),
266            message_store: Arc::new(InMemoryStore::new()),
267            state_store: Arc::new(InMemoryStore::new()),
268            config,
269            compaction_config: self.compaction_config,
270        }
271    }
272}
273
274impl<Ctx, P, H, M, S> AgentLoopBuilder<Ctx, P, H, M, S>
275where
276    Ctx: Send + Sync + 'static,
277    P: LlmProvider + 'static,
278    H: AgentHooks + 'static,
279    M: MessageStore + 'static,
280    S: StateStore + 'static,
281{
282    /// Build the agent loop with all custom components.
283    ///
284    /// # Panics
285    ///
286    /// Panics if any of the following have not been set:
287    /// - `provider`
288    /// - `hooks`
289    /// - `message_store`
290    /// - `state_store`
291    #[must_use]
292    pub fn build_with_stores(self) -> AgentLoop<Ctx, P, H, M, S> {
293        let provider = self.provider.expect("provider is required");
294        let tools = self.tools.unwrap_or_default();
295        let hooks = self
296            .hooks
297            .expect("hooks is required when using build_with_stores");
298        let message_store = self
299            .message_store
300            .expect("message_store is required when using build_with_stores");
301        let state_store = self
302            .state_store
303            .expect("state_store is required when using build_with_stores");
304        let config = self.config.unwrap_or_default();
305
306        AgentLoop {
307            provider: Arc::new(provider),
308            tools: Arc::new(tools),
309            hooks: Arc::new(hooks),
310            message_store: Arc::new(message_store),
311            state_store: Arc::new(state_store),
312            config,
313            compaction_config: self.compaction_config,
314        }
315    }
316}
317
318/// The main agent loop that orchestrates LLM calls and tool execution.
319///
320/// `AgentLoop` is the core component that:
321/// - Manages conversation state via message and state stores
322/// - Calls the LLM provider and processes responses
323/// - Executes tools through the tool registry
324/// - Emits events for real-time updates via an async channel
325/// - Enforces hooks for tool permissions and lifecycle events
326///
327/// # Type Parameters
328///
329/// - `Ctx`: Application-specific context passed to tools (e.g., user ID, database)
330/// - `P`: The LLM provider implementation
331/// - `H`: The hooks implementation for lifecycle customization
332/// - `M`: The message store implementation
333/// - `S`: The state store implementation
334///
335/// # Running the Agent
336///
337/// ```ignore
338/// let mut events = agent.run(thread_id, "Hello!".to_string(), tool_ctx);
339/// while let Some(event) = events.recv().await {
340///     match event {
341///         AgentEvent::Text { text } => println!("{}", text),
342///         AgentEvent::Done { .. } => break,
343///         _ => {}
344///     }
345/// }
346/// ```
347pub struct AgentLoop<Ctx, P, H, M, S>
348where
349    P: LlmProvider,
350    H: AgentHooks,
351    M: MessageStore,
352    S: StateStore,
353{
354    provider: Arc<P>,
355    tools: Arc<ToolRegistry<Ctx>>,
356    hooks: Arc<H>,
357    message_store: Arc<M>,
358    state_store: Arc<S>,
359    config: AgentConfig,
360    compaction_config: Option<CompactionConfig>,
361}
362
363/// Create a new builder for constructing an `AgentLoop`.
364#[must_use]
365pub const fn builder<Ctx>() -> AgentLoopBuilder<Ctx, (), (), (), ()> {
366    AgentLoopBuilder::new()
367}
368
369impl<Ctx, P, H, M, S> AgentLoop<Ctx, P, H, M, S>
370where
371    Ctx: Send + Sync + 'static,
372    P: LlmProvider + 'static,
373    H: AgentHooks + 'static,
374    M: MessageStore + 'static,
375    S: StateStore + 'static,
376{
377    /// Create a new agent loop with all components specified directly.
378    #[must_use]
379    pub fn new(
380        provider: P,
381        tools: ToolRegistry<Ctx>,
382        hooks: H,
383        message_store: M,
384        state_store: S,
385        config: AgentConfig,
386    ) -> Self {
387        Self {
388            provider: Arc::new(provider),
389            tools: Arc::new(tools),
390            hooks: Arc::new(hooks),
391            message_store: Arc::new(message_store),
392            state_store: Arc::new(state_store),
393            config,
394            compaction_config: None,
395        }
396    }
397
398    /// Create a new agent loop with compaction enabled.
399    #[must_use]
400    pub fn with_compaction(
401        provider: P,
402        tools: ToolRegistry<Ctx>,
403        hooks: H,
404        message_store: M,
405        state_store: S,
406        config: AgentConfig,
407        compaction_config: CompactionConfig,
408    ) -> Self {
409        Self {
410            provider: Arc::new(provider),
411            tools: Arc::new(tools),
412            hooks: Arc::new(hooks),
413            message_store: Arc::new(message_store),
414            state_store: Arc::new(state_store),
415            config,
416            compaction_config: Some(compaction_config),
417        }
418    }
419
420    /// Run the agent loop for a single user message.
421    /// Returns a channel receiver that yields `AgentEvents`.
422    pub fn run(
423        &self,
424        thread_id: ThreadId,
425        user_message: String,
426        tool_context: ToolContext<Ctx>,
427    ) -> mpsc::Receiver<AgentEvent>
428    where
429        Ctx: Clone,
430    {
431        let (tx, rx) = mpsc::channel(100);
432
433        let provider = Arc::clone(&self.provider);
434        let tools = Arc::clone(&self.tools);
435        let hooks = Arc::clone(&self.hooks);
436        let message_store = Arc::clone(&self.message_store);
437        let state_store = Arc::clone(&self.state_store);
438        let config = self.config.clone();
439        let compaction_config = self.compaction_config.clone();
440
441        tokio::spawn(async move {
442            let result = run_loop(
443                tx.clone(),
444                thread_id,
445                user_message,
446                tool_context,
447                provider,
448                tools,
449                hooks,
450                message_store,
451                state_store,
452                config,
453                compaction_config,
454            )
455            .await;
456
457            if let Err(e) = result {
458                let _ = tx.send(AgentEvent::error(e.to_string(), false)).await;
459            }
460        });
461
462        rx
463    }
464}
465
466#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
467async fn run_loop<Ctx, P, H, M, S>(
468    tx: mpsc::Sender<AgentEvent>,
469    thread_id: ThreadId,
470    user_message: String,
471    tool_context: ToolContext<Ctx>,
472    provider: Arc<P>,
473    tools: Arc<ToolRegistry<Ctx>>,
474    hooks: Arc<H>,
475    message_store: Arc<M>,
476    state_store: Arc<S>,
477    config: AgentConfig,
478    compaction_config: Option<CompactionConfig>,
479) -> Result<()>
480where
481    Ctx: Send + Sync + Clone + 'static,
482    P: LlmProvider,
483    H: AgentHooks,
484    M: MessageStore,
485    S: StateStore,
486{
487    // Add event channel to tool context so tools can emit events
488    let tool_context = tool_context.with_event_tx(tx.clone());
489
490    let start_time = Instant::now();
491    let mut turn = 0;
492    let mut total_usage = TokenUsage::default();
493
494    // Load or create state
495    let mut state = state_store
496        .load(&thread_id)
497        .await?
498        .unwrap_or_else(|| AgentState::new(thread_id.clone()));
499
500    // Add user message to history
501    let user_msg = Message::user(&user_message);
502    message_store.append(&thread_id, user_msg).await?;
503
504    // Main agent loop
505    loop {
506        turn += 1;
507        state.turn_count = turn;
508
509        if turn > config.max_turns {
510            warn!(turn, max = config.max_turns, "Max turns reached");
511            tx.send(AgentEvent::error(
512                format!("Maximum turns ({}) reached", config.max_turns),
513                true,
514            ))
515            .await?;
516            break;
517        }
518
519        // Emit start event
520        tx.send(AgentEvent::start(thread_id.clone(), turn)).await?;
521        hooks
522            .on_event(&AgentEvent::start(thread_id.clone(), turn))
523            .await;
524
525        // Get message history
526        let mut messages = message_store.get_history(&thread_id).await?;
527
528        // Check if compaction is needed
529        if let Some(ref compact_config) = compaction_config {
530            let compactor = LlmContextCompactor::new(Arc::clone(&provider), compact_config.clone());
531            if compactor.needs_compaction(&messages) {
532                debug!(
533                    turn,
534                    message_count = messages.len(),
535                    "Context compaction triggered"
536                );
537
538                match compactor.compact_history(messages).await {
539                    Ok(result) => {
540                        // Replace history in store
541                        message_store
542                            .replace_history(&thread_id, result.messages.clone())
543                            .await?;
544
545                        // Emit compaction event
546                        tx.send(AgentEvent::context_compacted(
547                            result.original_count,
548                            result.new_count,
549                            result.original_tokens,
550                            result.new_tokens,
551                        ))
552                        .await?;
553
554                        info!(
555                            original_count = result.original_count,
556                            new_count = result.new_count,
557                            original_tokens = result.original_tokens,
558                            new_tokens = result.new_tokens,
559                            "Context compacted successfully"
560                        );
561
562                        // Use the compacted messages
563                        messages = result.messages;
564                    }
565                    Err(e) => {
566                        warn!(error = %e, "Context compaction failed, continuing with full history");
567                        // Continue with original messages on failure
568                        messages = message_store.get_history(&thread_id).await?;
569                    }
570                }
571            }
572        }
573
574        // Build chat request
575        let llm_tools = if tools.is_empty() {
576            None
577        } else {
578            Some(tools.to_llm_tools())
579        };
580
581        let request = ChatRequest {
582            system: config.system_prompt.clone(),
583            messages,
584            tools: llm_tools,
585            max_tokens: config.max_tokens,
586        };
587
588        // Call LLM with retry logic for transient errors
589        debug!(turn, "Calling LLM");
590        let max_retries = config.retry.max_retries;
591        let response = {
592            let mut attempt = 0u32;
593            loop {
594                let outcome = provider.chat(request.clone()).await?;
595                match outcome {
596                    ChatOutcome::Success(response) => break Some(response),
597                    ChatOutcome::RateLimited => {
598                        attempt += 1;
599                        if attempt > max_retries {
600                            error!("Rate limited by LLM provider after {max_retries} retries");
601                            tx.send(AgentEvent::error(
602                                format!("Rate limited after {max_retries} retries"),
603                                true,
604                            ))
605                            .await?;
606                            break None;
607                        }
608                        let delay = calculate_backoff_delay(attempt, &config.retry);
609                        warn!(
610                            attempt,
611                            delay_ms = delay.as_millis(),
612                            "Rate limited, retrying after backoff"
613                        );
614                        tx.send(AgentEvent::text(format!(
615                            "\n[Rate limited, retrying in {:.1}s... (attempt {attempt}/{max_retries})]\n",
616                            delay.as_secs_f64()
617                        )))
618                        .await?;
619                        sleep(delay).await;
620                    }
621                    ChatOutcome::InvalidRequest(msg) => {
622                        error!(msg, "Invalid request to LLM");
623                        tx.send(AgentEvent::error(format!("Invalid request: {msg}"), false))
624                            .await?;
625                        break None;
626                    }
627                    ChatOutcome::ServerError(msg) => {
628                        attempt += 1;
629                        if attempt > max_retries {
630                            error!(msg, "LLM server error after {max_retries} retries");
631                            tx.send(AgentEvent::error(
632                                format!("Server error after {max_retries} retries: {msg}"),
633                                true,
634                            ))
635                            .await?;
636                            break None;
637                        }
638                        let delay = calculate_backoff_delay(attempt, &config.retry);
639                        warn!(
640                            attempt,
641                            delay_ms = delay.as_millis(),
642                            error = msg,
643                            "Server error, retrying after backoff"
644                        );
645                        tx.send(AgentEvent::text(format!(
646                            "\n[Server error: {msg}, retrying in {:.1}s... (attempt {attempt}/{max_retries})]\n",
647                            delay.as_secs_f64()
648                        )))
649                        .await?;
650                        sleep(delay).await;
651                    }
652                }
653            }
654        };
655
656        // If we failed to get a response after retries, exit the loop
657        let Some(response) = response else {
658            break;
659        };
660
661        // Track usage
662        let turn_usage = TokenUsage {
663            input_tokens: response.usage.input_tokens,
664            output_tokens: response.usage.output_tokens,
665        };
666        total_usage.add(&turn_usage);
667        state.total_usage = total_usage.clone();
668
669        // Process response content
670        let (text_content, tool_uses) = extract_content(&response);
671
672        // Emit text if present
673        if let Some(text) = &text_content {
674            tx.send(AgentEvent::text(text.clone())).await?;
675            hooks.on_event(&AgentEvent::text(text.clone())).await;
676        }
677
678        // If no tool uses, we're done
679        if tool_uses.is_empty() {
680            info!(turn, "Agent completed (no tool use)");
681            break;
682        }
683
684        // Store assistant message with tool uses
685        let assistant_msg = build_assistant_message(&response);
686        message_store.append(&thread_id, assistant_msg).await?;
687
688        // Execute tools
689        let mut tool_results = Vec::new();
690        for (tool_id, tool_name, tool_input) in &tool_uses {
691            let Some(tool) = tools.get(tool_name) else {
692                let result = ToolResult::error(format!("Unknown tool: {tool_name}"));
693                tool_results.push((tool_id.clone(), result));
694                continue;
695            };
696
697            let tier = tool.tier();
698
699            // Emit tool call start
700            tx.send(AgentEvent::tool_call_start(
701                tool_id,
702                tool_name,
703                tool_input.clone(),
704                tier,
705            ))
706            .await?;
707
708            // Check hooks for permission
709            let decision = hooks.pre_tool_use(tool_name, tool_input, tier).await;
710
711            match decision {
712                ToolDecision::Allow => {
713                    // Execute tool
714                    let tool_start = Instant::now();
715                    let result = match tool.execute(&tool_context, tool_input.clone()).await {
716                        Ok(mut r) => {
717                            r.duration_ms = Some(millis_to_u64(tool_start.elapsed().as_millis()));
718                            r
719                        }
720                        Err(e) => ToolResult::error(format!("Tool error: {e}"))
721                            .with_duration(millis_to_u64(tool_start.elapsed().as_millis())),
722                    };
723
724                    // Post-tool hook
725                    hooks.post_tool_use(tool_name, &result).await;
726
727                    // Emit tool call end
728                    tx.send(AgentEvent::tool_call_end(
729                        tool_id,
730                        tool_name,
731                        result.clone(),
732                    ))
733                    .await?;
734
735                    tool_results.push((tool_id.clone(), result));
736                }
737                ToolDecision::Block(reason) => {
738                    let result = ToolResult::error(format!("Blocked: {reason}"));
739                    tx.send(AgentEvent::tool_call_end(
740                        tool_id,
741                        tool_name,
742                        result.clone(),
743                    ))
744                    .await?;
745                    tool_results.push((tool_id.clone(), result));
746                }
747                ToolDecision::RequiresConfirmation(description) => {
748                    tx.send(AgentEvent::ToolRequiresConfirmation {
749                        id: tool_id.clone(),
750                        name: tool_name.clone(),
751                        input: tool_input.clone(),
752                        description,
753                    })
754                    .await?;
755                    // For now, treat as blocked - caller should handle confirmation flow
756                    let result = ToolResult::error("Awaiting user confirmation");
757                    tool_results.push((tool_id.clone(), result));
758                }
759                ToolDecision::RequiresPin(description) => {
760                    tx.send(AgentEvent::ToolRequiresPin {
761                        id: tool_id.clone(),
762                        name: tool_name.clone(),
763                        input: tool_input.clone(),
764                        description,
765                    })
766                    .await?;
767                    // For now, treat as blocked - caller should handle PIN flow
768                    let result = ToolResult::error("Awaiting PIN verification");
769                    tool_results.push((tool_id.clone(), result));
770                }
771            }
772        }
773
774        // Add tool results to message history
775        for (tool_id, result) in &tool_results {
776            let tool_result_msg = Message::tool_result(tool_id, &result.output, !result.success);
777            message_store.append(&thread_id, tool_result_msg).await?;
778        }
779
780        // Emit turn complete
781        tx.send(AgentEvent::TurnComplete {
782            turn,
783            usage: turn_usage,
784        })
785        .await?;
786
787        // Check stop reason
788        if response.stop_reason == Some(StopReason::EndTurn) {
789            info!(turn, "Agent completed (end_turn)");
790            break;
791        }
792
793        // Save state checkpoint
794        state_store.save(&state).await?;
795    }
796
797    // Final state save
798    state_store.save(&state).await?;
799
800    // Emit done
801    let duration = start_time.elapsed();
802    tx.send(AgentEvent::done(thread_id, turn, total_usage, duration))
803        .await?;
804
805    Ok(())
806}
807
808/// Convert u128 milliseconds to u64, capping at `u64::MAX`
809#[allow(clippy::cast_possible_truncation)]
810const fn millis_to_u64(millis: u128) -> u64 {
811    if millis > u64::MAX as u128 {
812        u64::MAX
813    } else {
814        millis as u64
815    }
816}
817
818/// Calculate exponential backoff delay with jitter.
819///
820/// Uses exponential backoff with the formula: `base * 2^(attempt-1) + jitter`,
821/// capped at the maximum delay. Jitter (0-1000ms) helps avoid thundering herd.
822fn calculate_backoff_delay(attempt: u32, config: &RetryConfig) -> Duration {
823    // Exponential backoff: base, base*2, base*4, base*8, ...
824    let base_delay = config
825        .base_delay_ms
826        .saturating_mul(1u64 << (attempt.saturating_sub(1)));
827
828    // Add jitter (0-1000ms or 10% of base, whichever is smaller) to avoid thundering herd
829    let max_jitter = config.base_delay_ms.min(1000);
830    let jitter = if max_jitter > 0 {
831        u64::from(
832            std::time::SystemTime::now()
833                .duration_since(std::time::UNIX_EPOCH)
834                .unwrap_or_default()
835                .subsec_nanos(),
836        ) % max_jitter
837    } else {
838        0
839    };
840
841    let delay_ms = base_delay.saturating_add(jitter).min(config.max_delay_ms);
842    Duration::from_millis(delay_ms)
843}
844
845fn extract_content(
846    response: &ChatResponse,
847) -> (Option<String>, Vec<(String, String, serde_json::Value)>) {
848    let mut text_parts = Vec::new();
849    let mut tool_uses = Vec::new();
850
851    for block in &response.content {
852        match block {
853            ContentBlock::Text { text } => {
854                text_parts.push(text.clone());
855            }
856            ContentBlock::ToolUse {
857                id, name, input, ..
858            } => {
859                tool_uses.push((id.clone(), name.clone(), input.clone()));
860            }
861            ContentBlock::ToolResult { .. } => {
862                // Shouldn't appear in response, but ignore if it does
863            }
864        }
865    }
866
867    let text = if text_parts.is_empty() {
868        None
869    } else {
870        Some(text_parts.join("\n"))
871    };
872
873    (text, tool_uses)
874}
875
876fn build_assistant_message(response: &ChatResponse) -> Message {
877    let mut blocks = Vec::new();
878
879    for block in &response.content {
880        match block {
881            ContentBlock::Text { text } => {
882                blocks.push(ContentBlock::Text { text: text.clone() });
883            }
884            ContentBlock::ToolUse {
885                id,
886                name,
887                input,
888                thought_signature,
889            } => {
890                blocks.push(ContentBlock::ToolUse {
891                    id: id.clone(),
892                    name: name.clone(),
893                    input: input.clone(),
894                    thought_signature: thought_signature.clone(),
895                });
896            }
897            ContentBlock::ToolResult { .. } => {}
898        }
899    }
900
901    Message {
902        role: Role::Assistant,
903        content: Content::Blocks(blocks),
904    }
905}
906
907#[cfg(test)]
908mod tests {
909    use super::*;
910    use crate::hooks::AllowAllHooks;
911    use crate::llm::{ChatOutcome, ChatRequest, ChatResponse, ContentBlock, StopReason, Usage};
912    use crate::stores::InMemoryStore;
913    use crate::tools::{Tool, ToolContext, ToolRegistry};
914    use crate::types::{AgentConfig, ToolResult, ToolTier};
915    use async_trait::async_trait;
916    use serde_json::json;
917    use std::sync::RwLock;
918    use std::sync::atomic::{AtomicUsize, Ordering};
919
920    // ===================
921    // Mock LLM Provider
922    // ===================
923
924    struct MockProvider {
925        responses: RwLock<Vec<ChatOutcome>>,
926        call_count: AtomicUsize,
927    }
928
929    impl MockProvider {
930        fn new(responses: Vec<ChatOutcome>) -> Self {
931            Self {
932                responses: RwLock::new(responses),
933                call_count: AtomicUsize::new(0),
934            }
935        }
936
937        fn text_response(text: &str) -> ChatOutcome {
938            ChatOutcome::Success(ChatResponse {
939                id: "msg_1".to_string(),
940                content: vec![ContentBlock::Text {
941                    text: text.to_string(),
942                }],
943                model: "mock-model".to_string(),
944                stop_reason: Some(StopReason::EndTurn),
945                usage: Usage {
946                    input_tokens: 10,
947                    output_tokens: 20,
948                },
949            })
950        }
951
952        fn tool_use_response(
953            tool_id: &str,
954            tool_name: &str,
955            input: serde_json::Value,
956        ) -> ChatOutcome {
957            ChatOutcome::Success(ChatResponse {
958                id: "msg_1".to_string(),
959                content: vec![ContentBlock::ToolUse {
960                    id: tool_id.to_string(),
961                    name: tool_name.to_string(),
962                    input,
963                    thought_signature: None,
964                }],
965                model: "mock-model".to_string(),
966                stop_reason: Some(StopReason::ToolUse),
967                usage: Usage {
968                    input_tokens: 10,
969                    output_tokens: 20,
970                },
971            })
972        }
973    }
974
975    #[async_trait]
976    impl LlmProvider for MockProvider {
977        async fn chat(&self, _request: ChatRequest) -> Result<ChatOutcome> {
978            let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
979            let responses = self.responses.read().unwrap();
980            if idx < responses.len() {
981                Ok(responses[idx].clone())
982            } else {
983                // Default: end conversation
984                Ok(Self::text_response("Done"))
985            }
986        }
987
988        fn model(&self) -> &'static str {
989            "mock-model"
990        }
991
992        fn provider(&self) -> &'static str {
993            "mock"
994        }
995    }
996
997    // Make ChatOutcome clonable for tests
998    impl Clone for ChatOutcome {
999        fn clone(&self) -> Self {
1000            match self {
1001                Self::Success(r) => Self::Success(r.clone()),
1002                Self::RateLimited => Self::RateLimited,
1003                Self::InvalidRequest(s) => Self::InvalidRequest(s.clone()),
1004                Self::ServerError(s) => Self::ServerError(s.clone()),
1005            }
1006        }
1007    }
1008
1009    // ===================
1010    // Mock Tool
1011    // ===================
1012
1013    struct EchoTool;
1014
1015    #[async_trait]
1016    impl Tool<()> for EchoTool {
1017        fn name(&self) -> &'static str {
1018            "echo"
1019        }
1020
1021        fn description(&self) -> &'static str {
1022            "Echo the input message"
1023        }
1024
1025        fn input_schema(&self) -> serde_json::Value {
1026            json!({
1027                "type": "object",
1028                "properties": {
1029                    "message": { "type": "string" }
1030                },
1031                "required": ["message"]
1032            })
1033        }
1034
1035        fn tier(&self) -> ToolTier {
1036            ToolTier::Observe
1037        }
1038
1039        async fn execute(
1040            &self,
1041            _ctx: &ToolContext<()>,
1042            input: serde_json::Value,
1043        ) -> Result<ToolResult> {
1044            let message = input
1045                .get("message")
1046                .and_then(|v| v.as_str())
1047                .unwrap_or("no message");
1048            Ok(ToolResult::success(format!("Echo: {message}")))
1049        }
1050    }
1051
1052    // ===================
1053    // Builder Tests
1054    // ===================
1055
1056    #[test]
1057    fn test_builder_creates_agent_loop() {
1058        let provider = MockProvider::new(vec![]);
1059        let agent = builder::<()>().provider(provider).build();
1060
1061        assert_eq!(agent.config.max_turns, 10);
1062        assert_eq!(agent.config.max_tokens, 4096);
1063    }
1064
1065    #[test]
1066    fn test_builder_with_custom_config() {
1067        let provider = MockProvider::new(vec![]);
1068        let config = AgentConfig {
1069            max_turns: 5,
1070            max_tokens: 2048,
1071            system_prompt: "Custom prompt".to_string(),
1072            model: "custom-model".to_string(),
1073            ..Default::default()
1074        };
1075
1076        let agent = builder::<()>().provider(provider).config(config).build();
1077
1078        assert_eq!(agent.config.max_turns, 5);
1079        assert_eq!(agent.config.max_tokens, 2048);
1080        assert_eq!(agent.config.system_prompt, "Custom prompt");
1081    }
1082
1083    #[test]
1084    fn test_builder_with_tools() {
1085        let provider = MockProvider::new(vec![]);
1086        let mut tools = ToolRegistry::new();
1087        tools.register(EchoTool);
1088
1089        let agent = builder::<()>().provider(provider).tools(tools).build();
1090
1091        assert_eq!(agent.tools.len(), 1);
1092    }
1093
1094    #[test]
1095    fn test_builder_with_custom_stores() {
1096        let provider = MockProvider::new(vec![]);
1097        let message_store = InMemoryStore::new();
1098        let state_store = InMemoryStore::new();
1099
1100        let agent = builder::<()>()
1101            .provider(provider)
1102            .hooks(AllowAllHooks)
1103            .message_store(message_store)
1104            .state_store(state_store)
1105            .build_with_stores();
1106
1107        // Just verify it builds without panicking
1108        assert_eq!(agent.config.max_turns, 10);
1109    }
1110
1111    // ===================
1112    // Run Loop Tests
1113    // ===================
1114
1115    #[tokio::test]
1116    async fn test_simple_text_response() -> anyhow::Result<()> {
1117        let provider = MockProvider::new(vec![MockProvider::text_response("Hello, user!")]);
1118
1119        let agent = builder::<()>().provider(provider).build();
1120
1121        let thread_id = ThreadId::new();
1122        let tool_ctx = ToolContext::new(());
1123        let mut rx = agent.run(thread_id, "Hi".to_string(), tool_ctx);
1124
1125        let mut events = Vec::new();
1126        while let Some(event) = rx.recv().await {
1127            events.push(event);
1128        }
1129
1130        // Should have: Start, Text, Done
1131        assert!(events.iter().any(|e| matches!(e, AgentEvent::Text { .. })));
1132        assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
1133
1134        Ok(())
1135    }
1136
1137    #[tokio::test]
1138    async fn test_tool_execution() -> anyhow::Result<()> {
1139        let provider = MockProvider::new(vec![
1140            // First call: request tool use
1141            MockProvider::tool_use_response("tool_1", "echo", json!({"message": "test"})),
1142            // Second call: respond with text
1143            MockProvider::text_response("Tool executed successfully"),
1144        ]);
1145
1146        let mut tools = ToolRegistry::new();
1147        tools.register(EchoTool);
1148
1149        let agent = builder::<()>().provider(provider).tools(tools).build();
1150
1151        let thread_id = ThreadId::new();
1152        let tool_ctx = ToolContext::new(());
1153        let mut rx = agent.run(thread_id, "Run echo".to_string(), tool_ctx);
1154
1155        let mut events = Vec::new();
1156        while let Some(event) = rx.recv().await {
1157            events.push(event);
1158        }
1159
1160        // Should have tool call events
1161        assert!(
1162            events
1163                .iter()
1164                .any(|e| matches!(e, AgentEvent::ToolCallStart { .. }))
1165        );
1166        assert!(
1167            events
1168                .iter()
1169                .any(|e| matches!(e, AgentEvent::ToolCallEnd { .. }))
1170        );
1171
1172        Ok(())
1173    }
1174
1175    #[tokio::test]
1176    async fn test_max_turns_limit() -> anyhow::Result<()> {
1177        // Provider that always requests a tool
1178        let provider = MockProvider::new(vec![
1179            MockProvider::tool_use_response("tool_1", "echo", json!({"message": "1"})),
1180            MockProvider::tool_use_response("tool_2", "echo", json!({"message": "2"})),
1181            MockProvider::tool_use_response("tool_3", "echo", json!({"message": "3"})),
1182            MockProvider::tool_use_response("tool_4", "echo", json!({"message": "4"})),
1183        ]);
1184
1185        let mut tools = ToolRegistry::new();
1186        tools.register(EchoTool);
1187
1188        let config = AgentConfig {
1189            max_turns: 2,
1190            ..Default::default()
1191        };
1192
1193        let agent = builder::<()>()
1194            .provider(provider)
1195            .tools(tools)
1196            .config(config)
1197            .build();
1198
1199        let thread_id = ThreadId::new();
1200        let tool_ctx = ToolContext::new(());
1201        let mut rx = agent.run(thread_id, "Loop".to_string(), tool_ctx);
1202
1203        let mut events = Vec::new();
1204        while let Some(event) = rx.recv().await {
1205            events.push(event);
1206        }
1207
1208        // Should have an error about max turns
1209        assert!(events.iter().any(|e| {
1210            matches!(e, AgentEvent::Error { message, .. } if message.contains("Maximum turns"))
1211        }));
1212
1213        Ok(())
1214    }
1215
1216    #[tokio::test]
1217    async fn test_unknown_tool_handling() -> anyhow::Result<()> {
1218        let provider = MockProvider::new(vec![
1219            // Request unknown tool
1220            MockProvider::tool_use_response("tool_1", "nonexistent_tool", json!({})),
1221            // LLM gets tool error and ends conversation
1222            MockProvider::text_response("I couldn't find that tool."),
1223        ]);
1224
1225        // Empty tool registry
1226        let tools = ToolRegistry::new();
1227
1228        let agent = builder::<()>().provider(provider).tools(tools).build();
1229
1230        let thread_id = ThreadId::new();
1231        let tool_ctx = ToolContext::new(());
1232        let mut rx = agent.run(thread_id, "Call unknown".to_string(), tool_ctx);
1233
1234        let mut events = Vec::new();
1235        while let Some(event) = rx.recv().await {
1236            events.push(event);
1237        }
1238
1239        // Unknown tool errors are returned to the LLM (not emitted as ToolCallEnd)
1240        // The conversation should complete successfully with a Done event
1241        assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
1242
1243        // The LLM's response about the missing tool should be in the events
1244        assert!(
1245            events.iter().any(|e| {
1246                matches!(e, AgentEvent::Text { text } if text.contains("couldn't find"))
1247            })
1248        );
1249
1250        Ok(())
1251    }
1252
1253    #[tokio::test]
1254    async fn test_rate_limit_handling() -> anyhow::Result<()> {
1255        // Provide enough RateLimited responses to exhaust all retries (max_retries + 1)
1256        let provider = MockProvider::new(vec![
1257            ChatOutcome::RateLimited,
1258            ChatOutcome::RateLimited,
1259            ChatOutcome::RateLimited,
1260            ChatOutcome::RateLimited,
1261            ChatOutcome::RateLimited,
1262            ChatOutcome::RateLimited, // 6th attempt exceeds max_retries (5)
1263        ]);
1264
1265        // Use fast retry config for faster tests
1266        let config = AgentConfig {
1267            retry: crate::types::RetryConfig::fast(),
1268            ..Default::default()
1269        };
1270
1271        let agent = builder::<()>().provider(provider).config(config).build();
1272
1273        let thread_id = ThreadId::new();
1274        let tool_ctx = ToolContext::new(());
1275        let mut rx = agent.run(thread_id, "Hi".to_string(), tool_ctx);
1276
1277        let mut events = Vec::new();
1278        while let Some(event) = rx.recv().await {
1279            events.push(event);
1280        }
1281
1282        // Should have rate limit error after exhausting retries
1283        assert!(events.iter().any(|e| {
1284            matches!(e, AgentEvent::Error { message, recoverable: true } if message.contains("Rate limited"))
1285        }));
1286
1287        // Should have retry text events
1288        assert!(
1289            events
1290                .iter()
1291                .any(|e| { matches!(e, AgentEvent::Text { text } if text.contains("retrying")) })
1292        );
1293
1294        Ok(())
1295    }
1296
1297    #[tokio::test]
1298    async fn test_rate_limit_recovery() -> anyhow::Result<()> {
1299        // Rate limited once, then succeeds
1300        let provider = MockProvider::new(vec![
1301            ChatOutcome::RateLimited,
1302            MockProvider::text_response("Recovered after rate limit"),
1303        ]);
1304
1305        // Use fast retry config for faster tests
1306        let config = AgentConfig {
1307            retry: crate::types::RetryConfig::fast(),
1308            ..Default::default()
1309        };
1310
1311        let agent = builder::<()>().provider(provider).config(config).build();
1312
1313        let thread_id = ThreadId::new();
1314        let tool_ctx = ToolContext::new(());
1315        let mut rx = agent.run(thread_id, "Hi".to_string(), tool_ctx);
1316
1317        let mut events = Vec::new();
1318        while let Some(event) = rx.recv().await {
1319            events.push(event);
1320        }
1321
1322        // Should have successful completion after retry
1323        assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
1324
1325        // Should have retry text event
1326        assert!(
1327            events
1328                .iter()
1329                .any(|e| { matches!(e, AgentEvent::Text { text } if text.contains("retrying")) })
1330        );
1331
1332        Ok(())
1333    }
1334
1335    #[tokio::test]
1336    async fn test_server_error_handling() -> anyhow::Result<()> {
1337        // Provide enough ServerError responses to exhaust all retries (max_retries + 1)
1338        let provider = MockProvider::new(vec![
1339            ChatOutcome::ServerError("Internal error".to_string()),
1340            ChatOutcome::ServerError("Internal error".to_string()),
1341            ChatOutcome::ServerError("Internal error".to_string()),
1342            ChatOutcome::ServerError("Internal error".to_string()),
1343            ChatOutcome::ServerError("Internal error".to_string()),
1344            ChatOutcome::ServerError("Internal error".to_string()), // 6th attempt exceeds max_retries
1345        ]);
1346
1347        // Use fast retry config for faster tests
1348        let config = AgentConfig {
1349            retry: crate::types::RetryConfig::fast(),
1350            ..Default::default()
1351        };
1352
1353        let agent = builder::<()>().provider(provider).config(config).build();
1354
1355        let thread_id = ThreadId::new();
1356        let tool_ctx = ToolContext::new(());
1357        let mut rx = agent.run(thread_id, "Hi".to_string(), tool_ctx);
1358
1359        let mut events = Vec::new();
1360        while let Some(event) = rx.recv().await {
1361            events.push(event);
1362        }
1363
1364        // Should have server error after exhausting retries
1365        assert!(events.iter().any(|e| {
1366            matches!(e, AgentEvent::Error { message, recoverable: true } if message.contains("Server error"))
1367        }));
1368
1369        // Should have retry text events
1370        assert!(
1371            events
1372                .iter()
1373                .any(|e| { matches!(e, AgentEvent::Text { text } if text.contains("retrying")) })
1374        );
1375
1376        Ok(())
1377    }
1378
1379    #[tokio::test]
1380    async fn test_server_error_recovery() -> anyhow::Result<()> {
1381        // Server error once, then succeeds
1382        let provider = MockProvider::new(vec![
1383            ChatOutcome::ServerError("Temporary error".to_string()),
1384            MockProvider::text_response("Recovered after server error"),
1385        ]);
1386
1387        // Use fast retry config for faster tests
1388        let config = AgentConfig {
1389            retry: crate::types::RetryConfig::fast(),
1390            ..Default::default()
1391        };
1392
1393        let agent = builder::<()>().provider(provider).config(config).build();
1394
1395        let thread_id = ThreadId::new();
1396        let tool_ctx = ToolContext::new(());
1397        let mut rx = agent.run(thread_id, "Hi".to_string(), tool_ctx);
1398
1399        let mut events = Vec::new();
1400        while let Some(event) = rx.recv().await {
1401            events.push(event);
1402        }
1403
1404        // Should have successful completion after retry
1405        assert!(events.iter().any(|e| matches!(e, AgentEvent::Done { .. })));
1406
1407        // Should have retry text event
1408        assert!(
1409            events
1410                .iter()
1411                .any(|e| { matches!(e, AgentEvent::Text { text } if text.contains("retrying")) })
1412        );
1413
1414        Ok(())
1415    }
1416
1417    // ===================
1418    // Helper Function Tests
1419    // ===================
1420
1421    #[test]
1422    fn test_extract_content_text_only() {
1423        let response = ChatResponse {
1424            id: "msg_1".to_string(),
1425            content: vec![ContentBlock::Text {
1426                text: "Hello".to_string(),
1427            }],
1428            model: "test".to_string(),
1429            stop_reason: None,
1430            usage: Usage {
1431                input_tokens: 0,
1432                output_tokens: 0,
1433            },
1434        };
1435
1436        let (text, tool_uses) = extract_content(&response);
1437        assert_eq!(text, Some("Hello".to_string()));
1438        assert!(tool_uses.is_empty());
1439    }
1440
1441    #[test]
1442    fn test_extract_content_tool_use() {
1443        let response = ChatResponse {
1444            id: "msg_1".to_string(),
1445            content: vec![ContentBlock::ToolUse {
1446                id: "tool_1".to_string(),
1447                name: "test_tool".to_string(),
1448                input: json!({"key": "value"}),
1449                thought_signature: None,
1450            }],
1451            model: "test".to_string(),
1452            stop_reason: None,
1453            usage: Usage {
1454                input_tokens: 0,
1455                output_tokens: 0,
1456            },
1457        };
1458
1459        let (text, tool_uses) = extract_content(&response);
1460        assert!(text.is_none());
1461        assert_eq!(tool_uses.len(), 1);
1462        assert_eq!(tool_uses[0].1, "test_tool");
1463    }
1464
1465    #[test]
1466    fn test_extract_content_mixed() {
1467        let response = ChatResponse {
1468            id: "msg_1".to_string(),
1469            content: vec![
1470                ContentBlock::Text {
1471                    text: "Let me help".to_string(),
1472                },
1473                ContentBlock::ToolUse {
1474                    id: "tool_1".to_string(),
1475                    name: "helper".to_string(),
1476                    input: json!({}),
1477                    thought_signature: None,
1478                },
1479            ],
1480            model: "test".to_string(),
1481            stop_reason: None,
1482            usage: Usage {
1483                input_tokens: 0,
1484                output_tokens: 0,
1485            },
1486        };
1487
1488        let (text, tool_uses) = extract_content(&response);
1489        assert_eq!(text, Some("Let me help".to_string()));
1490        assert_eq!(tool_uses.len(), 1);
1491    }
1492
1493    #[test]
1494    fn test_millis_to_u64() {
1495        assert_eq!(millis_to_u64(0), 0);
1496        assert_eq!(millis_to_u64(1000), 1000);
1497        assert_eq!(millis_to_u64(u128::from(u64::MAX)), u64::MAX);
1498        assert_eq!(millis_to_u64(u128::from(u64::MAX) + 1), u64::MAX);
1499    }
1500
1501    #[test]
1502    fn test_build_assistant_message() {
1503        let response = ChatResponse {
1504            id: "msg_1".to_string(),
1505            content: vec![
1506                ContentBlock::Text {
1507                    text: "Response text".to_string(),
1508                },
1509                ContentBlock::ToolUse {
1510                    id: "tool_1".to_string(),
1511                    name: "echo".to_string(),
1512                    input: json!({"message": "test"}),
1513                    thought_signature: None,
1514                },
1515            ],
1516            model: "test".to_string(),
1517            stop_reason: None,
1518            usage: Usage {
1519                input_tokens: 0,
1520                output_tokens: 0,
1521            },
1522        };
1523
1524        let msg = build_assistant_message(&response);
1525        assert_eq!(msg.role, Role::Assistant);
1526
1527        if let Content::Blocks(blocks) = msg.content {
1528            assert_eq!(blocks.len(), 2);
1529        } else {
1530            panic!("Expected Content::Blocks");
1531        }
1532    }
1533}