1use crate::context::{CompactionConfig, Compactor, TokenTracker};
2use crate::core::PromptCache;
3pub use crate::core::retry_config::RetryConfig;
4use crate::events::{AgentCommand, AgentMessage, Command, UserCommand};
5use crate::mcp::run_mcp_task::{McpCommand, ToolExecutionEvent};
6use futures::Stream;
7use llm::types::IsoString;
8use llm::{
9 AssistantReasoning, ChatMessage, Context, EncryptedReasoningContent, LlmError, LlmResponse, StopReason,
10 StreamingModelProvider, TokenUsage, ToolCallError, ToolCallRequest, ToolCallResult,
11};
12use std::collections::{HashMap, HashSet, VecDeque};
13use std::pin::Pin;
14use std::sync::Arc;
15use std::time::Duration;
16use tokio::sync::mpsc;
17use tokio::time::sleep;
18use tokio_stream::StreamExt;
19use tokio_stream::StreamMap;
20use tokio_stream::wrappers::ReceiverStream;
21
22#[derive(Debug)]
24enum StreamEvent {
25 Llm(Result<LlmResponse, LlmError>),
26 ToolExecution(ToolExecutionEvent),
27 Command(Command),
28}
29
30type EventStream = Pin<Box<dyn Stream<Item = StreamEvent> + Send>>;
31
32const USER_STREAM_KEY: &str = "user";
33const LLM_STREAM_KEY: &str = "llm";
34
35pub(crate) struct AgentConfig {
36 pub llm: Arc<dyn StreamingModelProvider>,
37 pub context: Context,
38 pub mcp_command_tx: Option<mpsc::Sender<McpCommand>>,
39 pub tool_timeout: Duration,
40 pub compaction_config: Option<CompactionConfig>,
41 pub auto_continue: AutoContinue,
42 pub retry_config: RetryConfig,
43 pub context_window: Option<u32>,
44 pub prompt_cache: PromptCache,
45}
46
47pub struct Agent {
48 llm: Arc<dyn StreamingModelProvider>,
49 context: Context,
50 mcp_command_tx: Option<mpsc::Sender<McpCommand>>,
51 message_tx: mpsc::Sender<AgentMessage>,
52 streams: StreamMap<String, EventStream>,
53 tool_timeout: Duration,
54 token_tracker: TokenTracker,
55 compaction_config: Option<CompactionConfig>,
56 auto_continue: AutoContinue,
57 retry_config: RetryConfig,
58 active_requests: HashMap<String, ToolCallRequest>,
59 queued_user_messages: VecDeque<Vec<llm::ContentBlock>>,
60 context_window: Option<u32>,
61 prompt_cache: PromptCache,
62}
63
64impl Agent {
65 pub(crate) fn new(
66 config: AgentConfig,
67 command_rx: mpsc::Receiver<Command>,
68 message_tx: mpsc::Sender<AgentMessage>,
69 ) -> Self {
70 let mut streams: StreamMap<String, EventStream> = StreamMap::new();
71 streams
72 .insert(USER_STREAM_KEY.to_string(), Box::pin(ReceiverStream::new(command_rx).map(StreamEvent::Command)));
73
74 let context_limit = config.context_window.or_else(|| config.llm.context_window());
75
76 Self {
77 llm: config.llm,
78 context: config.context,
79 mcp_command_tx: config.mcp_command_tx,
80 message_tx,
81 streams,
82 tool_timeout: config.tool_timeout,
83 token_tracker: TokenTracker::new(context_limit),
84 compaction_config: config.compaction_config,
85 auto_continue: config.auto_continue,
86 retry_config: config.retry_config,
87 active_requests: HashMap::new(),
88 queued_user_messages: VecDeque::new(),
89 context_window: config.context_window,
90 prompt_cache: config.prompt_cache,
91 }
92 }
93
94 pub fn current_model_display_name(&self) -> String {
95 self.llm.display_name()
96 }
97
98 pub fn token_tracker(&self) -> &TokenTracker {
100 &self.token_tracker
101 }
102
103 pub async fn run(mut self) {
104 let mut state = IterationState::new();
105
106 while let Some((_, event)) = self.streams.next().await {
107 match event {
108 StreamEvent::Command(Command::UserCommand(UserCommand::Cancel)) => {
109 self.on_user_cancel(&mut state).await;
110 }
111
112 StreamEvent::Command(Command::UserCommand(UserCommand::ClearContext)) => {
113 self.on_user_clear_context(&mut state).await;
114 }
115
116 StreamEvent::Command(Command::UserCommand(UserCommand::Text { content })) => {
117 if self.is_busy() {
118 self.queued_user_messages.push_back(content);
119 } else {
120 state = IterationState::new();
121 self.on_user_text(content);
122 }
123 }
124
125 StreamEvent::Command(Command::AgentCommand(AgentCommand::SwitchModel(new_provider))) => {
126 self.on_switch_model(new_provider).await;
127 }
128
129 StreamEvent::Command(Command::AgentCommand(AgentCommand::UpdateTools(tools))) => {
130 self.context.set_tools(tools);
131 }
132
133 StreamEvent::Command(Command::AgentCommand(AgentCommand::UpdateMcpInstructions { server, body })) => {
134 self.on_update_instruction(server, body).await;
135 }
136
137 StreamEvent::Command(Command::AgentCommand(AgentCommand::SetReasoningEffort(effort))) => {
138 self.context.set_reasoning_effort(effort);
139 }
140
141 StreamEvent::Command(Command::AgentCommand(AgentCommand::ReplaceConversation(messages))) => {
142 self.on_replace_conversation(messages, &mut state).await;
143 }
144
145 StreamEvent::Llm(llm_event) => {
146 self.on_llm_event(llm_event, &mut state).await;
147 }
148
149 StreamEvent::ToolExecution(tool_event) => {
150 self.on_tool_execution_event(tool_event, &mut state).await;
151 }
152 }
153
154 if state.is_complete() {
155 let Some(id) = state.current_message_id.take() else {
156 continue;
157 };
158 let iteration = std::mem::replace(&mut state, IterationState::new());
159 self.on_iteration_complete(id, iteration).await;
160 }
161 }
162
163 tracing::debug!("Agent task shutting down - input channel closed");
164 }
165
166 async fn on_iteration_complete(&mut self, id: String, iteration: IterationState) {
167 let IterationState {
168 message_content,
169 reasoning_summary_text,
170 encrypted_reasoning,
171 completed_tool_calls,
172 stop_reason,
173 ..
174 } = iteration;
175 let has_tool_calls = !completed_tool_calls.is_empty();
176 let has_content = !message_content.is_empty() || has_tool_calls;
177 let should_auto_continue = self.auto_continue.should_continue(stop_reason.as_ref());
178
179 if has_content {
180 let reasoning = AssistantReasoning::from_parts(reasoning_summary_text.clone(), encrypted_reasoning);
181 self.update_context(&message_content, reasoning, completed_tool_calls);
182
183 let _ = self
184 .message_tx
185 .send(AgentMessage::Text {
186 message_id: id.clone(),
187 chunk: message_content.clone(),
188 is_complete: true,
189 model_name: self.llm.display_name(),
190 })
191 .await;
192
193 if !reasoning_summary_text.is_empty() {
194 let _ = self
195 .message_tx
196 .send(AgentMessage::Thought {
197 message_id: id.clone(),
198 chunk: reasoning_summary_text,
199 is_complete: true,
200 model_name: self.llm.display_name(),
201 })
202 .await;
203 }
204 }
205
206 let has_queued_text = !self.queued_user_messages.is_empty();
207 if has_queued_text {
208 let content: Vec<_> = self.queued_user_messages.drain(..).flatten().collect();
209 self.context.add_message(ChatMessage::User { content, timestamp: IsoString::now() });
210 }
211
212 if has_queued_text || has_tool_calls {
213 self.auto_continue.reset();
214 self.start_next_turn().await;
215 } else if should_auto_continue {
216 self.auto_continue.advance();
217 tracing::info!(
218 "LLM stopped with {:?}, auto-continuing (attempt {}/{})",
219 stop_reason,
220 self.auto_continue.count(),
221 self.auto_continue.max()
222 );
223
224 let _ = self
225 .message_tx
226 .send(AgentMessage::AutoContinue {
227 attempt: self.auto_continue.count(),
228 max_attempts: self.auto_continue.max(),
229 })
230 .await;
231
232 self.inject_continuation_prompt(&message_content, stop_reason.as_ref());
233 self.start_next_turn().await;
234 } else {
235 tracing::debug!("LLM completed turn with stop reason: {:?}", stop_reason);
236 self.auto_continue.reset();
237 if let Err(e) = self.message_tx.send(AgentMessage::Done).await {
238 tracing::warn!("Failed to send Done message: {:?}", e);
239 }
240 }
241 }
242
243 async fn start_next_turn(&mut self) {
244 self.maybe_preflight_compact().await;
245 self.start_llm_stream(None);
246 }
247
248 async fn on_user_cancel(&mut self, state: &mut IterationState) {
249 self.abort_in_flight_work();
250 *state = IterationState::new();
251 let _ = self.message_tx.send(AgentMessage::Cancelled { message: "Processing cancelled".to_string() }).await;
252 let _ = self.message_tx.send(AgentMessage::Done).await;
253 }
254
255 async fn on_user_clear_context(&mut self, state: &mut IterationState) {
256 self.abort_in_flight_work();
257 self.context.clear_conversation();
258 self.token_tracker.reset_current_usage();
259 self.auto_continue.reset();
260 *state = IterationState::new();
261
262 let _ = self.message_tx.send(AgentMessage::ContextCleared).await;
263 }
264
265 async fn on_replace_conversation(&mut self, messages: Vec<ChatMessage>, state: &mut IterationState) {
266 self.abort_in_flight_work();
267 self.context.replace_conversation(messages);
268 self.auto_continue.reset();
269 *state = IterationState::new();
270 let _ = self.message_tx.send(self.context_usage_message()).await;
271 }
272
273 fn on_user_text(&mut self, content: Vec<llm::ContentBlock>) {
274 self.context.add_message(ChatMessage::User { content, timestamp: IsoString::now() });
275 self.auto_continue.reset();
276 self.start_llm_stream(None);
277 }
278
279 async fn on_update_instruction(&mut self, server: String, body: Option<String>) {
280 self.prompt_cache.update_mcp_instruction(server, body);
281 match self.prompt_cache.render().await {
282 Ok(content) => self.context.set_system_content(content),
283 Err(e) => tracing::warn!("Failed to rebuild system prompt after instructions update: {e}"),
284 }
285 }
286
287 async fn on_switch_model(&mut self, new_provider: Box<dyn StreamingModelProvider>) {
288 let previous = self.llm.display_name();
289 let new_context_limit = self.context_window.or_else(|| new_provider.context_window());
290 self.llm = Arc::from(new_provider);
291 self.token_tracker.reset_current_usage();
292 self.token_tracker.set_context_limit(new_context_limit);
293 let new = self.llm.display_name();
294 let _ = self.message_tx.send(AgentMessage::ModelSwitched { previous, new }).await;
295
296 let _ = self.message_tx.send(self.context_usage_message()).await;
297 }
298
299 fn start_llm_stream(&mut self, delay: Option<Duration>) {
300 self.streams.remove(LLM_STREAM_KEY);
301 let stream: EventStream = match delay {
302 None => Box::pin(self.llm.stream_response(&self.context).map(StreamEvent::Llm)),
303 Some(delay) => {
304 let llm = Arc::clone(&self.llm);
305 let context = self.context.clone();
306 Box::pin(async_stream::stream! {
307 sleep(delay).await;
308 let mut inner = llm.stream_response(&context);
309 while let Some(item) = inner.next().await {
310 yield StreamEvent::Llm(item);
311 }
312 })
313 }
314 };
315 self.streams.insert(LLM_STREAM_KEY.to_string(), stream);
316 }
317
318 async fn on_llm_error(&mut self, error: LlmError, state: &mut IterationState) {
319 if !error.is_retryable() || state.retry_attempt >= self.retry_config.max_attempts {
320 let _ = self.message_tx.send(AgentMessage::Error { message: error.to_string() }).await;
321 return;
322 }
323
324 state.retry_attempt += 1;
325 let delay = self.retry_config.compute_delay(state.retry_attempt);
326 let delay_ms = u64::try_from(delay.as_millis()).unwrap_or(u64::MAX);
327
328 tracing::warn!(
329 attempt = state.retry_attempt,
330 max_attempts = self.retry_config.max_attempts,
331 delay_ms,
332 error = %error,
333 "Retrying LLM request after transient failure"
334 );
335
336 let _ = self
337 .message_tx
338 .send(AgentMessage::Retrying {
339 attempt: state.retry_attempt,
340 max_attempts: self.retry_config.max_attempts,
341 delay_ms,
342 error: error.to_string(),
343 })
344 .await;
345
346 self.active_requests.clear();
349 self.start_llm_stream(Some(delay));
350 }
351
352 fn is_busy(&self) -> bool {
353 self.streams.contains_key(LLM_STREAM_KEY) || !self.active_requests.is_empty()
354 }
355
356 fn abort_in_flight_work(&mut self) {
357 self.streams.remove(LLM_STREAM_KEY);
358 for stream_key in self.active_requests.keys().cloned().collect::<Vec<_>>() {
359 self.streams.remove(&stream_key);
360 }
361 self.active_requests.clear();
362 self.queued_user_messages.clear();
363 }
364
365 fn inject_continuation_prompt(&mut self, previous_response: &str, stop_reason: Option<&StopReason>) {
367 if !previous_response.is_empty() {
368 self.context.add_message(ChatMessage::Assistant {
369 content: previous_response.to_string(),
370 reasoning: AssistantReasoning::default(),
371 timestamp: IsoString::now(),
372 tool_calls: Vec::new(),
373 });
374 }
375
376 let reason = stop_reason.map_or_else(|| "Unknown".to_string(), |reason| format!("{reason:?}"));
377
378 self.context.add_message(ChatMessage::User {
379 content: vec![llm::ContentBlock::text(format!(
380 "<system-notification>The LLM API stopped with reason '{reason}'. Continue from where you left off and finish your task.</system-notification>"
381 ))],
382 timestamp: IsoString::now(),
383 });
384 }
385
386 async fn on_llm_event(&mut self, result: Result<LlmResponse, LlmError>, state: &mut IterationState) {
387 use LlmResponse::{
388 Done, EncryptedReasoning, Error, Reasoning, Start, Text, ToolRequestArg, ToolRequestComplete,
389 ToolRequestStart, Usage,
390 };
391
392 let response = match result {
393 Ok(response) => response,
394 Err(e) => {
395 self.on_llm_error(e, state).await;
396 return;
397 }
398 };
399
400 match response {
401 Start { message_id } => {
402 state.on_llm_start(message_id);
403 }
404
405 Text { chunk } => {
406 self.handle_llm_text(chunk, state).await;
407 }
408
409 Reasoning { chunk } => {
410 state.reasoning_summary_text.push_str(&chunk);
411 if let Some(id) = &state.current_message_id {
412 let _ = self
413 .message_tx
414 .send(AgentMessage::Thought {
415 message_id: id.clone(),
416 chunk,
417 is_complete: false,
418 model_name: self.llm.display_name(),
419 })
420 .await;
421 }
422 }
423
424 EncryptedReasoning { id, content } => {
425 if let Some(model) = self.llm.model() {
426 state.encrypted_reasoning = Some(EncryptedReasoningContent { id, model, content });
427 }
428 }
429
430 ToolRequestStart { id, name } => {
431 self.handle_tool_request_start(id, name).await;
432 }
433
434 ToolRequestArg { id, chunk } => {
435 self.handle_tool_request_arg(id, chunk).await;
436 }
437
438 ToolRequestComplete { tool_call } => {
439 self.handle_tool_completion(tool_call, state).await;
440 }
441
442 Done { stop_reason } => {
443 state.llm_done = true;
444 state.stop_reason = stop_reason;
445 }
446
447 Error { message } => {
448 let _ = self.message_tx.send(AgentMessage::Error { message }).await;
449 }
450
451 Usage { tokens: sample } => {
452 self.handle_llm_usage(sample).await;
453 }
454 }
455 }
456
457 async fn handle_llm_text(&mut self, chunk: String, state: &mut IterationState) {
458 state.message_content.push_str(&chunk);
459
460 if let Some(id) = &state.current_message_id {
461 let _ = self
462 .message_tx
463 .send(AgentMessage::Text {
464 message_id: id.clone(),
465 chunk,
466 is_complete: false,
467 model_name: self.llm.display_name(),
468 })
469 .await;
470 }
471 }
472
473 async fn handle_tool_request_start(&mut self, id: String, name: String) {
474 let request = ToolCallRequest { id: id.clone(), name, arguments: String::new() };
475 self.active_requests.insert(id, request.clone());
476
477 let _ = self.message_tx.send(AgentMessage::ToolCall { request, model_name: self.llm.display_name() }).await;
478 }
479
480 async fn handle_tool_request_arg(&mut self, id: String, chunk: String) {
481 let Some(request) = self.active_requests.get_mut(&id) else {
482 return;
483 };
484 request.arguments.push_str(&chunk);
485
486 let _ = self
487 .message_tx
488 .send(AgentMessage::ToolCallUpdate { tool_call_id: id, chunk, model_name: self.llm.display_name() })
489 .await;
490 }
491
492 async fn handle_tool_completion(&mut self, tool_call: ToolCallRequest, state: &mut IterationState) {
493 state.pending_tool_ids.insert(tool_call.id.clone());
494 debug_assert!(
495 self.active_requests.contains_key(&tool_call.id),
496 "tool call {} should already be in active_requests from handle_tool_request_start",
497 tool_call.id
498 );
499
500 let (tx, rx) = mpsc::channel(100);
501 let stream = ReceiverStream::new(rx).map(StreamEvent::ToolExecution);
502 let stream_key = tool_call.id.clone();
503 self.streams.insert(stream_key, Box::pin(stream));
504
505 if let Some(ref mcp_command_tx) = self.mcp_command_tx {
506 let mcp_future =
507 mcp_command_tx.send(McpCommand::ExecuteTool { request: tool_call, timeout: self.tool_timeout, tx });
508 if let Err(e) = mcp_future.await {
509 tracing::warn!("Failed to send tool request to MCP task: {:?}", e);
510 }
511 }
512 }
513
514 async fn handle_llm_usage(&mut self, sample: TokenUsage) {
515 self.token_tracker.record_usage(sample);
516 let ratio_pct = self.token_tracker.usage_ratio().map(|r| r * 100.0);
517 let remaining = self.token_tracker.tokens_remaining();
518 tracing::debug!(?sample, ?ratio_pct, ?remaining, "Token usage");
519
520 let _ = self.message_tx.send(self.context_usage_message()).await;
521
522 self.maybe_compact_context().await;
523 }
524
525 fn context_usage_message(&self) -> AgentMessage {
526 let last = self.token_tracker.last_usage();
527 AgentMessage::ContextUsageUpdate {
528 usage_ratio: self.token_tracker.usage_ratio(),
529 context_limit: self.token_tracker.context_limit(),
530 input_tokens: last.input_tokens,
531 output_tokens: last.output_tokens,
532 cache_read_tokens: last.cache_read_tokens,
533 cache_creation_tokens: last.cache_creation_tokens,
534 reasoning_tokens: last.reasoning_tokens,
535 total_input_tokens: self.token_tracker.total_input_tokens(),
536 total_output_tokens: self.token_tracker.total_output_tokens(),
537 total_cache_read_tokens: self.token_tracker.total_cache_read_tokens(),
538 total_cache_creation_tokens: self.token_tracker.total_cache_creation_tokens(),
539 total_reasoning_tokens: self.token_tracker.total_reasoning_tokens(),
540 }
541 }
542
543 async fn maybe_preflight_compact(&mut self) {
547 let Some(context_limit) = self.token_tracker.context_limit() else {
548 return;
549 };
550 let Some(config) = self.compaction_config.as_ref() else {
551 return;
552 };
553 let estimated = self.context.estimated_token_count();
554 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
555 let threshold = (f64::from(context_limit) * config.threshold).ceil() as u32;
556 if estimated >= threshold {
557 tracing::info!(
558 "Pre-flight compaction triggered: estimated {estimated} tokens >= {:.1}% of {context_limit} limit",
559 config.threshold * 100.0
560 );
561 if let CompactionOutcome::Failed(e) = self.compact_context().await {
562 tracing::warn!("Pre-flight compaction failed: {e}");
563 }
564 }
565 }
566
567 async fn maybe_compact_context(&mut self) {
569 if !self.compaction_config.as_ref().is_some_and(|config| self.token_tracker.should_compact(config.threshold)) {
570 return;
571 }
572
573 if let CompactionOutcome::Failed(error_message) = self.compact_context().await {
574 tracing::warn!("Context compaction failed: {}", error_message);
575 }
576 }
577
578 async fn compact_context(&mut self) -> CompactionOutcome {
579 let Some(ref _config) = self.compaction_config else {
580 tracing::warn!("Context compaction requested but compaction is disabled");
581 return CompactionOutcome::SkippedDisabled;
582 };
583
584 match self.token_tracker.usage_ratio() {
585 Some(usage_ratio) => {
586 tracing::info!(
587 "Starting context compaction - {} messages, {:.1}% of context limit",
588 self.context.message_count(),
589 usage_ratio * 100.0
590 );
591 }
592 None => {
593 tracing::info!(
594 "Starting context compaction - {} messages (context limit unknown)",
595 self.context.message_count(),
596 );
597 }
598 }
599
600 let _ = self
601 .message_tx
602 .send(AgentMessage::ContextCompactionStarted { message_count: self.context.message_count() })
603 .await;
604
605 let compactor = Compactor::new(self.llm.clone());
606
607 match compactor.compact(&self.context).await {
608 Ok(result) => {
609 tracing::info!("Context compacted: {} messages removed", result.messages_removed);
610
611 self.context = result.context;
612 self.token_tracker.reset_current_usage();
613
614 let _ = self
615 .message_tx
616 .send(AgentMessage::ContextCompactionResult {
617 summary: result.summary,
618 messages_removed: result.messages_removed,
619 })
620 .await;
621 CompactionOutcome::Compacted
622 }
623 Err(e) => CompactionOutcome::Failed(e.to_string()),
624 }
625 }
626
627 async fn on_tool_execution_event(&mut self, event: ToolExecutionEvent, state: &mut IterationState) {
628 match event {
629 ToolExecutionEvent::Started { tool_id, tool_name } => {
630 tracing::debug!("Tool execution started: {} ({})", tool_name, tool_id);
631 }
632
633 ToolExecutionEvent::Progress { tool_id, progress } => {
634 tracing::debug!(
635 "Tool progress for {}: {}/{}",
636 tool_id,
637 progress.progress,
638 progress.total.unwrap_or(0.0)
639 );
640
641 if let Some(request) = self.active_requests.get(&tool_id) {
642 let _ = self
643 .message_tx
644 .send(AgentMessage::ToolProgress {
645 request: request.clone(),
646 progress: progress.progress,
647 total: progress.total,
648 message: progress.message.clone(),
649 })
650 .await;
651 }
652 }
653
654 ToolExecutionEvent::Complete { tool_id: _, result, result_meta } => match result {
655 Ok(tool_result) => {
656 tracing::debug!("Tool result received: {} -> {}", tool_result.name, tool_result.result.len());
657
658 if state.pending_tool_ids.remove(&tool_result.id) {
659 self.active_requests.remove(&tool_result.id);
660 state.completed_tool_calls.push(Ok(tool_result.clone()));
661
662 let msg = AgentMessage::ToolResult {
663 result: tool_result,
664 result_meta,
665 model_name: self.llm.display_name(),
666 };
667
668 if let Err(e) = self.message_tx.send(msg).await {
669 tracing::warn!("Failed to send ToolCall completion message: {:?}", e);
670 }
671 } else {
672 tracing::debug!("Ignoring stale tool result for id: {}", tool_result.id);
673 }
674 }
675
676 Err(tool_error) => {
677 if state.pending_tool_ids.remove(&tool_error.id) {
678 self.active_requests.remove(&tool_error.id);
679 state.completed_tool_calls.push(Err(tool_error.clone()));
680
681 let _ = self
682 .message_tx
683 .send(AgentMessage::ToolError { error: tool_error, model_name: self.llm.display_name() })
684 .await;
685 }
686 }
687 },
688 }
689 }
690
691 fn update_context(
692 &mut self,
693 message_content: &str,
694 reasoning: AssistantReasoning,
695 completed_tools: Vec<Result<ToolCallResult, ToolCallError>>,
696 ) {
697 self.context.push_assistant_turn(message_content, reasoning, completed_tools);
698 }
699}
700
701#[derive(Debug, Clone, PartialEq, Eq)]
702enum CompactionOutcome {
703 Compacted,
704 SkippedDisabled,
705 Failed(String),
706}
707
708pub(crate) struct AutoContinue {
709 max: u32,
710 count: u32,
711}
712
713impl AutoContinue {
714 pub(crate) fn new(max: u32) -> Self {
715 Self { max, count: 0 }
716 }
717
718 fn reset(&mut self) {
719 self.count = 0;
720 }
721
722 fn should_continue(&self, stop_reason: Option<&StopReason>) -> bool {
723 matches!(stop_reason, Some(StopReason::Length)) && self.count < self.max
724 }
725
726 fn advance(&mut self) {
727 self.count += 1;
728 }
729
730 fn count(&self) -> u32 {
731 self.count
732 }
733
734 fn max(&self) -> u32 {
735 self.max
736 }
737}
738
739#[derive(Debug)]
740struct IterationState {
741 current_message_id: Option<String>,
742 message_content: String,
743 reasoning_summary_text: String,
744 encrypted_reasoning: Option<EncryptedReasoningContent>,
745 pending_tool_ids: HashSet<String>,
746 completed_tool_calls: Vec<Result<ToolCallResult, ToolCallError>>,
747 llm_done: bool,
748 stop_reason: Option<StopReason>,
749 retry_attempt: u32,
750}
751
752impl IterationState {
753 fn new() -> Self {
754 Self {
755 current_message_id: None,
756 message_content: String::new(),
757 reasoning_summary_text: String::new(),
758 encrypted_reasoning: None,
759 pending_tool_ids: HashSet::new(),
760 completed_tool_calls: Vec::new(),
761 llm_done: false,
762 stop_reason: None,
763 retry_attempt: 0,
764 }
765 }
766
767 fn on_llm_start(&mut self, message_id: String) {
768 self.current_message_id = Some(message_id);
769 self.message_content.clear();
770 self.reasoning_summary_text.clear();
771 self.encrypted_reasoning = None;
772 self.stop_reason = None;
773 }
774
775 fn is_complete(&self) -> bool {
776 self.llm_done && self.pending_tool_ids.is_empty()
777 }
778}
779
780#[cfg(test)]
781mod tests {
782 use crate::core::{AgentBuilder, Prompt};
783
784 use super::*;
785 use llm::{ContentBlock, testing::FakeLlmProvider};
786 use tokio::sync::mpsc;
787
788 #[tokio::test]
789 async fn replace_conversation_preserves_system_prompt_for_next_request() {
790 let llm = FakeLlmProvider::with_single_response(vec![LlmResponse::start("msg"), LlmResponse::done()]);
791
792 let captured_contexts = llm.captured_contexts();
793 let (tx, mut rx, handle) =
794 AgentBuilder::new(Arc::new(llm)).system_prompt(Prompt::text("original system")).spawn().await.unwrap();
795
796 tx.send(Command::AgentCommand(AgentCommand::ReplaceConversation(vec![
797 ChatMessage::User { content: vec![ContentBlock::text("old user")], timestamp: IsoString::now() },
798 ChatMessage::Assistant {
799 content: "old assistant".to_string(),
800 reasoning: AssistantReasoning::default(),
801 timestamp: IsoString::now(),
802 tool_calls: vec![],
803 },
804 ])))
805 .await
806 .unwrap();
807
808 tx.send(Command::UserCommand(UserCommand::Text { content: vec![ContentBlock::text("new user")] }))
809 .await
810 .unwrap();
811
812 while let Some(message) = rx.recv().await {
813 if matches!(message, AgentMessage::Done) {
814 break;
815 }
816 }
817
818 let contexts = captured_contexts.lock().unwrap();
819 let messages = contexts.last().expect("provider should receive a context").messages();
820 assert!(matches!(messages[0], ChatMessage::System { ref content, .. } if content == "original system"));
821 assert!(
822 matches!(messages[1], ChatMessage::User { ref content, .. } if content == &vec![llm::ContentBlock::text("old user")])
823 );
824 assert!(matches!(messages[2], ChatMessage::Assistant { ref content, .. } if content == "old assistant"));
825 assert!(
826 matches!(messages[3], ChatMessage::User { ref content, .. } if content == &vec![llm::ContentBlock::text("new user")])
827 );
828 handle.abort();
829 }
830
831 #[tokio::test]
832 async fn replace_conversation_preserves_token_usage() {
833 let llm = FakeLlmProvider::new(vec![vec![
834 LlmResponse::start("msg"),
835 LlmResponse::usage(800, 10),
836 LlmResponse::done(),
837 ]])
838 .with_context_window(Some(1000));
839 let (tx, mut rx, handle) = AgentBuilder::new(Arc::new(llm)).spawn().await.unwrap();
840
841 tx.send(Command::UserCommand(UserCommand::Text { content: vec![llm::ContentBlock::text("first user")] }))
842 .await
843 .unwrap();
844
845 while let Some(message) = rx.recv().await {
846 if matches!(message, AgentMessage::Done) {
847 break;
848 }
849 }
850
851 tx.send(Command::AgentCommand(AgentCommand::ReplaceConversation(vec![ChatMessage::User {
852 content: vec![ContentBlock::text("replacement user")],
853 timestamp: IsoString::now(),
854 }])))
855 .await
856 .unwrap();
857
858 let Some(AgentMessage::ContextUsageUpdate { input_tokens, usage_ratio, .. }) = rx.recv().await else {
859 panic!("expected context usage update after conversation replacement");
860 };
861
862 assert_eq!(input_tokens, 800);
863 assert_eq!(usage_ratio, Some(0.8));
864 handle.abort();
865 }
866
867 #[tokio::test]
868 async fn test_preflight_compaction_uses_configured_threshold() {
869 let llm = Arc::new(
870 FakeLlmProvider::with_single_response(vec![
871 LlmResponse::start("summary"),
872 LlmResponse::text("summary"),
873 LlmResponse::done(),
874 ])
875 .with_context_window(Some(100)),
876 );
877 let context = Context::new(
878 vec![ChatMessage::User {
879 content: vec![llm::ContentBlock::text("x".repeat(344))],
880 timestamp: IsoString::now(),
881 }],
882 vec![],
883 );
884 let (user_tx, user_rx) = mpsc::channel(1);
885 let (message_tx, _message_rx) = mpsc::channel(8);
886 drop(user_tx);
887
888 let mut agent = Agent::new(
889 AgentConfig {
890 llm,
891 context,
892 mcp_command_tx: None,
893 tool_timeout: Duration::from_secs(1),
894 compaction_config: Some(CompactionConfig::with_threshold(0.85)),
895 auto_continue: AutoContinue::new(0),
896 retry_config: RetryConfig::disabled(),
897 context_window: None,
898 prompt_cache: PromptCache::new(vec![]),
899 },
900 user_rx,
901 message_tx,
902 );
903
904 agent.maybe_preflight_compact().await;
905
906 assert!(
907 matches!(
908 agent.context.messages().as_slice(),
909 [ChatMessage::Summary { content, .. }] if content == "summary"
910 ),
911 "expected context to be compacted, got {:?}",
912 agent.context.messages()
913 );
914 }
915}