Skip to main content

aether_core/core/
agent.rs

1use crate::context::{CompactionConfig, Compactor, TokenTracker};
2use crate::core::PromptCache;
3pub use crate::core::retry_config::RetryConfig;
4use crate::events::{AgentCommand, AgentMessage, Command, UserCommand};
5use crate::mcp::run_mcp_task::{McpCommand, ToolExecutionEvent};
6use futures::Stream;
7use llm::types::IsoString;
8use llm::{
9    AssistantReasoning, ChatMessage, Context, EncryptedReasoningContent, LlmError, LlmResponse, StopReason,
10    StreamingModelProvider, TokenUsage, ToolCallError, ToolCallRequest, ToolCallResult,
11};
12use std::collections::{HashMap, HashSet, VecDeque};
13use std::pin::Pin;
14use std::sync::Arc;
15use std::time::Duration;
16use tokio::sync::mpsc;
17use tokio::time::sleep;
18use tokio_stream::StreamExt;
19use tokio_stream::StreamMap;
20use tokio_stream::wrappers::ReceiverStream;
21
22/// Internal event type for merging LLM and tool result streams
23#[derive(Debug)]
24enum StreamEvent {
25    Llm(Result<LlmResponse, LlmError>),
26    ToolExecution(ToolExecutionEvent),
27    Command(Command),
28}
29
30type EventStream = Pin<Box<dyn Stream<Item = StreamEvent> + Send>>;
31
32const USER_STREAM_KEY: &str = "user";
33const LLM_STREAM_KEY: &str = "llm";
34
35pub(crate) struct AgentConfig {
36    pub llm: Arc<dyn StreamingModelProvider>,
37    pub context: Context,
38    pub mcp_command_tx: Option<mpsc::Sender<McpCommand>>,
39    pub tool_timeout: Duration,
40    pub compaction_config: Option<CompactionConfig>,
41    pub auto_continue: AutoContinue,
42    pub retry_config: RetryConfig,
43    pub context_window: Option<u32>,
44    pub prompt_cache: PromptCache,
45}
46
47pub struct Agent {
48    llm: Arc<dyn StreamingModelProvider>,
49    context: Context,
50    mcp_command_tx: Option<mpsc::Sender<McpCommand>>,
51    message_tx: mpsc::Sender<AgentMessage>,
52    streams: StreamMap<String, EventStream>,
53    tool_timeout: Duration,
54    token_tracker: TokenTracker,
55    compaction_config: Option<CompactionConfig>,
56    auto_continue: AutoContinue,
57    retry_config: RetryConfig,
58    active_requests: HashMap<String, ToolCallRequest>,
59    queued_user_messages: VecDeque<Vec<llm::ContentBlock>>,
60    context_window: Option<u32>,
61    prompt_cache: PromptCache,
62}
63
64impl Agent {
65    pub(crate) fn new(
66        config: AgentConfig,
67        command_rx: mpsc::Receiver<Command>,
68        message_tx: mpsc::Sender<AgentMessage>,
69    ) -> Self {
70        let mut streams: StreamMap<String, EventStream> = StreamMap::new();
71        streams
72            .insert(USER_STREAM_KEY.to_string(), Box::pin(ReceiverStream::new(command_rx).map(StreamEvent::Command)));
73
74        let context_limit = config.context_window.or_else(|| config.llm.context_window());
75
76        Self {
77            llm: config.llm,
78            context: config.context,
79            mcp_command_tx: config.mcp_command_tx,
80            message_tx,
81            streams,
82            tool_timeout: config.tool_timeout,
83            token_tracker: TokenTracker::new(context_limit),
84            compaction_config: config.compaction_config,
85            auto_continue: config.auto_continue,
86            retry_config: config.retry_config,
87            active_requests: HashMap::new(),
88            queued_user_messages: VecDeque::new(),
89            context_window: config.context_window,
90            prompt_cache: config.prompt_cache,
91        }
92    }
93
94    pub fn current_model_display_name(&self) -> String {
95        self.llm.display_name()
96    }
97
98    /// Get a reference to the token tracker
99    pub fn token_tracker(&self) -> &TokenTracker {
100        &self.token_tracker
101    }
102
103    pub async fn run(mut self) {
104        let mut state = IterationState::new();
105
106        while let Some((_, event)) = self.streams.next().await {
107            match event {
108                StreamEvent::Command(Command::UserCommand(UserCommand::Cancel)) => {
109                    self.on_user_cancel(&mut state).await;
110                }
111
112                StreamEvent::Command(Command::UserCommand(UserCommand::ClearContext)) => {
113                    self.on_user_clear_context(&mut state).await;
114                }
115
116                StreamEvent::Command(Command::UserCommand(UserCommand::Text { content })) => {
117                    if self.is_busy() {
118                        self.queued_user_messages.push_back(content);
119                    } else {
120                        state = IterationState::new();
121                        self.on_user_text(content);
122                    }
123                }
124
125                StreamEvent::Command(Command::AgentCommand(AgentCommand::SwitchModel(new_provider))) => {
126                    self.on_switch_model(new_provider).await;
127                }
128
129                StreamEvent::Command(Command::AgentCommand(AgentCommand::UpdateTools(tools))) => {
130                    self.context.set_tools(tools);
131                }
132
133                StreamEvent::Command(Command::AgentCommand(AgentCommand::UpdateMcpInstructions { server, body })) => {
134                    self.on_update_instruction(server, body).await;
135                }
136
137                StreamEvent::Command(Command::AgentCommand(AgentCommand::SetReasoningEffort(effort))) => {
138                    self.context.set_reasoning_effort(effort);
139                }
140
141                StreamEvent::Command(Command::AgentCommand(AgentCommand::ReplaceConversation(messages))) => {
142                    self.on_replace_conversation(messages, &mut state).await;
143                }
144
145                StreamEvent::Llm(llm_event) => {
146                    self.on_llm_event(llm_event, &mut state).await;
147                }
148
149                StreamEvent::ToolExecution(tool_event) => {
150                    self.on_tool_execution_event(tool_event, &mut state).await;
151                }
152            }
153
154            if state.is_complete() {
155                let Some(id) = state.current_message_id.take() else {
156                    continue;
157                };
158                let iteration = std::mem::replace(&mut state, IterationState::new());
159                self.on_iteration_complete(id, iteration).await;
160            }
161        }
162
163        tracing::debug!("Agent task shutting down - input channel closed");
164    }
165
166    async fn on_iteration_complete(&mut self, id: String, iteration: IterationState) {
167        let IterationState {
168            message_content,
169            reasoning_summary_text,
170            encrypted_reasoning,
171            completed_tool_calls,
172            stop_reason,
173            ..
174        } = iteration;
175        let has_tool_calls = !completed_tool_calls.is_empty();
176        let has_content = !message_content.is_empty() || has_tool_calls;
177        let should_auto_continue = self.auto_continue.should_continue(stop_reason.as_ref());
178
179        if has_content {
180            let reasoning = AssistantReasoning::from_parts(reasoning_summary_text.clone(), encrypted_reasoning);
181            self.update_context(&message_content, reasoning, completed_tool_calls);
182
183            let _ = self
184                .message_tx
185                .send(AgentMessage::Text {
186                    message_id: id.clone(),
187                    chunk: message_content.clone(),
188                    is_complete: true,
189                    model_name: self.llm.display_name(),
190                })
191                .await;
192
193            if !reasoning_summary_text.is_empty() {
194                let _ = self
195                    .message_tx
196                    .send(AgentMessage::Thought {
197                        message_id: id.clone(),
198                        chunk: reasoning_summary_text,
199                        is_complete: true,
200                        model_name: self.llm.display_name(),
201                    })
202                    .await;
203            }
204        }
205
206        let has_queued_text = !self.queued_user_messages.is_empty();
207        if has_queued_text {
208            let content: Vec<_> = self.queued_user_messages.drain(..).flatten().collect();
209            self.context.add_message(ChatMessage::User { content, timestamp: IsoString::now() });
210        }
211
212        if has_queued_text || has_tool_calls {
213            self.auto_continue.reset();
214            self.start_next_turn().await;
215        } else if should_auto_continue {
216            self.auto_continue.advance();
217            tracing::info!(
218                "LLM stopped with {:?}, auto-continuing (attempt {}/{})",
219                stop_reason,
220                self.auto_continue.count(),
221                self.auto_continue.max()
222            );
223
224            let _ = self
225                .message_tx
226                .send(AgentMessage::AutoContinue {
227                    attempt: self.auto_continue.count(),
228                    max_attempts: self.auto_continue.max(),
229                })
230                .await;
231
232            self.inject_continuation_prompt(&message_content, stop_reason.as_ref());
233            self.start_next_turn().await;
234        } else {
235            tracing::debug!("LLM completed turn with stop reason: {:?}", stop_reason);
236            self.auto_continue.reset();
237            if let Err(e) = self.message_tx.send(AgentMessage::Done).await {
238                tracing::warn!("Failed to send Done message: {:?}", e);
239            }
240        }
241    }
242
243    async fn start_next_turn(&mut self) {
244        self.maybe_preflight_compact().await;
245        self.start_llm_stream(None);
246    }
247
248    async fn on_user_cancel(&mut self, state: &mut IterationState) {
249        self.abort_in_flight_work();
250        *state = IterationState::new();
251        let _ = self.message_tx.send(AgentMessage::Cancelled { message: "Processing cancelled".to_string() }).await;
252        let _ = self.message_tx.send(AgentMessage::Done).await;
253    }
254
255    async fn on_user_clear_context(&mut self, state: &mut IterationState) {
256        self.abort_in_flight_work();
257        self.context.clear_conversation();
258        self.token_tracker.reset_current_usage();
259        self.auto_continue.reset();
260        *state = IterationState::new();
261
262        let _ = self.message_tx.send(AgentMessage::ContextCleared).await;
263    }
264
265    async fn on_replace_conversation(&mut self, messages: Vec<ChatMessage>, state: &mut IterationState) {
266        self.abort_in_flight_work();
267        self.context.replace_conversation(messages);
268        self.auto_continue.reset();
269        *state = IterationState::new();
270        let _ = self.message_tx.send(self.context_usage_message()).await;
271    }
272
273    fn on_user_text(&mut self, content: Vec<llm::ContentBlock>) {
274        self.context.add_message(ChatMessage::User { content, timestamp: IsoString::now() });
275        self.auto_continue.reset();
276        self.start_llm_stream(None);
277    }
278
279    async fn on_update_instruction(&mut self, server: String, body: Option<String>) {
280        self.prompt_cache.update_mcp_instruction(server, body);
281        match self.prompt_cache.render().await {
282            Ok(content) => self.context.set_system_content(content),
283            Err(e) => tracing::warn!("Failed to rebuild system prompt after instructions update: {e}"),
284        }
285    }
286
287    async fn on_switch_model(&mut self, new_provider: Box<dyn StreamingModelProvider>) {
288        let previous = self.llm.display_name();
289        let new_context_limit = self.context_window.or_else(|| new_provider.context_window());
290        self.llm = Arc::from(new_provider);
291        self.token_tracker.reset_current_usage();
292        self.token_tracker.set_context_limit(new_context_limit);
293        let new = self.llm.display_name();
294        let _ = self.message_tx.send(AgentMessage::ModelSwitched { previous, new }).await;
295
296        let _ = self.message_tx.send(self.context_usage_message()).await;
297    }
298
299    fn start_llm_stream(&mut self, delay: Option<Duration>) {
300        self.streams.remove(LLM_STREAM_KEY);
301        let stream: EventStream = match delay {
302            None => Box::pin(self.llm.stream_response(&self.context).map(StreamEvent::Llm)),
303            Some(delay) => {
304                let llm = Arc::clone(&self.llm);
305                let context = self.context.clone();
306                Box::pin(async_stream::stream! {
307                    sleep(delay).await;
308                    let mut inner = llm.stream_response(&context);
309                    while let Some(item) = inner.next().await {
310                        yield StreamEvent::Llm(item);
311                    }
312                })
313            }
314        };
315        self.streams.insert(LLM_STREAM_KEY.to_string(), stream);
316    }
317
318    async fn on_llm_error(&mut self, error: LlmError, state: &mut IterationState) {
319        if !error.is_retryable() || state.retry_attempt >= self.retry_config.max_attempts {
320            let _ = self.message_tx.send(AgentMessage::Error { message: error.to_string() }).await;
321            return;
322        }
323
324        state.retry_attempt += 1;
325        let delay = self.retry_config.compute_delay(state.retry_attempt);
326        let delay_ms = u64::try_from(delay.as_millis()).unwrap_or(u64::MAX);
327
328        tracing::warn!(
329            attempt = state.retry_attempt,
330            max_attempts = self.retry_config.max_attempts,
331            delay_ms,
332            error = %error,
333            "Retrying LLM request after transient failure"
334        );
335
336        let _ = self
337            .message_tx
338            .send(AgentMessage::Retrying {
339                attempt: state.retry_attempt,
340                max_attempts: self.retry_config.max_attempts,
341                delay_ms,
342                error: error.to_string(),
343            })
344            .await;
345
346        // The previous stream may have emitted partial tool-call deltas
347        // before interrupting so we drop them to ensure we rebuild tool state
348        self.active_requests.clear();
349        self.start_llm_stream(Some(delay));
350    }
351
352    fn is_busy(&self) -> bool {
353        self.streams.contains_key(LLM_STREAM_KEY) || !self.active_requests.is_empty()
354    }
355
356    fn abort_in_flight_work(&mut self) {
357        self.streams.remove(LLM_STREAM_KEY);
358        for stream_key in self.active_requests.keys().cloned().collect::<Vec<_>>() {
359            self.streams.remove(&stream_key);
360        }
361        self.active_requests.clear();
362        self.queued_user_messages.clear();
363    }
364
365    /// Inject a continuation prompt when the LLM stops due to a resumable reason.
366    fn inject_continuation_prompt(&mut self, previous_response: &str, stop_reason: Option<&StopReason>) {
367        if !previous_response.is_empty() {
368            self.context.add_message(ChatMessage::Assistant {
369                content: previous_response.to_string(),
370                reasoning: AssistantReasoning::default(),
371                timestamp: IsoString::now(),
372                tool_calls: Vec::new(),
373            });
374        }
375
376        let reason = stop_reason.map_or_else(|| "Unknown".to_string(), |reason| format!("{reason:?}"));
377
378        self.context.add_message(ChatMessage::User {
379            content: vec![llm::ContentBlock::text(format!(
380                "<system-notification>The LLM API stopped with reason '{reason}'. Continue from where you left off and finish your task.</system-notification>"
381            ))],
382            timestamp: IsoString::now(),
383        });
384    }
385
386    async fn on_llm_event(&mut self, result: Result<LlmResponse, LlmError>, state: &mut IterationState) {
387        use LlmResponse::{
388            Done, EncryptedReasoning, Error, Reasoning, Start, Text, ToolRequestArg, ToolRequestComplete,
389            ToolRequestStart, Usage,
390        };
391
392        let response = match result {
393            Ok(response) => response,
394            Err(e) => {
395                self.on_llm_error(e, state).await;
396                return;
397            }
398        };
399
400        match response {
401            Start { message_id } => {
402                state.on_llm_start(message_id);
403            }
404
405            Text { chunk } => {
406                self.handle_llm_text(chunk, state).await;
407            }
408
409            Reasoning { chunk } => {
410                state.reasoning_summary_text.push_str(&chunk);
411                if let Some(id) = &state.current_message_id {
412                    let _ = self
413                        .message_tx
414                        .send(AgentMessage::Thought {
415                            message_id: id.clone(),
416                            chunk,
417                            is_complete: false,
418                            model_name: self.llm.display_name(),
419                        })
420                        .await;
421                }
422            }
423
424            EncryptedReasoning { id, content } => {
425                if let Some(model) = self.llm.model() {
426                    state.encrypted_reasoning = Some(EncryptedReasoningContent { id, model, content });
427                }
428            }
429
430            ToolRequestStart { id, name } => {
431                self.handle_tool_request_start(id, name).await;
432            }
433
434            ToolRequestArg { id, chunk } => {
435                self.handle_tool_request_arg(id, chunk).await;
436            }
437
438            ToolRequestComplete { tool_call } => {
439                self.handle_tool_completion(tool_call, state).await;
440            }
441
442            Done { stop_reason } => {
443                state.llm_done = true;
444                state.stop_reason = stop_reason;
445            }
446
447            Error { message } => {
448                let _ = self.message_tx.send(AgentMessage::Error { message }).await;
449            }
450
451            Usage { tokens: sample } => {
452                self.handle_llm_usage(sample).await;
453            }
454        }
455    }
456
457    async fn handle_llm_text(&mut self, chunk: String, state: &mut IterationState) {
458        state.message_content.push_str(&chunk);
459
460        if let Some(id) = &state.current_message_id {
461            let _ = self
462                .message_tx
463                .send(AgentMessage::Text {
464                    message_id: id.clone(),
465                    chunk,
466                    is_complete: false,
467                    model_name: self.llm.display_name(),
468                })
469                .await;
470        }
471    }
472
473    async fn handle_tool_request_start(&mut self, id: String, name: String) {
474        let request = ToolCallRequest { id: id.clone(), name, arguments: String::new() };
475        self.active_requests.insert(id, request.clone());
476
477        let _ = self.message_tx.send(AgentMessage::ToolCall { request, model_name: self.llm.display_name() }).await;
478    }
479
480    async fn handle_tool_request_arg(&mut self, id: String, chunk: String) {
481        let Some(request) = self.active_requests.get_mut(&id) else {
482            return;
483        };
484        request.arguments.push_str(&chunk);
485
486        let _ = self
487            .message_tx
488            .send(AgentMessage::ToolCallUpdate { tool_call_id: id, chunk, model_name: self.llm.display_name() })
489            .await;
490    }
491
492    async fn handle_tool_completion(&mut self, tool_call: ToolCallRequest, state: &mut IterationState) {
493        state.pending_tool_ids.insert(tool_call.id.clone());
494        debug_assert!(
495            self.active_requests.contains_key(&tool_call.id),
496            "tool call {} should already be in active_requests from handle_tool_request_start",
497            tool_call.id
498        );
499
500        let (tx, rx) = mpsc::channel(100);
501        let stream = ReceiverStream::new(rx).map(StreamEvent::ToolExecution);
502        let stream_key = tool_call.id.clone();
503        self.streams.insert(stream_key, Box::pin(stream));
504
505        if let Some(ref mcp_command_tx) = self.mcp_command_tx {
506            let mcp_future =
507                mcp_command_tx.send(McpCommand::ExecuteTool { request: tool_call, timeout: self.tool_timeout, tx });
508            if let Err(e) = mcp_future.await {
509                tracing::warn!("Failed to send tool request to MCP task: {:?}", e);
510            }
511        }
512    }
513
514    async fn handle_llm_usage(&mut self, sample: TokenUsage) {
515        self.token_tracker.record_usage(sample);
516        let ratio_pct = self.token_tracker.usage_ratio().map(|r| r * 100.0);
517        let remaining = self.token_tracker.tokens_remaining();
518        tracing::debug!(?sample, ?ratio_pct, ?remaining, "Token usage");
519
520        let _ = self.message_tx.send(self.context_usage_message()).await;
521
522        self.maybe_compact_context().await;
523    }
524
525    fn context_usage_message(&self) -> AgentMessage {
526        let last = self.token_tracker.last_usage();
527        AgentMessage::ContextUsageUpdate {
528            usage_ratio: self.token_tracker.usage_ratio(),
529            context_limit: self.token_tracker.context_limit(),
530            input_tokens: last.input_tokens,
531            output_tokens: last.output_tokens,
532            cache_read_tokens: last.cache_read_tokens,
533            cache_creation_tokens: last.cache_creation_tokens,
534            reasoning_tokens: last.reasoning_tokens,
535            total_input_tokens: self.token_tracker.total_input_tokens(),
536            total_output_tokens: self.token_tracker.total_output_tokens(),
537            total_cache_read_tokens: self.token_tracker.total_cache_read_tokens(),
538            total_cache_creation_tokens: self.token_tracker.total_cache_creation_tokens(),
539            total_reasoning_tokens: self.token_tracker.total_reasoning_tokens(),
540        }
541    }
542
543    /// Pre-flight check: estimate context size and compact proactively if it would
544    /// overflow before the LLM even sees it. This catches the case where large tool
545    /// results push context past the limit before usage-based compaction can fire.
546    async fn maybe_preflight_compact(&mut self) {
547        let Some(context_limit) = self.token_tracker.context_limit() else {
548            return;
549        };
550        let Some(config) = self.compaction_config.as_ref() else {
551            return;
552        };
553        let estimated = self.context.estimated_token_count();
554        #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
555        let threshold = (f64::from(context_limit) * config.threshold).ceil() as u32;
556        if estimated >= threshold {
557            tracing::info!(
558                "Pre-flight compaction triggered: estimated {estimated} tokens >= {:.1}% of {context_limit} limit",
559                config.threshold * 100.0
560            );
561            if let CompactionOutcome::Failed(e) = self.compact_context().await {
562                tracing::warn!("Pre-flight compaction failed: {e}");
563            }
564        }
565    }
566
567    /// Check if compaction is needed and perform it if so.
568    async fn maybe_compact_context(&mut self) {
569        if !self.compaction_config.as_ref().is_some_and(|config| self.token_tracker.should_compact(config.threshold)) {
570            return;
571        }
572
573        if let CompactionOutcome::Failed(error_message) = self.compact_context().await {
574            tracing::warn!("Context compaction failed: {}", error_message);
575        }
576    }
577
578    async fn compact_context(&mut self) -> CompactionOutcome {
579        let Some(ref _config) = self.compaction_config else {
580            tracing::warn!("Context compaction requested but compaction is disabled");
581            return CompactionOutcome::SkippedDisabled;
582        };
583
584        match self.token_tracker.usage_ratio() {
585            Some(usage_ratio) => {
586                tracing::info!(
587                    "Starting context compaction - {} messages, {:.1}% of context limit",
588                    self.context.message_count(),
589                    usage_ratio * 100.0
590                );
591            }
592            None => {
593                tracing::info!(
594                    "Starting context compaction - {} messages (context limit unknown)",
595                    self.context.message_count(),
596                );
597            }
598        }
599
600        let _ = self
601            .message_tx
602            .send(AgentMessage::ContextCompactionStarted { message_count: self.context.message_count() })
603            .await;
604
605        let compactor = Compactor::new(self.llm.clone());
606
607        match compactor.compact(&self.context).await {
608            Ok(result) => {
609                tracing::info!("Context compacted: {} messages removed", result.messages_removed);
610
611                self.context = result.context;
612                self.token_tracker.reset_current_usage();
613
614                let _ = self
615                    .message_tx
616                    .send(AgentMessage::ContextCompactionResult {
617                        summary: result.summary,
618                        messages_removed: result.messages_removed,
619                    })
620                    .await;
621                CompactionOutcome::Compacted
622            }
623            Err(e) => CompactionOutcome::Failed(e.to_string()),
624        }
625    }
626
627    async fn on_tool_execution_event(&mut self, event: ToolExecutionEvent, state: &mut IterationState) {
628        match event {
629            ToolExecutionEvent::Started { tool_id, tool_name } => {
630                tracing::debug!("Tool execution started: {} ({})", tool_name, tool_id);
631            }
632
633            ToolExecutionEvent::Progress { tool_id, progress } => {
634                tracing::debug!(
635                    "Tool progress for {}: {}/{}",
636                    tool_id,
637                    progress.progress,
638                    progress.total.unwrap_or(0.0)
639                );
640
641                if let Some(request) = self.active_requests.get(&tool_id) {
642                    let _ = self
643                        .message_tx
644                        .send(AgentMessage::ToolProgress {
645                            request: request.clone(),
646                            progress: progress.progress,
647                            total: progress.total,
648                            message: progress.message.clone(),
649                        })
650                        .await;
651                }
652            }
653
654            ToolExecutionEvent::Complete { tool_id: _, result, result_meta } => match result {
655                Ok(tool_result) => {
656                    tracing::debug!("Tool result received: {} -> {}", tool_result.name, tool_result.result.len());
657
658                    if state.pending_tool_ids.remove(&tool_result.id) {
659                        self.active_requests.remove(&tool_result.id);
660                        state.completed_tool_calls.push(Ok(tool_result.clone()));
661
662                        let msg = AgentMessage::ToolResult {
663                            result: tool_result,
664                            result_meta,
665                            model_name: self.llm.display_name(),
666                        };
667
668                        if let Err(e) = self.message_tx.send(msg).await {
669                            tracing::warn!("Failed to send ToolCall completion message: {:?}", e);
670                        }
671                    } else {
672                        tracing::debug!("Ignoring stale tool result for id: {}", tool_result.id);
673                    }
674                }
675
676                Err(tool_error) => {
677                    if state.pending_tool_ids.remove(&tool_error.id) {
678                        self.active_requests.remove(&tool_error.id);
679                        state.completed_tool_calls.push(Err(tool_error.clone()));
680
681                        let _ = self
682                            .message_tx
683                            .send(AgentMessage::ToolError { error: tool_error, model_name: self.llm.display_name() })
684                            .await;
685                    }
686                }
687            },
688        }
689    }
690
691    fn update_context(
692        &mut self,
693        message_content: &str,
694        reasoning: AssistantReasoning,
695        completed_tools: Vec<Result<ToolCallResult, ToolCallError>>,
696    ) {
697        self.context.push_assistant_turn(message_content, reasoning, completed_tools);
698    }
699}
700
701#[derive(Debug, Clone, PartialEq, Eq)]
702enum CompactionOutcome {
703    Compacted,
704    SkippedDisabled,
705    Failed(String),
706}
707
708pub(crate) struct AutoContinue {
709    max: u32,
710    count: u32,
711}
712
713impl AutoContinue {
714    pub(crate) fn new(max: u32) -> Self {
715        Self { max, count: 0 }
716    }
717
718    fn reset(&mut self) {
719        self.count = 0;
720    }
721
722    fn should_continue(&self, stop_reason: Option<&StopReason>) -> bool {
723        matches!(stop_reason, Some(StopReason::Length)) && self.count < self.max
724    }
725
726    fn advance(&mut self) {
727        self.count += 1;
728    }
729
730    fn count(&self) -> u32 {
731        self.count
732    }
733
734    fn max(&self) -> u32 {
735        self.max
736    }
737}
738
739#[derive(Debug)]
740struct IterationState {
741    current_message_id: Option<String>,
742    message_content: String,
743    reasoning_summary_text: String,
744    encrypted_reasoning: Option<EncryptedReasoningContent>,
745    pending_tool_ids: HashSet<String>,
746    completed_tool_calls: Vec<Result<ToolCallResult, ToolCallError>>,
747    llm_done: bool,
748    stop_reason: Option<StopReason>,
749    retry_attempt: u32,
750}
751
752impl IterationState {
753    fn new() -> Self {
754        Self {
755            current_message_id: None,
756            message_content: String::new(),
757            reasoning_summary_text: String::new(),
758            encrypted_reasoning: None,
759            pending_tool_ids: HashSet::new(),
760            completed_tool_calls: Vec::new(),
761            llm_done: false,
762            stop_reason: None,
763            retry_attempt: 0,
764        }
765    }
766
767    fn on_llm_start(&mut self, message_id: String) {
768        self.current_message_id = Some(message_id);
769        self.message_content.clear();
770        self.reasoning_summary_text.clear();
771        self.encrypted_reasoning = None;
772        self.stop_reason = None;
773    }
774
775    fn is_complete(&self) -> bool {
776        self.llm_done && self.pending_tool_ids.is_empty()
777    }
778}
779
780#[cfg(test)]
781mod tests {
782    use crate::core::{AgentBuilder, Prompt};
783
784    use super::*;
785    use llm::{ContentBlock, testing::FakeLlmProvider};
786    use tokio::sync::mpsc;
787
788    #[tokio::test]
789    async fn replace_conversation_preserves_system_prompt_for_next_request() {
790        let llm = FakeLlmProvider::with_single_response(vec![LlmResponse::start("msg"), LlmResponse::done()]);
791
792        let captured_contexts = llm.captured_contexts();
793        let (tx, mut rx, handle) =
794            AgentBuilder::new(Arc::new(llm)).system_prompt(Prompt::text("original system")).spawn().await.unwrap();
795
796        tx.send(Command::AgentCommand(AgentCommand::ReplaceConversation(vec![
797            ChatMessage::User { content: vec![ContentBlock::text("old user")], timestamp: IsoString::now() },
798            ChatMessage::Assistant {
799                content: "old assistant".to_string(),
800                reasoning: AssistantReasoning::default(),
801                timestamp: IsoString::now(),
802                tool_calls: vec![],
803            },
804        ])))
805        .await
806        .unwrap();
807
808        tx.send(Command::UserCommand(UserCommand::Text { content: vec![ContentBlock::text("new user")] }))
809            .await
810            .unwrap();
811
812        while let Some(message) = rx.recv().await {
813            if matches!(message, AgentMessage::Done) {
814                break;
815            }
816        }
817
818        let contexts = captured_contexts.lock().unwrap();
819        let messages = contexts.last().expect("provider should receive a context").messages();
820        assert!(matches!(messages[0], ChatMessage::System { ref content, .. } if content == "original system"));
821        assert!(
822            matches!(messages[1], ChatMessage::User { ref content, .. } if content == &vec![llm::ContentBlock::text("old user")])
823        );
824        assert!(matches!(messages[2], ChatMessage::Assistant { ref content, .. } if content == "old assistant"));
825        assert!(
826            matches!(messages[3], ChatMessage::User { ref content, .. } if content == &vec![llm::ContentBlock::text("new user")])
827        );
828        handle.abort();
829    }
830
831    #[tokio::test]
832    async fn replace_conversation_preserves_token_usage() {
833        let llm = FakeLlmProvider::new(vec![vec![
834            LlmResponse::start("msg"),
835            LlmResponse::usage(800, 10),
836            LlmResponse::done(),
837        ]])
838        .with_context_window(Some(1000));
839        let (tx, mut rx, handle) = AgentBuilder::new(Arc::new(llm)).spawn().await.unwrap();
840
841        tx.send(Command::UserCommand(UserCommand::Text { content: vec![llm::ContentBlock::text("first user")] }))
842            .await
843            .unwrap();
844
845        while let Some(message) = rx.recv().await {
846            if matches!(message, AgentMessage::Done) {
847                break;
848            }
849        }
850
851        tx.send(Command::AgentCommand(AgentCommand::ReplaceConversation(vec![ChatMessage::User {
852            content: vec![ContentBlock::text("replacement user")],
853            timestamp: IsoString::now(),
854        }])))
855        .await
856        .unwrap();
857
858        let Some(AgentMessage::ContextUsageUpdate { input_tokens, usage_ratio, .. }) = rx.recv().await else {
859            panic!("expected context usage update after conversation replacement");
860        };
861
862        assert_eq!(input_tokens, 800);
863        assert_eq!(usage_ratio, Some(0.8));
864        handle.abort();
865    }
866
867    #[tokio::test]
868    async fn test_preflight_compaction_uses_configured_threshold() {
869        let llm = Arc::new(
870            FakeLlmProvider::with_single_response(vec![
871                LlmResponse::start("summary"),
872                LlmResponse::text("summary"),
873                LlmResponse::done(),
874            ])
875            .with_context_window(Some(100)),
876        );
877        let context = Context::new(
878            vec![ChatMessage::User {
879                content: vec![llm::ContentBlock::text("x".repeat(344))],
880                timestamp: IsoString::now(),
881            }],
882            vec![],
883        );
884        let (user_tx, user_rx) = mpsc::channel(1);
885        let (message_tx, _message_rx) = mpsc::channel(8);
886        drop(user_tx);
887
888        let mut agent = Agent::new(
889            AgentConfig {
890                llm,
891                context,
892                mcp_command_tx: None,
893                tool_timeout: Duration::from_secs(1),
894                compaction_config: Some(CompactionConfig::with_threshold(0.85)),
895                auto_continue: AutoContinue::new(0),
896                retry_config: RetryConfig::disabled(),
897                context_window: None,
898                prompt_cache: PromptCache::new(vec![]),
899            },
900            user_rx,
901            message_tx,
902        );
903
904        agent.maybe_preflight_compact().await;
905
906        assert!(
907            matches!(
908                agent.context.messages().as_slice(),
909                [ChatMessage::Summary { content, .. }] if content == "summary"
910            ),
911            "expected context to be compacted, got {:?}",
912            agent.context.messages()
913        );
914    }
915}