1use crate::context::{CompactionConfig, Compactor, TokenTracker};
2pub use crate::core::retry_config::RetryConfig;
3use crate::events::{AgentMessage, UserMessage};
4use crate::mcp::run_mcp_task::{McpCommand, ToolExecutionEvent};
5use futures::Stream;
6use llm::types::IsoString;
7use llm::{
8 AssistantReasoning, ChatMessage, Context, EncryptedReasoningContent, LlmError, LlmResponse, StopReason,
9 StreamingModelProvider, TokenUsage, ToolCallError, ToolCallRequest, ToolCallResult,
10};
11use std::collections::{HashMap, HashSet, VecDeque};
12use std::pin::Pin;
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::sync::mpsc;
16use tokio::time::sleep;
17use tokio_stream::StreamExt;
18use tokio_stream::StreamMap;
19use tokio_stream::wrappers::ReceiverStream;
20
21#[derive(Debug)]
23enum StreamEvent {
24 Llm(Result<LlmResponse, LlmError>),
25 ToolExecution(ToolExecutionEvent),
26 UserMessage(UserMessage),
27}
28
29type EventStream = Pin<Box<dyn Stream<Item = StreamEvent> + Send>>;
30
31const USER_STREAM_KEY: &str = "user";
32const LLM_STREAM_KEY: &str = "llm";
33
34pub(crate) struct AgentConfig {
35 pub llm: Arc<dyn StreamingModelProvider>,
36 pub context: Context,
37 pub mcp_command_tx: Option<mpsc::Sender<McpCommand>>,
38 pub tool_timeout: Duration,
39 pub compaction_config: Option<CompactionConfig>,
40 pub auto_continue: AutoContinue,
41 pub retry_config: RetryConfig,
42}
43
44pub struct Agent {
45 llm: Arc<dyn StreamingModelProvider>,
46 context: Context,
47 mcp_command_tx: Option<mpsc::Sender<McpCommand>>,
48 message_tx: mpsc::Sender<AgentMessage>,
49 streams: StreamMap<String, EventStream>,
50 tool_timeout: Duration,
51 token_tracker: TokenTracker,
52 compaction_config: Option<CompactionConfig>,
53 auto_continue: AutoContinue,
54 retry_config: RetryConfig,
55 active_requests: HashMap<String, ToolCallRequest>,
56 queued_user_messages: VecDeque<Vec<llm::ContentBlock>>,
57}
58
59impl Agent {
60 pub(crate) fn new(
61 config: AgentConfig,
62 user_message_rx: mpsc::Receiver<UserMessage>,
63 message_tx: mpsc::Sender<AgentMessage>,
64 ) -> Self {
65 let mut streams: StreamMap<String, EventStream> = StreamMap::new();
66 streams.insert(
67 USER_STREAM_KEY.to_string(),
68 Box::pin(ReceiverStream::new(user_message_rx).map(StreamEvent::UserMessage)),
69 );
70
71 let context_limit = config.llm.context_window();
72
73 Self {
74 llm: config.llm,
75 context: config.context,
76 mcp_command_tx: config.mcp_command_tx,
77 message_tx,
78 streams,
79 tool_timeout: config.tool_timeout,
80 token_tracker: TokenTracker::new(context_limit),
81 compaction_config: config.compaction_config,
82 auto_continue: config.auto_continue,
83 retry_config: config.retry_config,
84 active_requests: HashMap::new(),
85 queued_user_messages: VecDeque::new(),
86 }
87 }
88
89 pub fn current_model_display_name(&self) -> String {
90 self.llm.display_name()
91 }
92
93 pub fn token_tracker(&self) -> &TokenTracker {
95 &self.token_tracker
96 }
97
98 pub async fn run(mut self) {
99 let mut state = IterationState::new();
100
101 while let Some((_, event)) = self.streams.next().await {
102 use UserMessage::{Cancel, ClearContext, SetReasoningEffort, SwitchModel, Text, UpdateTools};
103 match event {
104 StreamEvent::UserMessage(Cancel) => {
105 self.on_user_cancel(&mut state).await;
106 }
107
108 StreamEvent::UserMessage(ClearContext) => {
109 self.on_user_clear_context(&mut state).await;
110 }
111
112 StreamEvent::UserMessage(Text { content }) => {
113 if self.is_busy() {
114 self.queued_user_messages.push_back(content);
115 } else {
116 state = IterationState::new();
117 self.on_user_text(content);
118 }
119 }
120
121 StreamEvent::UserMessage(SwitchModel(new_provider)) => {
122 self.on_switch_model(new_provider).await;
123 }
124
125 StreamEvent::UserMessage(UpdateTools(tools)) => {
126 self.context.set_tools(tools);
127 }
128
129 StreamEvent::UserMessage(SetReasoningEffort(effort)) => {
130 self.context.set_reasoning_effort(effort);
131 }
132
133 StreamEvent::Llm(llm_event) => {
134 self.on_llm_event(llm_event, &mut state).await;
135 }
136
137 StreamEvent::ToolExecution(tool_event) => {
138 self.on_tool_execution_event(tool_event, &mut state).await;
139 }
140 }
141
142 if state.is_complete() {
143 let Some(id) = state.current_message_id.take() else {
144 continue;
145 };
146 let iteration = std::mem::replace(&mut state, IterationState::new());
147 self.on_iteration_complete(id, iteration).await;
148 }
149 }
150
151 tracing::debug!("Agent task shutting down - input channel closed");
152 }
153
154 async fn on_iteration_complete(&mut self, id: String, iteration: IterationState) {
155 let IterationState {
156 message_content,
157 reasoning_summary_text,
158 encrypted_reasoning,
159 completed_tool_calls,
160 stop_reason,
161 ..
162 } = iteration;
163 let has_tool_calls = !completed_tool_calls.is_empty();
164 let has_content = !message_content.is_empty() || has_tool_calls;
165 let should_auto_continue = self.auto_continue.should_continue(stop_reason.as_ref());
166
167 if has_content {
168 let reasoning = AssistantReasoning::from_parts(reasoning_summary_text.clone(), encrypted_reasoning);
169 self.update_context(&message_content, reasoning, completed_tool_calls);
170
171 let _ = self
172 .message_tx
173 .send(AgentMessage::Text {
174 message_id: id.clone(),
175 chunk: message_content.clone(),
176 is_complete: true,
177 model_name: self.llm.display_name(),
178 })
179 .await;
180
181 if !reasoning_summary_text.is_empty() {
182 let _ = self
183 .message_tx
184 .send(AgentMessage::Thought {
185 message_id: id.clone(),
186 chunk: reasoning_summary_text,
187 is_complete: true,
188 model_name: self.llm.display_name(),
189 })
190 .await;
191 }
192 }
193
194 let has_queued_text = !self.queued_user_messages.is_empty();
195 if has_queued_text {
196 let content: Vec<_> = self.queued_user_messages.drain(..).flatten().collect();
197 self.context.add_message(ChatMessage::User { content, timestamp: IsoString::now() });
198 }
199
200 if has_queued_text || has_tool_calls {
201 self.auto_continue.reset();
202 self.start_next_turn().await;
203 } else if should_auto_continue {
204 self.auto_continue.advance();
205 tracing::info!(
206 "LLM stopped with {:?}, auto-continuing (attempt {}/{})",
207 stop_reason,
208 self.auto_continue.count(),
209 self.auto_continue.max()
210 );
211
212 let _ = self
213 .message_tx
214 .send(AgentMessage::AutoContinue {
215 attempt: self.auto_continue.count(),
216 max_attempts: self.auto_continue.max(),
217 })
218 .await;
219
220 self.inject_continuation_prompt(&message_content, stop_reason.as_ref());
221 self.start_next_turn().await;
222 } else {
223 tracing::debug!("LLM completed turn with stop reason: {:?}", stop_reason);
224 self.auto_continue.reset();
225 if let Err(e) = self.message_tx.send(AgentMessage::Done).await {
226 tracing::warn!("Failed to send Done message: {:?}", e);
227 }
228 }
229 }
230
231 async fn start_next_turn(&mut self) {
232 self.maybe_preflight_compact().await;
233 self.start_llm_stream(None);
234 }
235
236 async fn on_user_cancel(&mut self, state: &mut IterationState) {
237 self.abort_in_flight_work();
238 *state = IterationState::new();
239 let _ = self.message_tx.send(AgentMessage::Cancelled { message: "Processing cancelled".to_string() }).await;
240 let _ = self.message_tx.send(AgentMessage::Done).await;
241 }
242
243 async fn on_user_clear_context(&mut self, state: &mut IterationState) {
244 self.abort_in_flight_work();
245 self.context.clear_conversation();
246 self.token_tracker.reset_current_usage();
247 self.auto_continue.reset();
248 *state = IterationState::new();
249
250 let _ = self.message_tx.send(AgentMessage::ContextCleared).await;
251 }
252
253 fn on_user_text(&mut self, content: Vec<llm::ContentBlock>) {
254 self.context.add_message(ChatMessage::User { content, timestamp: IsoString::now() });
255 self.auto_continue.reset();
256 self.start_llm_stream(None);
257 }
258
259 async fn on_switch_model(&mut self, new_provider: Box<dyn StreamingModelProvider>) {
260 let previous = self.llm.display_name();
261 let new_context_limit = new_provider.context_window();
262 self.llm = Arc::from(new_provider);
263 self.token_tracker.reset_current_usage();
264 self.token_tracker.set_context_limit(new_context_limit);
265 let new = self.llm.display_name();
266 let _ = self.message_tx.send(AgentMessage::ModelSwitched { previous, new }).await;
267
268 let _ = self.message_tx.send(self.context_usage_message()).await;
269 }
270
271 fn start_llm_stream(&mut self, delay: Option<Duration>) {
272 self.streams.remove(LLM_STREAM_KEY);
273 let stream: EventStream = match delay {
274 None => Box::pin(self.llm.stream_response(&self.context).map(StreamEvent::Llm)),
275 Some(delay) => {
276 let llm = Arc::clone(&self.llm);
277 let context = self.context.clone();
278 Box::pin(async_stream::stream! {
279 sleep(delay).await;
280 let mut inner = llm.stream_response(&context);
281 while let Some(item) = inner.next().await {
282 yield StreamEvent::Llm(item);
283 }
284 })
285 }
286 };
287 self.streams.insert(LLM_STREAM_KEY.to_string(), stream);
288 }
289
290 async fn on_llm_error(&mut self, error: LlmError, state: &mut IterationState) {
291 if !error.is_retryable() || state.retry_attempt >= self.retry_config.max_attempts {
292 let _ = self.message_tx.send(AgentMessage::Error { message: error.to_string() }).await;
293 return;
294 }
295
296 state.retry_attempt += 1;
297 let delay = self.retry_config.compute_delay(state.retry_attempt);
298 let delay_ms = u64::try_from(delay.as_millis()).unwrap_or(u64::MAX);
299
300 tracing::warn!(
301 attempt = state.retry_attempt,
302 max_attempts = self.retry_config.max_attempts,
303 delay_ms,
304 error = %error,
305 "Retrying LLM request after transient failure"
306 );
307
308 let _ = self
309 .message_tx
310 .send(AgentMessage::Retrying {
311 attempt: state.retry_attempt,
312 max_attempts: self.retry_config.max_attempts,
313 delay_ms,
314 error: error.to_string(),
315 })
316 .await;
317
318 self.active_requests.clear();
321 self.start_llm_stream(Some(delay));
322 }
323
324 fn is_busy(&self) -> bool {
325 self.streams.contains_key(LLM_STREAM_KEY) || !self.active_requests.is_empty()
326 }
327
328 fn abort_in_flight_work(&mut self) {
329 self.streams.remove(LLM_STREAM_KEY);
330 for stream_key in self.active_requests.keys().cloned().collect::<Vec<_>>() {
331 self.streams.remove(&stream_key);
332 }
333 self.active_requests.clear();
334 self.queued_user_messages.clear();
335 }
336
337 fn inject_continuation_prompt(&mut self, previous_response: &str, stop_reason: Option<&StopReason>) {
339 if !previous_response.is_empty() {
340 self.context.add_message(ChatMessage::Assistant {
341 content: previous_response.to_string(),
342 reasoning: AssistantReasoning::default(),
343 timestamp: IsoString::now(),
344 tool_calls: Vec::new(),
345 });
346 }
347
348 let reason = stop_reason.map_or_else(|| "Unknown".to_string(), |reason| format!("{reason:?}"));
349
350 self.context.add_message(ChatMessage::User {
351 content: vec![llm::ContentBlock::text(format!(
352 "<system-notification>The LLM API stopped with reason '{reason}'. Continue from where you left off and finish your task.</system-notification>"
353 ))],
354 timestamp: IsoString::now(),
355 });
356 }
357
358 async fn on_llm_event(&mut self, result: Result<LlmResponse, LlmError>, state: &mut IterationState) {
359 use LlmResponse::{
360 Done, EncryptedReasoning, Error, Reasoning, Start, Text, ToolRequestArg, ToolRequestComplete,
361 ToolRequestStart, Usage,
362 };
363
364 let response = match result {
365 Ok(response) => response,
366 Err(e) => {
367 self.on_llm_error(e, state).await;
368 return;
369 }
370 };
371
372 match response {
373 Start { message_id } => {
374 state.on_llm_start(message_id);
375 }
376
377 Text { chunk } => {
378 self.handle_llm_text(chunk, state).await;
379 }
380
381 Reasoning { chunk } => {
382 state.reasoning_summary_text.push_str(&chunk);
383 if let Some(id) = &state.current_message_id {
384 let _ = self
385 .message_tx
386 .send(AgentMessage::Thought {
387 message_id: id.clone(),
388 chunk,
389 is_complete: false,
390 model_name: self.llm.display_name(),
391 })
392 .await;
393 }
394 }
395
396 EncryptedReasoning { id, content } => {
397 if let Some(model) = self.llm.model() {
398 state.encrypted_reasoning = Some(EncryptedReasoningContent { id, model, content });
399 }
400 }
401
402 ToolRequestStart { id, name } => {
403 self.handle_tool_request_start(id, name).await;
404 }
405
406 ToolRequestArg { id, chunk } => {
407 self.handle_tool_request_arg(id, chunk).await;
408 }
409
410 ToolRequestComplete { tool_call } => {
411 self.handle_tool_completion(tool_call, state).await;
412 }
413
414 Done { stop_reason } => {
415 state.llm_done = true;
416 state.stop_reason = stop_reason;
417 }
418
419 Error { message } => {
420 let _ = self.message_tx.send(AgentMessage::Error { message }).await;
421 }
422
423 Usage { tokens: sample } => {
424 self.handle_llm_usage(sample).await;
425 }
426 }
427 }
428
429 async fn handle_llm_text(&mut self, chunk: String, state: &mut IterationState) {
430 state.message_content.push_str(&chunk);
431
432 if let Some(id) = &state.current_message_id {
433 let _ = self
434 .message_tx
435 .send(AgentMessage::Text {
436 message_id: id.clone(),
437 chunk,
438 is_complete: false,
439 model_name: self.llm.display_name(),
440 })
441 .await;
442 }
443 }
444
445 async fn handle_tool_request_start(&mut self, id: String, name: String) {
446 let request = ToolCallRequest { id: id.clone(), name, arguments: String::new() };
447 self.active_requests.insert(id, request.clone());
448
449 let _ = self.message_tx.send(AgentMessage::ToolCall { request, model_name: self.llm.display_name() }).await;
450 }
451
452 async fn handle_tool_request_arg(&mut self, id: String, chunk: String) {
453 let Some(request) = self.active_requests.get_mut(&id) else {
454 return;
455 };
456 request.arguments.push_str(&chunk);
457
458 let _ = self
459 .message_tx
460 .send(AgentMessage::ToolCallUpdate { tool_call_id: id, chunk, model_name: self.llm.display_name() })
461 .await;
462 }
463
464 async fn handle_tool_completion(&mut self, tool_call: ToolCallRequest, state: &mut IterationState) {
465 state.pending_tool_ids.insert(tool_call.id.clone());
466 debug_assert!(
467 self.active_requests.contains_key(&tool_call.id),
468 "tool call {} should already be in active_requests from handle_tool_request_start",
469 tool_call.id
470 );
471
472 let (tx, rx) = mpsc::channel(100);
473 let stream = ReceiverStream::new(rx).map(StreamEvent::ToolExecution);
474 let stream_key = tool_call.id.clone();
475 self.streams.insert(stream_key, Box::pin(stream));
476
477 if let Some(ref mcp_command_tx) = self.mcp_command_tx {
478 let mcp_future =
479 mcp_command_tx.send(McpCommand::ExecuteTool { request: tool_call, timeout: self.tool_timeout, tx });
480 if let Err(e) = mcp_future.await {
481 tracing::warn!("Failed to send tool request to MCP task: {:?}", e);
482 }
483 }
484 }
485
486 async fn handle_llm_usage(&mut self, sample: TokenUsage) {
487 self.token_tracker.record_usage(sample);
488 let ratio_pct = self.token_tracker.usage_ratio().map(|r| r * 100.0);
489 let remaining = self.token_tracker.tokens_remaining();
490 tracing::debug!(?sample, ?ratio_pct, ?remaining, "Token usage");
491
492 let _ = self.message_tx.send(self.context_usage_message()).await;
493
494 self.maybe_compact_context().await;
495 }
496
497 fn context_usage_message(&self) -> AgentMessage {
498 let last = self.token_tracker.last_usage();
499 AgentMessage::ContextUsageUpdate {
500 usage_ratio: self.token_tracker.usage_ratio(),
501 context_limit: self.token_tracker.context_limit(),
502 input_tokens: last.input_tokens,
503 output_tokens: last.output_tokens,
504 cache_read_tokens: last.cache_read_tokens,
505 cache_creation_tokens: last.cache_creation_tokens,
506 reasoning_tokens: last.reasoning_tokens,
507 total_input_tokens: self.token_tracker.total_input_tokens(),
508 total_output_tokens: self.token_tracker.total_output_tokens(),
509 total_cache_read_tokens: self.token_tracker.total_cache_read_tokens(),
510 total_cache_creation_tokens: self.token_tracker.total_cache_creation_tokens(),
511 total_reasoning_tokens: self.token_tracker.total_reasoning_tokens(),
512 }
513 }
514
515 async fn maybe_preflight_compact(&mut self) {
519 let Some(context_limit) = self.token_tracker.context_limit() else {
520 return;
521 };
522 let Some(config) = self.compaction_config.as_ref() else {
523 return;
524 };
525 let estimated = self.context.estimated_token_count();
526 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
527 let threshold = (f64::from(context_limit) * config.threshold).ceil() as u32;
528 if estimated >= threshold {
529 tracing::info!(
530 "Pre-flight compaction triggered: estimated {estimated} tokens >= {:.1}% of {context_limit} limit",
531 config.threshold * 100.0
532 );
533 if let CompactionOutcome::Failed(e) = self.compact_context().await {
534 tracing::warn!("Pre-flight compaction failed: {e}");
535 }
536 }
537 }
538
539 async fn maybe_compact_context(&mut self) {
541 if !self.compaction_config.as_ref().is_some_and(|config| self.token_tracker.should_compact(config.threshold)) {
542 return;
543 }
544
545 if let CompactionOutcome::Failed(error_message) = self.compact_context().await {
546 tracing::warn!("Context compaction failed: {}", error_message);
547 }
548 }
549
550 async fn compact_context(&mut self) -> CompactionOutcome {
551 let Some(ref _config) = self.compaction_config else {
552 tracing::warn!("Context compaction requested but compaction is disabled");
553 return CompactionOutcome::SkippedDisabled;
554 };
555
556 match self.token_tracker.usage_ratio() {
557 Some(usage_ratio) => {
558 tracing::info!(
559 "Starting context compaction - {} messages, {:.1}% of context limit",
560 self.context.message_count(),
561 usage_ratio * 100.0
562 );
563 }
564 None => {
565 tracing::info!(
566 "Starting context compaction - {} messages (context limit unknown)",
567 self.context.message_count(),
568 );
569 }
570 }
571
572 let _ = self
573 .message_tx
574 .send(AgentMessage::ContextCompactionStarted { message_count: self.context.message_count() })
575 .await;
576
577 let compactor = Compactor::new(self.llm.clone());
578
579 match compactor.compact(&self.context).await {
580 Ok(result) => {
581 tracing::info!("Context compacted: {} messages removed", result.messages_removed);
582
583 self.context = result.context;
584 self.token_tracker.reset_current_usage();
585
586 let _ = self
587 .message_tx
588 .send(AgentMessage::ContextCompactionResult {
589 summary: result.summary,
590 messages_removed: result.messages_removed,
591 })
592 .await;
593 CompactionOutcome::Compacted
594 }
595 Err(e) => CompactionOutcome::Failed(e.to_string()),
596 }
597 }
598
599 async fn on_tool_execution_event(&mut self, event: ToolExecutionEvent, state: &mut IterationState) {
600 match event {
601 ToolExecutionEvent::Started { tool_id, tool_name } => {
602 tracing::debug!("Tool execution started: {} ({})", tool_name, tool_id);
603 }
604
605 ToolExecutionEvent::Progress { tool_id, progress } => {
606 tracing::debug!(
607 "Tool progress for {}: {}/{}",
608 tool_id,
609 progress.progress,
610 progress.total.unwrap_or(0.0)
611 );
612
613 if let Some(request) = self.active_requests.get(&tool_id) {
614 let _ = self
615 .message_tx
616 .send(AgentMessage::ToolProgress {
617 request: request.clone(),
618 progress: progress.progress,
619 total: progress.total,
620 message: progress.message.clone(),
621 })
622 .await;
623 }
624 }
625
626 ToolExecutionEvent::Complete { tool_id: _, result, result_meta } => match result {
627 Ok(tool_result) => {
628 tracing::debug!("Tool result received: {} -> {}", tool_result.name, tool_result.result.len());
629
630 if state.pending_tool_ids.remove(&tool_result.id) {
631 self.active_requests.remove(&tool_result.id);
632 state.completed_tool_calls.push(Ok(tool_result.clone()));
633
634 let msg = AgentMessage::ToolResult {
635 result: tool_result,
636 result_meta,
637 model_name: self.llm.display_name(),
638 };
639
640 if let Err(e) = self.message_tx.send(msg).await {
641 tracing::warn!("Failed to send ToolCall completion message: {:?}", e);
642 }
643 } else {
644 tracing::debug!("Ignoring stale tool result for id: {}", tool_result.id);
645 }
646 }
647
648 Err(tool_error) => {
649 if state.pending_tool_ids.remove(&tool_error.id) {
650 self.active_requests.remove(&tool_error.id);
651 state.completed_tool_calls.push(Err(tool_error.clone()));
652
653 let _ = self
654 .message_tx
655 .send(AgentMessage::ToolError { error: tool_error, model_name: self.llm.display_name() })
656 .await;
657 }
658 }
659 },
660 }
661 }
662
663 fn update_context(
664 &mut self,
665 message_content: &str,
666 reasoning: AssistantReasoning,
667 completed_tools: Vec<Result<ToolCallResult, ToolCallError>>,
668 ) {
669 self.context.push_assistant_turn(message_content, reasoning, completed_tools);
670 }
671}
672
673#[derive(Debug, Clone, PartialEq, Eq)]
674enum CompactionOutcome {
675 Compacted,
676 SkippedDisabled,
677 Failed(String),
678}
679
680pub(crate) struct AutoContinue {
681 max: u32,
682 count: u32,
683}
684
685impl AutoContinue {
686 pub(crate) fn new(max: u32) -> Self {
687 Self { max, count: 0 }
688 }
689
690 fn reset(&mut self) {
691 self.count = 0;
692 }
693
694 fn should_continue(&self, stop_reason: Option<&StopReason>) -> bool {
695 matches!(stop_reason, Some(StopReason::Length)) && self.count < self.max
696 }
697
698 fn advance(&mut self) {
699 self.count += 1;
700 }
701
702 fn count(&self) -> u32 {
703 self.count
704 }
705
706 fn max(&self) -> u32 {
707 self.max
708 }
709}
710
711#[derive(Debug)]
712struct IterationState {
713 current_message_id: Option<String>,
714 message_content: String,
715 reasoning_summary_text: String,
716 encrypted_reasoning: Option<EncryptedReasoningContent>,
717 pending_tool_ids: HashSet<String>,
718 completed_tool_calls: Vec<Result<ToolCallResult, ToolCallError>>,
719 llm_done: bool,
720 stop_reason: Option<StopReason>,
721 retry_attempt: u32,
722}
723
724impl IterationState {
725 fn new() -> Self {
726 Self {
727 current_message_id: None,
728 message_content: String::new(),
729 reasoning_summary_text: String::new(),
730 encrypted_reasoning: None,
731 pending_tool_ids: HashSet::new(),
732 completed_tool_calls: Vec::new(),
733 llm_done: false,
734 stop_reason: None,
735 retry_attempt: 0,
736 }
737 }
738
739 fn on_llm_start(&mut self, message_id: String) {
740 self.current_message_id = Some(message_id);
741 self.message_content.clear();
742 self.reasoning_summary_text.clear();
743 self.encrypted_reasoning = None;
744 self.stop_reason = None;
745 }
746
747 fn is_complete(&self) -> bool {
748 self.llm_done && self.pending_tool_ids.is_empty()
749 }
750}
751
752#[cfg(test)]
753mod tests {
754 use super::*;
755 use llm::testing::FakeLlmProvider;
756 use tokio::sync::mpsc;
757
758 #[tokio::test]
759 async fn test_preflight_compaction_uses_configured_threshold() {
760 let llm = Arc::new(
761 FakeLlmProvider::with_single_response(vec![
762 LlmResponse::start("summary"),
763 LlmResponse::text("summary"),
764 LlmResponse::done(),
765 ])
766 .with_context_window(Some(100)),
767 );
768 let context = Context::new(
769 vec![ChatMessage::User {
770 content: vec![llm::ContentBlock::text("x".repeat(344))],
771 timestamp: IsoString::now(),
772 }],
773 vec![],
774 );
775 let (user_tx, user_rx) = mpsc::channel(1);
776 let (message_tx, _message_rx) = mpsc::channel(8);
777 drop(user_tx);
778
779 let mut agent = Agent::new(
780 AgentConfig {
781 llm,
782 context,
783 mcp_command_tx: None,
784 tool_timeout: Duration::from_secs(1),
785 compaction_config: Some(CompactionConfig::with_threshold(0.85)),
786 auto_continue: AutoContinue::new(0),
787 retry_config: RetryConfig::disabled(),
788 },
789 user_rx,
790 message_tx,
791 );
792
793 agent.maybe_preflight_compact().await;
794
795 assert!(
796 matches!(
797 agent.context.messages().as_slice(),
798 [ChatMessage::Summary { content, .. }] if content == "summary"
799 ),
800 "expected context to be compacted, got {:?}",
801 agent.context.messages()
802 );
803 }
804}