Skip to main content

aether_core/core/
agent.rs

1use crate::context::{CompactionConfig, Compactor, TokenTracker};
2pub use crate::core::retry_config::RetryConfig;
3use crate::events::{AgentMessage, UserMessage};
4use crate::mcp::run_mcp_task::{McpCommand, ToolExecutionEvent};
5use futures::Stream;
6use llm::types::IsoString;
7use llm::{
8    AssistantReasoning, ChatMessage, Context, EncryptedReasoningContent, LlmError, LlmResponse, StopReason,
9    StreamingModelProvider, TokenUsage, ToolCallError, ToolCallRequest, ToolCallResult,
10};
11use std::collections::{HashMap, HashSet, VecDeque};
12use std::pin::Pin;
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::sync::mpsc;
16use tokio::time::sleep;
17use tokio_stream::StreamExt;
18use tokio_stream::StreamMap;
19use tokio_stream::wrappers::ReceiverStream;
20
21/// Internal event type for merging LLM and tool result streams
22#[derive(Debug)]
23enum StreamEvent {
24    Llm(Result<LlmResponse, LlmError>),
25    ToolExecution(ToolExecutionEvent),
26    UserMessage(UserMessage),
27}
28
29type EventStream = Pin<Box<dyn Stream<Item = StreamEvent> + Send>>;
30
31const USER_STREAM_KEY: &str = "user";
32const LLM_STREAM_KEY: &str = "llm";
33
34pub(crate) struct AgentConfig {
35    pub llm: Arc<dyn StreamingModelProvider>,
36    pub context: Context,
37    pub mcp_command_tx: Option<mpsc::Sender<McpCommand>>,
38    pub tool_timeout: Duration,
39    pub compaction_config: Option<CompactionConfig>,
40    pub auto_continue: AutoContinue,
41    pub retry_config: RetryConfig,
42}
43
44pub struct Agent {
45    llm: Arc<dyn StreamingModelProvider>,
46    context: Context,
47    mcp_command_tx: Option<mpsc::Sender<McpCommand>>,
48    message_tx: mpsc::Sender<AgentMessage>,
49    streams: StreamMap<String, EventStream>,
50    tool_timeout: Duration,
51    token_tracker: TokenTracker,
52    compaction_config: Option<CompactionConfig>,
53    auto_continue: AutoContinue,
54    retry_config: RetryConfig,
55    active_requests: HashMap<String, ToolCallRequest>,
56    queued_user_messages: VecDeque<Vec<llm::ContentBlock>>,
57}
58
59impl Agent {
60    pub(crate) fn new(
61        config: AgentConfig,
62        user_message_rx: mpsc::Receiver<UserMessage>,
63        message_tx: mpsc::Sender<AgentMessage>,
64    ) -> Self {
65        let mut streams: StreamMap<String, EventStream> = StreamMap::new();
66        streams.insert(
67            USER_STREAM_KEY.to_string(),
68            Box::pin(ReceiverStream::new(user_message_rx).map(StreamEvent::UserMessage)),
69        );
70
71        let context_limit = config.llm.context_window();
72
73        Self {
74            llm: config.llm,
75            context: config.context,
76            mcp_command_tx: config.mcp_command_tx,
77            message_tx,
78            streams,
79            tool_timeout: config.tool_timeout,
80            token_tracker: TokenTracker::new(context_limit),
81            compaction_config: config.compaction_config,
82            auto_continue: config.auto_continue,
83            retry_config: config.retry_config,
84            active_requests: HashMap::new(),
85            queued_user_messages: VecDeque::new(),
86        }
87    }
88
89    pub fn current_model_display_name(&self) -> String {
90        self.llm.display_name()
91    }
92
93    /// Get a reference to the token tracker
94    pub fn token_tracker(&self) -> &TokenTracker {
95        &self.token_tracker
96    }
97
98    pub async fn run(mut self) {
99        let mut state = IterationState::new();
100
101        while let Some((_, event)) = self.streams.next().await {
102            use UserMessage::{Cancel, ClearContext, SetReasoningEffort, SwitchModel, Text, UpdateTools};
103            match event {
104                StreamEvent::UserMessage(Cancel) => {
105                    self.on_user_cancel(&mut state).await;
106                }
107
108                StreamEvent::UserMessage(ClearContext) => {
109                    self.on_user_clear_context(&mut state).await;
110                }
111
112                StreamEvent::UserMessage(Text { content }) => {
113                    if self.is_busy() {
114                        self.queued_user_messages.push_back(content);
115                    } else {
116                        state = IterationState::new();
117                        self.on_user_text(content);
118                    }
119                }
120
121                StreamEvent::UserMessage(SwitchModel(new_provider)) => {
122                    self.on_switch_model(new_provider).await;
123                }
124
125                StreamEvent::UserMessage(UpdateTools(tools)) => {
126                    self.context.set_tools(tools);
127                }
128
129                StreamEvent::UserMessage(SetReasoningEffort(effort)) => {
130                    self.context.set_reasoning_effort(effort);
131                }
132
133                StreamEvent::Llm(llm_event) => {
134                    self.on_llm_event(llm_event, &mut state).await;
135                }
136
137                StreamEvent::ToolExecution(tool_event) => {
138                    self.on_tool_execution_event(tool_event, &mut state).await;
139                }
140            }
141
142            if state.is_complete() {
143                let Some(id) = state.current_message_id.take() else {
144                    continue;
145                };
146                let iteration = std::mem::replace(&mut state, IterationState::new());
147                self.on_iteration_complete(id, iteration).await;
148            }
149        }
150
151        tracing::debug!("Agent task shutting down - input channel closed");
152    }
153
154    async fn on_iteration_complete(&mut self, id: String, iteration: IterationState) {
155        let IterationState {
156            message_content,
157            reasoning_summary_text,
158            encrypted_reasoning,
159            completed_tool_calls,
160            stop_reason,
161            ..
162        } = iteration;
163        let has_tool_calls = !completed_tool_calls.is_empty();
164        let has_content = !message_content.is_empty() || has_tool_calls;
165        let should_auto_continue = self.auto_continue.should_continue(stop_reason.as_ref());
166
167        if has_content {
168            let reasoning = AssistantReasoning::from_parts(reasoning_summary_text.clone(), encrypted_reasoning);
169            self.update_context(&message_content, reasoning, completed_tool_calls);
170
171            let _ = self
172                .message_tx
173                .send(AgentMessage::Text {
174                    message_id: id.clone(),
175                    chunk: message_content.clone(),
176                    is_complete: true,
177                    model_name: self.llm.display_name(),
178                })
179                .await;
180
181            if !reasoning_summary_text.is_empty() {
182                let _ = self
183                    .message_tx
184                    .send(AgentMessage::Thought {
185                        message_id: id.clone(),
186                        chunk: reasoning_summary_text,
187                        is_complete: true,
188                        model_name: self.llm.display_name(),
189                    })
190                    .await;
191            }
192        }
193
194        let has_queued_text = !self.queued_user_messages.is_empty();
195        if has_queued_text {
196            let content: Vec<_> = self.queued_user_messages.drain(..).flatten().collect();
197            self.context.add_message(ChatMessage::User { content, timestamp: IsoString::now() });
198        }
199
200        if has_queued_text || has_tool_calls {
201            self.auto_continue.reset();
202            self.start_next_turn().await;
203        } else if should_auto_continue {
204            self.auto_continue.advance();
205            tracing::info!(
206                "LLM stopped with {:?}, auto-continuing (attempt {}/{})",
207                stop_reason,
208                self.auto_continue.count(),
209                self.auto_continue.max()
210            );
211
212            let _ = self
213                .message_tx
214                .send(AgentMessage::AutoContinue {
215                    attempt: self.auto_continue.count(),
216                    max_attempts: self.auto_continue.max(),
217                })
218                .await;
219
220            self.inject_continuation_prompt(&message_content, stop_reason.as_ref());
221            self.start_next_turn().await;
222        } else {
223            tracing::debug!("LLM completed turn with stop reason: {:?}", stop_reason);
224            self.auto_continue.reset();
225            if let Err(e) = self.message_tx.send(AgentMessage::Done).await {
226                tracing::warn!("Failed to send Done message: {:?}", e);
227            }
228        }
229    }
230
231    async fn start_next_turn(&mut self) {
232        self.maybe_preflight_compact().await;
233        self.start_llm_stream(None);
234    }
235
236    async fn on_user_cancel(&mut self, state: &mut IterationState) {
237        self.abort_in_flight_work();
238        *state = IterationState::new();
239        let _ = self.message_tx.send(AgentMessage::Cancelled { message: "Processing cancelled".to_string() }).await;
240        let _ = self.message_tx.send(AgentMessage::Done).await;
241    }
242
243    async fn on_user_clear_context(&mut self, state: &mut IterationState) {
244        self.abort_in_flight_work();
245        self.context.clear_conversation();
246        self.token_tracker.reset_current_usage();
247        self.auto_continue.reset();
248        *state = IterationState::new();
249
250        let _ = self.message_tx.send(AgentMessage::ContextCleared).await;
251    }
252
253    fn on_user_text(&mut self, content: Vec<llm::ContentBlock>) {
254        self.context.add_message(ChatMessage::User { content, timestamp: IsoString::now() });
255        self.auto_continue.reset();
256        self.start_llm_stream(None);
257    }
258
259    async fn on_switch_model(&mut self, new_provider: Box<dyn StreamingModelProvider>) {
260        let previous = self.llm.display_name();
261        let new_context_limit = new_provider.context_window();
262        self.llm = Arc::from(new_provider);
263        self.token_tracker.reset_current_usage();
264        self.token_tracker.set_context_limit(new_context_limit);
265        let new = self.llm.display_name();
266        let _ = self.message_tx.send(AgentMessage::ModelSwitched { previous, new }).await;
267
268        let _ = self.message_tx.send(self.context_usage_message()).await;
269    }
270
271    fn start_llm_stream(&mut self, delay: Option<Duration>) {
272        self.streams.remove(LLM_STREAM_KEY);
273        let stream: EventStream = match delay {
274            None => Box::pin(self.llm.stream_response(&self.context).map(StreamEvent::Llm)),
275            Some(delay) => {
276                let llm = Arc::clone(&self.llm);
277                let context = self.context.clone();
278                Box::pin(async_stream::stream! {
279                    sleep(delay).await;
280                    let mut inner = llm.stream_response(&context);
281                    while let Some(item) = inner.next().await {
282                        yield StreamEvent::Llm(item);
283                    }
284                })
285            }
286        };
287        self.streams.insert(LLM_STREAM_KEY.to_string(), stream);
288    }
289
290    async fn on_llm_error(&mut self, error: LlmError, state: &mut IterationState) {
291        if !error.is_retryable() || state.retry_attempt >= self.retry_config.max_attempts {
292            let _ = self.message_tx.send(AgentMessage::Error { message: error.to_string() }).await;
293            return;
294        }
295
296        state.retry_attempt += 1;
297        let delay = self.retry_config.compute_delay(state.retry_attempt);
298        let delay_ms = u64::try_from(delay.as_millis()).unwrap_or(u64::MAX);
299
300        tracing::warn!(
301            attempt = state.retry_attempt,
302            max_attempts = self.retry_config.max_attempts,
303            delay_ms,
304            error = %error,
305            "Retrying LLM request after transient failure"
306        );
307
308        let _ = self
309            .message_tx
310            .send(AgentMessage::Retrying {
311                attempt: state.retry_attempt,
312                max_attempts: self.retry_config.max_attempts,
313                delay_ms,
314                error: error.to_string(),
315            })
316            .await;
317
318        // The previous stream may have emitted partial tool-call deltas
319        // before interrupting so we drop them to ensure we rebuild tool state
320        self.active_requests.clear();
321        self.start_llm_stream(Some(delay));
322    }
323
324    fn is_busy(&self) -> bool {
325        self.streams.contains_key(LLM_STREAM_KEY) || !self.active_requests.is_empty()
326    }
327
328    fn abort_in_flight_work(&mut self) {
329        self.streams.remove(LLM_STREAM_KEY);
330        for stream_key in self.active_requests.keys().cloned().collect::<Vec<_>>() {
331            self.streams.remove(&stream_key);
332        }
333        self.active_requests.clear();
334        self.queued_user_messages.clear();
335    }
336
337    /// Inject a continuation prompt when the LLM stops due to a resumable reason.
338    fn inject_continuation_prompt(&mut self, previous_response: &str, stop_reason: Option<&StopReason>) {
339        if !previous_response.is_empty() {
340            self.context.add_message(ChatMessage::Assistant {
341                content: previous_response.to_string(),
342                reasoning: AssistantReasoning::default(),
343                timestamp: IsoString::now(),
344                tool_calls: Vec::new(),
345            });
346        }
347
348        let reason = stop_reason.map_or_else(|| "Unknown".to_string(), |reason| format!("{reason:?}"));
349
350        self.context.add_message(ChatMessage::User {
351            content: vec![llm::ContentBlock::text(format!(
352                "<system-notification>The LLM API stopped with reason '{reason}'. Continue from where you left off and finish your task.</system-notification>"
353            ))],
354            timestamp: IsoString::now(),
355        });
356    }
357
358    async fn on_llm_event(&mut self, result: Result<LlmResponse, LlmError>, state: &mut IterationState) {
359        use LlmResponse::{
360            Done, EncryptedReasoning, Error, Reasoning, Start, Text, ToolRequestArg, ToolRequestComplete,
361            ToolRequestStart, Usage,
362        };
363
364        let response = match result {
365            Ok(response) => response,
366            Err(e) => {
367                self.on_llm_error(e, state).await;
368                return;
369            }
370        };
371
372        match response {
373            Start { message_id } => {
374                state.on_llm_start(message_id);
375            }
376
377            Text { chunk } => {
378                self.handle_llm_text(chunk, state).await;
379            }
380
381            Reasoning { chunk } => {
382                state.reasoning_summary_text.push_str(&chunk);
383                if let Some(id) = &state.current_message_id {
384                    let _ = self
385                        .message_tx
386                        .send(AgentMessage::Thought {
387                            message_id: id.clone(),
388                            chunk,
389                            is_complete: false,
390                            model_name: self.llm.display_name(),
391                        })
392                        .await;
393                }
394            }
395
396            EncryptedReasoning { id, content } => {
397                if let Some(model) = self.llm.model() {
398                    state.encrypted_reasoning = Some(EncryptedReasoningContent { id, model, content });
399                }
400            }
401
402            ToolRequestStart { id, name } => {
403                self.handle_tool_request_start(id, name).await;
404            }
405
406            ToolRequestArg { id, chunk } => {
407                self.handle_tool_request_arg(id, chunk).await;
408            }
409
410            ToolRequestComplete { tool_call } => {
411                self.handle_tool_completion(tool_call, state).await;
412            }
413
414            Done { stop_reason } => {
415                state.llm_done = true;
416                state.stop_reason = stop_reason;
417            }
418
419            Error { message } => {
420                let _ = self.message_tx.send(AgentMessage::Error { message }).await;
421            }
422
423            Usage { tokens: sample } => {
424                self.handle_llm_usage(sample).await;
425            }
426        }
427    }
428
429    async fn handle_llm_text(&mut self, chunk: String, state: &mut IterationState) {
430        state.message_content.push_str(&chunk);
431
432        if let Some(id) = &state.current_message_id {
433            let _ = self
434                .message_tx
435                .send(AgentMessage::Text {
436                    message_id: id.clone(),
437                    chunk,
438                    is_complete: false,
439                    model_name: self.llm.display_name(),
440                })
441                .await;
442        }
443    }
444
445    async fn handle_tool_request_start(&mut self, id: String, name: String) {
446        let request = ToolCallRequest { id: id.clone(), name, arguments: String::new() };
447        self.active_requests.insert(id, request.clone());
448
449        let _ = self.message_tx.send(AgentMessage::ToolCall { request, model_name: self.llm.display_name() }).await;
450    }
451
452    async fn handle_tool_request_arg(&mut self, id: String, chunk: String) {
453        let Some(request) = self.active_requests.get_mut(&id) else {
454            return;
455        };
456        request.arguments.push_str(&chunk);
457
458        let _ = self
459            .message_tx
460            .send(AgentMessage::ToolCallUpdate { tool_call_id: id, chunk, model_name: self.llm.display_name() })
461            .await;
462    }
463
464    async fn handle_tool_completion(&mut self, tool_call: ToolCallRequest, state: &mut IterationState) {
465        state.pending_tool_ids.insert(tool_call.id.clone());
466        debug_assert!(
467            self.active_requests.contains_key(&tool_call.id),
468            "tool call {} should already be in active_requests from handle_tool_request_start",
469            tool_call.id
470        );
471
472        let (tx, rx) = mpsc::channel(100);
473        let stream = ReceiverStream::new(rx).map(StreamEvent::ToolExecution);
474        let stream_key = tool_call.id.clone();
475        self.streams.insert(stream_key, Box::pin(stream));
476
477        if let Some(ref mcp_command_tx) = self.mcp_command_tx {
478            let mcp_future =
479                mcp_command_tx.send(McpCommand::ExecuteTool { request: tool_call, timeout: self.tool_timeout, tx });
480            if let Err(e) = mcp_future.await {
481                tracing::warn!("Failed to send tool request to MCP task: {:?}", e);
482            }
483        }
484    }
485
486    async fn handle_llm_usage(&mut self, sample: TokenUsage) {
487        self.token_tracker.record_usage(sample);
488        let ratio_pct = self.token_tracker.usage_ratio().map(|r| r * 100.0);
489        let remaining = self.token_tracker.tokens_remaining();
490        tracing::debug!(?sample, ?ratio_pct, ?remaining, "Token usage");
491
492        let _ = self.message_tx.send(self.context_usage_message()).await;
493
494        self.maybe_compact_context().await;
495    }
496
497    fn context_usage_message(&self) -> AgentMessage {
498        let last = self.token_tracker.last_usage();
499        AgentMessage::ContextUsageUpdate {
500            usage_ratio: self.token_tracker.usage_ratio(),
501            context_limit: self.token_tracker.context_limit(),
502            input_tokens: last.input_tokens,
503            output_tokens: last.output_tokens,
504            cache_read_tokens: last.cache_read_tokens,
505            cache_creation_tokens: last.cache_creation_tokens,
506            reasoning_tokens: last.reasoning_tokens,
507            total_input_tokens: self.token_tracker.total_input_tokens(),
508            total_output_tokens: self.token_tracker.total_output_tokens(),
509            total_cache_read_tokens: self.token_tracker.total_cache_read_tokens(),
510            total_cache_creation_tokens: self.token_tracker.total_cache_creation_tokens(),
511            total_reasoning_tokens: self.token_tracker.total_reasoning_tokens(),
512        }
513    }
514
515    /// Pre-flight check: estimate context size and compact proactively if it would
516    /// overflow before the LLM even sees it. This catches the case where large tool
517    /// results push context past the limit before usage-based compaction can fire.
518    async fn maybe_preflight_compact(&mut self) {
519        let Some(context_limit) = self.token_tracker.context_limit() else {
520            return;
521        };
522        let Some(config) = self.compaction_config.as_ref() else {
523            return;
524        };
525        let estimated = self.context.estimated_token_count();
526        #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
527        let threshold = (f64::from(context_limit) * config.threshold).ceil() as u32;
528        if estimated >= threshold {
529            tracing::info!(
530                "Pre-flight compaction triggered: estimated {estimated} tokens >= {:.1}% of {context_limit} limit",
531                config.threshold * 100.0
532            );
533            if let CompactionOutcome::Failed(e) = self.compact_context().await {
534                tracing::warn!("Pre-flight compaction failed: {e}");
535            }
536        }
537    }
538
539    /// Check if compaction is needed and perform it if so.
540    async fn maybe_compact_context(&mut self) {
541        if !self.compaction_config.as_ref().is_some_and(|config| self.token_tracker.should_compact(config.threshold)) {
542            return;
543        }
544
545        if let CompactionOutcome::Failed(error_message) = self.compact_context().await {
546            tracing::warn!("Context compaction failed: {}", error_message);
547        }
548    }
549
550    async fn compact_context(&mut self) -> CompactionOutcome {
551        let Some(ref _config) = self.compaction_config else {
552            tracing::warn!("Context compaction requested but compaction is disabled");
553            return CompactionOutcome::SkippedDisabled;
554        };
555
556        match self.token_tracker.usage_ratio() {
557            Some(usage_ratio) => {
558                tracing::info!(
559                    "Starting context compaction - {} messages, {:.1}% of context limit",
560                    self.context.message_count(),
561                    usage_ratio * 100.0
562                );
563            }
564            None => {
565                tracing::info!(
566                    "Starting context compaction - {} messages (context limit unknown)",
567                    self.context.message_count(),
568                );
569            }
570        }
571
572        let _ = self
573            .message_tx
574            .send(AgentMessage::ContextCompactionStarted { message_count: self.context.message_count() })
575            .await;
576
577        let compactor = Compactor::new(self.llm.clone());
578
579        match compactor.compact(&self.context).await {
580            Ok(result) => {
581                tracing::info!("Context compacted: {} messages removed", result.messages_removed);
582
583                self.context = result.context;
584                self.token_tracker.reset_current_usage();
585
586                let _ = self
587                    .message_tx
588                    .send(AgentMessage::ContextCompactionResult {
589                        summary: result.summary,
590                        messages_removed: result.messages_removed,
591                    })
592                    .await;
593                CompactionOutcome::Compacted
594            }
595            Err(e) => CompactionOutcome::Failed(e.to_string()),
596        }
597    }
598
599    async fn on_tool_execution_event(&mut self, event: ToolExecutionEvent, state: &mut IterationState) {
600        match event {
601            ToolExecutionEvent::Started { tool_id, tool_name } => {
602                tracing::debug!("Tool execution started: {} ({})", tool_name, tool_id);
603            }
604
605            ToolExecutionEvent::Progress { tool_id, progress } => {
606                tracing::debug!(
607                    "Tool progress for {}: {}/{}",
608                    tool_id,
609                    progress.progress,
610                    progress.total.unwrap_or(0.0)
611                );
612
613                if let Some(request) = self.active_requests.get(&tool_id) {
614                    let _ = self
615                        .message_tx
616                        .send(AgentMessage::ToolProgress {
617                            request: request.clone(),
618                            progress: progress.progress,
619                            total: progress.total,
620                            message: progress.message.clone(),
621                        })
622                        .await;
623                }
624            }
625
626            ToolExecutionEvent::Complete { tool_id: _, result, result_meta } => match result {
627                Ok(tool_result) => {
628                    tracing::debug!("Tool result received: {} -> {}", tool_result.name, tool_result.result.len());
629
630                    if state.pending_tool_ids.remove(&tool_result.id) {
631                        self.active_requests.remove(&tool_result.id);
632                        state.completed_tool_calls.push(Ok(tool_result.clone()));
633
634                        let msg = AgentMessage::ToolResult {
635                            result: tool_result,
636                            result_meta,
637                            model_name: self.llm.display_name(),
638                        };
639
640                        if let Err(e) = self.message_tx.send(msg).await {
641                            tracing::warn!("Failed to send ToolCall completion message: {:?}", e);
642                        }
643                    } else {
644                        tracing::debug!("Ignoring stale tool result for id: {}", tool_result.id);
645                    }
646                }
647
648                Err(tool_error) => {
649                    if state.pending_tool_ids.remove(&tool_error.id) {
650                        self.active_requests.remove(&tool_error.id);
651                        state.completed_tool_calls.push(Err(tool_error.clone()));
652
653                        let _ = self
654                            .message_tx
655                            .send(AgentMessage::ToolError { error: tool_error, model_name: self.llm.display_name() })
656                            .await;
657                    }
658                }
659            },
660        }
661    }
662
663    fn update_context(
664        &mut self,
665        message_content: &str,
666        reasoning: AssistantReasoning,
667        completed_tools: Vec<Result<ToolCallResult, ToolCallError>>,
668    ) {
669        self.context.push_assistant_turn(message_content, reasoning, completed_tools);
670    }
671}
672
673#[derive(Debug, Clone, PartialEq, Eq)]
674enum CompactionOutcome {
675    Compacted,
676    SkippedDisabled,
677    Failed(String),
678}
679
680pub(crate) struct AutoContinue {
681    max: u32,
682    count: u32,
683}
684
685impl AutoContinue {
686    pub(crate) fn new(max: u32) -> Self {
687        Self { max, count: 0 }
688    }
689
690    fn reset(&mut self) {
691        self.count = 0;
692    }
693
694    fn should_continue(&self, stop_reason: Option<&StopReason>) -> bool {
695        matches!(stop_reason, Some(StopReason::Length)) && self.count < self.max
696    }
697
698    fn advance(&mut self) {
699        self.count += 1;
700    }
701
702    fn count(&self) -> u32 {
703        self.count
704    }
705
706    fn max(&self) -> u32 {
707        self.max
708    }
709}
710
711#[derive(Debug)]
712struct IterationState {
713    current_message_id: Option<String>,
714    message_content: String,
715    reasoning_summary_text: String,
716    encrypted_reasoning: Option<EncryptedReasoningContent>,
717    pending_tool_ids: HashSet<String>,
718    completed_tool_calls: Vec<Result<ToolCallResult, ToolCallError>>,
719    llm_done: bool,
720    stop_reason: Option<StopReason>,
721    retry_attempt: u32,
722}
723
724impl IterationState {
725    fn new() -> Self {
726        Self {
727            current_message_id: None,
728            message_content: String::new(),
729            reasoning_summary_text: String::new(),
730            encrypted_reasoning: None,
731            pending_tool_ids: HashSet::new(),
732            completed_tool_calls: Vec::new(),
733            llm_done: false,
734            stop_reason: None,
735            retry_attempt: 0,
736        }
737    }
738
739    fn on_llm_start(&mut self, message_id: String) {
740        self.current_message_id = Some(message_id);
741        self.message_content.clear();
742        self.reasoning_summary_text.clear();
743        self.encrypted_reasoning = None;
744        self.stop_reason = None;
745    }
746
747    fn is_complete(&self) -> bool {
748        self.llm_done && self.pending_tool_ids.is_empty()
749    }
750}
751
752#[cfg(test)]
753mod tests {
754    use super::*;
755    use llm::testing::FakeLlmProvider;
756    use tokio::sync::mpsc;
757
758    #[tokio::test]
759    async fn test_preflight_compaction_uses_configured_threshold() {
760        let llm = Arc::new(
761            FakeLlmProvider::with_single_response(vec![
762                LlmResponse::start("summary"),
763                LlmResponse::text("summary"),
764                LlmResponse::done(),
765            ])
766            .with_context_window(Some(100)),
767        );
768        let context = Context::new(
769            vec![ChatMessage::User {
770                content: vec![llm::ContentBlock::text("x".repeat(344))],
771                timestamp: IsoString::now(),
772            }],
773            vec![],
774        );
775        let (user_tx, user_rx) = mpsc::channel(1);
776        let (message_tx, _message_rx) = mpsc::channel(8);
777        drop(user_tx);
778
779        let mut agent = Agent::new(
780            AgentConfig {
781                llm,
782                context,
783                mcp_command_tx: None,
784                tool_timeout: Duration::from_secs(1),
785                compaction_config: Some(CompactionConfig::with_threshold(0.85)),
786                auto_continue: AutoContinue::new(0),
787                retry_config: RetryConfig::disabled(),
788            },
789            user_rx,
790            message_tx,
791        );
792
793        agent.maybe_preflight_compact().await;
794
795        assert!(
796            matches!(
797                agent.context.messages().as_slice(),
798                [ChatMessage::Summary { content, .. }] if content == "summary"
799            ),
800            "expected context to be compacted, got {:?}",
801            agent.context.messages()
802        );
803    }
804}