1use crate::context::{CompactionConfig, Compactor, TokenTracker};
2use crate::events::{AgentMessage, UserMessage};
3use crate::mcp::run_mcp_task::{McpCommand, ToolExecutionEvent};
4use futures::Stream;
5use llm::types::IsoString;
6use llm::{
7 AssistantReasoning, ChatMessage, Context, EncryptedReasoningContent, LlmError, LlmResponse, StopReason,
8 StreamingModelProvider, TokenUsage, ToolCallError, ToolCallRequest, ToolCallResult,
9};
10use std::collections::{HashMap, HashSet};
11use std::pin::Pin;
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::sync::mpsc;
15use tokio_stream::StreamExt;
16use tokio_stream::StreamMap;
17use tokio_stream::wrappers::ReceiverStream;
18
19#[derive(Debug)]
21enum StreamEvent {
22 Llm(Result<LlmResponse, LlmError>),
23 ToolExecution(ToolExecutionEvent),
24 UserMessage(UserMessage),
25}
26
27type EventStream = Pin<Box<dyn Stream<Item = StreamEvent> + Send>>;
28
29pub(crate) struct AgentConfig {
30 pub llm: Arc<dyn StreamingModelProvider>,
31 pub context: Context,
32 pub mcp_command_tx: Option<mpsc::Sender<McpCommand>>,
33 pub tool_timeout: Duration,
34 pub compaction_config: Option<CompactionConfig>,
35 pub auto_continue: AutoContinue,
36}
37
38pub struct Agent {
39 llm: Arc<dyn StreamingModelProvider>,
40 context: Context,
41 mcp_command_tx: Option<mpsc::Sender<McpCommand>>,
42 message_tx: mpsc::Sender<AgentMessage>,
43 streams: StreamMap<String, EventStream>,
44 tool_timeout: Duration,
45 token_tracker: TokenTracker,
46 compaction_config: Option<CompactionConfig>,
47 auto_continue: AutoContinue,
48 active_requests: HashMap<String, ToolCallRequest>,
49}
50
51impl Agent {
52 pub(crate) fn new(
53 config: AgentConfig,
54 user_message_rx: mpsc::Receiver<UserMessage>,
55 message_tx: mpsc::Sender<AgentMessage>,
56 ) -> Self {
57 let mut streams: StreamMap<String, EventStream> = StreamMap::new();
58 streams
59 .insert("user".to_string(), Box::pin(ReceiverStream::new(user_message_rx).map(StreamEvent::UserMessage)));
60
61 let context_limit = config.llm.context_window();
62
63 Self {
64 llm: config.llm,
65 context: config.context,
66 mcp_command_tx: config.mcp_command_tx,
67 message_tx,
68 streams,
69 tool_timeout: config.tool_timeout,
70 token_tracker: TokenTracker::new(context_limit),
71 compaction_config: config.compaction_config,
72 auto_continue: config.auto_continue,
73 active_requests: HashMap::new(),
74 }
75 }
76
77 pub fn current_model_display_name(&self) -> String {
78 self.llm.display_name()
79 }
80
81 pub fn token_tracker(&self) -> &TokenTracker {
83 &self.token_tracker
84 }
85
86 pub async fn run(mut self) {
87 let mut state = IterationState::new();
88
89 while let Some((_, event)) = self.streams.next().await {
90 use UserMessage::{Cancel, ClearContext, SetReasoningEffort, SwitchModel, Text, UpdateTools};
91 match event {
92 StreamEvent::UserMessage(Cancel) => {
93 self.on_user_cancel(&mut state).await;
94 }
95
96 StreamEvent::UserMessage(ClearContext) => {
97 self.on_user_clear_context(&mut state).await;
98 }
99
100 StreamEvent::UserMessage(Text { content }) => {
101 state = IterationState::new();
102 self.on_user_text(content);
103 }
104
105 StreamEvent::UserMessage(SwitchModel(new_provider)) => {
106 self.on_switch_model(new_provider).await;
107 }
108
109 StreamEvent::UserMessage(UpdateTools(tools)) => {
110 self.context.set_tools(tools);
111 }
112
113 StreamEvent::UserMessage(SetReasoningEffort(effort)) => {
114 self.context.set_reasoning_effort(effort);
115 }
116
117 StreamEvent::Llm(llm_event) => {
118 if !state.cancelled {
119 self.on_llm_event(llm_event, &mut state).await;
120 }
121 }
122
123 StreamEvent::ToolExecution(tool_event) => {
124 if !state.cancelled {
125 self.on_tool_execution_event(tool_event, &mut state).await;
126 }
127 }
128 }
129
130 if state.is_complete() {
131 let Some(id) = state.current_message_id.take() else {
132 continue;
133 };
134 let iteration = std::mem::replace(&mut state, IterationState::new());
135 self.on_iteration_complete(id, iteration).await;
136 }
137 }
138
139 tracing::debug!("Agent task shutting down - input channel closed");
140 }
141
142 async fn on_iteration_complete(&mut self, id: String, iteration: IterationState) {
143 let IterationState {
144 message_content,
145 reasoning_summary_text,
146 encrypted_reasoning,
147 completed_tool_calls,
148 stop_reason,
149 ..
150 } = iteration;
151 let has_tool_calls = !completed_tool_calls.is_empty();
152 let has_content = !message_content.is_empty() || has_tool_calls;
153
154 if !has_content && !self.auto_continue.should_continue(stop_reason.as_ref()) {
156 let _ = self.message_tx.send(AgentMessage::Done).await;
157 return;
158 }
159
160 let reasoning = AssistantReasoning::from_parts(reasoning_summary_text.clone(), encrypted_reasoning);
161 self.update_context(&message_content, reasoning, completed_tool_calls);
162
163 let _ = self
164 .message_tx
165 .send(AgentMessage::Text {
166 message_id: id.clone(),
167 chunk: message_content.clone(),
168 is_complete: true,
169 model_name: self.llm.display_name(),
170 })
171 .await;
172
173 if !reasoning_summary_text.is_empty() {
174 let _ = self
175 .message_tx
176 .send(AgentMessage::Thought {
177 message_id: id.clone(),
178 chunk: reasoning_summary_text,
179 is_complete: true,
180 model_name: self.llm.display_name(),
181 })
182 .await;
183 }
184
185 if has_tool_calls {
186 self.auto_continue.on_tool_calls();
187 self.maybe_preflight_compact().await;
188 self.start_llm_stream();
189 } else if self.auto_continue.should_continue(stop_reason.as_ref()) {
190 self.auto_continue.advance();
191 tracing::info!(
192 "LLM stopped with {:?}, auto-continuing (attempt {}/{})",
193 stop_reason,
194 self.auto_continue.count(),
195 self.auto_continue.max()
196 );
197
198 let _ = self
199 .message_tx
200 .send(AgentMessage::AutoContinue {
201 attempt: self.auto_continue.count(),
202 max_attempts: self.auto_continue.max(),
203 })
204 .await;
205
206 self.inject_continuation_prompt(&message_content, stop_reason.as_ref());
207 self.maybe_preflight_compact().await;
208 self.start_llm_stream();
209 } else {
210 tracing::debug!("LLM completed turn with stop reason: {:?}", stop_reason);
211 self.auto_continue.on_completion();
212 if let Err(e) = self.message_tx.send(AgentMessage::Done).await {
213 tracing::warn!("Failed to send Done message: {:?}", e);
214 }
215 }
216 }
217
218 async fn on_user_cancel(&mut self, state: &mut IterationState) {
219 state.cancelled = true;
220 self.streams.remove("llm");
221 let _ = self.message_tx.send(AgentMessage::Cancelled { message: "Processing cancelled".to_string() }).await;
222 let _ = self.message_tx.send(AgentMessage::Done).await;
223 }
224
225 async fn on_user_clear_context(&mut self, state: &mut IterationState) {
226 self.clear_active_streams();
227 self.active_requests.clear();
228 self.context.clear_conversation();
229 self.token_tracker.reset_current_usage();
230 self.auto_continue.on_completion();
231 *state = IterationState::new();
232
233 let _ = self.message_tx.send(AgentMessage::ContextCleared).await;
234 }
235
236 fn on_user_text(&mut self, content: Vec<llm::ContentBlock>) {
237 self.context.add_message(ChatMessage::User { content, timestamp: IsoString::now() });
238
239 self.start_llm_stream();
240 }
241
242 async fn on_switch_model(&mut self, new_provider: Box<dyn StreamingModelProvider>) {
243 let previous = self.llm.display_name();
244 let new_context_limit = new_provider.context_window();
245 self.llm = Arc::from(new_provider);
246 self.token_tracker.reset_current_usage();
247 self.token_tracker.set_context_limit(new_context_limit);
248 let new = self.llm.display_name();
249 let _ = self.message_tx.send(AgentMessage::ModelSwitched { previous, new }).await;
250
251 let _ = self.message_tx.send(self.context_usage_message()).await;
252 }
253
254 fn start_llm_stream(&mut self) {
255 self.streams.remove("llm");
256 let llm_stream = self.llm.stream_response(&self.context).map(StreamEvent::Llm);
257 self.streams.insert("llm".to_string(), Box::pin(llm_stream));
258 }
259
260 fn clear_active_streams(&mut self) {
261 self.streams.remove("llm");
262 for stream_key in self.active_requests.keys().cloned().collect::<Vec<_>>() {
263 self.streams.remove(&stream_key);
264 }
265 }
266
267 fn inject_continuation_prompt(&mut self, previous_response: &str, stop_reason: Option<&StopReason>) {
269 if !previous_response.is_empty() {
270 self.context.add_message(ChatMessage::Assistant {
271 content: previous_response.to_string(),
272 reasoning: AssistantReasoning::default(),
273 timestamp: IsoString::now(),
274 tool_calls: Vec::new(),
275 });
276 }
277
278 let reason = stop_reason.map_or_else(|| "Unknown".to_string(), |reason| format!("{reason:?}"));
279
280 self.context.add_message(ChatMessage::User {
281 content: vec![llm::ContentBlock::text(format!(
282 "<system-notification>The LLM API stopped with reason '{reason}'. Continue from where you left off and finish your task.</system-notification>"
283 ))],
284 timestamp: IsoString::now(),
285 });
286 }
287
288 async fn on_llm_event(&mut self, result: Result<LlmResponse, LlmError>, state: &mut IterationState) {
289 use LlmResponse::{
290 Done, EncryptedReasoning, Error, Reasoning, Start, Text, ToolRequestArg, ToolRequestComplete,
291 ToolRequestStart, Usage,
292 };
293
294 let response = match result {
295 Ok(response) => response,
296 Err(e) => {
297 let _ = self.message_tx.send(AgentMessage::Error { message: e.to_string() }).await;
298 return;
299 }
300 };
301
302 match response {
303 Start { message_id } => {
304 state.on_llm_start(message_id);
305 }
306
307 Text { chunk } => {
308 self.handle_llm_text(chunk, state).await;
309 }
310
311 Reasoning { chunk } => {
312 state.reasoning_summary_text.push_str(&chunk);
313 if let Some(id) = &state.current_message_id {
314 let _ = self
315 .message_tx
316 .send(AgentMessage::Thought {
317 message_id: id.clone(),
318 chunk,
319 is_complete: false,
320 model_name: self.llm.display_name(),
321 })
322 .await;
323 }
324 }
325
326 EncryptedReasoning { id, content } => {
327 if let Some(model) = self.llm.model() {
328 state.encrypted_reasoning = Some(EncryptedReasoningContent { id, model, content });
329 }
330 }
331
332 ToolRequestStart { id, name } => {
333 self.handle_tool_request_start(id, name).await;
334 }
335
336 ToolRequestArg { id, chunk } => {
337 self.handle_tool_request_arg(id, chunk).await;
338 }
339
340 ToolRequestComplete { tool_call } => {
341 self.handle_tool_completion(tool_call, state).await;
342 }
343
344 Done { stop_reason } => {
345 state.llm_done = true;
346 state.stop_reason = stop_reason;
347 }
348
349 Error { message } => {
350 let _ = self.message_tx.send(AgentMessage::Error { message }).await;
351 }
352
353 Usage { tokens: sample } => {
354 self.handle_llm_usage(sample).await;
355 }
356 }
357 }
358
359 async fn handle_llm_text(&mut self, chunk: String, state: &mut IterationState) {
360 state.message_content.push_str(&chunk);
361
362 if let Some(id) = &state.current_message_id {
363 let _ = self
364 .message_tx
365 .send(AgentMessage::Text {
366 message_id: id.clone(),
367 chunk,
368 is_complete: false,
369 model_name: self.llm.display_name(),
370 })
371 .await;
372 }
373 }
374
375 async fn handle_tool_request_start(&mut self, id: String, name: String) {
376 let request = ToolCallRequest { id: id.clone(), name, arguments: String::new() };
377 self.active_requests.insert(id, request.clone());
378
379 let _ = self.message_tx.send(AgentMessage::ToolCall { request, model_name: self.llm.display_name() }).await;
380 }
381
382 async fn handle_tool_request_arg(&mut self, id: String, chunk: String) {
383 let Some(request) = self.active_requests.get_mut(&id) else {
384 return;
385 };
386 request.arguments.push_str(&chunk);
387
388 let _ = self
389 .message_tx
390 .send(AgentMessage::ToolCallUpdate { tool_call_id: id, chunk, model_name: self.llm.display_name() })
391 .await;
392 }
393
394 async fn handle_tool_completion(&mut self, tool_call: ToolCallRequest, state: &mut IterationState) {
395 state.pending_tool_ids.insert(tool_call.id.clone());
396 debug_assert!(
397 self.active_requests.contains_key(&tool_call.id),
398 "tool call {} should already be in active_requests from handle_tool_request_start",
399 tool_call.id
400 );
401
402 let (tx, rx) = mpsc::channel(100);
403 let stream = ReceiverStream::new(rx).map(StreamEvent::ToolExecution);
404 let stream_key = tool_call.id.clone();
405 self.streams.insert(stream_key, Box::pin(stream));
406
407 if let Some(ref mcp_command_tx) = self.mcp_command_tx {
408 let mcp_future =
409 mcp_command_tx.send(McpCommand::ExecuteTool { request: tool_call, timeout: self.tool_timeout, tx });
410 if let Err(e) = mcp_future.await {
411 tracing::warn!("Failed to send tool request to MCP task: {:?}", e);
412 }
413 }
414 }
415
416 async fn handle_llm_usage(&mut self, sample: TokenUsage) {
417 self.token_tracker.record_usage(sample);
418 let ratio_pct = self.token_tracker.usage_ratio().map(|r| r * 100.0);
419 let remaining = self.token_tracker.tokens_remaining();
420 tracing::debug!(?sample, ?ratio_pct, ?remaining, "Token usage");
421
422 let _ = self.message_tx.send(self.context_usage_message()).await;
423
424 self.maybe_compact_context().await;
425 }
426
427 fn context_usage_message(&self) -> AgentMessage {
428 let last = self.token_tracker.last_usage();
429 AgentMessage::ContextUsageUpdate {
430 usage_ratio: self.token_tracker.usage_ratio(),
431 context_limit: self.token_tracker.context_limit(),
432 input_tokens: last.input_tokens,
433 output_tokens: last.output_tokens,
434 cache_read_tokens: last.cache_read_tokens,
435 cache_creation_tokens: last.cache_creation_tokens,
436 reasoning_tokens: last.reasoning_tokens,
437 total_input_tokens: self.token_tracker.total_input_tokens(),
438 total_output_tokens: self.token_tracker.total_output_tokens(),
439 total_cache_read_tokens: self.token_tracker.total_cache_read_tokens(),
440 total_cache_creation_tokens: self.token_tracker.total_cache_creation_tokens(),
441 total_reasoning_tokens: self.token_tracker.total_reasoning_tokens(),
442 }
443 }
444
445 async fn maybe_preflight_compact(&mut self) {
449 let Some(context_limit) = self.token_tracker.context_limit() else {
450 return;
451 };
452 let Some(config) = self.compaction_config.as_ref() else {
453 return;
454 };
455 let estimated = self.context.estimated_token_count();
456 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
457 let threshold = (f64::from(context_limit) * config.threshold).ceil() as u32;
458 if estimated >= threshold {
459 tracing::info!(
460 "Pre-flight compaction triggered: estimated {estimated} tokens >= {:.1}% of {context_limit} limit",
461 config.threshold * 100.0
462 );
463 if let CompactionOutcome::Failed(e) = self.compact_context().await {
464 tracing::warn!("Pre-flight compaction failed: {e}");
465 }
466 }
467 }
468
469 async fn maybe_compact_context(&mut self) {
471 if !self.compaction_config.as_ref().is_some_and(|config| self.token_tracker.should_compact(config.threshold)) {
472 return;
473 }
474
475 if let CompactionOutcome::Failed(error_message) = self.compact_context().await {
476 tracing::warn!("Context compaction failed: {}", error_message);
477 }
478 }
479
480 async fn compact_context(&mut self) -> CompactionOutcome {
481 let Some(ref _config) = self.compaction_config else {
482 tracing::warn!("Context compaction requested but compaction is disabled");
483 return CompactionOutcome::SkippedDisabled;
484 };
485
486 match self.token_tracker.usage_ratio() {
487 Some(usage_ratio) => {
488 tracing::info!(
489 "Starting context compaction - {} messages, {:.1}% of context limit",
490 self.context.message_count(),
491 usage_ratio * 100.0
492 );
493 }
494 None => {
495 tracing::info!(
496 "Starting context compaction - {} messages (context limit unknown)",
497 self.context.message_count(),
498 );
499 }
500 }
501
502 let _ = self
503 .message_tx
504 .send(AgentMessage::ContextCompactionStarted { message_count: self.context.message_count() })
505 .await;
506
507 let compactor = Compactor::new(self.llm.clone());
508
509 match compactor.compact(&self.context).await {
510 Ok(result) => {
511 tracing::info!("Context compacted: {} messages removed", result.messages_removed);
512
513 self.context = result.context;
514 self.token_tracker.reset_current_usage();
515
516 let _ = self
517 .message_tx
518 .send(AgentMessage::ContextCompactionResult {
519 summary: result.summary,
520 messages_removed: result.messages_removed,
521 })
522 .await;
523 CompactionOutcome::Compacted
524 }
525 Err(e) => CompactionOutcome::Failed(e.to_string()),
526 }
527 }
528
529 async fn on_tool_execution_event(&mut self, event: ToolExecutionEvent, state: &mut IterationState) {
530 match event {
531 ToolExecutionEvent::Started { tool_id, tool_name } => {
532 tracing::debug!("Tool execution started: {} ({})", tool_name, tool_id);
533 }
534
535 ToolExecutionEvent::Progress { tool_id, progress } => {
536 tracing::debug!(
537 "Tool progress for {}: {}/{}",
538 tool_id,
539 progress.progress,
540 progress.total.unwrap_or(0.0)
541 );
542
543 if let Some(request) = self.active_requests.get(&tool_id) {
544 let _ = self
545 .message_tx
546 .send(AgentMessage::ToolProgress {
547 request: request.clone(),
548 progress: progress.progress,
549 total: progress.total,
550 message: progress.message.clone(),
551 })
552 .await;
553 }
554 }
555
556 ToolExecutionEvent::Complete { tool_id: _, result, result_meta } => match result {
557 Ok(tool_result) => {
558 tracing::debug!("Tool result received: {} -> {}", tool_result.name, tool_result.result.len());
559
560 if state.pending_tool_ids.remove(&tool_result.id) {
561 self.active_requests.remove(&tool_result.id);
562 state.completed_tool_calls.push(Ok(tool_result.clone()));
563
564 let msg = AgentMessage::ToolResult {
565 result: tool_result,
566 result_meta,
567 model_name: self.llm.display_name(),
568 };
569
570 if let Err(e) = self.message_tx.send(msg).await {
571 tracing::warn!("Failed to send ToolCall completion message: {:?}", e);
572 }
573 } else {
574 tracing::debug!("Ignoring stale tool result for id: {}", tool_result.id);
575 }
576 }
577
578 Err(tool_error) => {
579 if state.pending_tool_ids.remove(&tool_error.id) {
580 self.active_requests.remove(&tool_error.id);
581 state.completed_tool_calls.push(Err(tool_error.clone()));
582
583 let _ = self
584 .message_tx
585 .send(AgentMessage::ToolError { error: tool_error, model_name: self.llm.display_name() })
586 .await;
587 }
588 }
589 },
590 }
591 }
592
593 fn update_context(
594 &mut self,
595 message_content: &str,
596 reasoning: AssistantReasoning,
597 completed_tools: Vec<Result<ToolCallResult, ToolCallError>>,
598 ) {
599 self.context.push_assistant_turn(message_content, reasoning, completed_tools);
600 }
601}
602
603#[derive(Debug, Clone, PartialEq, Eq)]
604enum CompactionOutcome {
605 Compacted,
606 SkippedDisabled,
607 Failed(String),
608}
609
610pub(crate) struct AutoContinue {
611 max: u32,
612 count: u32,
613}
614
615impl AutoContinue {
616 pub(crate) fn new(max: u32) -> Self {
617 Self { max, count: 0 }
618 }
619
620 fn on_tool_calls(&mut self) {
621 self.count = 0;
622 }
623
624 fn on_completion(&mut self) {
625 self.count = 0;
626 }
627
628 fn should_continue(&self, stop_reason: Option<&StopReason>) -> bool {
629 matches!(stop_reason, Some(StopReason::Length)) && self.count < self.max
630 }
631
632 fn advance(&mut self) {
633 self.count += 1;
634 }
635
636 fn count(&self) -> u32 {
637 self.count
638 }
639
640 fn max(&self) -> u32 {
641 self.max
642 }
643}
644
645#[derive(Debug)]
646struct IterationState {
647 current_message_id: Option<String>,
648 message_content: String,
649 reasoning_summary_text: String,
650 encrypted_reasoning: Option<EncryptedReasoningContent>,
651 pending_tool_ids: HashSet<String>,
652 completed_tool_calls: Vec<Result<ToolCallResult, ToolCallError>>,
653 llm_done: bool,
654 stop_reason: Option<StopReason>,
655 cancelled: bool,
656}
657
658impl IterationState {
659 fn new() -> Self {
660 Self {
661 current_message_id: None,
662 message_content: String::new(),
663 reasoning_summary_text: String::new(),
664 encrypted_reasoning: None,
665 pending_tool_ids: HashSet::new(),
666 completed_tool_calls: Vec::new(),
667 llm_done: false,
668 stop_reason: None,
669 cancelled: false,
670 }
671 }
672
673 fn on_llm_start(&mut self, message_id: String) {
674 self.current_message_id = Some(message_id);
675 self.message_content.clear();
676 self.reasoning_summary_text.clear();
677 self.encrypted_reasoning = None;
678 self.stop_reason = None;
679 }
680
681 fn is_complete(&self) -> bool {
682 self.llm_done && self.pending_tool_ids.is_empty()
683 }
684}
685
686#[cfg(test)]
687mod tests {
688 use super::*;
689 use llm::testing::FakeLlmProvider;
690 use tokio::sync::mpsc;
691
692 #[tokio::test]
693 async fn test_preflight_compaction_uses_configured_threshold() {
694 let llm = Arc::new(
695 FakeLlmProvider::with_single_response(vec![
696 LlmResponse::start("summary"),
697 LlmResponse::text("summary"),
698 LlmResponse::done(),
699 ])
700 .with_context_window(Some(100)),
701 );
702 let context = Context::new(
703 vec![ChatMessage::User {
704 content: vec![llm::ContentBlock::text("x".repeat(344))],
705 timestamp: IsoString::now(),
706 }],
707 vec![],
708 );
709 let (user_tx, user_rx) = mpsc::channel(1);
710 let (message_tx, _message_rx) = mpsc::channel(8);
711 drop(user_tx);
712
713 let mut agent = Agent::new(
714 AgentConfig {
715 llm,
716 context,
717 mcp_command_tx: None,
718 tool_timeout: Duration::from_secs(1),
719 compaction_config: Some(CompactionConfig::with_threshold(0.85)),
720 auto_continue: AutoContinue::new(0),
721 },
722 user_rx,
723 message_tx,
724 );
725
726 agent.maybe_preflight_compact().await;
727
728 assert!(
729 matches!(
730 agent.context.messages().as_slice(),
731 [ChatMessage::Summary { content, .. }] if content == "summary"
732 ),
733 "expected context to be compacted, got {:?}",
734 agent.context.messages()
735 );
736 }
737}