Skip to main content

agent_code_lib/query/
mod.rs

1//! Query engine: the core agent loop.
2//!
3//! Implements the agentic cycle:
4//!
5//! 1. Auto-compact if context nears the window limit
6//! 2. Microcompact stale tool results
7//! 3. Call LLM with streaming
8//! 4. Accumulate response content blocks
9//! 5. Handle errors (prompt-too-long, rate limits, max-output-tokens)
10//! 6. Extract tool_use blocks
11//! 7. Execute tools (concurrent/serial batching)
12//! 8. Inject tool results into history
13//! 9. Repeat from step 1 until no tool_use or max turns
14
15pub mod source;
16
17use std::path::PathBuf;
18use std::sync::Arc;
19
20use tokio_util::sync::CancellationToken;
21use tracing::{debug, info, warn};
22use uuid::Uuid;
23
24use crate::hooks::{HookEvent, HookRegistry};
25use crate::llm::message::*;
26use crate::llm::provider::{Provider, ProviderError, ProviderRequest};
27use crate::llm::stream::StreamEvent;
28use crate::permissions::PermissionChecker;
29use crate::services::compact::{self, CompactTracking, MAX_OUTPUT_TOKENS_RECOVERY_LIMIT};
30use crate::services::tokens;
31use crate::state::AppState;
32use crate::tools::ToolContext;
33use crate::tools::executor::{execute_tool_calls, extract_tool_calls};
34use crate::tools::registry::ToolRegistry;
35
36/// Configuration for the query engine.
37pub struct QueryEngineConfig {
38    pub max_turns: Option<usize>,
39    pub verbose: bool,
40    /// Whether this is a non-interactive (one-shot) session.
41    pub unattended: bool,
42}
43
44/// The query engine orchestrates the agent loop.
45///
46/// Central coordinator that drives the LLM → tools → LLM cycle.
47/// Manages conversation history, context compaction, tool execution,
48/// error recovery, and hook dispatch. Create via [`QueryEngine::new`].
49pub struct QueryEngine {
50    llm: Arc<dyn Provider>,
51    tools: ToolRegistry,
52    file_cache: Arc<tokio::sync::Mutex<crate::services::file_cache::FileCache>>,
53    permissions: Arc<PermissionChecker>,
54    state: AppState,
55    config: QueryEngineConfig,
56    /// Shared handle so the signal handler always cancels the current token.
57    cancel_shared: Arc<std::sync::Mutex<CancellationToken>>,
58    /// Per-turn cancellation token (cloned from cancel_shared at turn start).
59    cancel: CancellationToken,
60    hooks: HookRegistry,
61    cache_tracker: crate::services::cache_tracking::CacheTracker,
62    denial_tracker: Arc<tokio::sync::Mutex<crate::permissions::tracking::DenialTracker>>,
63    extraction_state: Arc<tokio::sync::Mutex<crate::memory::extraction::ExtractionState>>,
64    session_allows: Arc<tokio::sync::Mutex<std::collections::HashSet<String>>>,
65    permission_prompter: Option<Arc<dyn crate::tools::PermissionPrompter>>,
66    /// Cached system prompt (rebuilt only when inputs change).
67    cached_system_prompt: Option<(u64, String)>, // (hash, prompt)
68}
69
70/// Callback for streaming events to the UI.
71pub trait StreamSink: Send + Sync {
72    fn on_text(&self, text: &str);
73    fn on_tool_start(&self, tool_name: &str, input: &serde_json::Value);
74    fn on_tool_result(&self, tool_name: &str, result: &crate::tools::ToolResult);
75    fn on_thinking(&self, _text: &str) {}
76    fn on_turn_complete(&self, _turn: usize) {}
77    fn on_error(&self, error: &str);
78    fn on_usage(&self, _usage: &Usage) {}
79    fn on_compact(&self, _freed_tokens: u64) {}
80    fn on_warning(&self, _msg: &str) {}
81}
82
83/// A no-op stream sink for non-interactive mode.
84pub struct NullSink;
85impl StreamSink for NullSink {
86    fn on_text(&self, _: &str) {}
87    fn on_tool_start(&self, _: &str, _: &serde_json::Value) {}
88    fn on_tool_result(&self, _: &str, _: &crate::tools::ToolResult) {}
89    fn on_error(&self, _: &str) {}
90}
91
92impl QueryEngine {
93    pub fn new(
94        llm: Arc<dyn Provider>,
95        tools: ToolRegistry,
96        permissions: PermissionChecker,
97        state: AppState,
98        config: QueryEngineConfig,
99    ) -> Self {
100        let cancel = CancellationToken::new();
101        let cancel_shared = Arc::new(std::sync::Mutex::new(cancel.clone()));
102        Self {
103            llm,
104            tools,
105            file_cache: Arc::new(tokio::sync::Mutex::new(
106                crate::services::file_cache::FileCache::new(),
107            )),
108            permissions: Arc::new(permissions),
109            state,
110            config,
111            cancel,
112            cancel_shared,
113            hooks: HookRegistry::new(),
114            cache_tracker: crate::services::cache_tracking::CacheTracker::new(),
115            denial_tracker: Arc::new(tokio::sync::Mutex::new(
116                crate::permissions::tracking::DenialTracker::new(100),
117            )),
118            extraction_state: Arc::new(tokio::sync::Mutex::new(
119                crate::memory::extraction::ExtractionState::new(),
120            )),
121            session_allows: Arc::new(tokio::sync::Mutex::new(std::collections::HashSet::new())),
122            permission_prompter: None,
123            cached_system_prompt: None,
124        }
125    }
126
127    /// Load hooks from configuration into the registry.
128    pub fn load_hooks(&mut self, hook_defs: &[crate::hooks::HookDefinition]) {
129        for def in hook_defs {
130            self.hooks.register(def.clone());
131        }
132        if !hook_defs.is_empty() {
133            tracing::info!("Loaded {} hooks from config", hook_defs.len());
134        }
135    }
136
137    /// Get a reference to the app state.
138    pub fn state(&self) -> &AppState {
139        &self.state
140    }
141
142    /// Get a mutable reference to the app state.
143    pub fn state_mut(&mut self) -> &mut AppState {
144        &mut self.state
145    }
146
147    /// Install a Ctrl+C handler that triggers the cancellation token.
148    /// Call this once at startup. Subsequent Ctrl+C signals during a
149    /// turn will cancel the active operation instead of killing the process.
150    pub fn install_signal_handler(&self) {
151        let shared = self.cancel_shared.clone();
152        tokio::spawn(async move {
153            let mut pending = false;
154            loop {
155                if tokio::signal::ctrl_c().await.is_ok() {
156                    let token = shared.lock().unwrap().clone();
157                    if token.is_cancelled() && pending {
158                        // Second Ctrl+C after cancel — hard exit.
159                        std::process::exit(130);
160                    }
161                    token.cancel();
162                    pending = true;
163                }
164            }
165        });
166    }
167
168    /// Run a single turn: process user input through the full agent loop.
169    pub async fn run_turn(&mut self, user_input: &str) -> crate::error::Result<()> {
170        self.run_turn_with_sink(user_input, &NullSink).await
171    }
172
173    /// Run a turn with a stream sink for real-time UI updates.
174    pub async fn run_turn_with_sink(
175        &mut self,
176        user_input: &str,
177        sink: &dyn StreamSink,
178    ) -> crate::error::Result<()> {
179        // Reset cancellation token for this turn. The shared handle is
180        // updated so the signal handler always cancels the current token.
181        self.cancel = CancellationToken::new();
182        *self.cancel_shared.lock().unwrap() = self.cancel.clone();
183
184        // Add the user message to history.
185        let user_msg = user_message(user_input);
186        self.state.push_message(user_msg);
187
188        let max_turns = self.config.max_turns.unwrap_or(50);
189        let mut compact_tracking = CompactTracking::default();
190        let mut retry_state = crate::llm::retry::RetryState::default();
191        let retry_config = crate::llm::retry::RetryConfig::default();
192        let mut max_output_recovery_count = 0u32;
193
194        // Agent loop: budget check → normalize → compact → call LLM → execute tools → repeat.
195        for turn in 0..max_turns {
196            self.state.turn_count = turn + 1;
197            self.state.is_query_active = true;
198
199            // Budget check before each turn.
200            let budget_config = crate::services::budget::BudgetConfig::default();
201            match crate::services::budget::check_budget(
202                self.state.total_cost_usd,
203                self.state.total_usage.total(),
204                &budget_config,
205            ) {
206                crate::services::budget::BudgetDecision::Stop { message } => {
207                    sink.on_warning(&message);
208                    self.state.is_query_active = false;
209                    return Ok(());
210                }
211                crate::services::budget::BudgetDecision::ContinueWithWarning {
212                    message, ..
213                } => {
214                    sink.on_warning(&message);
215                }
216                crate::services::budget::BudgetDecision::Continue => {}
217            }
218
219            // Normalize messages for API compatibility.
220            crate::llm::normalize::ensure_tool_result_pairing(&mut self.state.messages);
221            crate::llm::normalize::strip_empty_blocks(&mut self.state.messages);
222            crate::llm::normalize::remove_empty_messages(&mut self.state.messages);
223            crate::llm::normalize::cap_document_blocks(&mut self.state.messages, 500_000);
224            crate::llm::normalize::merge_consecutive_user_messages(&mut self.state.messages);
225
226            debug!("Agent turn {}/{}", turn + 1, max_turns);
227
228            let mut model = self.state.config.api.model.clone();
229
230            // Step 1: Auto-compact if context is too large.
231            if compact::should_auto_compact(self.state.history(), &model, &compact_tracking) {
232                let token_count = tokens::estimate_context_tokens(self.state.history());
233                let threshold = compact::auto_compact_threshold(&model);
234                info!("Auto-compact triggered: {token_count} tokens >= {threshold} threshold");
235
236                // Microcompact first: clear stale tool results.
237                let freed = compact::microcompact(&mut self.state.messages, 5);
238                if freed > 0 {
239                    sink.on_compact(freed);
240                    info!("Microcompact freed ~{freed} tokens");
241                }
242
243                // Check if microcompact was enough.
244                let post_mc_tokens = tokens::estimate_context_tokens(self.state.history());
245                if post_mc_tokens >= threshold {
246                    // Full LLM-based compaction: summarize older messages.
247                    info!("Microcompact insufficient, attempting LLM compaction");
248                    match compact::compact_with_llm(&mut self.state.messages, &*self.llm, &model)
249                        .await
250                    {
251                        Some(removed) => {
252                            info!("LLM compaction removed {removed} messages");
253                            compact_tracking.was_compacted = true;
254                            compact_tracking.consecutive_failures = 0;
255                        }
256                        None => {
257                            compact_tracking.consecutive_failures += 1;
258                            warn!(
259                                "LLM compaction failed (attempt {})",
260                                compact_tracking.consecutive_failures
261                            );
262                            // Fallback: context collapse (snip middle messages).
263                            let effective = compact::effective_context_window(&model);
264                            if let Some(collapse) =
265                                crate::services::context_collapse::collapse_to_budget(
266                                    self.state.history(),
267                                    effective,
268                                )
269                            {
270                                info!(
271                                    "Context collapse snipped {} messages, freed ~{} tokens",
272                                    collapse.snipped_count, collapse.tokens_freed
273                                );
274                                self.state.messages = collapse.api_messages;
275                                sink.on_compact(collapse.tokens_freed);
276                            } else {
277                                // Last resort: aggressive microcompact.
278                                let freed2 = compact::microcompact(&mut self.state.messages, 2);
279                                if freed2 > 0 {
280                                    sink.on_compact(freed2);
281                                }
282                            }
283                        }
284                    }
285                }
286            }
287
288            // Inject compaction reminder if compacted and feature enabled.
289            if compact_tracking.was_compacted && self.state.config.features.compaction_reminders {
290                let reminder = user_message(
291                    "<system-reminder>Context was automatically compacted. \
292                     Earlier messages were summarized. If you need details from \
293                     before compaction, ask the user or re-read the relevant files.</system-reminder>",
294                );
295                self.state.push_message(reminder);
296                compact_tracking.was_compacted = false; // Only remind once per compaction.
297            }
298
299            // Step 2: Check token warning state.
300            let warning = compact::token_warning_state(self.state.history(), &model);
301            if warning.is_blocking {
302                sink.on_warning("Context window nearly full. Consider starting a new session.");
303            } else if warning.is_above_warning {
304                sink.on_warning(&format!("Context {}% remaining", warning.percent_left));
305            }
306
307            // Step 3: Build and send the API request.
308            // Memoize: only rebuild system prompt when inputs change.
309            let prompt_hash = {
310                use std::hash::{Hash, Hasher};
311                let mut h = std::collections::hash_map::DefaultHasher::new();
312                self.state.config.api.model.hash(&mut h);
313                self.state.cwd.hash(&mut h);
314                self.state.config.mcp_servers.len().hash(&mut h);
315                self.tools.all().len().hash(&mut h);
316                h.finish()
317            };
318            let system_prompt = if let Some((cached_hash, ref cached)) = self.cached_system_prompt
319                && cached_hash == prompt_hash
320            {
321                cached.clone()
322            } else {
323                let prompt = build_system_prompt(&self.tools, &self.state);
324                self.cached_system_prompt = Some((prompt_hash, prompt.clone()));
325                prompt
326            };
327            // Use core schemas (deferred tools loaded on demand via ToolSearch).
328            let tool_schemas = self.tools.core_schemas();
329
330            // Escalate max_tokens after a max_output recovery (8k → 64k).
331            let base_tokens = self.state.config.api.max_output_tokens.unwrap_or(16384);
332            let effective_tokens = if max_output_recovery_count > 0 {
333                base_tokens.max(65536) // Escalate to at least 64k after first recovery
334            } else {
335                base_tokens
336            };
337
338            let request = ProviderRequest {
339                messages: self.state.history().to_vec(),
340                system_prompt: system_prompt.clone(),
341                tools: tool_schemas.clone(),
342                model: model.clone(),
343                max_tokens: effective_tokens,
344                temperature: None,
345                enable_caching: self.state.config.features.prompt_caching,
346                tool_choice: Default::default(),
347                metadata: None,
348            };
349
350            let mut rx = match self.llm.stream(&request).await {
351                Ok(rx) => {
352                    retry_state.reset();
353                    rx
354                }
355                Err(e) => {
356                    let retryable = match &e {
357                        ProviderError::RateLimited { retry_after_ms } => {
358                            crate::llm::retry::RetryableError::RateLimited {
359                                retry_after: *retry_after_ms,
360                            }
361                        }
362                        ProviderError::Overloaded => crate::llm::retry::RetryableError::Overloaded,
363                        ProviderError::Network(_) => {
364                            crate::llm::retry::RetryableError::StreamInterrupted
365                        }
366                        other => crate::llm::retry::RetryableError::NonRetryable(other.to_string()),
367                    };
368
369                    match retry_state.next_action(&retryable, &retry_config) {
370                        crate::llm::retry::RetryAction::Retry { after } => {
371                            warn!("Retrying in {}ms", after.as_millis());
372                            tokio::time::sleep(after).await;
373                            continue;
374                        }
375                        crate::llm::retry::RetryAction::FallbackModel => {
376                            // Switch to a smaller/cheaper model for this turn.
377                            let fallback = get_fallback_model(&model);
378                            sink.on_warning(&format!("Falling back from {model} to {fallback}"));
379                            model = fallback;
380                            continue;
381                        }
382                        crate::llm::retry::RetryAction::Abort(reason) => {
383                            // Unattended retry: in non-interactive mode, retry
384                            // capacity errors with longer backoff instead of aborting.
385                            if self.config.unattended
386                                && self.state.config.features.unattended_retry
387                                && matches!(
388                                    &e,
389                                    ProviderError::Overloaded | ProviderError::RateLimited { .. }
390                                )
391                            {
392                                warn!("Unattended retry: waiting 30s for capacity");
393                                tokio::time::sleep(std::time::Duration::from_secs(30)).await;
394                                continue;
395                            }
396                            // Before giving up, try reactive compact for size errors.
397                            // Two-stage recovery: context collapse first, then microcompact.
398                            if let ProviderError::RequestTooLarge(body) = &e {
399                                let gap = compact::parse_prompt_too_long_gap(body);
400
401                                // Stage 1: Context collapse (snip middle messages).
402                                let effective = compact::effective_context_window(&model);
403                                if let Some(collapse) =
404                                    crate::services::context_collapse::collapse_to_budget(
405                                        self.state.history(),
406                                        effective,
407                                    )
408                                {
409                                    info!(
410                                        "Reactive collapse: snipped {} messages, freed ~{} tokens",
411                                        collapse.snipped_count, collapse.tokens_freed
412                                    );
413                                    self.state.messages = collapse.api_messages;
414                                    sink.on_compact(collapse.tokens_freed);
415                                    continue;
416                                }
417
418                                // Stage 2: Aggressive microcompact.
419                                let freed = compact::microcompact(&mut self.state.messages, 1);
420                                if freed > 0 {
421                                    sink.on_compact(freed);
422                                    info!(
423                                        "Reactive microcompact freed ~{freed} tokens (gap: {gap:?})"
424                                    );
425                                    continue;
426                                }
427                            }
428                            sink.on_error(&reason);
429                            self.state.is_query_active = false;
430                            return Err(crate::error::Error::Other(e.to_string()));
431                        }
432                    }
433                }
434            };
435
436            // Step 4: Stream response. Start executing read-only tools
437            // as their input completes (streaming tool execution).
438            let mut content_blocks = Vec::new();
439            let mut usage = Usage::default();
440            let mut stop_reason: Option<StopReason> = None;
441            let mut got_error = false;
442            let mut error_text = String::new();
443
444            // Streaming tool handles: tools kicked off during streaming.
445            let mut streaming_tool_handles: Vec<(
446                String,
447                String,
448                tokio::task::JoinHandle<crate::tools::ToolResult>,
449            )> = Vec::new();
450
451            let mut cancelled = false;
452            loop {
453                tokio::select! {
454                    event = rx.recv() => {
455                        match event {
456                            Some(StreamEvent::TextDelta(text)) => {
457                                sink.on_text(&text);
458                            }
459                            Some(StreamEvent::ContentBlockComplete(block)) => {
460                                if let ContentBlock::ToolUse {
461                                    ref id,
462                                    ref name,
463                                    ref input,
464                                } = block
465                                {
466                                    sink.on_tool_start(name, input);
467
468                                    // Start read-only tools immediately during streaming.
469                                    if let Some(tool) = self.tools.get(name)
470                                        && tool.is_read_only()
471                                        && tool.is_concurrency_safe()
472                                    {
473                                        let tool = tool.clone();
474                                        let input = input.clone();
475                                        let cwd = std::path::PathBuf::from(&self.state.cwd);
476                                        let cancel = self.cancel.clone();
477                                        let perm = self.permissions.clone();
478                                        let tool_id = id.clone();
479                                        let tool_name = name.clone();
480
481                                        let handle = tokio::spawn(async move {
482                                            match tool
483                                                .call(
484                                                    input,
485                                                    &ToolContext {
486                                                        cwd,
487                                                        cancel,
488                                                        permission_checker: perm.clone(),
489                                                        verbose: false,
490                                                        plan_mode: false,
491                                                        file_cache: None,
492                                                        denial_tracker: None,
493                                                        task_manager: None,
494                                                        session_allows: None,
495                                                        permission_prompter: None,
496                                                        sandbox: None,
497                                                    },
498                                                )
499                                                .await
500                                            {
501                                                Ok(r) => r,
502                                                Err(e) => crate::tools::ToolResult::error(e.to_string()),
503                                            }
504                                        });
505
506                                        streaming_tool_handles.push((tool_id, tool_name, handle));
507                                    }
508                                }
509                                if let ContentBlock::Thinking { ref thinking, .. } = block {
510                                    sink.on_thinking(thinking);
511                                }
512                                content_blocks.push(block);
513                            }
514                            Some(StreamEvent::Done {
515                                usage: u,
516                                stop_reason: sr,
517                            }) => {
518                                usage = u;
519                                stop_reason = sr;
520                                sink.on_usage(&usage);
521                            }
522                            Some(StreamEvent::Error(msg)) => {
523                                got_error = true;
524                                error_text = msg.clone();
525                                sink.on_error(&msg);
526                            }
527                            Some(_) => {}
528                            None => break,
529                        }
530                    }
531                    _ = self.cancel.cancelled() => {
532                        warn!("Turn cancelled by user");
533                        cancelled = true;
534                        // Abort any in-flight streaming tool handles.
535                        for (_, _, handle) in streaming_tool_handles.drain(..) {
536                            handle.abort();
537                        }
538                        break;
539                    }
540                }
541            }
542
543            if cancelled {
544                sink.on_warning("Cancelled");
545                self.state.is_query_active = false;
546                return Ok(());
547            }
548
549            // Step 5: Record the assistant message.
550            let assistant_msg = Message::Assistant(AssistantMessage {
551                uuid: Uuid::new_v4(),
552                timestamp: chrono::Utc::now().to_rfc3339(),
553                content: content_blocks.clone(),
554                model: Some(model.clone()),
555                usage: Some(usage.clone()),
556                stop_reason: stop_reason.clone(),
557                request_id: None,
558            });
559            self.state.push_message(assistant_msg);
560            self.state.record_usage(&usage, &model);
561
562            // Token budget tracking per turn.
563            if self.state.config.features.token_budget && usage.total() > 0 {
564                let turn_total = usage.input_tokens + usage.output_tokens;
565                if turn_total > 100_000 {
566                    sink.on_warning(&format!(
567                        "High token usage this turn: {} tokens ({}in + {}out)",
568                        turn_total, usage.input_tokens, usage.output_tokens
569                    ));
570                }
571            }
572
573            // Record cache and telemetry.
574            let _cache_event = self.cache_tracker.record(&usage);
575            {
576                let mut span = crate::services::telemetry::api_call_span(
577                    &model,
578                    turn + 1,
579                    &self.state.session_id,
580                );
581                crate::services::telemetry::record_usage(&mut span, &usage);
582                span.finish();
583                tracing::debug!(
584                    "API call: {}ms, {}in/{}out tokens",
585                    span.duration_ms().unwrap_or(0),
586                    usage.input_tokens,
587                    usage.output_tokens,
588                );
589            }
590
591            // Step 6: Handle stream errors.
592            if got_error {
593                // Check if it's a prompt-too-long error in the stream.
594                if error_text.contains("prompt is too long")
595                    || error_text.contains("Prompt is too long")
596                {
597                    let freed = compact::microcompact(&mut self.state.messages, 1);
598                    if freed > 0 {
599                        sink.on_compact(freed);
600                        continue;
601                    }
602                }
603
604                // Check for max-output-tokens hit (partial response).
605                if content_blocks
606                    .iter()
607                    .any(|b| matches!(b, ContentBlock::Text { .. }))
608                    && error_text.contains("max_tokens")
609                    && max_output_recovery_count < MAX_OUTPUT_TOKENS_RECOVERY_LIMIT
610                {
611                    max_output_recovery_count += 1;
612                    info!(
613                        "Max output tokens recovery attempt {}/{}",
614                        max_output_recovery_count, MAX_OUTPUT_TOKENS_RECOVERY_LIMIT
615                    );
616                    let recovery_msg = compact::max_output_recovery_message();
617                    self.state.push_message(recovery_msg);
618                    continue;
619                }
620            }
621
622            // Step 6b: Handle max_tokens stop reason (escalate and continue).
623            if matches!(stop_reason, Some(StopReason::MaxTokens))
624                && !got_error
625                && content_blocks
626                    .iter()
627                    .any(|b| matches!(b, ContentBlock::Text { .. }))
628                && max_output_recovery_count < MAX_OUTPUT_TOKENS_RECOVERY_LIMIT
629            {
630                max_output_recovery_count += 1;
631                info!(
632                    "Max tokens stop reason — recovery attempt {}/{}",
633                    max_output_recovery_count, MAX_OUTPUT_TOKENS_RECOVERY_LIMIT
634                );
635                let recovery_msg = compact::max_output_recovery_message();
636                self.state.push_message(recovery_msg);
637                continue;
638            }
639
640            // Step 7: Extract tool calls from the response.
641            let tool_calls = extract_tool_calls(&content_blocks);
642
643            if tool_calls.is_empty() {
644                // No tools requested — turn is complete.
645                info!("Turn complete (no tool calls)");
646                sink.on_turn_complete(turn + 1);
647                self.state.is_query_active = false;
648
649                // Fire background memory extraction (fire-and-forget).
650                // Only runs if feature enabled and memory directory exists.
651                if self.state.config.features.extract_memories
652                    && crate::memory::ensure_memory_dir().is_some()
653                {
654                    let extraction_messages = self.state.messages.clone();
655                    let extraction_state = self.extraction_state.clone();
656                    let extraction_llm = self.llm.clone();
657                    let extraction_model = model.clone();
658                    tokio::spawn(async move {
659                        crate::memory::extraction::extract_memories_background(
660                            extraction_messages,
661                            extraction_state,
662                            extraction_llm,
663                            extraction_model,
664                        )
665                        .await;
666                    });
667                }
668
669                return Ok(());
670            }
671
672            // Step 8: Execute tool calls with pre/post hooks.
673            info!("Executing {} tool call(s)", tool_calls.len());
674            let cwd = PathBuf::from(&self.state.cwd);
675            let tool_ctx = ToolContext {
676                cwd: cwd.clone(),
677                cancel: self.cancel.clone(),
678                permission_checker: self.permissions.clone(),
679                verbose: self.config.verbose,
680                plan_mode: self.state.plan_mode,
681                file_cache: Some(self.file_cache.clone()),
682                denial_tracker: Some(self.denial_tracker.clone()),
683                task_manager: Some(self.state.task_manager.clone()),
684                session_allows: Some(self.session_allows.clone()),
685                permission_prompter: self.permission_prompter.clone(),
686                sandbox: Some(std::sync::Arc::new(
687                    crate::sandbox::SandboxExecutor::from_session_config(&self.state.config, &cwd),
688                )),
689            };
690
691            // Collect streaming tool results first.
692            let streaming_ids: std::collections::HashSet<String> = streaming_tool_handles
693                .iter()
694                .map(|(id, _, _)| id.clone())
695                .collect();
696
697            let mut streaming_results = Vec::new();
698            for (id, name, handle) in streaming_tool_handles.drain(..) {
699                match handle.await {
700                    Ok(result) => streaming_results.push(crate::tools::executor::ToolCallResult {
701                        tool_use_id: id,
702                        tool_name: name,
703                        result,
704                    }),
705                    Err(e) => streaming_results.push(crate::tools::executor::ToolCallResult {
706                        tool_use_id: id,
707                        tool_name: name,
708                        result: crate::tools::ToolResult::error(format!("Task failed: {e}")),
709                    }),
710                }
711            }
712
713            // Fire pre-tool-use hooks.
714            for call in &tool_calls {
715                self.hooks
716                    .run_hooks(&HookEvent::PreToolUse, Some(&call.name), &call.input)
717                    .await;
718            }
719
720            // Execute remaining tools (ones not started during streaming).
721            let remaining_calls: Vec<_> = tool_calls
722                .iter()
723                .filter(|c| !streaming_ids.contains(&c.id))
724                .cloned()
725                .collect();
726
727            let mut results = streaming_results;
728            if !remaining_calls.is_empty() {
729                let batch_results = execute_tool_calls(
730                    &remaining_calls,
731                    self.tools.all(),
732                    &tool_ctx,
733                    &self.permissions,
734                )
735                .await;
736                results.extend(batch_results);
737            }
738
739            // Step 9: Inject tool results + fire post-tool-use hooks.
740            for result in &results {
741                // Handle plan mode state transitions.
742                if !result.result.is_error {
743                    match result.tool_name.as_str() {
744                        "EnterPlanMode" => {
745                            self.state.plan_mode = true;
746                            info!("Plan mode enabled");
747                        }
748                        "ExitPlanMode" => {
749                            self.state.plan_mode = false;
750                            info!("Plan mode disabled");
751                        }
752                        _ => {}
753                    }
754                }
755
756                sink.on_tool_result(&result.tool_name, &result.result);
757
758                // Fire post-tool-use hooks.
759                self.hooks
760                    .run_hooks(
761                        &HookEvent::PostToolUse,
762                        Some(&result.tool_name),
763                        &serde_json::json!({
764                            "tool": result.tool_name,
765                            "is_error": result.result.is_error,
766                        }),
767                    )
768                    .await;
769
770                let msg = tool_result_message(
771                    &result.tool_use_id,
772                    &result.result.content,
773                    result.result.is_error,
774                );
775                self.state.push_message(msg);
776            }
777
778            // Continue the loop — the model will see the tool results.
779        }
780
781        warn!("Max turns ({max_turns}) reached");
782        sink.on_warning(&format!("Agent stopped after {max_turns} turns"));
783        self.state.is_query_active = false;
784        Ok(())
785    }
786
787    /// Cancel the current operation.
788    pub fn cancel(&self) {
789        self.cancel.cancel();
790    }
791
792    /// Get a cloneable cancel token for use in background tasks.
793    pub fn cancel_token(&self) -> tokio_util::sync::CancellationToken {
794        self.cancel.clone()
795    }
796}
797
798/// Get a fallback model (smaller/cheaper) for retry on overload.
799fn get_fallback_model(current: &str) -> String {
800    let lower = current.to_lowercase();
801    if lower.contains("opus") {
802        // Opus → Sonnet
803        current.replace("opus", "sonnet")
804    } else if (lower.contains("gpt-5.4") || lower.contains("gpt-4.1"))
805        && !lower.contains("mini")
806        && !lower.contains("nano")
807    {
808        format!("{current}-mini")
809    } else if lower.contains("large") {
810        current.replace("large", "small")
811    } else {
812        // Already a small model, keep it.
813        current.to_string()
814    }
815}
816
817/// Build the system prompt from tool definitions, app state, and memory.
818pub fn build_system_prompt(tools: &ToolRegistry, state: &AppState) -> String {
819    let mut prompt = String::new();
820
821    prompt.push_str(
822        "You are an AI coding agent. You help users with software engineering tasks \
823         by reading, writing, and searching code. Use the tools available to you to \
824         accomplish tasks.\n\n",
825    );
826
827    // Environment context.
828    let shell = std::env::var("SHELL").unwrap_or_else(|_| "bash".to_string());
829    let is_git = std::path::Path::new(&state.cwd).join(".git").exists();
830    prompt.push_str(&format!(
831        "# Environment\n\
832         - Working directory: {}\n\
833         - Platform: {}\n\
834         - Shell: {shell}\n\
835         - Git repository: {}\n\n",
836        state.cwd,
837        std::env::consts::OS,
838        if is_git { "yes" } else { "no" },
839    ));
840
841    // Inject memory context (project + user + on-demand relevant).
842    let mut memory = crate::memory::MemoryContext::load(Some(std::path::Path::new(&state.cwd)));
843
844    // On-demand: surface relevant memories based on recent conversation.
845    let recent_text: String = state
846        .messages
847        .iter()
848        .rev()
849        .take(5)
850        .filter_map(|m| match m {
851            crate::llm::message::Message::User(u) => Some(
852                u.content
853                    .iter()
854                    .filter_map(|b| b.as_text())
855                    .collect::<Vec<_>>()
856                    .join(" "),
857            ),
858            _ => None,
859        })
860        .collect::<Vec<_>>()
861        .join(" ");
862
863    if !recent_text.is_empty() {
864        memory.load_relevant(&recent_text);
865    }
866
867    let memory_section = memory.to_system_prompt_section();
868    if !memory_section.is_empty() {
869        prompt.push_str(&memory_section);
870    }
871
872    // Tool documentation.
873    prompt.push_str("# Available Tools\n\n");
874    for tool in tools.all() {
875        if tool.is_enabled() {
876            prompt.push_str(&format!("## {}\n{}\n\n", tool.name(), tool.prompt()));
877        }
878    }
879
880    // Available skills.
881    let skills = crate::skills::SkillRegistry::load_all(Some(std::path::Path::new(&state.cwd)));
882    let invocable = skills.user_invocable();
883    if !invocable.is_empty() {
884        prompt.push_str("# Available Skills\n\n");
885        for skill in invocable {
886            let desc = skill.metadata.description.as_deref().unwrap_or("");
887            let when = skill.metadata.when_to_use.as_deref().unwrap_or("");
888            prompt.push_str(&format!("- `/{}`", skill.name));
889            if !desc.is_empty() {
890                prompt.push_str(&format!(": {desc}"));
891            }
892            if !when.is_empty() {
893                prompt.push_str(&format!(" (use when: {when})"));
894            }
895            prompt.push('\n');
896        }
897        prompt.push('\n');
898    }
899
900    // Guidelines and safety framework.
901    prompt.push_str(
902        "# Using tools\n\n\
903         Use dedicated tools instead of shell commands when available:\n\
904         - File search: Glob (not find or ls)\n\
905         - Content search: Grep (not grep or rg)\n\
906         - Read files: FileRead (not cat/head/tail)\n\
907         - Edit files: FileEdit (not sed/awk)\n\
908         - Write files: FileWrite (not echo/cat with redirect)\n\
909         - Reserve Bash for system commands and operations that require shell execution.\n\
910         - Break complex tasks into steps. Use multiple tool calls in parallel when independent.\n\
911         - Use the Agent tool for complex multi-step research or tasks that benefit from isolation.\n\n\
912         # Working with code\n\n\
913         - Read files before editing them. Understand existing code before suggesting changes.\n\
914         - Prefer editing existing files over creating new ones to avoid file bloat.\n\
915         - Only make changes that were requested. Don't add features, refactor, add comments, \
916           or make \"improvements\" beyond the ask.\n\
917         - Don't add error handling for scenarios that can't happen. Don't design for \
918           hypothetical future requirements.\n\
919         - When referencing code, include file_path:line_number.\n\
920         - Be careful not to introduce security vulnerabilities (command injection, XSS, SQL injection, \
921           OWASP top 10). If you notice insecure code you wrote, fix it immediately.\n\
922         - Don't add docstrings, comments, or type annotations to code you didn't change.\n\
923         - Three similar lines of code is better than a premature abstraction.\n\n\
924         # Git safety protocol\n\n\
925         - NEVER update the git config.\n\
926         - NEVER run destructive git commands (push --force, reset --hard, checkout ., restore ., \
927           clean -f, branch -D) unless the user explicitly requests them.\n\
928         - NEVER skip hooks (--no-verify, --no-gpg-sign) unless the user explicitly requests it.\n\
929         - NEVER force push to main/master. Warn the user if they request it.\n\
930         - Always create NEW commits rather than amending, unless the user explicitly requests amend. \
931           After hook failure, the commit did NOT happen — amend would modify the PREVIOUS commit.\n\
932         - When staging files, prefer adding specific files by name rather than git add -A or git add ., \
933           which can accidentally include sensitive files.\n\
934         - NEVER commit changes unless the user explicitly asks.\n\n\
935         # Committing changes\n\n\
936         When the user asks to commit:\n\
937         1. Run git status and git diff to see all changes.\n\
938         2. Run git log --oneline -5 to match the repository's commit message style.\n\
939         3. Draft a concise (1-2 sentence) commit message focusing on \"why\" not \"what\".\n\
940         4. Do not commit files that likely contain secrets (.env, credentials.json).\n\
941         5. Stage specific files, create the commit.\n\
942         6. If pre-commit hook fails, fix the issue and create a NEW commit.\n\
943         7. When creating commits, include a co-author attribution line at the end of the message.\n\n\
944         # Creating pull requests\n\n\
945         When the user asks to create a PR:\n\
946         1. Run git status, git diff, and git log to understand all changes on the branch.\n\
947         2. Analyze ALL commits (not just the latest) that will be in the PR.\n\
948         3. Draft a short title (under 70 chars) and detailed body with summary and test plan.\n\
949         4. Push to remote with -u flag if needed, then create PR using gh pr create.\n\
950         5. Return the PR URL when done.\n\n\
951         # Executing actions safely\n\n\
952         Consider the reversibility and blast radius of every action:\n\
953         - Freely take local, reversible actions (editing files, running tests).\n\
954         - For hard-to-reverse or shared-state actions, confirm with the user first:\n\
955           - Destructive: deleting files/branches, dropping tables, rm -rf, overwriting uncommitted changes.\n\
956           - Hard to reverse: force-pushing, git reset --hard, amending published commits.\n\
957           - Visible to others: pushing code, creating/commenting on PRs/issues, sending messages.\n\
958         - When you encounter an obstacle, do not use destructive actions as a shortcut. \
959           Identify root causes and fix underlying issues.\n\
960         - If you discover unexpected state (unfamiliar files, branches, config), investigate \
961           before deleting or overwriting — it may be the user's in-progress work.\n\n\
962         # Response style\n\n\
963         - Be concise. Lead with the answer or action, not the reasoning.\n\
964         - Skip filler, preamble, and unnecessary transitions.\n\
965         - Don't restate what the user said.\n\
966         - If you can say it in one sentence, don't use three.\n\
967         - Focus output on: decisions that need input, status updates, and errors that change the plan.\n\
968         - When referencing GitHub issues or PRs, use owner/repo#123 format.\n\
969         - Only use emojis if the user explicitly requests it.\n\n\
970         # Memory\n\n\
971         You can save information across sessions by writing memory files.\n\
972         - Save to: ~/.config/agent-code/memory/ (one .md file per topic)\n\
973         - Each file needs YAML frontmatter: name, description, type (user/feedback/project/reference)\n\
974         - After writing a file, update MEMORY.md with a one-line pointer\n\
975         - Memory types: user (role, preferences), feedback (corrections, confirmations), \
976           project (decisions, deadlines), reference (external resources)\n\
977         - Do NOT store: code patterns, git history, debugging solutions, anything derivable from code\n\
978         - Memory is a hint — always verify against current state before acting on it\n",
979    );
980
981    // Detailed tool usage examples and workflow patterns.
982    prompt.push_str(
983        "# Tool usage patterns\n\n\
984         Common patterns for effective tool use:\n\n\
985         **Read before edit**: Always read a file before editing it. This ensures you \
986         understand the current state and can make targeted changes.\n\
987         ```\n\
988         1. FileRead file_path → understand structure\n\
989         2. FileEdit old_string, new_string → targeted change\n\
990         ```\n\n\
991         **Search then act**: Use Glob to find files, Grep to find content, then read/edit.\n\
992         ```\n\
993         1. Glob **/*.rs → find Rust files\n\
994         2. Grep pattern path → find specific code\n\
995         3. FileRead → read the match\n\
996         4. FileEdit → make the change\n\
997         ```\n\n\
998         **Parallel tool calls**: When you need to read multiple independent files or run \
999         independent searches, make all the tool calls in one response. Don't serialize \
1000         independent operations.\n\n\
1001         **Test after change**: After editing code, run tests to verify the change works.\n\
1002         ```\n\
1003         1. FileEdit → make change\n\
1004         2. Bash cargo test / pytest / npm test → verify\n\
1005         3. If tests fail, read the error, fix, re-test\n\
1006         ```\n\n\
1007         # Error recovery\n\n\
1008         When something goes wrong:\n\
1009         - **Tool not found**: Use ToolSearch to find the right tool name.\n\
1010         - **Permission denied**: Explain why the action is needed, ask the user to approve.\n\
1011         - **File not found**: Use Glob to find the correct path. Check for typos.\n\
1012         - **Edit failed (not unique)**: Provide more surrounding context in old_string, \
1013           or use replace_all=true if renaming.\n\
1014         - **Command failed**: Read the full error message. Don't retry the same command. \
1015           Diagnose the root cause first.\n\
1016         - **Context too large**: The system will auto-compact. If you need specific \
1017           information from before compaction, re-read the relevant files.\n\
1018         - **Rate limited**: The system will auto-retry with backoff. Just wait.\n\n\
1019         # Common workflows\n\n\
1020         **Bug fix**: Read the failing test → read the source code it tests → \
1021         identify the bug → fix it → run the test → confirm it passes.\n\n\
1022         **New feature**: Read existing patterns in the codebase → create or edit files → \
1023         add tests → run tests → update docs if needed.\n\n\
1024         **Code review**: Read the diff → identify issues (bugs, security, style) → \
1025         report findings with file:line references.\n\n\
1026         **Refactor**: Search for all usages of the symbol → plan the changes → \
1027         edit each file → run tests to verify nothing broke.\n\n",
1028    );
1029
1030    // MCP server instructions (dynamic, per-server).
1031    if !state.config.mcp_servers.is_empty() {
1032        prompt.push_str("# MCP Servers\n\n");
1033        prompt.push_str(
1034            "Connected MCP servers provide additional tools. MCP tools are prefixed \
1035             with `mcp__{server}__{tool}`. Use them like any other tool.\n\n",
1036        );
1037        for (name, entry) in &state.config.mcp_servers {
1038            let transport = if entry.command.is_some() {
1039                "stdio"
1040            } else if entry.url.is_some() {
1041                "sse"
1042            } else {
1043                "unknown"
1044            };
1045            prompt.push_str(&format!("- **{name}** ({transport})\n"));
1046        }
1047        prompt.push('\n');
1048    }
1049
1050    // Deferred tools listing.
1051    let deferred = tools.deferred_names();
1052    if !deferred.is_empty() {
1053        prompt.push_str("# Deferred Tools\n\n");
1054        prompt.push_str(
1055            "These tools are available but not loaded by default. \
1056             Use ToolSearch to load them when needed:\n",
1057        );
1058        for name in &deferred {
1059            prompt.push_str(&format!("- {name}\n"));
1060        }
1061        prompt.push('\n');
1062    }
1063
1064    // Task management guidance.
1065    prompt.push_str(
1066        "# Task management\n\n\
1067         - Use TaskCreate to break complex work into trackable steps.\n\
1068         - Mark tasks as in_progress when starting, completed when done.\n\
1069         - Use the Agent tool to spawn subagents for parallel independent work.\n\
1070         - Use EnterPlanMode/ExitPlanMode for read-only exploration before making changes.\n\
1071         - Use EnterWorktree/ExitWorktree for isolated changes in git worktrees.\n\n\
1072         # Output formatting\n\n\
1073         - All text output is displayed to the user. Use GitHub-flavored markdown.\n\
1074         - Use fenced code blocks with language hints for code: ```rust, ```python, etc.\n\
1075         - Use inline `code` for file names, function names, and short code references.\n\
1076         - Use tables for structured comparisons.\n\
1077         - Use bullet lists for multiple items.\n\
1078         - Keep paragraphs short (2-3 sentences).\n\
1079         - Never output raw HTML or complex formatting — stick to standard markdown.\n",
1080    );
1081
1082    prompt
1083}
1084
1085#[cfg(test)]
1086mod tests {
1087    use super::*;
1088
1089    /// Verify that cancelling via the shared handle cancels the current
1090    /// turn's token (regression: the signal handler previously held a
1091    /// stale clone that couldn't cancel subsequent turns).
1092    #[test]
1093    fn cancel_shared_propagates_to_current_token() {
1094        let root = CancellationToken::new();
1095        let shared = Arc::new(std::sync::Mutex::new(root.clone()));
1096
1097        // Simulate turn reset: create a new token and update the shared handle.
1098        let turn1 = CancellationToken::new();
1099        *shared.lock().unwrap() = turn1.clone();
1100
1101        // Cancelling via the shared handle should cancel turn1.
1102        shared.lock().unwrap().cancel();
1103        assert!(turn1.is_cancelled());
1104
1105        // New turn: replace the token. The old cancellation shouldn't affect it.
1106        let turn2 = CancellationToken::new();
1107        *shared.lock().unwrap() = turn2.clone();
1108        assert!(!turn2.is_cancelled());
1109
1110        // Cancelling via shared should cancel turn2.
1111        shared.lock().unwrap().cancel();
1112        assert!(turn2.is_cancelled());
1113    }
1114
1115    /// Verify that the streaming loop breaks on cancellation by simulating
1116    /// the select pattern used in run_turn_with_sink.
1117    #[tokio::test]
1118    async fn stream_loop_responds_to_cancellation() {
1119        let cancel = CancellationToken::new();
1120        let (tx, mut rx) = tokio::sync::mpsc::channel::<StreamEvent>(10);
1121
1122        // Simulate a slow stream: send one event, then cancel before more arrive.
1123        tx.send(StreamEvent::TextDelta("hello".into()))
1124            .await
1125            .unwrap();
1126
1127        let cancel2 = cancel.clone();
1128        tokio::spawn(async move {
1129            // Small delay, then cancel.
1130            tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1131            cancel2.cancel();
1132        });
1133
1134        let mut events_received = 0u32;
1135        let mut cancelled = false;
1136
1137        loop {
1138            tokio::select! {
1139                event = rx.recv() => {
1140                    match event {
1141                        Some(_) => events_received += 1,
1142                        None => break,
1143                    }
1144                }
1145                _ = cancel.cancelled() => {
1146                    cancelled = true;
1147                    break;
1148                }
1149            }
1150        }
1151
1152        assert!(cancelled, "Loop should have been cancelled");
1153        assert_eq!(
1154            events_received, 1,
1155            "Should have received exactly one event before cancel"
1156        );
1157    }
1158
1159    // ------------------------------------------------------------------
1160    // End-to-end regression tests for #103.
1161    //
1162    // These tests build a real QueryEngine with a mock Provider and
1163    // exercise run_turn_with_sink directly, verifying that cancellation
1164    // actually interrupts the streaming loop (not just the select!
1165    // pattern in isolation).
1166    // ------------------------------------------------------------------
1167
1168    use crate::llm::provider::{Provider, ProviderError, ProviderRequest};
1169
1170    /// A provider whose stream yields one TextDelta and then hangs forever.
1171    /// Simulates the real bug: a slow LLM response the user wants to interrupt.
1172    struct HangingProvider;
1173
1174    #[async_trait::async_trait]
1175    impl Provider for HangingProvider {
1176        fn name(&self) -> &str {
1177            "hanging-mock"
1178        }
1179
1180        async fn stream(
1181            &self,
1182            _request: &ProviderRequest,
1183        ) -> Result<tokio::sync::mpsc::Receiver<StreamEvent>, ProviderError> {
1184            let (tx, rx) = tokio::sync::mpsc::channel(4);
1185            tokio::spawn(async move {
1186                let _ = tx.send(StreamEvent::TextDelta("thinking...".into())).await;
1187                // Hang forever without closing the channel or sending Done.
1188                let _tx_holder = tx;
1189                std::future::pending::<()>().await;
1190            });
1191            Ok(rx)
1192        }
1193    }
1194
1195    /// A provider that completes a turn normally: emits text and a Done event.
1196    struct CompletingProvider;
1197
1198    #[async_trait::async_trait]
1199    impl Provider for CompletingProvider {
1200        fn name(&self) -> &str {
1201            "completing-mock"
1202        }
1203
1204        async fn stream(
1205            &self,
1206            _request: &ProviderRequest,
1207        ) -> Result<tokio::sync::mpsc::Receiver<StreamEvent>, ProviderError> {
1208            let (tx, rx) = tokio::sync::mpsc::channel(8);
1209            tokio::spawn(async move {
1210                let _ = tx.send(StreamEvent::TextDelta("hello".into())).await;
1211                let _ = tx
1212                    .send(StreamEvent::ContentBlockComplete(ContentBlock::Text {
1213                        text: "hello".into(),
1214                    }))
1215                    .await;
1216                let _ = tx
1217                    .send(StreamEvent::Done {
1218                        usage: Usage::default(),
1219                        stop_reason: Some(StopReason::EndTurn),
1220                    })
1221                    .await;
1222                // Channel closes when tx drops.
1223            });
1224            Ok(rx)
1225        }
1226    }
1227
1228    fn build_engine(llm: Arc<dyn Provider>) -> QueryEngine {
1229        use crate::config::Config;
1230        use crate::permissions::PermissionChecker;
1231        use crate::state::AppState;
1232        use crate::tools::registry::ToolRegistry;
1233
1234        let config = Config::default();
1235        let permissions = PermissionChecker::from_config(&config.permissions);
1236        let state = AppState::new(config);
1237
1238        QueryEngine::new(
1239            llm,
1240            ToolRegistry::default_tools(),
1241            permissions,
1242            state,
1243            QueryEngineConfig {
1244                max_turns: Some(1),
1245                verbose: false,
1246                unattended: true,
1247            },
1248        )
1249    }
1250
1251    /// Schedule a cancellation after `delay_ms` via the shared handle
1252    /// (same path the signal handler uses).
1253    fn schedule_cancel(engine: &QueryEngine, delay_ms: u64) {
1254        let shared = engine.cancel_shared.clone();
1255        tokio::spawn(async move {
1256            tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
1257            shared.lock().unwrap().cancel();
1258        });
1259    }
1260
1261    /// Builds a mock provider whose stream yields one TextDelta and then hangs.
1262    /// Verifies the turn returns promptly once cancel fires.
1263    #[tokio::test]
1264    async fn run_turn_with_sink_interrupts_on_cancel() {
1265        let mut engine = build_engine(Arc::new(HangingProvider));
1266        schedule_cancel(&engine, 100);
1267
1268        let result = tokio::time::timeout(
1269            std::time::Duration::from_secs(5),
1270            engine.run_turn_with_sink("test input", &NullSink),
1271        )
1272        .await;
1273
1274        assert!(
1275            result.is_ok(),
1276            "run_turn_with_sink should return promptly on cancel, not hang"
1277        );
1278        assert!(
1279            result.unwrap().is_ok(),
1280            "cancelled turn should return Ok(()), not an error"
1281        );
1282        assert!(
1283            !engine.state().is_query_active,
1284            "is_query_active should be reset after cancel"
1285        );
1286    }
1287
1288    /// Regression test for the original #103 bug: the signal handler held
1289    /// a stale clone of the cancellation token, so Ctrl+C only worked on
1290    /// the *first* turn. This test cancels turn 1, then runs turn 2 and
1291    /// verifies it is ALSO cancellable via the same shared handle.
1292    #[tokio::test]
1293    async fn cancel_works_across_multiple_turns() {
1294        let mut engine = build_engine(Arc::new(HangingProvider));
1295
1296        // Turn 1: cancel mid-stream.
1297        schedule_cancel(&engine, 80);
1298        let r1 = tokio::time::timeout(
1299            std::time::Duration::from_secs(5),
1300            engine.run_turn_with_sink("turn 1", &NullSink),
1301        )
1302        .await;
1303        assert!(r1.is_ok(), "turn 1 should cancel promptly");
1304        assert!(!engine.state().is_query_active);
1305
1306        // Turn 2: cancel again via the same shared handle.
1307        // With the pre-fix stale-token bug, the handle would be pointing
1308        // at turn 1's already-used token and turn 2 would hang forever.
1309        schedule_cancel(&engine, 80);
1310        let r2 = tokio::time::timeout(
1311            std::time::Duration::from_secs(5),
1312            engine.run_turn_with_sink("turn 2", &NullSink),
1313        )
1314        .await;
1315        assert!(
1316            r2.is_ok(),
1317            "turn 2 should also cancel promptly — regression would hang here"
1318        );
1319        assert!(!engine.state().is_query_active);
1320
1321        // Turn 3: one more for good measure.
1322        schedule_cancel(&engine, 80);
1323        let r3 = tokio::time::timeout(
1324            std::time::Duration::from_secs(5),
1325            engine.run_turn_with_sink("turn 3", &NullSink),
1326        )
1327        .await;
1328        assert!(r3.is_ok(), "turn 3 should still be cancellable");
1329        assert!(!engine.state().is_query_active);
1330    }
1331
1332    /// Verifies that a previously-cancelled token does not poison subsequent
1333    /// turns. A fresh run_turn_with_sink on the same engine should complete
1334    /// normally even after a prior cancel.
1335    #[tokio::test]
1336    async fn cancel_does_not_poison_next_turn() {
1337        // Turn 1: hangs and gets cancelled.
1338        let mut engine = build_engine(Arc::new(HangingProvider));
1339        schedule_cancel(&engine, 80);
1340        let _ = tokio::time::timeout(
1341            std::time::Duration::from_secs(5),
1342            engine.run_turn_with_sink("turn 1", &NullSink),
1343        )
1344        .await
1345        .expect("turn 1 should cancel");
1346
1347        // Swap the provider to one that completes normally by rebuilding
1348        // the engine (we can't swap llm on an existing engine, so this
1349        // simulates the isolated "fresh turn" behavior). The key property
1350        // being tested is that the per-turn cancel reset correctly
1351        // initializes a non-cancelled token.
1352        let mut engine2 = build_engine(Arc::new(CompletingProvider));
1353
1354        // Pre-cancel engine2 to simulate a leftover cancelled state, then
1355        // verify run_turn_with_sink still works because it resets the token.
1356        engine2.cancel_shared.lock().unwrap().cancel();
1357
1358        let result = tokio::time::timeout(
1359            std::time::Duration::from_secs(5),
1360            engine2.run_turn_with_sink("hello", &NullSink),
1361        )
1362        .await;
1363
1364        assert!(result.is_ok(), "completing turn should not hang");
1365        assert!(
1366            result.unwrap().is_ok(),
1367            "turn should succeed — the stale cancel flag must be cleared on turn start"
1368        );
1369        // Message history should contain the user + assistant messages.
1370        assert!(
1371            engine2.state().messages.len() >= 2,
1372            "normal turn should push both user and assistant messages"
1373        );
1374    }
1375
1376    /// Verifies that cancelling BEFORE any event arrives still interrupts
1377    /// the turn cleanly (edge case: cancellation races with the first recv).
1378    #[tokio::test]
1379    async fn cancel_before_first_event_interrupts_cleanly() {
1380        let mut engine = build_engine(Arc::new(HangingProvider));
1381        // Very short delay so cancel likely fires before or during the
1382        // first event. The test is tolerant of either ordering.
1383        schedule_cancel(&engine, 1);
1384
1385        let result = tokio::time::timeout(
1386            std::time::Duration::from_secs(5),
1387            engine.run_turn_with_sink("immediate", &NullSink),
1388        )
1389        .await;
1390
1391        assert!(result.is_ok(), "early cancel should not hang");
1392        assert!(result.unwrap().is_ok());
1393        assert!(!engine.state().is_query_active);
1394    }
1395
1396    /// Verifies the sink receives cancellation feedback via on_warning.
1397    #[tokio::test]
1398    async fn cancelled_turn_emits_warning_to_sink() {
1399        use std::sync::Mutex;
1400
1401        /// Captures sink events for assertion.
1402        struct CapturingSink {
1403            warnings: Mutex<Vec<String>>,
1404        }
1405
1406        impl StreamSink for CapturingSink {
1407            fn on_text(&self, _: &str) {}
1408            fn on_tool_start(&self, _: &str, _: &serde_json::Value) {}
1409            fn on_tool_result(&self, _: &str, _: &crate::tools::ToolResult) {}
1410            fn on_error(&self, _: &str) {}
1411            fn on_warning(&self, msg: &str) {
1412                self.warnings.lock().unwrap().push(msg.to_string());
1413            }
1414        }
1415
1416        let sink = CapturingSink {
1417            warnings: Mutex::new(Vec::new()),
1418        };
1419
1420        let mut engine = build_engine(Arc::new(HangingProvider));
1421        schedule_cancel(&engine, 100);
1422
1423        let _ = tokio::time::timeout(
1424            std::time::Duration::from_secs(5),
1425            engine.run_turn_with_sink("test", &sink),
1426        )
1427        .await
1428        .expect("should not hang");
1429
1430        let warnings = sink.warnings.lock().unwrap();
1431        assert!(
1432            warnings.iter().any(|w| w.contains("Cancelled")),
1433            "expected 'Cancelled' warning in sink, got: {:?}",
1434            *warnings
1435        );
1436    }
1437}