Skip to main content

aether_core/core/
agent.rs

1use crate::context::{CompactionConfig, Compactor, TokenTracker};
2pub use crate::core::retry_config::RetryConfig;
3use crate::events::{AgentMessage, UserMessage};
4use crate::mcp::run_mcp_task::{McpCommand, ToolExecutionEvent};
5use futures::Stream;
6use llm::types::IsoString;
7use llm::{
8    AssistantReasoning, ChatMessage, Context, EncryptedReasoningContent, LlmError, LlmResponse, StopReason,
9    StreamingModelProvider, TokenUsage, ToolCallError, ToolCallRequest, ToolCallResult,
10};
11use std::collections::{HashMap, HashSet, VecDeque};
12use std::pin::Pin;
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::sync::mpsc;
16use tokio::time::sleep;
17use tokio_stream::StreamExt;
18use tokio_stream::StreamMap;
19use tokio_stream::wrappers::ReceiverStream;
20
21/// Internal event type for merging LLM and tool result streams
22#[derive(Debug)]
23enum StreamEvent {
24    Llm(Result<LlmResponse, LlmError>),
25    ToolExecution(ToolExecutionEvent),
26    UserMessage(UserMessage),
27}
28
29type EventStream = Pin<Box<dyn Stream<Item = StreamEvent> + Send>>;
30
31const USER_STREAM_KEY: &str = "user";
32const LLM_STREAM_KEY: &str = "llm";
33
34pub(crate) struct AgentConfig {
35    pub llm: Arc<dyn StreamingModelProvider>,
36    pub context: Context,
37    pub mcp_command_tx: Option<mpsc::Sender<McpCommand>>,
38    pub tool_timeout: Duration,
39    pub compaction_config: Option<CompactionConfig>,
40    pub auto_continue: AutoContinue,
41    pub retry_config: RetryConfig,
42    pub context_window: Option<u32>,
43}
44
45pub struct Agent {
46    llm: Arc<dyn StreamingModelProvider>,
47    context: Context,
48    mcp_command_tx: Option<mpsc::Sender<McpCommand>>,
49    message_tx: mpsc::Sender<AgentMessage>,
50    streams: StreamMap<String, EventStream>,
51    tool_timeout: Duration,
52    token_tracker: TokenTracker,
53    compaction_config: Option<CompactionConfig>,
54    auto_continue: AutoContinue,
55    retry_config: RetryConfig,
56    active_requests: HashMap<String, ToolCallRequest>,
57    queued_user_messages: VecDeque<Vec<llm::ContentBlock>>,
58    context_window: Option<u32>,
59}
60
61impl Agent {
62    pub(crate) fn new(
63        config: AgentConfig,
64        user_message_rx: mpsc::Receiver<UserMessage>,
65        message_tx: mpsc::Sender<AgentMessage>,
66    ) -> Self {
67        let mut streams: StreamMap<String, EventStream> = StreamMap::new();
68        streams.insert(
69            USER_STREAM_KEY.to_string(),
70            Box::pin(ReceiverStream::new(user_message_rx).map(StreamEvent::UserMessage)),
71        );
72
73        let context_limit = config.context_window.or_else(|| config.llm.context_window());
74
75        Self {
76            llm: config.llm,
77            context: config.context,
78            mcp_command_tx: config.mcp_command_tx,
79            message_tx,
80            streams,
81            tool_timeout: config.tool_timeout,
82            token_tracker: TokenTracker::new(context_limit),
83            compaction_config: config.compaction_config,
84            auto_continue: config.auto_continue,
85            retry_config: config.retry_config,
86            active_requests: HashMap::new(),
87            queued_user_messages: VecDeque::new(),
88            context_window: config.context_window,
89        }
90    }
91
92    pub fn current_model_display_name(&self) -> String {
93        self.llm.display_name()
94    }
95
96    /// Get a reference to the token tracker
97    pub fn token_tracker(&self) -> &TokenTracker {
98        &self.token_tracker
99    }
100
101    pub async fn run(mut self) {
102        let mut state = IterationState::new();
103
104        while let Some((_, event)) = self.streams.next().await {
105            use UserMessage::{Cancel, ClearContext, SetReasoningEffort, SwitchModel, Text, UpdateTools};
106            match event {
107                StreamEvent::UserMessage(Cancel) => {
108                    self.on_user_cancel(&mut state).await;
109                }
110
111                StreamEvent::UserMessage(ClearContext) => {
112                    self.on_user_clear_context(&mut state).await;
113                }
114
115                StreamEvent::UserMessage(Text { content }) => {
116                    if self.is_busy() {
117                        self.queued_user_messages.push_back(content);
118                    } else {
119                        state = IterationState::new();
120                        self.on_user_text(content);
121                    }
122                }
123
124                StreamEvent::UserMessage(SwitchModel(new_provider)) => {
125                    self.on_switch_model(new_provider).await;
126                }
127
128                StreamEvent::UserMessage(UpdateTools(tools)) => {
129                    self.context.set_tools(tools);
130                }
131
132                StreamEvent::UserMessage(SetReasoningEffort(effort)) => {
133                    self.context.set_reasoning_effort(effort);
134                }
135
136                StreamEvent::Llm(llm_event) => {
137                    self.on_llm_event(llm_event, &mut state).await;
138                }
139
140                StreamEvent::ToolExecution(tool_event) => {
141                    self.on_tool_execution_event(tool_event, &mut state).await;
142                }
143            }
144
145            if state.is_complete() {
146                let Some(id) = state.current_message_id.take() else {
147                    continue;
148                };
149                let iteration = std::mem::replace(&mut state, IterationState::new());
150                self.on_iteration_complete(id, iteration).await;
151            }
152        }
153
154        tracing::debug!("Agent task shutting down - input channel closed");
155    }
156
157    async fn on_iteration_complete(&mut self, id: String, iteration: IterationState) {
158        let IterationState {
159            message_content,
160            reasoning_summary_text,
161            encrypted_reasoning,
162            completed_tool_calls,
163            stop_reason,
164            ..
165        } = iteration;
166        let has_tool_calls = !completed_tool_calls.is_empty();
167        let has_content = !message_content.is_empty() || has_tool_calls;
168        let should_auto_continue = self.auto_continue.should_continue(stop_reason.as_ref());
169
170        if has_content {
171            let reasoning = AssistantReasoning::from_parts(reasoning_summary_text.clone(), encrypted_reasoning);
172            self.update_context(&message_content, reasoning, completed_tool_calls);
173
174            let _ = self
175                .message_tx
176                .send(AgentMessage::Text {
177                    message_id: id.clone(),
178                    chunk: message_content.clone(),
179                    is_complete: true,
180                    model_name: self.llm.display_name(),
181                })
182                .await;
183
184            if !reasoning_summary_text.is_empty() {
185                let _ = self
186                    .message_tx
187                    .send(AgentMessage::Thought {
188                        message_id: id.clone(),
189                        chunk: reasoning_summary_text,
190                        is_complete: true,
191                        model_name: self.llm.display_name(),
192                    })
193                    .await;
194            }
195        }
196
197        let has_queued_text = !self.queued_user_messages.is_empty();
198        if has_queued_text {
199            let content: Vec<_> = self.queued_user_messages.drain(..).flatten().collect();
200            self.context.add_message(ChatMessage::User { content, timestamp: IsoString::now() });
201        }
202
203        if has_queued_text || has_tool_calls {
204            self.auto_continue.reset();
205            self.start_next_turn().await;
206        } else if should_auto_continue {
207            self.auto_continue.advance();
208            tracing::info!(
209                "LLM stopped with {:?}, auto-continuing (attempt {}/{})",
210                stop_reason,
211                self.auto_continue.count(),
212                self.auto_continue.max()
213            );
214
215            let _ = self
216                .message_tx
217                .send(AgentMessage::AutoContinue {
218                    attempt: self.auto_continue.count(),
219                    max_attempts: self.auto_continue.max(),
220                })
221                .await;
222
223            self.inject_continuation_prompt(&message_content, stop_reason.as_ref());
224            self.start_next_turn().await;
225        } else {
226            tracing::debug!("LLM completed turn with stop reason: {:?}", stop_reason);
227            self.auto_continue.reset();
228            if let Err(e) = self.message_tx.send(AgentMessage::Done).await {
229                tracing::warn!("Failed to send Done message: {:?}", e);
230            }
231        }
232    }
233
234    async fn start_next_turn(&mut self) {
235        self.maybe_preflight_compact().await;
236        self.start_llm_stream(None);
237    }
238
239    async fn on_user_cancel(&mut self, state: &mut IterationState) {
240        self.abort_in_flight_work();
241        *state = IterationState::new();
242        let _ = self.message_tx.send(AgentMessage::Cancelled { message: "Processing cancelled".to_string() }).await;
243        let _ = self.message_tx.send(AgentMessage::Done).await;
244    }
245
246    async fn on_user_clear_context(&mut self, state: &mut IterationState) {
247        self.abort_in_flight_work();
248        self.context.clear_conversation();
249        self.token_tracker.reset_current_usage();
250        self.auto_continue.reset();
251        *state = IterationState::new();
252
253        let _ = self.message_tx.send(AgentMessage::ContextCleared).await;
254    }
255
256    fn on_user_text(&mut self, content: Vec<llm::ContentBlock>) {
257        self.context.add_message(ChatMessage::User { content, timestamp: IsoString::now() });
258        self.auto_continue.reset();
259        self.start_llm_stream(None);
260    }
261
262    async fn on_switch_model(&mut self, new_provider: Box<dyn StreamingModelProvider>) {
263        let previous = self.llm.display_name();
264        let new_context_limit = self.context_window.or_else(|| new_provider.context_window());
265        self.llm = Arc::from(new_provider);
266        self.token_tracker.reset_current_usage();
267        self.token_tracker.set_context_limit(new_context_limit);
268        let new = self.llm.display_name();
269        let _ = self.message_tx.send(AgentMessage::ModelSwitched { previous, new }).await;
270
271        let _ = self.message_tx.send(self.context_usage_message()).await;
272    }
273
274    fn start_llm_stream(&mut self, delay: Option<Duration>) {
275        self.streams.remove(LLM_STREAM_KEY);
276        let stream: EventStream = match delay {
277            None => Box::pin(self.llm.stream_response(&self.context).map(StreamEvent::Llm)),
278            Some(delay) => {
279                let llm = Arc::clone(&self.llm);
280                let context = self.context.clone();
281                Box::pin(async_stream::stream! {
282                    sleep(delay).await;
283                    let mut inner = llm.stream_response(&context);
284                    while let Some(item) = inner.next().await {
285                        yield StreamEvent::Llm(item);
286                    }
287                })
288            }
289        };
290        self.streams.insert(LLM_STREAM_KEY.to_string(), stream);
291    }
292
293    async fn on_llm_error(&mut self, error: LlmError, state: &mut IterationState) {
294        if !error.is_retryable() || state.retry_attempt >= self.retry_config.max_attempts {
295            let _ = self.message_tx.send(AgentMessage::Error { message: error.to_string() }).await;
296            return;
297        }
298
299        state.retry_attempt += 1;
300        let delay = self.retry_config.compute_delay(state.retry_attempt);
301        let delay_ms = u64::try_from(delay.as_millis()).unwrap_or(u64::MAX);
302
303        tracing::warn!(
304            attempt = state.retry_attempt,
305            max_attempts = self.retry_config.max_attempts,
306            delay_ms,
307            error = %error,
308            "Retrying LLM request after transient failure"
309        );
310
311        let _ = self
312            .message_tx
313            .send(AgentMessage::Retrying {
314                attempt: state.retry_attempt,
315                max_attempts: self.retry_config.max_attempts,
316                delay_ms,
317                error: error.to_string(),
318            })
319            .await;
320
321        // The previous stream may have emitted partial tool-call deltas
322        // before interrupting so we drop them to ensure we rebuild tool state
323        self.active_requests.clear();
324        self.start_llm_stream(Some(delay));
325    }
326
327    fn is_busy(&self) -> bool {
328        self.streams.contains_key(LLM_STREAM_KEY) || !self.active_requests.is_empty()
329    }
330
331    fn abort_in_flight_work(&mut self) {
332        self.streams.remove(LLM_STREAM_KEY);
333        for stream_key in self.active_requests.keys().cloned().collect::<Vec<_>>() {
334            self.streams.remove(&stream_key);
335        }
336        self.active_requests.clear();
337        self.queued_user_messages.clear();
338    }
339
340    /// Inject a continuation prompt when the LLM stops due to a resumable reason.
341    fn inject_continuation_prompt(&mut self, previous_response: &str, stop_reason: Option<&StopReason>) {
342        if !previous_response.is_empty() {
343            self.context.add_message(ChatMessage::Assistant {
344                content: previous_response.to_string(),
345                reasoning: AssistantReasoning::default(),
346                timestamp: IsoString::now(),
347                tool_calls: Vec::new(),
348            });
349        }
350
351        let reason = stop_reason.map_or_else(|| "Unknown".to_string(), |reason| format!("{reason:?}"));
352
353        self.context.add_message(ChatMessage::User {
354            content: vec![llm::ContentBlock::text(format!(
355                "<system-notification>The LLM API stopped with reason '{reason}'. Continue from where you left off and finish your task.</system-notification>"
356            ))],
357            timestamp: IsoString::now(),
358        });
359    }
360
361    async fn on_llm_event(&mut self, result: Result<LlmResponse, LlmError>, state: &mut IterationState) {
362        use LlmResponse::{
363            Done, EncryptedReasoning, Error, Reasoning, Start, Text, ToolRequestArg, ToolRequestComplete,
364            ToolRequestStart, Usage,
365        };
366
367        let response = match result {
368            Ok(response) => response,
369            Err(e) => {
370                self.on_llm_error(e, state).await;
371                return;
372            }
373        };
374
375        match response {
376            Start { message_id } => {
377                state.on_llm_start(message_id);
378            }
379
380            Text { chunk } => {
381                self.handle_llm_text(chunk, state).await;
382            }
383
384            Reasoning { chunk } => {
385                state.reasoning_summary_text.push_str(&chunk);
386                if let Some(id) = &state.current_message_id {
387                    let _ = self
388                        .message_tx
389                        .send(AgentMessage::Thought {
390                            message_id: id.clone(),
391                            chunk,
392                            is_complete: false,
393                            model_name: self.llm.display_name(),
394                        })
395                        .await;
396                }
397            }
398
399            EncryptedReasoning { id, content } => {
400                if let Some(model) = self.llm.model() {
401                    state.encrypted_reasoning = Some(EncryptedReasoningContent { id, model, content });
402                }
403            }
404
405            ToolRequestStart { id, name } => {
406                self.handle_tool_request_start(id, name).await;
407            }
408
409            ToolRequestArg { id, chunk } => {
410                self.handle_tool_request_arg(id, chunk).await;
411            }
412
413            ToolRequestComplete { tool_call } => {
414                self.handle_tool_completion(tool_call, state).await;
415            }
416
417            Done { stop_reason } => {
418                state.llm_done = true;
419                state.stop_reason = stop_reason;
420            }
421
422            Error { message } => {
423                let _ = self.message_tx.send(AgentMessage::Error { message }).await;
424            }
425
426            Usage { tokens: sample } => {
427                self.handle_llm_usage(sample).await;
428            }
429        }
430    }
431
432    async fn handle_llm_text(&mut self, chunk: String, state: &mut IterationState) {
433        state.message_content.push_str(&chunk);
434
435        if let Some(id) = &state.current_message_id {
436            let _ = self
437                .message_tx
438                .send(AgentMessage::Text {
439                    message_id: id.clone(),
440                    chunk,
441                    is_complete: false,
442                    model_name: self.llm.display_name(),
443                })
444                .await;
445        }
446    }
447
448    async fn handle_tool_request_start(&mut self, id: String, name: String) {
449        let request = ToolCallRequest { id: id.clone(), name, arguments: String::new() };
450        self.active_requests.insert(id, request.clone());
451
452        let _ = self.message_tx.send(AgentMessage::ToolCall { request, model_name: self.llm.display_name() }).await;
453    }
454
455    async fn handle_tool_request_arg(&mut self, id: String, chunk: String) {
456        let Some(request) = self.active_requests.get_mut(&id) else {
457            return;
458        };
459        request.arguments.push_str(&chunk);
460
461        let _ = self
462            .message_tx
463            .send(AgentMessage::ToolCallUpdate { tool_call_id: id, chunk, model_name: self.llm.display_name() })
464            .await;
465    }
466
467    async fn handle_tool_completion(&mut self, tool_call: ToolCallRequest, state: &mut IterationState) {
468        state.pending_tool_ids.insert(tool_call.id.clone());
469        debug_assert!(
470            self.active_requests.contains_key(&tool_call.id),
471            "tool call {} should already be in active_requests from handle_tool_request_start",
472            tool_call.id
473        );
474
475        let (tx, rx) = mpsc::channel(100);
476        let stream = ReceiverStream::new(rx).map(StreamEvent::ToolExecution);
477        let stream_key = tool_call.id.clone();
478        self.streams.insert(stream_key, Box::pin(stream));
479
480        if let Some(ref mcp_command_tx) = self.mcp_command_tx {
481            let mcp_future =
482                mcp_command_tx.send(McpCommand::ExecuteTool { request: tool_call, timeout: self.tool_timeout, tx });
483            if let Err(e) = mcp_future.await {
484                tracing::warn!("Failed to send tool request to MCP task: {:?}", e);
485            }
486        }
487    }
488
489    async fn handle_llm_usage(&mut self, sample: TokenUsage) {
490        self.token_tracker.record_usage(sample);
491        let ratio_pct = self.token_tracker.usage_ratio().map(|r| r * 100.0);
492        let remaining = self.token_tracker.tokens_remaining();
493        tracing::debug!(?sample, ?ratio_pct, ?remaining, "Token usage");
494
495        let _ = self.message_tx.send(self.context_usage_message()).await;
496
497        self.maybe_compact_context().await;
498    }
499
500    fn context_usage_message(&self) -> AgentMessage {
501        let last = self.token_tracker.last_usage();
502        AgentMessage::ContextUsageUpdate {
503            usage_ratio: self.token_tracker.usage_ratio(),
504            context_limit: self.token_tracker.context_limit(),
505            input_tokens: last.input_tokens,
506            output_tokens: last.output_tokens,
507            cache_read_tokens: last.cache_read_tokens,
508            cache_creation_tokens: last.cache_creation_tokens,
509            reasoning_tokens: last.reasoning_tokens,
510            total_input_tokens: self.token_tracker.total_input_tokens(),
511            total_output_tokens: self.token_tracker.total_output_tokens(),
512            total_cache_read_tokens: self.token_tracker.total_cache_read_tokens(),
513            total_cache_creation_tokens: self.token_tracker.total_cache_creation_tokens(),
514            total_reasoning_tokens: self.token_tracker.total_reasoning_tokens(),
515        }
516    }
517
518    /// Pre-flight check: estimate context size and compact proactively if it would
519    /// overflow before the LLM even sees it. This catches the case where large tool
520    /// results push context past the limit before usage-based compaction can fire.
521    async fn maybe_preflight_compact(&mut self) {
522        let Some(context_limit) = self.token_tracker.context_limit() else {
523            return;
524        };
525        let Some(config) = self.compaction_config.as_ref() else {
526            return;
527        };
528        let estimated = self.context.estimated_token_count();
529        #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
530        let threshold = (f64::from(context_limit) * config.threshold).ceil() as u32;
531        if estimated >= threshold {
532            tracing::info!(
533                "Pre-flight compaction triggered: estimated {estimated} tokens >= {:.1}% of {context_limit} limit",
534                config.threshold * 100.0
535            );
536            if let CompactionOutcome::Failed(e) = self.compact_context().await {
537                tracing::warn!("Pre-flight compaction failed: {e}");
538            }
539        }
540    }
541
542    /// Check if compaction is needed and perform it if so.
543    async fn maybe_compact_context(&mut self) {
544        if !self.compaction_config.as_ref().is_some_and(|config| self.token_tracker.should_compact(config.threshold)) {
545            return;
546        }
547
548        if let CompactionOutcome::Failed(error_message) = self.compact_context().await {
549            tracing::warn!("Context compaction failed: {}", error_message);
550        }
551    }
552
553    async fn compact_context(&mut self) -> CompactionOutcome {
554        let Some(ref _config) = self.compaction_config else {
555            tracing::warn!("Context compaction requested but compaction is disabled");
556            return CompactionOutcome::SkippedDisabled;
557        };
558
559        match self.token_tracker.usage_ratio() {
560            Some(usage_ratio) => {
561                tracing::info!(
562                    "Starting context compaction - {} messages, {:.1}% of context limit",
563                    self.context.message_count(),
564                    usage_ratio * 100.0
565                );
566            }
567            None => {
568                tracing::info!(
569                    "Starting context compaction - {} messages (context limit unknown)",
570                    self.context.message_count(),
571                );
572            }
573        }
574
575        let _ = self
576            .message_tx
577            .send(AgentMessage::ContextCompactionStarted { message_count: self.context.message_count() })
578            .await;
579
580        let compactor = Compactor::new(self.llm.clone());
581
582        match compactor.compact(&self.context).await {
583            Ok(result) => {
584                tracing::info!("Context compacted: {} messages removed", result.messages_removed);
585
586                self.context = result.context;
587                self.token_tracker.reset_current_usage();
588
589                let _ = self
590                    .message_tx
591                    .send(AgentMessage::ContextCompactionResult {
592                        summary: result.summary,
593                        messages_removed: result.messages_removed,
594                    })
595                    .await;
596                CompactionOutcome::Compacted
597            }
598            Err(e) => CompactionOutcome::Failed(e.to_string()),
599        }
600    }
601
602    async fn on_tool_execution_event(&mut self, event: ToolExecutionEvent, state: &mut IterationState) {
603        match event {
604            ToolExecutionEvent::Started { tool_id, tool_name } => {
605                tracing::debug!("Tool execution started: {} ({})", tool_name, tool_id);
606            }
607
608            ToolExecutionEvent::Progress { tool_id, progress } => {
609                tracing::debug!(
610                    "Tool progress for {}: {}/{}",
611                    tool_id,
612                    progress.progress,
613                    progress.total.unwrap_or(0.0)
614                );
615
616                if let Some(request) = self.active_requests.get(&tool_id) {
617                    let _ = self
618                        .message_tx
619                        .send(AgentMessage::ToolProgress {
620                            request: request.clone(),
621                            progress: progress.progress,
622                            total: progress.total,
623                            message: progress.message.clone(),
624                        })
625                        .await;
626                }
627            }
628
629            ToolExecutionEvent::Complete { tool_id: _, result, result_meta } => match result {
630                Ok(tool_result) => {
631                    tracing::debug!("Tool result received: {} -> {}", tool_result.name, tool_result.result.len());
632
633                    if state.pending_tool_ids.remove(&tool_result.id) {
634                        self.active_requests.remove(&tool_result.id);
635                        state.completed_tool_calls.push(Ok(tool_result.clone()));
636
637                        let msg = AgentMessage::ToolResult {
638                            result: tool_result,
639                            result_meta,
640                            model_name: self.llm.display_name(),
641                        };
642
643                        if let Err(e) = self.message_tx.send(msg).await {
644                            tracing::warn!("Failed to send ToolCall completion message: {:?}", e);
645                        }
646                    } else {
647                        tracing::debug!("Ignoring stale tool result for id: {}", tool_result.id);
648                    }
649                }
650
651                Err(tool_error) => {
652                    if state.pending_tool_ids.remove(&tool_error.id) {
653                        self.active_requests.remove(&tool_error.id);
654                        state.completed_tool_calls.push(Err(tool_error.clone()));
655
656                        let _ = self
657                            .message_tx
658                            .send(AgentMessage::ToolError { error: tool_error, model_name: self.llm.display_name() })
659                            .await;
660                    }
661                }
662            },
663        }
664    }
665
666    fn update_context(
667        &mut self,
668        message_content: &str,
669        reasoning: AssistantReasoning,
670        completed_tools: Vec<Result<ToolCallResult, ToolCallError>>,
671    ) {
672        self.context.push_assistant_turn(message_content, reasoning, completed_tools);
673    }
674}
675
676#[derive(Debug, Clone, PartialEq, Eq)]
677enum CompactionOutcome {
678    Compacted,
679    SkippedDisabled,
680    Failed(String),
681}
682
683pub(crate) struct AutoContinue {
684    max: u32,
685    count: u32,
686}
687
688impl AutoContinue {
689    pub(crate) fn new(max: u32) -> Self {
690        Self { max, count: 0 }
691    }
692
693    fn reset(&mut self) {
694        self.count = 0;
695    }
696
697    fn should_continue(&self, stop_reason: Option<&StopReason>) -> bool {
698        matches!(stop_reason, Some(StopReason::Length)) && self.count < self.max
699    }
700
701    fn advance(&mut self) {
702        self.count += 1;
703    }
704
705    fn count(&self) -> u32 {
706        self.count
707    }
708
709    fn max(&self) -> u32 {
710        self.max
711    }
712}
713
714#[derive(Debug)]
715struct IterationState {
716    current_message_id: Option<String>,
717    message_content: String,
718    reasoning_summary_text: String,
719    encrypted_reasoning: Option<EncryptedReasoningContent>,
720    pending_tool_ids: HashSet<String>,
721    completed_tool_calls: Vec<Result<ToolCallResult, ToolCallError>>,
722    llm_done: bool,
723    stop_reason: Option<StopReason>,
724    retry_attempt: u32,
725}
726
727impl IterationState {
728    fn new() -> Self {
729        Self {
730            current_message_id: None,
731            message_content: String::new(),
732            reasoning_summary_text: String::new(),
733            encrypted_reasoning: None,
734            pending_tool_ids: HashSet::new(),
735            completed_tool_calls: Vec::new(),
736            llm_done: false,
737            stop_reason: None,
738            retry_attempt: 0,
739        }
740    }
741
742    fn on_llm_start(&mut self, message_id: String) {
743        self.current_message_id = Some(message_id);
744        self.message_content.clear();
745        self.reasoning_summary_text.clear();
746        self.encrypted_reasoning = None;
747        self.stop_reason = None;
748    }
749
750    fn is_complete(&self) -> bool {
751        self.llm_done && self.pending_tool_ids.is_empty()
752    }
753}
754
755#[cfg(test)]
756mod tests {
757    use super::*;
758    use llm::testing::FakeLlmProvider;
759    use tokio::sync::mpsc;
760
761    #[tokio::test]
762    async fn test_preflight_compaction_uses_configured_threshold() {
763        let llm = Arc::new(
764            FakeLlmProvider::with_single_response(vec![
765                LlmResponse::start("summary"),
766                LlmResponse::text("summary"),
767                LlmResponse::done(),
768            ])
769            .with_context_window(Some(100)),
770        );
771        let context = Context::new(
772            vec![ChatMessage::User {
773                content: vec![llm::ContentBlock::text("x".repeat(344))],
774                timestamp: IsoString::now(),
775            }],
776            vec![],
777        );
778        let (user_tx, user_rx) = mpsc::channel(1);
779        let (message_tx, _message_rx) = mpsc::channel(8);
780        drop(user_tx);
781
782        let mut agent = Agent::new(
783            AgentConfig {
784                llm,
785                context,
786                mcp_command_tx: None,
787                tool_timeout: Duration::from_secs(1),
788                compaction_config: Some(CompactionConfig::with_threshold(0.85)),
789                auto_continue: AutoContinue::new(0),
790                retry_config: RetryConfig::disabled(),
791                context_window: None,
792            },
793            user_rx,
794            message_tx,
795        );
796
797        agent.maybe_preflight_compact().await;
798
799        assert!(
800            matches!(
801                agent.context.messages().as_slice(),
802                [ChatMessage::Summary { content, .. }] if content == "summary"
803            ),
804            "expected context to be compacted, got {:?}",
805            agent.context.messages()
806        );
807    }
808}