Skip to main content

agent_code_lib/query/
mod.rs

1//! Query engine: the core agent loop.
2//!
3//! Implements the agentic cycle:
4//!
5//! 1. Auto-compact if context nears the window limit
6//! 2. Microcompact stale tool results
7//! 3. Call LLM with streaming
8//! 4. Accumulate response content blocks
9//! 5. Handle errors (prompt-too-long, rate limits, max-output-tokens)
10//! 6. Extract tool_use blocks
11//! 7. Execute tools (concurrent/serial batching)
12//! 8. Inject tool results into history
13//! 9. Repeat from step 1 until no tool_use or max turns
14
15pub mod source;
16
17use std::path::PathBuf;
18use std::sync::Arc;
19
20use tokio_util::sync::CancellationToken;
21use tracing::{debug, info, warn};
22use uuid::Uuid;
23
24use crate::hooks::{HookEvent, HookRegistry};
25use crate::llm::message::*;
26use crate::llm::provider::{Provider, ProviderError, ProviderRequest};
27use crate::llm::stream::StreamEvent;
28use crate::permissions::PermissionChecker;
29use crate::services::compact::{self, CompactTracking, MAX_OUTPUT_TOKENS_RECOVERY_LIMIT};
30use crate::services::tokens;
31use crate::state::AppState;
32use crate::tools::ToolContext;
33use crate::tools::executor::{execute_tool_calls, extract_tool_calls};
34use crate::tools::registry::ToolRegistry;
35
36/// Configuration for the query engine.
37pub struct QueryEngineConfig {
38    pub max_turns: Option<usize>,
39    pub verbose: bool,
40    /// Whether this is a non-interactive (one-shot) session.
41    pub unattended: bool,
42}
43
44/// The query engine orchestrates the agent loop.
45///
46/// Central coordinator that drives the LLM → tools → LLM cycle.
47/// Manages conversation history, context compaction, tool execution,
48/// error recovery, and hook dispatch. Create via [`QueryEngine::new`].
49pub struct QueryEngine {
50    llm: Arc<dyn Provider>,
51    tools: ToolRegistry,
52    file_cache: Arc<tokio::sync::Mutex<crate::services::file_cache::FileCache>>,
53    permissions: Arc<PermissionChecker>,
54    state: AppState,
55    config: QueryEngineConfig,
56    /// Shared handle so the signal handler always cancels the current token.
57    cancel_shared: Arc<std::sync::Mutex<CancellationToken>>,
58    /// Per-turn cancellation token (cloned from cancel_shared at turn start).
59    cancel: CancellationToken,
60    hooks: HookRegistry,
61    cache_tracker: crate::services::cache_tracking::CacheTracker,
62    denial_tracker: Arc<tokio::sync::Mutex<crate::permissions::tracking::DenialTracker>>,
63    extraction_state: Arc<tokio::sync::Mutex<crate::memory::extraction::ExtractionState>>,
64    session_allows: Arc<tokio::sync::Mutex<std::collections::HashSet<String>>>,
65    permission_prompter: Option<Arc<dyn crate::tools::PermissionPrompter>>,
66    /// Cached system prompt (rebuilt only when inputs change).
67    cached_system_prompt: Option<(u64, String)>, // (hash, prompt)
68}
69
70/// Callback for streaming events to the UI.
71pub trait StreamSink: Send + Sync {
72    fn on_text(&self, text: &str);
73    fn on_tool_start(&self, tool_name: &str, input: &serde_json::Value);
74    fn on_tool_result(&self, tool_name: &str, result: &crate::tools::ToolResult);
75    fn on_thinking(&self, _text: &str) {}
76    fn on_turn_complete(&self, _turn: usize) {}
77    fn on_error(&self, error: &str);
78    fn on_usage(&self, _usage: &Usage) {}
79    fn on_compact(&self, _freed_tokens: u64) {}
80    fn on_warning(&self, _msg: &str) {}
81}
82
83/// A no-op stream sink for non-interactive mode.
84pub struct NullSink;
85impl StreamSink for NullSink {
86    fn on_text(&self, _: &str) {}
87    fn on_tool_start(&self, _: &str, _: &serde_json::Value) {}
88    fn on_tool_result(&self, _: &str, _: &crate::tools::ToolResult) {}
89    fn on_error(&self, _: &str) {}
90}
91
92impl QueryEngine {
93    pub fn new(
94        llm: Arc<dyn Provider>,
95        tools: ToolRegistry,
96        permissions: PermissionChecker,
97        state: AppState,
98        config: QueryEngineConfig,
99    ) -> Self {
100        let cancel = CancellationToken::new();
101        let cancel_shared = Arc::new(std::sync::Mutex::new(cancel.clone()));
102        Self {
103            llm,
104            tools,
105            file_cache: Arc::new(tokio::sync::Mutex::new(
106                crate::services::file_cache::FileCache::new(),
107            )),
108            permissions: Arc::new(permissions),
109            state,
110            config,
111            cancel,
112            cancel_shared,
113            hooks: HookRegistry::new(),
114            cache_tracker: crate::services::cache_tracking::CacheTracker::new(),
115            denial_tracker: Arc::new(tokio::sync::Mutex::new(
116                crate::permissions::tracking::DenialTracker::new(100),
117            )),
118            extraction_state: Arc::new(tokio::sync::Mutex::new(
119                crate::memory::extraction::ExtractionState::new(),
120            )),
121            session_allows: Arc::new(tokio::sync::Mutex::new(std::collections::HashSet::new())),
122            permission_prompter: None,
123            cached_system_prompt: None,
124        }
125    }
126
127    /// Load hooks from configuration into the registry.
128    pub fn load_hooks(&mut self, hook_defs: &[crate::hooks::HookDefinition]) {
129        for def in hook_defs {
130            self.hooks.register(def.clone());
131        }
132        if !hook_defs.is_empty() {
133            tracing::info!("Loaded {} hooks from config", hook_defs.len());
134        }
135    }
136
137    /// Get a reference to the app state.
138    pub fn state(&self) -> &AppState {
139        &self.state
140    }
141
142    /// Get a mutable reference to the app state.
143    pub fn state_mut(&mut self) -> &mut AppState {
144        &mut self.state
145    }
146
147    /// Install a Ctrl+C handler that triggers the cancellation token.
148    /// Call this once at startup. Subsequent Ctrl+C signals during a
149    /// turn will cancel the active operation instead of killing the process.
150    pub fn install_signal_handler(&self) {
151        let shared = self.cancel_shared.clone();
152        tokio::spawn(async move {
153            let mut pending = false;
154            loop {
155                if tokio::signal::ctrl_c().await.is_ok() {
156                    let token = shared.lock().unwrap().clone();
157                    if token.is_cancelled() && pending {
158                        // Second Ctrl+C after cancel — hard exit.
159                        std::process::exit(130);
160                    }
161                    token.cancel();
162                    pending = true;
163                }
164            }
165        });
166    }
167
168    /// Run a single turn: process user input through the full agent loop.
169    pub async fn run_turn(&mut self, user_input: &str) -> crate::error::Result<()> {
170        self.run_turn_with_sink(user_input, &NullSink).await
171    }
172
173    /// Run a turn with a stream sink for real-time UI updates.
174    pub async fn run_turn_with_sink(
175        &mut self,
176        user_input: &str,
177        sink: &dyn StreamSink,
178    ) -> crate::error::Result<()> {
179        // Reset cancellation token for this turn. The shared handle is
180        // updated so the signal handler always cancels the current token.
181        self.cancel = CancellationToken::new();
182        *self.cancel_shared.lock().unwrap() = self.cancel.clone();
183
184        // Add the user message to history.
185        let user_msg = user_message(user_input);
186        self.state.push_message(user_msg);
187
188        let max_turns = self.config.max_turns.unwrap_or(50);
189        let mut compact_tracking = CompactTracking::default();
190        let mut retry_state = crate::llm::retry::RetryState::default();
191        let retry_config = crate::llm::retry::RetryConfig::default();
192        let mut max_output_recovery_count = 0u32;
193
194        // Agent loop: budget check → normalize → compact → call LLM → execute tools → repeat.
195        for turn in 0..max_turns {
196            self.state.turn_count = turn + 1;
197            self.state.is_query_active = true;
198
199            // Budget check before each turn.
200            let budget_config = crate::services::budget::BudgetConfig::default();
201            match crate::services::budget::check_budget(
202                self.state.total_cost_usd,
203                self.state.total_usage.total(),
204                &budget_config,
205            ) {
206                crate::services::budget::BudgetDecision::Stop { message } => {
207                    sink.on_warning(&message);
208                    self.state.is_query_active = false;
209                    return Ok(());
210                }
211                crate::services::budget::BudgetDecision::ContinueWithWarning {
212                    message, ..
213                } => {
214                    sink.on_warning(&message);
215                }
216                crate::services::budget::BudgetDecision::Continue => {}
217            }
218
219            // Normalize messages for API compatibility.
220            crate::llm::normalize::ensure_tool_result_pairing(&mut self.state.messages);
221            crate::llm::normalize::strip_empty_blocks(&mut self.state.messages);
222            crate::llm::normalize::remove_empty_messages(&mut self.state.messages);
223            crate::llm::normalize::cap_document_blocks(&mut self.state.messages, 500_000);
224            crate::llm::normalize::merge_consecutive_user_messages(&mut self.state.messages);
225
226            debug!("Agent turn {}/{}", turn + 1, max_turns);
227
228            let mut model = self.state.config.api.model.clone();
229
230            // Step 1: Auto-compact if context is too large.
231            if compact::should_auto_compact(self.state.history(), &model, &compact_tracking) {
232                let token_count = tokens::estimate_context_tokens(self.state.history());
233                let threshold = compact::auto_compact_threshold(&model);
234                info!("Auto-compact triggered: {token_count} tokens >= {threshold} threshold");
235
236                // Microcompact first: clear stale tool results.
237                let freed = compact::microcompact(&mut self.state.messages, 5);
238                if freed > 0 {
239                    sink.on_compact(freed);
240                    info!("Microcompact freed ~{freed} tokens");
241                }
242
243                // Check if microcompact was enough.
244                let post_mc_tokens = tokens::estimate_context_tokens(self.state.history());
245                if post_mc_tokens >= threshold {
246                    // Full LLM-based compaction: summarize older messages.
247                    info!("Microcompact insufficient, attempting LLM compaction");
248                    match compact::compact_with_llm(
249                        &mut self.state.messages,
250                        &*self.llm,
251                        &model,
252                        self.cancel.clone(),
253                    )
254                    .await
255                    {
256                        Some(removed) => {
257                            info!("LLM compaction removed {removed} messages");
258                            compact_tracking.was_compacted = true;
259                            compact_tracking.consecutive_failures = 0;
260                        }
261                        None => {
262                            compact_tracking.consecutive_failures += 1;
263                            warn!(
264                                "LLM compaction failed (attempt {})",
265                                compact_tracking.consecutive_failures
266                            );
267                            // Fallback: context collapse (snip middle messages).
268                            let effective = compact::effective_context_window(&model);
269                            if let Some(collapse) =
270                                crate::services::context_collapse::collapse_to_budget(
271                                    self.state.history(),
272                                    effective,
273                                )
274                            {
275                                info!(
276                                    "Context collapse snipped {} messages, freed ~{} tokens",
277                                    collapse.snipped_count, collapse.tokens_freed
278                                );
279                                self.state.messages = collapse.api_messages;
280                                sink.on_compact(collapse.tokens_freed);
281                            } else {
282                                // Last resort: aggressive microcompact.
283                                let freed2 = compact::microcompact(&mut self.state.messages, 2);
284                                if freed2 > 0 {
285                                    sink.on_compact(freed2);
286                                }
287                            }
288                        }
289                    }
290                }
291            }
292
293            // Inject compaction reminder if compacted and feature enabled.
294            if compact_tracking.was_compacted && self.state.config.features.compaction_reminders {
295                let reminder = user_message(
296                    "<system-reminder>Context was automatically compacted. \
297                     Earlier messages were summarized. If you need details from \
298                     before compaction, ask the user or re-read the relevant files.</system-reminder>",
299                );
300                self.state.push_message(reminder);
301                compact_tracking.was_compacted = false; // Only remind once per compaction.
302            }
303
304            // Step 2: Check token warning state.
305            let warning = compact::token_warning_state(self.state.history(), &model);
306            if warning.is_blocking {
307                sink.on_warning("Context window nearly full. Consider starting a new session.");
308            } else if warning.is_above_warning {
309                sink.on_warning(&format!("Context {}% remaining", warning.percent_left));
310            }
311
312            // Step 3: Build and send the API request.
313            // Memoize: only rebuild system prompt when inputs change.
314            let prompt_hash = {
315                use std::hash::{Hash, Hasher};
316                let mut h = std::collections::hash_map::DefaultHasher::new();
317                self.state.config.api.model.hash(&mut h);
318                self.state.cwd.hash(&mut h);
319                self.state.config.mcp_servers.len().hash(&mut h);
320                self.tools.all().len().hash(&mut h);
321                h.finish()
322            };
323            let system_prompt = if let Some((cached_hash, ref cached)) = self.cached_system_prompt
324                && cached_hash == prompt_hash
325            {
326                cached.clone()
327            } else {
328                let prompt = build_system_prompt(&self.tools, &self.state);
329                self.cached_system_prompt = Some((prompt_hash, prompt.clone()));
330                prompt
331            };
332            // Use core schemas (deferred tools loaded on demand via ToolSearch).
333            let tool_schemas = self.tools.core_schemas();
334
335            // Escalate max_tokens after a max_output recovery (8k → 64k).
336            let base_tokens = self.state.config.api.max_output_tokens.unwrap_or(16384);
337            let effective_tokens = if max_output_recovery_count > 0 {
338                base_tokens.max(65536) // Escalate to at least 64k after first recovery
339            } else {
340                base_tokens
341            };
342
343            let request = ProviderRequest {
344                messages: self.state.history().to_vec(),
345                system_prompt: system_prompt.clone(),
346                tools: tool_schemas.clone(),
347                model: model.clone(),
348                max_tokens: effective_tokens,
349                temperature: None,
350                enable_caching: self.state.config.features.prompt_caching,
351                tool_choice: Default::default(),
352                metadata: None,
353                cancel: self.cancel.clone(),
354            };
355
356            let mut rx = match self.llm.stream(&request).await {
357                Ok(rx) => {
358                    retry_state.reset();
359                    rx
360                }
361                Err(e) => {
362                    let retryable = match &e {
363                        ProviderError::RateLimited { retry_after_ms } => {
364                            crate::llm::retry::RetryableError::RateLimited {
365                                retry_after: *retry_after_ms,
366                            }
367                        }
368                        ProviderError::Overloaded => crate::llm::retry::RetryableError::Overloaded,
369                        ProviderError::Network(_) => {
370                            crate::llm::retry::RetryableError::StreamInterrupted
371                        }
372                        other => crate::llm::retry::RetryableError::NonRetryable(other.to_string()),
373                    };
374
375                    match retry_state.next_action(&retryable, &retry_config) {
376                        crate::llm::retry::RetryAction::Retry { after } => {
377                            warn!("Retrying in {}ms", after.as_millis());
378                            tokio::time::sleep(after).await;
379                            continue;
380                        }
381                        crate::llm::retry::RetryAction::FallbackModel => {
382                            // Switch to a smaller/cheaper model for this turn.
383                            let fallback = get_fallback_model(&model);
384                            sink.on_warning(&format!("Falling back from {model} to {fallback}"));
385                            model = fallback;
386                            continue;
387                        }
388                        crate::llm::retry::RetryAction::Abort(reason) => {
389                            // Unattended retry: in non-interactive mode, retry
390                            // capacity errors with longer backoff instead of aborting.
391                            if self.config.unattended
392                                && self.state.config.features.unattended_retry
393                                && matches!(
394                                    &e,
395                                    ProviderError::Overloaded | ProviderError::RateLimited { .. }
396                                )
397                            {
398                                warn!("Unattended retry: waiting 30s for capacity");
399                                tokio::time::sleep(std::time::Duration::from_secs(30)).await;
400                                continue;
401                            }
402                            // Before giving up, try reactive compact for size errors.
403                            // Two-stage recovery: context collapse first, then microcompact.
404                            if let ProviderError::RequestTooLarge(body) = &e {
405                                let gap = compact::parse_prompt_too_long_gap(body);
406
407                                // Stage 1: Context collapse (snip middle messages).
408                                let effective = compact::effective_context_window(&model);
409                                if let Some(collapse) =
410                                    crate::services::context_collapse::collapse_to_budget(
411                                        self.state.history(),
412                                        effective,
413                                    )
414                                {
415                                    info!(
416                                        "Reactive collapse: snipped {} messages, freed ~{} tokens",
417                                        collapse.snipped_count, collapse.tokens_freed
418                                    );
419                                    self.state.messages = collapse.api_messages;
420                                    sink.on_compact(collapse.tokens_freed);
421                                    continue;
422                                }
423
424                                // Stage 2: Aggressive microcompact.
425                                let freed = compact::microcompact(&mut self.state.messages, 1);
426                                if freed > 0 {
427                                    sink.on_compact(freed);
428                                    info!(
429                                        "Reactive microcompact freed ~{freed} tokens (gap: {gap:?})"
430                                    );
431                                    continue;
432                                }
433                            }
434                            sink.on_error(&reason);
435                            self.state.is_query_active = false;
436                            return Err(crate::error::Error::Other(e.to_string()));
437                        }
438                    }
439                }
440            };
441
442            // Step 4: Stream response. Start executing read-only tools
443            // as their input completes (streaming tool execution).
444            let mut content_blocks = Vec::new();
445            let mut usage = Usage::default();
446            let mut stop_reason: Option<StopReason> = None;
447            let mut got_error = false;
448            let mut error_text = String::new();
449
450            // Streaming tool handles: tools kicked off during streaming.
451            let mut streaming_tool_handles: Vec<(
452                String,
453                String,
454                tokio::task::JoinHandle<crate::tools::ToolResult>,
455            )> = Vec::new();
456
457            let mut cancelled = false;
458            loop {
459                tokio::select! {
460                    event = rx.recv() => {
461                        match event {
462                            Some(StreamEvent::TextDelta(text)) => {
463                                sink.on_text(&text);
464                            }
465                            Some(StreamEvent::ContentBlockComplete(block)) => {
466                                if let ContentBlock::ToolUse {
467                                    ref id,
468                                    ref name,
469                                    ref input,
470                                } = block
471                                {
472                                    sink.on_tool_start(name, input);
473
474                                    // Start read-only tools immediately during streaming.
475                                    if let Some(tool) = self.tools.get(name)
476                                        && tool.is_read_only()
477                                        && tool.is_concurrency_safe()
478                                    {
479                                        let tool = tool.clone();
480                                        let input = input.clone();
481                                        let cwd = std::path::PathBuf::from(&self.state.cwd);
482                                        let cancel = self.cancel.clone();
483                                        let perm = self.permissions.clone();
484                                        let tool_id = id.clone();
485                                        let tool_name = name.clone();
486
487                                        let handle = tokio::spawn(async move {
488                                            match tool
489                                                .call(
490                                                    input,
491                                                    &ToolContext {
492                                                        cwd,
493                                                        cancel,
494                                                        permission_checker: perm.clone(),
495                                                        verbose: false,
496                                                        plan_mode: false,
497                                                        file_cache: None,
498                                                        denial_tracker: None,
499                                                        task_manager: None,
500                                                        session_allows: None,
501                                                        permission_prompter: None,
502                                                        sandbox: None,
503                                                    },
504                                                )
505                                                .await
506                                            {
507                                                Ok(r) => r,
508                                                Err(e) => crate::tools::ToolResult::error(e.to_string()),
509                                            }
510                                        });
511
512                                        streaming_tool_handles.push((tool_id, tool_name, handle));
513                                    }
514                                }
515                                if let ContentBlock::Thinking { ref thinking, .. } = block {
516                                    sink.on_thinking(thinking);
517                                }
518                                content_blocks.push(block);
519                            }
520                            Some(StreamEvent::Done {
521                                usage: u,
522                                stop_reason: sr,
523                            }) => {
524                                usage = u;
525                                stop_reason = sr;
526                                sink.on_usage(&usage);
527                            }
528                            Some(StreamEvent::Error(msg)) => {
529                                got_error = true;
530                                error_text = msg.clone();
531                                sink.on_error(&msg);
532                            }
533                            Some(_) => {}
534                            None => break,
535                        }
536                    }
537                    _ = self.cancel.cancelled() => {
538                        warn!("Turn cancelled by user");
539                        cancelled = true;
540                        // Abort any in-flight streaming tool handles.
541                        for (_, _, handle) in streaming_tool_handles.drain(..) {
542                            handle.abort();
543                        }
544                        break;
545                    }
546                }
547            }
548
549            if cancelled {
550                sink.on_warning("Cancelled");
551                self.state.is_query_active = false;
552                return Ok(());
553            }
554
555            // Step 5: Record the assistant message.
556            let assistant_msg = Message::Assistant(AssistantMessage {
557                uuid: Uuid::new_v4(),
558                timestamp: chrono::Utc::now().to_rfc3339(),
559                content: content_blocks.clone(),
560                model: Some(model.clone()),
561                usage: Some(usage.clone()),
562                stop_reason: stop_reason.clone(),
563                request_id: None,
564            });
565            self.state.push_message(assistant_msg);
566            self.state.record_usage(&usage, &model);
567
568            // Token budget tracking per turn.
569            if self.state.config.features.token_budget && usage.total() > 0 {
570                let turn_total = usage.input_tokens + usage.output_tokens;
571                if turn_total > 100_000 {
572                    sink.on_warning(&format!(
573                        "High token usage this turn: {} tokens ({}in + {}out)",
574                        turn_total, usage.input_tokens, usage.output_tokens
575                    ));
576                }
577            }
578
579            // Record cache and telemetry.
580            let _cache_event = self.cache_tracker.record(&usage);
581            {
582                let mut span = crate::services::telemetry::api_call_span(
583                    &model,
584                    turn + 1,
585                    &self.state.session_id,
586                );
587                crate::services::telemetry::record_usage(&mut span, &usage);
588                span.finish();
589                tracing::debug!(
590                    "API call: {}ms, {}in/{}out tokens",
591                    span.duration_ms().unwrap_or(0),
592                    usage.input_tokens,
593                    usage.output_tokens,
594                );
595            }
596
597            // Step 6: Handle stream errors.
598            if got_error {
599                // Check if it's a prompt-too-long error in the stream.
600                if error_text.contains("prompt is too long")
601                    || error_text.contains("Prompt is too long")
602                {
603                    let freed = compact::microcompact(&mut self.state.messages, 1);
604                    if freed > 0 {
605                        sink.on_compact(freed);
606                        continue;
607                    }
608                }
609
610                // Check for max-output-tokens hit (partial response).
611                if content_blocks
612                    .iter()
613                    .any(|b| matches!(b, ContentBlock::Text { .. }))
614                    && error_text.contains("max_tokens")
615                    && max_output_recovery_count < MAX_OUTPUT_TOKENS_RECOVERY_LIMIT
616                {
617                    max_output_recovery_count += 1;
618                    info!(
619                        "Max output tokens recovery attempt {}/{}",
620                        max_output_recovery_count, MAX_OUTPUT_TOKENS_RECOVERY_LIMIT
621                    );
622                    let recovery_msg = compact::max_output_recovery_message();
623                    self.state.push_message(recovery_msg);
624                    continue;
625                }
626            }
627
628            // Step 6b: Handle max_tokens stop reason (escalate and continue).
629            if matches!(stop_reason, Some(StopReason::MaxTokens))
630                && !got_error
631                && content_blocks
632                    .iter()
633                    .any(|b| matches!(b, ContentBlock::Text { .. }))
634                && max_output_recovery_count < MAX_OUTPUT_TOKENS_RECOVERY_LIMIT
635            {
636                max_output_recovery_count += 1;
637                info!(
638                    "Max tokens stop reason — recovery attempt {}/{}",
639                    max_output_recovery_count, MAX_OUTPUT_TOKENS_RECOVERY_LIMIT
640                );
641                let recovery_msg = compact::max_output_recovery_message();
642                self.state.push_message(recovery_msg);
643                continue;
644            }
645
646            // Step 7: Extract tool calls from the response.
647            let tool_calls = extract_tool_calls(&content_blocks);
648
649            if tool_calls.is_empty() {
650                // No tools requested — turn is complete.
651                info!("Turn complete (no tool calls)");
652                sink.on_turn_complete(turn + 1);
653                self.state.is_query_active = false;
654
655                // Fire background memory extraction (fire-and-forget).
656                // Only runs if feature enabled and memory directory exists.
657                if self.state.config.features.extract_memories
658                    && crate::memory::ensure_memory_dir().is_some()
659                {
660                    let extraction_messages = self.state.messages.clone();
661                    let extraction_state = self.extraction_state.clone();
662                    let extraction_llm = self.llm.clone();
663                    let extraction_model = model.clone();
664                    tokio::spawn(async move {
665                        crate::memory::extraction::extract_memories_background(
666                            extraction_messages,
667                            extraction_state,
668                            extraction_llm,
669                            extraction_model,
670                        )
671                        .await;
672                    });
673                }
674
675                return Ok(());
676            }
677
678            // Step 8: Execute tool calls with pre/post hooks.
679            info!("Executing {} tool call(s)", tool_calls.len());
680            let cwd = PathBuf::from(&self.state.cwd);
681            let tool_ctx = ToolContext {
682                cwd: cwd.clone(),
683                cancel: self.cancel.clone(),
684                permission_checker: self.permissions.clone(),
685                verbose: self.config.verbose,
686                plan_mode: self.state.plan_mode,
687                file_cache: Some(self.file_cache.clone()),
688                denial_tracker: Some(self.denial_tracker.clone()),
689                task_manager: Some(self.state.task_manager.clone()),
690                session_allows: Some(self.session_allows.clone()),
691                permission_prompter: self.permission_prompter.clone(),
692                sandbox: Some(std::sync::Arc::new(
693                    crate::sandbox::SandboxExecutor::from_session_config(&self.state.config, &cwd),
694                )),
695            };
696
697            // Collect streaming tool results first.
698            let streaming_ids: std::collections::HashSet<String> = streaming_tool_handles
699                .iter()
700                .map(|(id, _, _)| id.clone())
701                .collect();
702
703            let mut streaming_results = Vec::new();
704            for (id, name, handle) in streaming_tool_handles.drain(..) {
705                match handle.await {
706                    Ok(result) => streaming_results.push(crate::tools::executor::ToolCallResult {
707                        tool_use_id: id,
708                        tool_name: name,
709                        result,
710                    }),
711                    Err(e) => streaming_results.push(crate::tools::executor::ToolCallResult {
712                        tool_use_id: id,
713                        tool_name: name,
714                        result: crate::tools::ToolResult::error(format!("Task failed: {e}")),
715                    }),
716                }
717            }
718
719            // Fire pre-tool-use hooks.
720            for call in &tool_calls {
721                self.hooks
722                    .run_hooks(&HookEvent::PreToolUse, Some(&call.name), &call.input)
723                    .await;
724            }
725
726            // Execute remaining tools (ones not started during streaming).
727            let remaining_calls: Vec<_> = tool_calls
728                .iter()
729                .filter(|c| !streaming_ids.contains(&c.id))
730                .cloned()
731                .collect();
732
733            let mut results = streaming_results;
734            if !remaining_calls.is_empty() {
735                let batch_results = execute_tool_calls(
736                    &remaining_calls,
737                    self.tools.all(),
738                    &tool_ctx,
739                    &self.permissions,
740                )
741                .await;
742                results.extend(batch_results);
743            }
744
745            // Step 9: Inject tool results + fire post-tool-use hooks.
746            for result in &results {
747                // Handle plan mode state transitions.
748                if !result.result.is_error {
749                    match result.tool_name.as_str() {
750                        "EnterPlanMode" => {
751                            self.state.plan_mode = true;
752                            info!("Plan mode enabled");
753                        }
754                        "ExitPlanMode" => {
755                            self.state.plan_mode = false;
756                            info!("Plan mode disabled");
757                        }
758                        _ => {}
759                    }
760                }
761
762                sink.on_tool_result(&result.tool_name, &result.result);
763
764                // Fire post-tool-use hooks.
765                self.hooks
766                    .run_hooks(
767                        &HookEvent::PostToolUse,
768                        Some(&result.tool_name),
769                        &serde_json::json!({
770                            "tool": result.tool_name,
771                            "is_error": result.result.is_error,
772                        }),
773                    )
774                    .await;
775
776                let msg = tool_result_message(
777                    &result.tool_use_id,
778                    &result.result.content,
779                    result.result.is_error,
780                );
781                self.state.push_message(msg);
782            }
783
784            // Continue the loop — the model will see the tool results.
785        }
786
787        warn!("Max turns ({max_turns}) reached");
788        sink.on_warning(&format!("Agent stopped after {max_turns} turns"));
789        self.state.is_query_active = false;
790        Ok(())
791    }
792
793    /// Cancel the current operation.
794    pub fn cancel(&self) {
795        self.cancel.cancel();
796    }
797
798    /// Get a cloneable cancel token for use in background tasks.
799    pub fn cancel_token(&self) -> tokio_util::sync::CancellationToken {
800        self.cancel.clone()
801    }
802}
803
804/// Get a fallback model (smaller/cheaper) for retry on overload.
805fn get_fallback_model(current: &str) -> String {
806    let lower = current.to_lowercase();
807    if lower.contains("opus") {
808        // Opus → Sonnet
809        current.replace("opus", "sonnet")
810    } else if (lower.contains("gpt-5.4") || lower.contains("gpt-4.1"))
811        && !lower.contains("mini")
812        && !lower.contains("nano")
813    {
814        format!("{current}-mini")
815    } else if lower.contains("large") {
816        current.replace("large", "small")
817    } else {
818        // Already a small model, keep it.
819        current.to_string()
820    }
821}
822
823/// Build the system prompt from tool definitions, app state, and memory.
824pub fn build_system_prompt(tools: &ToolRegistry, state: &AppState) -> String {
825    let mut prompt = String::new();
826
827    prompt.push_str(
828        "You are an AI coding agent. You help users with software engineering tasks \
829         by reading, writing, and searching code. Use the tools available to you to \
830         accomplish tasks.\n\n",
831    );
832
833    // Environment context.
834    let shell = std::env::var("SHELL").unwrap_or_else(|_| "bash".to_string());
835    let is_git = std::path::Path::new(&state.cwd).join(".git").exists();
836    prompt.push_str(&format!(
837        "# Environment\n\
838         - Working directory: {}\n\
839         - Platform: {}\n\
840         - Shell: {shell}\n\
841         - Git repository: {}\n\n",
842        state.cwd,
843        std::env::consts::OS,
844        if is_git { "yes" } else { "no" },
845    ));
846
847    // Inject memory context (project + user + on-demand relevant).
848    let mut memory = crate::memory::MemoryContext::load(Some(std::path::Path::new(&state.cwd)));
849
850    // On-demand: surface relevant memories based on recent conversation.
851    let recent_text: String = state
852        .messages
853        .iter()
854        .rev()
855        .take(5)
856        .filter_map(|m| match m {
857            crate::llm::message::Message::User(u) => Some(
858                u.content
859                    .iter()
860                    .filter_map(|b| b.as_text())
861                    .collect::<Vec<_>>()
862                    .join(" "),
863            ),
864            _ => None,
865        })
866        .collect::<Vec<_>>()
867        .join(" ");
868
869    if !recent_text.is_empty() {
870        memory.load_relevant(&recent_text);
871    }
872
873    let memory_section = memory.to_system_prompt_section();
874    if !memory_section.is_empty() {
875        prompt.push_str(&memory_section);
876    }
877
878    // Tool documentation.
879    prompt.push_str("# Available Tools\n\n");
880    for tool in tools.all() {
881        if tool.is_enabled() {
882            prompt.push_str(&format!("## {}\n{}\n\n", tool.name(), tool.prompt()));
883        }
884    }
885
886    // Available skills.
887    let skills = crate::skills::SkillRegistry::load_all(Some(std::path::Path::new(&state.cwd)));
888    let invocable = skills.user_invocable();
889    if !invocable.is_empty() {
890        prompt.push_str("# Available Skills\n\n");
891        for skill in invocable {
892            let desc = skill.metadata.description.as_deref().unwrap_or("");
893            let when = skill.metadata.when_to_use.as_deref().unwrap_or("");
894            prompt.push_str(&format!("- `/{}`", skill.name));
895            if !desc.is_empty() {
896                prompt.push_str(&format!(": {desc}"));
897            }
898            if !when.is_empty() {
899                prompt.push_str(&format!(" (use when: {when})"));
900            }
901            prompt.push('\n');
902        }
903        prompt.push('\n');
904    }
905
906    // Guidelines and safety framework.
907    prompt.push_str(
908        "# Using tools\n\n\
909         Use dedicated tools instead of shell commands when available:\n\
910         - File search: Glob (not find or ls)\n\
911         - Content search: Grep (not grep or rg)\n\
912         - Read files: FileRead (not cat/head/tail)\n\
913         - Edit files: FileEdit (not sed/awk)\n\
914         - Write files: FileWrite (not echo/cat with redirect)\n\
915         - Reserve Bash for system commands and operations that require shell execution.\n\
916         - Break complex tasks into steps. Use multiple tool calls in parallel when independent.\n\
917         - Use the Agent tool for complex multi-step research or tasks that benefit from isolation.\n\n\
918         # Working with code\n\n\
919         - Read files before editing them. Understand existing code before suggesting changes.\n\
920         - Prefer editing existing files over creating new ones to avoid file bloat.\n\
921         - Only make changes that were requested. Don't add features, refactor, add comments, \
922           or make \"improvements\" beyond the ask.\n\
923         - Don't add error handling for scenarios that can't happen. Don't design for \
924           hypothetical future requirements.\n\
925         - When referencing code, include file_path:line_number.\n\
926         - Be careful not to introduce security vulnerabilities (command injection, XSS, SQL injection, \
927           OWASP top 10). If you notice insecure code you wrote, fix it immediately.\n\
928         - Don't add docstrings, comments, or type annotations to code you didn't change.\n\
929         - Three similar lines of code is better than a premature abstraction.\n\n\
930         # Git safety protocol\n\n\
931         - NEVER update the git config.\n\
932         - NEVER run destructive git commands (push --force, reset --hard, checkout ., restore ., \
933           clean -f, branch -D) unless the user explicitly requests them.\n\
934         - NEVER skip hooks (--no-verify, --no-gpg-sign) unless the user explicitly requests it.\n\
935         - NEVER force push to main/master. Warn the user if they request it.\n\
936         - Always create NEW commits rather than amending, unless the user explicitly requests amend. \
937           After hook failure, the commit did NOT happen — amend would modify the PREVIOUS commit.\n\
938         - When staging files, prefer adding specific files by name rather than git add -A or git add ., \
939           which can accidentally include sensitive files.\n\
940         - NEVER commit changes unless the user explicitly asks.\n\n\
941         # Committing changes\n\n\
942         When the user asks to commit:\n\
943         1. Run git status and git diff to see all changes.\n\
944         2. Run git log --oneline -5 to match the repository's commit message style.\n\
945         3. Draft a concise (1-2 sentence) commit message focusing on \"why\" not \"what\".\n\
946         4. Do not commit files that likely contain secrets (.env, credentials.json).\n\
947         5. Stage specific files, create the commit.\n\
948         6. If pre-commit hook fails, fix the issue and create a NEW commit.\n\
949         7. When creating commits, include a co-author attribution line at the end of the message.\n\n\
950         # Creating pull requests\n\n\
951         When the user asks to create a PR:\n\
952         1. Run git status, git diff, and git log to understand all changes on the branch.\n\
953         2. Analyze ALL commits (not just the latest) that will be in the PR.\n\
954         3. Draft a short title (under 70 chars) and detailed body with summary and test plan.\n\
955         4. Push to remote with -u flag if needed, then create PR using gh pr create.\n\
956         5. Return the PR URL when done.\n\n\
957         # Executing actions safely\n\n\
958         Consider the reversibility and blast radius of every action:\n\
959         - Freely take local, reversible actions (editing files, running tests).\n\
960         - For hard-to-reverse or shared-state actions, confirm with the user first:\n\
961           - Destructive: deleting files/branches, dropping tables, rm -rf, overwriting uncommitted changes.\n\
962           - Hard to reverse: force-pushing, git reset --hard, amending published commits.\n\
963           - Visible to others: pushing code, creating/commenting on PRs/issues, sending messages.\n\
964         - When you encounter an obstacle, do not use destructive actions as a shortcut. \
965           Identify root causes and fix underlying issues.\n\
966         - If you discover unexpected state (unfamiliar files, branches, config), investigate \
967           before deleting or overwriting — it may be the user's in-progress work.\n\n\
968         # Response style\n\n\
969         - Be concise. Lead with the answer or action, not the reasoning.\n\
970         - Skip filler, preamble, and unnecessary transitions.\n\
971         - Don't restate what the user said.\n\
972         - If you can say it in one sentence, don't use three.\n\
973         - Focus output on: decisions that need input, status updates, and errors that change the plan.\n\
974         - When referencing GitHub issues or PRs, use owner/repo#123 format.\n\
975         - Only use emojis if the user explicitly requests it.\n\n\
976         # Memory\n\n\
977         You can save information across sessions by writing memory files.\n\
978         - Save to: ~/.config/agent-code/memory/ (one .md file per topic)\n\
979         - Each file needs YAML frontmatter: name, description, type (user/feedback/project/reference)\n\
980         - After writing a file, update MEMORY.md with a one-line pointer\n\
981         - Memory types: user (role, preferences), feedback (corrections, confirmations), \
982           project (decisions, deadlines), reference (external resources)\n\
983         - Do NOT store: code patterns, git history, debugging solutions, anything derivable from code\n\
984         - Memory is a hint — always verify against current state before acting on it\n",
985    );
986
987    // Detailed tool usage examples and workflow patterns.
988    prompt.push_str(
989        "# Tool usage patterns\n\n\
990         Common patterns for effective tool use:\n\n\
991         **Read before edit**: Always read a file before editing it. This ensures you \
992         understand the current state and can make targeted changes.\n\
993         ```\n\
994         1. FileRead file_path → understand structure\n\
995         2. FileEdit old_string, new_string → targeted change\n\
996         ```\n\n\
997         **Search then act**: Use Glob to find files, Grep to find content, then read/edit.\n\
998         ```\n\
999         1. Glob **/*.rs → find Rust files\n\
1000         2. Grep pattern path → find specific code\n\
1001         3. FileRead → read the match\n\
1002         4. FileEdit → make the change\n\
1003         ```\n\n\
1004         **Parallel tool calls**: When you need to read multiple independent files or run \
1005         independent searches, make all the tool calls in one response. Don't serialize \
1006         independent operations.\n\n\
1007         **Test after change**: After editing code, run tests to verify the change works.\n\
1008         ```\n\
1009         1. FileEdit → make change\n\
1010         2. Bash cargo test / pytest / npm test → verify\n\
1011         3. If tests fail, read the error, fix, re-test\n\
1012         ```\n\n\
1013         # Error recovery\n\n\
1014         When something goes wrong:\n\
1015         - **Tool not found**: Use ToolSearch to find the right tool name.\n\
1016         - **Permission denied**: Explain why the action is needed, ask the user to approve.\n\
1017         - **File not found**: Use Glob to find the correct path. Check for typos.\n\
1018         - **Edit failed (not unique)**: Provide more surrounding context in old_string, \
1019           or use replace_all=true if renaming.\n\
1020         - **Command failed**: Read the full error message. Don't retry the same command. \
1021           Diagnose the root cause first.\n\
1022         - **Context too large**: The system will auto-compact. If you need specific \
1023           information from before compaction, re-read the relevant files.\n\
1024         - **Rate limited**: The system will auto-retry with backoff. Just wait.\n\n\
1025         # Common workflows\n\n\
1026         **Bug fix**: Read the failing test → read the source code it tests → \
1027         identify the bug → fix it → run the test → confirm it passes.\n\n\
1028         **New feature**: Read existing patterns in the codebase → create or edit files → \
1029         add tests → run tests → update docs if needed.\n\n\
1030         **Code review**: Read the diff → identify issues (bugs, security, style) → \
1031         report findings with file:line references.\n\n\
1032         **Refactor**: Search for all usages of the symbol → plan the changes → \
1033         edit each file → run tests to verify nothing broke.\n\n",
1034    );
1035
1036    // MCP server instructions (dynamic, per-server).
1037    if !state.config.mcp_servers.is_empty() {
1038        prompt.push_str("# MCP Servers\n\n");
1039        prompt.push_str(
1040            "Connected MCP servers provide additional tools. MCP tools are prefixed \
1041             with `mcp__{server}__{tool}`. Use them like any other tool.\n\n",
1042        );
1043        for (name, entry) in &state.config.mcp_servers {
1044            let transport = if entry.command.is_some() {
1045                "stdio"
1046            } else if entry.url.is_some() {
1047                "sse"
1048            } else {
1049                "unknown"
1050            };
1051            prompt.push_str(&format!("- **{name}** ({transport})\n"));
1052        }
1053        prompt.push('\n');
1054    }
1055
1056    // Deferred tools listing.
1057    let deferred = tools.deferred_names();
1058    if !deferred.is_empty() {
1059        prompt.push_str("# Deferred Tools\n\n");
1060        prompt.push_str(
1061            "These tools are available but not loaded by default. \
1062             Use ToolSearch to load them when needed:\n",
1063        );
1064        for name in &deferred {
1065            prompt.push_str(&format!("- {name}\n"));
1066        }
1067        prompt.push('\n');
1068    }
1069
1070    // Task management guidance.
1071    prompt.push_str(
1072        "# Task management\n\n\
1073         - Use TaskCreate to break complex work into trackable steps.\n\
1074         - Mark tasks as in_progress when starting, completed when done.\n\
1075         - Use the Agent tool to spawn subagents for parallel independent work.\n\
1076         - Use EnterPlanMode/ExitPlanMode for read-only exploration before making changes.\n\
1077         - Use EnterWorktree/ExitWorktree for isolated changes in git worktrees.\n\n\
1078         # Output formatting\n\n\
1079         - All text output is displayed to the user. Use GitHub-flavored markdown.\n\
1080         - Use fenced code blocks with language hints for code: ```rust, ```python, etc.\n\
1081         - Use inline `code` for file names, function names, and short code references.\n\
1082         - Use tables for structured comparisons.\n\
1083         - Use bullet lists for multiple items.\n\
1084         - Keep paragraphs short (2-3 sentences).\n\
1085         - Never output raw HTML or complex formatting — stick to standard markdown.\n",
1086    );
1087
1088    prompt
1089}
1090
1091#[cfg(test)]
1092mod tests {
1093    use super::*;
1094
1095    /// Verify that cancelling via the shared handle cancels the current
1096    /// turn's token (regression: the signal handler previously held a
1097    /// stale clone that couldn't cancel subsequent turns).
1098    #[test]
1099    fn cancel_shared_propagates_to_current_token() {
1100        let root = CancellationToken::new();
1101        let shared = Arc::new(std::sync::Mutex::new(root.clone()));
1102
1103        // Simulate turn reset: create a new token and update the shared handle.
1104        let turn1 = CancellationToken::new();
1105        *shared.lock().unwrap() = turn1.clone();
1106
1107        // Cancelling via the shared handle should cancel turn1.
1108        shared.lock().unwrap().cancel();
1109        assert!(turn1.is_cancelled());
1110
1111        // New turn: replace the token. The old cancellation shouldn't affect it.
1112        let turn2 = CancellationToken::new();
1113        *shared.lock().unwrap() = turn2.clone();
1114        assert!(!turn2.is_cancelled());
1115
1116        // Cancelling via shared should cancel turn2.
1117        shared.lock().unwrap().cancel();
1118        assert!(turn2.is_cancelled());
1119    }
1120
1121    /// Verify that the streaming loop breaks on cancellation by simulating
1122    /// the select pattern used in run_turn_with_sink.
1123    #[tokio::test]
1124    async fn stream_loop_responds_to_cancellation() {
1125        let cancel = CancellationToken::new();
1126        let (tx, mut rx) = tokio::sync::mpsc::channel::<StreamEvent>(10);
1127
1128        // Simulate a slow stream: send one event, then cancel before more arrive.
1129        tx.send(StreamEvent::TextDelta("hello".into()))
1130            .await
1131            .unwrap();
1132
1133        let cancel2 = cancel.clone();
1134        tokio::spawn(async move {
1135            // Small delay, then cancel.
1136            tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1137            cancel2.cancel();
1138        });
1139
1140        let mut events_received = 0u32;
1141        let mut cancelled = false;
1142
1143        loop {
1144            tokio::select! {
1145                event = rx.recv() => {
1146                    match event {
1147                        Some(_) => events_received += 1,
1148                        None => break,
1149                    }
1150                }
1151                _ = cancel.cancelled() => {
1152                    cancelled = true;
1153                    break;
1154                }
1155            }
1156        }
1157
1158        assert!(cancelled, "Loop should have been cancelled");
1159        assert_eq!(
1160            events_received, 1,
1161            "Should have received exactly one event before cancel"
1162        );
1163    }
1164
1165    // ------------------------------------------------------------------
1166    // End-to-end regression tests for #103.
1167    //
1168    // These tests build a real QueryEngine with a mock Provider and
1169    // exercise run_turn_with_sink directly, verifying that cancellation
1170    // actually interrupts the streaming loop (not just the select!
1171    // pattern in isolation).
1172    // ------------------------------------------------------------------
1173
1174    use crate::llm::provider::{Provider, ProviderError, ProviderRequest};
1175
1176    /// A provider whose stream yields one TextDelta and then hangs forever.
1177    /// Simulates the real bug: a slow LLM response the user wants to interrupt.
1178    struct HangingProvider;
1179
1180    #[async_trait::async_trait]
1181    impl Provider for HangingProvider {
1182        fn name(&self) -> &str {
1183            "hanging-mock"
1184        }
1185
1186        async fn stream(
1187            &self,
1188            _request: &ProviderRequest,
1189        ) -> Result<tokio::sync::mpsc::Receiver<StreamEvent>, ProviderError> {
1190            let (tx, rx) = tokio::sync::mpsc::channel(4);
1191            tokio::spawn(async move {
1192                let _ = tx.send(StreamEvent::TextDelta("thinking...".into())).await;
1193                // Hang forever without closing the channel or sending Done.
1194                let _tx_holder = tx;
1195                std::future::pending::<()>().await;
1196            });
1197            Ok(rx)
1198        }
1199    }
1200
1201    /// A provider whose spawned streaming task honors `ProviderRequest::cancel`.
1202    /// When it exits cleanly via cancellation it flips `exit_flag` to true.
1203    /// Used to prove that the cancel token reaches the provider's own task
1204    /// (the anthropic/openai/azure SSE loops), not just the query engine's
1205    /// outer select.
1206    struct CancelAwareHangingProvider {
1207        exit_flag: Arc<std::sync::atomic::AtomicBool>,
1208    }
1209
1210    #[async_trait::async_trait]
1211    impl Provider for CancelAwareHangingProvider {
1212        fn name(&self) -> &str {
1213            "cancel-aware-mock"
1214        }
1215
1216        async fn stream(
1217            &self,
1218            request: &ProviderRequest,
1219        ) -> Result<tokio::sync::mpsc::Receiver<StreamEvent>, ProviderError> {
1220            let (tx, rx) = tokio::sync::mpsc::channel(4);
1221            let cancel = request.cancel.clone();
1222            let exit_flag = self.exit_flag.clone();
1223            tokio::spawn(async move {
1224                let _ = tx.send(StreamEvent::TextDelta("thinking...".into())).await;
1225                // Mirror the real provider pattern: race the "next SSE chunk"
1226                // future against the cancel token. A real provider's
1227                // `byte_stream.next().await` would sit where `pending` does.
1228                tokio::select! {
1229                    biased;
1230                    _ = cancel.cancelled() => {
1231                        exit_flag.store(true, std::sync::atomic::Ordering::SeqCst);
1232                    }
1233                    _ = std::future::pending::<()>() => unreachable!(),
1234                }
1235            });
1236            Ok(rx)
1237        }
1238    }
1239
1240    /// A provider that completes a turn normally: emits text and a Done event.
1241    struct CompletingProvider;
1242
1243    #[async_trait::async_trait]
1244    impl Provider for CompletingProvider {
1245        fn name(&self) -> &str {
1246            "completing-mock"
1247        }
1248
1249        async fn stream(
1250            &self,
1251            _request: &ProviderRequest,
1252        ) -> Result<tokio::sync::mpsc::Receiver<StreamEvent>, ProviderError> {
1253            let (tx, rx) = tokio::sync::mpsc::channel(8);
1254            tokio::spawn(async move {
1255                let _ = tx.send(StreamEvent::TextDelta("hello".into())).await;
1256                let _ = tx
1257                    .send(StreamEvent::ContentBlockComplete(ContentBlock::Text {
1258                        text: "hello".into(),
1259                    }))
1260                    .await;
1261                let _ = tx
1262                    .send(StreamEvent::Done {
1263                        usage: Usage::default(),
1264                        stop_reason: Some(StopReason::EndTurn),
1265                    })
1266                    .await;
1267                // Channel closes when tx drops.
1268            });
1269            Ok(rx)
1270        }
1271    }
1272
1273    fn build_engine(llm: Arc<dyn Provider>) -> QueryEngine {
1274        use crate::config::Config;
1275        use crate::permissions::PermissionChecker;
1276        use crate::state::AppState;
1277        use crate::tools::registry::ToolRegistry;
1278
1279        let config = Config::default();
1280        let permissions = PermissionChecker::from_config(&config.permissions);
1281        let state = AppState::new(config);
1282
1283        QueryEngine::new(
1284            llm,
1285            ToolRegistry::default_tools(),
1286            permissions,
1287            state,
1288            QueryEngineConfig {
1289                max_turns: Some(1),
1290                verbose: false,
1291                unattended: true,
1292            },
1293        )
1294    }
1295
1296    /// Schedule a cancellation after `delay_ms` via the shared handle
1297    /// (same path the signal handler uses).
1298    fn schedule_cancel(engine: &QueryEngine, delay_ms: u64) {
1299        let shared = engine.cancel_shared.clone();
1300        tokio::spawn(async move {
1301            tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
1302            shared.lock().unwrap().cancel();
1303        });
1304    }
1305
1306    /// Builds a mock provider whose stream yields one TextDelta and then hangs.
1307    /// Verifies the turn returns promptly once cancel fires.
1308    #[tokio::test]
1309    async fn run_turn_with_sink_interrupts_on_cancel() {
1310        let mut engine = build_engine(Arc::new(HangingProvider));
1311        schedule_cancel(&engine, 100);
1312
1313        let result = tokio::time::timeout(
1314            std::time::Duration::from_secs(5),
1315            engine.run_turn_with_sink("test input", &NullSink),
1316        )
1317        .await;
1318
1319        assert!(
1320            result.is_ok(),
1321            "run_turn_with_sink should return promptly on cancel, not hang"
1322        );
1323        assert!(
1324            result.unwrap().is_ok(),
1325            "cancelled turn should return Ok(()), not an error"
1326        );
1327        assert!(
1328            !engine.state().is_query_active,
1329            "is_query_active should be reset after cancel"
1330        );
1331    }
1332
1333    /// Regression test for the original #103 bug: the signal handler held
1334    /// a stale clone of the cancellation token, so Ctrl+C only worked on
1335    /// the *first* turn. This test cancels turn 1, then runs turn 2 and
1336    /// verifies it is ALSO cancellable via the same shared handle.
1337    #[tokio::test]
1338    async fn cancel_works_across_multiple_turns() {
1339        let mut engine = build_engine(Arc::new(HangingProvider));
1340
1341        // Turn 1: cancel mid-stream.
1342        schedule_cancel(&engine, 80);
1343        let r1 = tokio::time::timeout(
1344            std::time::Duration::from_secs(5),
1345            engine.run_turn_with_sink("turn 1", &NullSink),
1346        )
1347        .await;
1348        assert!(r1.is_ok(), "turn 1 should cancel promptly");
1349        assert!(!engine.state().is_query_active);
1350
1351        // Turn 2: cancel again via the same shared handle.
1352        // With the pre-fix stale-token bug, the handle would be pointing
1353        // at turn 1's already-used token and turn 2 would hang forever.
1354        schedule_cancel(&engine, 80);
1355        let r2 = tokio::time::timeout(
1356            std::time::Duration::from_secs(5),
1357            engine.run_turn_with_sink("turn 2", &NullSink),
1358        )
1359        .await;
1360        assert!(
1361            r2.is_ok(),
1362            "turn 2 should also cancel promptly — regression would hang here"
1363        );
1364        assert!(!engine.state().is_query_active);
1365
1366        // Turn 3: one more for good measure.
1367        schedule_cancel(&engine, 80);
1368        let r3 = tokio::time::timeout(
1369            std::time::Duration::from_secs(5),
1370            engine.run_turn_with_sink("turn 3", &NullSink),
1371        )
1372        .await;
1373        assert!(r3.is_ok(), "turn 3 should still be cancellable");
1374        assert!(!engine.state().is_query_active);
1375    }
1376
1377    /// Verifies that a previously-cancelled token does not poison subsequent
1378    /// turns. A fresh run_turn_with_sink on the same engine should complete
1379    /// normally even after a prior cancel.
1380    #[tokio::test]
1381    async fn cancel_does_not_poison_next_turn() {
1382        // Turn 1: hangs and gets cancelled.
1383        let mut engine = build_engine(Arc::new(HangingProvider));
1384        schedule_cancel(&engine, 80);
1385        let _ = tokio::time::timeout(
1386            std::time::Duration::from_secs(5),
1387            engine.run_turn_with_sink("turn 1", &NullSink),
1388        )
1389        .await
1390        .expect("turn 1 should cancel");
1391
1392        // Swap the provider to one that completes normally by rebuilding
1393        // the engine (we can't swap llm on an existing engine, so this
1394        // simulates the isolated "fresh turn" behavior). The key property
1395        // being tested is that the per-turn cancel reset correctly
1396        // initializes a non-cancelled token.
1397        let mut engine2 = build_engine(Arc::new(CompletingProvider));
1398
1399        // Pre-cancel engine2 to simulate a leftover cancelled state, then
1400        // verify run_turn_with_sink still works because it resets the token.
1401        engine2.cancel_shared.lock().unwrap().cancel();
1402
1403        let result = tokio::time::timeout(
1404            std::time::Duration::from_secs(5),
1405            engine2.run_turn_with_sink("hello", &NullSink),
1406        )
1407        .await;
1408
1409        assert!(result.is_ok(), "completing turn should not hang");
1410        assert!(
1411            result.unwrap().is_ok(),
1412            "turn should succeed — the stale cancel flag must be cleared on turn start"
1413        );
1414        // Message history should contain the user + assistant messages.
1415        assert!(
1416            engine2.state().messages.len() >= 2,
1417            "normal turn should push both user and assistant messages"
1418        );
1419    }
1420
1421    /// Verifies that cancelling BEFORE any event arrives still interrupts
1422    /// the turn cleanly (edge case: cancellation races with the first recv).
1423    #[tokio::test]
1424    async fn cancel_before_first_event_interrupts_cleanly() {
1425        let mut engine = build_engine(Arc::new(HangingProvider));
1426        // Very short delay so cancel likely fires before or during the
1427        // first event. The test is tolerant of either ordering.
1428        schedule_cancel(&engine, 1);
1429
1430        let result = tokio::time::timeout(
1431            std::time::Duration::from_secs(5),
1432            engine.run_turn_with_sink("immediate", &NullSink),
1433        )
1434        .await;
1435
1436        assert!(result.is_ok(), "early cancel should not hang");
1437        assert!(result.unwrap().is_ok());
1438        assert!(!engine.state().is_query_active);
1439    }
1440
1441    /// Verifies the sink receives cancellation feedback via on_warning.
1442    #[tokio::test]
1443    async fn cancelled_turn_emits_warning_to_sink() {
1444        use std::sync::Mutex;
1445
1446        /// Captures sink events for assertion.
1447        struct CapturingSink {
1448            warnings: Mutex<Vec<String>>,
1449        }
1450
1451        impl StreamSink for CapturingSink {
1452            fn on_text(&self, _: &str) {}
1453            fn on_tool_start(&self, _: &str, _: &serde_json::Value) {}
1454            fn on_tool_result(&self, _: &str, _: &crate::tools::ToolResult) {}
1455            fn on_error(&self, _: &str) {}
1456            fn on_warning(&self, msg: &str) {
1457                self.warnings.lock().unwrap().push(msg.to_string());
1458            }
1459        }
1460
1461        let sink = CapturingSink {
1462            warnings: Mutex::new(Vec::new()),
1463        };
1464
1465        let mut engine = build_engine(Arc::new(HangingProvider));
1466        schedule_cancel(&engine, 100);
1467
1468        let _ = tokio::time::timeout(
1469            std::time::Duration::from_secs(5),
1470            engine.run_turn_with_sink("test", &sink),
1471        )
1472        .await
1473        .expect("should not hang");
1474
1475        let warnings = sink.warnings.lock().unwrap();
1476        assert!(
1477            warnings.iter().any(|w| w.contains("Cancelled")),
1478            "expected 'Cancelled' warning in sink, got: {:?}",
1479            *warnings
1480        );
1481    }
1482
1483    /// Regression test for the #125-followup: the cancellation token must
1484    /// reach the *provider's own spawned streaming task*, not just the query
1485    /// engine's outer select loop. This is what was broken — the existing
1486    /// `HangingProvider` test above passed even while Escape-interrupt was
1487    /// completely dead in production, because that provider ignores the
1488    /// token and the query loop exits cleanly on its own. This test fails
1489    /// if `ProviderRequest::cancel` is dropped on the floor anywhere between
1490    /// `query::mod.rs` and the provider's spawn.
1491    #[tokio::test]
1492    async fn provider_stream_task_observes_cancellation() {
1493        use std::sync::atomic::{AtomicBool, Ordering};
1494
1495        let exit_flag = Arc::new(AtomicBool::new(false));
1496        let provider = Arc::new(CancelAwareHangingProvider {
1497            exit_flag: exit_flag.clone(),
1498        });
1499        let mut engine = build_engine(provider);
1500        schedule_cancel(&engine, 50);
1501
1502        let result = tokio::time::timeout(
1503            std::time::Duration::from_secs(2),
1504            engine.run_turn_with_sink("test input", &NullSink),
1505        )
1506        .await;
1507        assert!(result.is_ok(), "engine should exit promptly on cancel");
1508
1509        // Give the provider's spawned task a moment to observe the cancel
1510        // after the engine returns. In release builds this is effectively
1511        // instantaneous; the sleep just makes the test tolerant of scheduler
1512        // variance on slow CI runners.
1513        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
1514
1515        assert!(
1516            exit_flag.load(Ordering::SeqCst),
1517            "provider's streaming task should have observed cancel via \
1518             ProviderRequest::cancel and exited; if this flag is false, \
1519             the token is being dropped somewhere in query::mod.rs or the \
1520             provider is ignoring it"
1521        );
1522    }
1523}