Skip to main content

agent_code_lib/query/
mod.rs

1//! Query engine: the core agent loop.
2//!
3//! Implements the agentic cycle:
4//!
5//! 1. Auto-compact if context nears the window limit
6//! 2. Microcompact stale tool results
7//! 3. Call LLM with streaming
8//! 4. Accumulate response content blocks
9//! 5. Handle errors (prompt-too-long, rate limits, max-output-tokens)
10//! 6. Extract tool_use blocks
11//! 7. Execute tools (concurrent/serial batching)
12//! 8. Inject tool results into history
13//! 9. Repeat from step 1 until no tool_use or max turns
14
15pub mod source;
16
17use std::path::PathBuf;
18use std::sync::Arc;
19
20use tokio_util::sync::CancellationToken;
21use tracing::{debug, info, warn};
22use uuid::Uuid;
23
24use crate::hooks::{HookEvent, HookRegistry};
25use crate::llm::message::*;
26use crate::llm::provider::{Provider, ProviderError, ProviderRequest};
27use crate::llm::stream::StreamEvent;
28use crate::permissions::PermissionChecker;
29use crate::services::compact::{self, CompactTracking, MAX_OUTPUT_TOKENS_RECOVERY_LIMIT};
30use crate::services::tokens;
31use crate::state::AppState;
32use crate::tools::ToolContext;
33use crate::tools::executor::{execute_tool_calls, extract_tool_calls};
34use crate::tools::registry::ToolRegistry;
35
36/// Configuration for the query engine.
37pub struct QueryEngineConfig {
38    pub max_turns: Option<usize>,
39    pub verbose: bool,
40    /// Whether this is a non-interactive (one-shot) session.
41    pub unattended: bool,
42}
43
44/// The query engine orchestrates the agent loop.
45///
46/// Central coordinator that drives the LLM → tools → LLM cycle.
47/// Manages conversation history, context compaction, tool execution,
48/// error recovery, and hook dispatch. Create via [`QueryEngine::new`].
49pub struct QueryEngine {
50    llm: Arc<dyn Provider>,
51    tools: ToolRegistry,
52    file_cache: Arc<tokio::sync::Mutex<crate::services::file_cache::FileCache>>,
53    permissions: Arc<PermissionChecker>,
54    state: AppState,
55    config: QueryEngineConfig,
56    /// Shared handle so the signal handler always cancels the current token.
57    cancel_shared: Arc<std::sync::Mutex<CancellationToken>>,
58    /// Per-turn cancellation token (cloned from cancel_shared at turn start).
59    cancel: CancellationToken,
60    hooks: HookRegistry,
61    cache_tracker: crate::services::cache_tracking::CacheTracker,
62    denial_tracker: Arc<tokio::sync::Mutex<crate::permissions::tracking::DenialTracker>>,
63    extraction_state: Arc<tokio::sync::Mutex<crate::memory::extraction::ExtractionState>>,
64    session_allows: Arc<tokio::sync::Mutex<std::collections::HashSet<String>>>,
65    permission_prompter: Option<Arc<dyn crate::tools::PermissionPrompter>>,
66    /// Cached system prompt (rebuilt only when inputs change).
67    cached_system_prompt: Option<(u64, String)>, // (hash, prompt)
68}
69
70/// Callback for streaming events to the UI.
71pub trait StreamSink: Send + Sync {
72    fn on_text(&self, text: &str);
73    fn on_tool_start(&self, tool_name: &str, input: &serde_json::Value);
74    fn on_tool_result(&self, tool_name: &str, result: &crate::tools::ToolResult);
75    fn on_thinking(&self, _text: &str) {}
76    /// Called at the start of each agent turn (1-indexed).
77    fn on_turn_start(&self, _turn: usize) {}
78    fn on_turn_complete(&self, _turn: usize) {}
79    fn on_error(&self, error: &str);
80    fn on_usage(&self, _usage: &Usage) {}
81    fn on_compact(&self, _freed_tokens: u64) {}
82    fn on_warning(&self, _msg: &str) {}
83}
84
85/// A no-op stream sink for non-interactive mode.
86pub struct NullSink;
87impl StreamSink for NullSink {
88    fn on_text(&self, _: &str) {}
89    fn on_tool_start(&self, _: &str, _: &serde_json::Value) {}
90    fn on_tool_result(&self, _: &str, _: &crate::tools::ToolResult) {}
91    fn on_error(&self, _: &str) {}
92}
93
94impl QueryEngine {
95    pub fn new(
96        llm: Arc<dyn Provider>,
97        tools: ToolRegistry,
98        permissions: PermissionChecker,
99        state: AppState,
100        config: QueryEngineConfig,
101    ) -> Self {
102        let cancel = CancellationToken::new();
103        let cancel_shared = Arc::new(std::sync::Mutex::new(cancel.clone()));
104        Self {
105            llm,
106            tools,
107            file_cache: Arc::new(tokio::sync::Mutex::new(
108                crate::services::file_cache::FileCache::new(),
109            )),
110            permissions: Arc::new(permissions),
111            state,
112            config,
113            cancel,
114            cancel_shared,
115            hooks: HookRegistry::new(),
116            cache_tracker: crate::services::cache_tracking::CacheTracker::new(),
117            denial_tracker: Arc::new(tokio::sync::Mutex::new(
118                crate::permissions::tracking::DenialTracker::new(100),
119            )),
120            extraction_state: Arc::new(tokio::sync::Mutex::new(
121                crate::memory::extraction::ExtractionState::new(),
122            )),
123            session_allows: Arc::new(tokio::sync::Mutex::new(std::collections::HashSet::new())),
124            permission_prompter: None,
125            cached_system_prompt: None,
126        }
127    }
128
129    /// Load hooks from configuration into the registry.
130    pub fn load_hooks(&mut self, hook_defs: &[crate::hooks::HookDefinition]) {
131        for def in hook_defs {
132            self.hooks.register(def.clone());
133        }
134        if !hook_defs.is_empty() {
135            tracing::info!("Loaded {} hooks from config", hook_defs.len());
136        }
137    }
138
139    /// Get a reference to the app state.
140    pub fn state(&self) -> &AppState {
141        &self.state
142    }
143
144    /// Get a mutable reference to the app state.
145    pub fn state_mut(&mut self) -> &mut AppState {
146        &mut self.state
147    }
148
149    /// Install a Ctrl+C handler that triggers the cancellation token.
150    /// Call this once at startup. Subsequent Ctrl+C signals during a
151    /// turn will cancel the active operation instead of killing the process.
152    pub fn install_signal_handler(&self) {
153        let shared = self.cancel_shared.clone();
154        tokio::spawn(async move {
155            let mut pending = false;
156            loop {
157                if tokio::signal::ctrl_c().await.is_ok() {
158                    let token = shared.lock().unwrap().clone();
159                    if token.is_cancelled() && pending {
160                        // Second Ctrl+C after cancel — hard exit.
161                        std::process::exit(130);
162                    }
163                    token.cancel();
164                    pending = true;
165                }
166            }
167        });
168    }
169
170    /// Run a single turn: process user input through the full agent loop.
171    pub async fn run_turn(&mut self, user_input: &str) -> crate::error::Result<()> {
172        self.run_turn_with_sink(user_input, &NullSink).await
173    }
174
175    /// Run a turn with a stream sink for real-time UI updates.
176    pub async fn run_turn_with_sink(
177        &mut self,
178        user_input: &str,
179        sink: &dyn StreamSink,
180    ) -> crate::error::Result<()> {
181        // Reset cancellation token for this turn. The shared handle is
182        // updated so the signal handler always cancels the current token.
183        self.cancel = CancellationToken::new();
184        *self.cancel_shared.lock().unwrap() = self.cancel.clone();
185
186        // Add the user message to history.
187        let user_msg = user_message(user_input);
188        self.state.push_message(user_msg);
189
190        let max_turns = self.config.max_turns.unwrap_or(50);
191        let mut compact_tracking = CompactTracking::default();
192        let mut retry_state = crate::llm::retry::RetryState::default();
193        let retry_config = crate::llm::retry::RetryConfig::default();
194        let mut max_output_recovery_count = 0u32;
195
196        // Agent loop: budget check → normalize → compact → call LLM → execute tools → repeat.
197        for turn in 0..max_turns {
198            self.state.turn_count = turn + 1;
199            self.state.is_query_active = true;
200            sink.on_turn_start(turn + 1);
201
202            // Budget check before each turn.
203            let budget_config = crate::services::budget::BudgetConfig::default();
204            match crate::services::budget::check_budget(
205                self.state.total_cost_usd,
206                self.state.total_usage.total(),
207                &budget_config,
208            ) {
209                crate::services::budget::BudgetDecision::Stop { message } => {
210                    sink.on_warning(&message);
211                    self.state.is_query_active = false;
212                    return Ok(());
213                }
214                crate::services::budget::BudgetDecision::ContinueWithWarning {
215                    message, ..
216                } => {
217                    sink.on_warning(&message);
218                }
219                crate::services::budget::BudgetDecision::Continue => {}
220            }
221
222            // Normalize messages for API compatibility.
223            crate::llm::normalize::ensure_tool_result_pairing(&mut self.state.messages);
224            crate::llm::normalize::strip_empty_blocks(&mut self.state.messages);
225            crate::llm::normalize::remove_empty_messages(&mut self.state.messages);
226            crate::llm::normalize::cap_document_blocks(&mut self.state.messages, 500_000);
227            crate::llm::normalize::merge_consecutive_user_messages(&mut self.state.messages);
228
229            debug!("Agent turn {}/{}", turn + 1, max_turns);
230
231            let mut model = self.state.config.api.model.clone();
232
233            // Step 1: Auto-compact if context is too large.
234            if compact::should_auto_compact(self.state.history(), &model, &compact_tracking) {
235                let token_count = tokens::estimate_context_tokens(self.state.history());
236                let threshold = compact::auto_compact_threshold(&model);
237                info!("Auto-compact triggered: {token_count} tokens >= {threshold} threshold");
238
239                // Microcompact first: clear stale tool results.
240                let freed = compact::microcompact(&mut self.state.messages, 5);
241                if freed > 0 {
242                    sink.on_compact(freed);
243                    info!("Microcompact freed ~{freed} tokens");
244                }
245
246                // Check if microcompact was enough.
247                let post_mc_tokens = tokens::estimate_context_tokens(self.state.history());
248                if post_mc_tokens >= threshold {
249                    // Full LLM-based compaction: summarize older messages.
250                    info!("Microcompact insufficient, attempting LLM compaction");
251                    match compact::compact_with_llm(
252                        &mut self.state.messages,
253                        &*self.llm,
254                        &model,
255                        self.cancel.clone(),
256                    )
257                    .await
258                    {
259                        Some(removed) => {
260                            info!("LLM compaction removed {removed} messages");
261                            compact_tracking.was_compacted = true;
262                            compact_tracking.consecutive_failures = 0;
263                        }
264                        None => {
265                            compact_tracking.consecutive_failures += 1;
266                            warn!(
267                                "LLM compaction failed (attempt {})",
268                                compact_tracking.consecutive_failures
269                            );
270                            // Fallback: context collapse (snip middle messages).
271                            let effective = compact::effective_context_window(&model);
272                            if let Some(collapse) =
273                                crate::services::context_collapse::collapse_to_budget(
274                                    self.state.history(),
275                                    effective,
276                                )
277                            {
278                                info!(
279                                    "Context collapse snipped {} messages, freed ~{} tokens",
280                                    collapse.snipped_count, collapse.tokens_freed
281                                );
282                                self.state.messages = collapse.api_messages;
283                                sink.on_compact(collapse.tokens_freed);
284                            } else {
285                                // Last resort: aggressive microcompact.
286                                let freed2 = compact::microcompact(&mut self.state.messages, 2);
287                                if freed2 > 0 {
288                                    sink.on_compact(freed2);
289                                }
290                            }
291                        }
292                    }
293                }
294            }
295
296            // Inject compaction reminder if compacted and feature enabled.
297            if compact_tracking.was_compacted && self.state.config.features.compaction_reminders {
298                let reminder = user_message(
299                    "<system-reminder>Context was automatically compacted. \
300                     Earlier messages were summarized. If you need details from \
301                     before compaction, ask the user or re-read the relevant files.</system-reminder>",
302                );
303                self.state.push_message(reminder);
304                compact_tracking.was_compacted = false; // Only remind once per compaction.
305            }
306
307            // Step 2: Check token warning state.
308            let warning = compact::token_warning_state(self.state.history(), &model);
309            if warning.is_blocking {
310                sink.on_warning("Context window nearly full. Consider starting a new session.");
311            } else if warning.is_above_warning {
312                sink.on_warning(&format!("Context {}% remaining", warning.percent_left));
313            }
314
315            // Step 3: Build and send the API request.
316            // Memoize: only rebuild system prompt when inputs change.
317            let prompt_hash = {
318                use std::hash::{Hash, Hasher};
319                let mut h = std::collections::hash_map::DefaultHasher::new();
320                self.state.config.api.model.hash(&mut h);
321                self.state.cwd.hash(&mut h);
322                self.state.config.mcp_servers.len().hash(&mut h);
323                self.tools.all().len().hash(&mut h);
324                h.finish()
325            };
326            let system_prompt = if let Some((cached_hash, ref cached)) = self.cached_system_prompt
327                && cached_hash == prompt_hash
328            {
329                cached.clone()
330            } else {
331                let prompt = build_system_prompt(&self.tools, &self.state);
332                self.cached_system_prompt = Some((prompt_hash, prompt.clone()));
333                prompt
334            };
335            // Use core schemas (deferred tools loaded on demand via ToolSearch).
336            let tool_schemas = self.tools.core_schemas();
337
338            // Escalate max_tokens after a max_output recovery (8k → 64k).
339            let base_tokens = self.state.config.api.max_output_tokens.unwrap_or(16384);
340            let effective_tokens = if max_output_recovery_count > 0 {
341                base_tokens.max(65536) // Escalate to at least 64k after first recovery
342            } else {
343                base_tokens
344            };
345
346            let request = ProviderRequest {
347                messages: self.state.history().to_vec(),
348                system_prompt: system_prompt.clone(),
349                tools: tool_schemas.clone(),
350                model: model.clone(),
351                max_tokens: effective_tokens,
352                temperature: None,
353                enable_caching: self.state.config.features.prompt_caching,
354                tool_choice: Default::default(),
355                metadata: None,
356                cancel: self.cancel.clone(),
357            };
358
359            let mut rx = match self.llm.stream(&request).await {
360                Ok(rx) => {
361                    retry_state.reset();
362                    rx
363                }
364                Err(e) => {
365                    let retryable = match &e {
366                        ProviderError::RateLimited { retry_after_ms } => {
367                            crate::llm::retry::RetryableError::RateLimited {
368                                retry_after: *retry_after_ms,
369                            }
370                        }
371                        ProviderError::Overloaded => crate::llm::retry::RetryableError::Overloaded,
372                        ProviderError::Network(_) => {
373                            crate::llm::retry::RetryableError::StreamInterrupted
374                        }
375                        other => crate::llm::retry::RetryableError::NonRetryable(other.to_string()),
376                    };
377
378                    match retry_state.next_action(&retryable, &retry_config) {
379                        crate::llm::retry::RetryAction::Retry { after } => {
380                            warn!("Retrying in {}ms", after.as_millis());
381                            tokio::time::sleep(after).await;
382                            continue;
383                        }
384                        crate::llm::retry::RetryAction::FallbackModel => {
385                            // Switch to a smaller/cheaper model for this turn.
386                            let fallback = get_fallback_model(&model);
387                            sink.on_warning(&format!("Falling back from {model} to {fallback}"));
388                            model = fallback;
389                            continue;
390                        }
391                        crate::llm::retry::RetryAction::Abort(reason) => {
392                            // Unattended retry: in non-interactive mode, retry
393                            // capacity errors with longer backoff instead of aborting.
394                            if self.config.unattended
395                                && self.state.config.features.unattended_retry
396                                && matches!(
397                                    &e,
398                                    ProviderError::Overloaded | ProviderError::RateLimited { .. }
399                                )
400                            {
401                                warn!("Unattended retry: waiting 30s for capacity");
402                                tokio::time::sleep(std::time::Duration::from_secs(30)).await;
403                                continue;
404                            }
405                            // Before giving up, try reactive compact for size errors.
406                            // Two-stage recovery: context collapse first, then microcompact.
407                            if let ProviderError::RequestTooLarge(body) = &e {
408                                let gap = compact::parse_prompt_too_long_gap(body);
409
410                                // Stage 1: Context collapse (snip middle messages).
411                                let effective = compact::effective_context_window(&model);
412                                if let Some(collapse) =
413                                    crate::services::context_collapse::collapse_to_budget(
414                                        self.state.history(),
415                                        effective,
416                                    )
417                                {
418                                    info!(
419                                        "Reactive collapse: snipped {} messages, freed ~{} tokens",
420                                        collapse.snipped_count, collapse.tokens_freed
421                                    );
422                                    self.state.messages = collapse.api_messages;
423                                    sink.on_compact(collapse.tokens_freed);
424                                    continue;
425                                }
426
427                                // Stage 2: Aggressive microcompact.
428                                let freed = compact::microcompact(&mut self.state.messages, 1);
429                                if freed > 0 {
430                                    sink.on_compact(freed);
431                                    info!(
432                                        "Reactive microcompact freed ~{freed} tokens (gap: {gap:?})"
433                                    );
434                                    continue;
435                                }
436                            }
437                            sink.on_error(&reason);
438                            self.state.is_query_active = false;
439                            return Err(crate::error::Error::Other(e.to_string()));
440                        }
441                    }
442                }
443            };
444
445            // Step 4: Stream response. Start executing read-only tools
446            // as their input completes (streaming tool execution).
447            let mut content_blocks = Vec::new();
448            let mut usage = Usage::default();
449            let mut stop_reason: Option<StopReason> = None;
450            let mut got_error = false;
451            let mut error_text = String::new();
452
453            // Streaming tool handles: tools kicked off during streaming.
454            let mut streaming_tool_handles: Vec<(
455                String,
456                String,
457                tokio::task::JoinHandle<crate::tools::ToolResult>,
458            )> = Vec::new();
459
460            let mut cancelled = false;
461            loop {
462                tokio::select! {
463                    event = rx.recv() => {
464                        match event {
465                            Some(StreamEvent::TextDelta(text)) => {
466                                sink.on_text(&text);
467                            }
468                            Some(StreamEvent::ContentBlockComplete(block)) => {
469                                if let ContentBlock::ToolUse {
470                                    ref id,
471                                    ref name,
472                                    ref input,
473                                } = block
474                                {
475                                    sink.on_tool_start(name, input);
476
477                                    // Start read-only tools immediately during streaming.
478                                    if let Some(tool) = self.tools.get(name)
479                                        && tool.is_read_only()
480                                        && tool.is_concurrency_safe()
481                                    {
482                                        let tool = tool.clone();
483                                        let input = input.clone();
484                                        let cwd = std::path::PathBuf::from(&self.state.cwd);
485                                        let cancel = self.cancel.clone();
486                                        let perm = self.permissions.clone();
487                                        let tool_id = id.clone();
488                                        let tool_name = name.clone();
489
490                                        let handle = tokio::spawn(async move {
491                                            match tool
492                                                .call(
493                                                    input,
494                                                    &ToolContext {
495                                                        cwd,
496                                                        cancel,
497                                                        permission_checker: perm.clone(),
498                                                        verbose: false,
499                                                        plan_mode: false,
500                                                        file_cache: None,
501                                                        denial_tracker: None,
502                                                        task_manager: None,
503                                                        session_allows: None,
504                                                        permission_prompter: None,
505                                                        sandbox: None,
506                                                    },
507                                                )
508                                                .await
509                                            {
510                                                Ok(r) => r,
511                                                Err(e) => crate::tools::ToolResult::error(e.to_string()),
512                                            }
513                                        });
514
515                                        streaming_tool_handles.push((tool_id, tool_name, handle));
516                                    }
517                                }
518                                if let ContentBlock::Thinking { ref thinking, .. } = block {
519                                    sink.on_thinking(thinking);
520                                }
521                                content_blocks.push(block);
522                            }
523                            Some(StreamEvent::Done {
524                                usage: u,
525                                stop_reason: sr,
526                            }) => {
527                                usage = u;
528                                stop_reason = sr;
529                                sink.on_usage(&usage);
530                            }
531                            Some(StreamEvent::Error(msg)) => {
532                                got_error = true;
533                                error_text = msg.clone();
534                                sink.on_error(&msg);
535                            }
536                            Some(_) => {}
537                            None => break,
538                        }
539                    }
540                    _ = self.cancel.cancelled() => {
541                        warn!("Turn cancelled by user");
542                        cancelled = true;
543                        // Abort any in-flight streaming tool handles.
544                        for (_, _, handle) in streaming_tool_handles.drain(..) {
545                            handle.abort();
546                        }
547                        break;
548                    }
549                }
550            }
551
552            if cancelled {
553                sink.on_warning("Cancelled");
554                self.state.is_query_active = false;
555                return Ok(());
556            }
557
558            // Step 5: Record the assistant message.
559            let assistant_msg = Message::Assistant(AssistantMessage {
560                uuid: Uuid::new_v4(),
561                timestamp: chrono::Utc::now().to_rfc3339(),
562                content: content_blocks.clone(),
563                model: Some(model.clone()),
564                usage: Some(usage.clone()),
565                stop_reason: stop_reason.clone(),
566                request_id: None,
567            });
568            self.state.push_message(assistant_msg);
569            self.state.record_usage(&usage, &model);
570
571            // Token budget tracking per turn.
572            if self.state.config.features.token_budget && usage.total() > 0 {
573                let turn_total = usage.input_tokens + usage.output_tokens;
574                if turn_total > 100_000 {
575                    sink.on_warning(&format!(
576                        "High token usage this turn: {} tokens ({}in + {}out)",
577                        turn_total, usage.input_tokens, usage.output_tokens
578                    ));
579                }
580            }
581
582            // Record cache and telemetry.
583            let _cache_event = self.cache_tracker.record(&usage);
584            {
585                let mut span = crate::services::telemetry::api_call_span(
586                    &model,
587                    turn + 1,
588                    &self.state.session_id,
589                );
590                crate::services::telemetry::record_usage(&mut span, &usage);
591                span.finish();
592                tracing::debug!(
593                    "API call: {}ms, {}in/{}out tokens",
594                    span.duration_ms().unwrap_or(0),
595                    usage.input_tokens,
596                    usage.output_tokens,
597                );
598            }
599
600            // Step 6: Handle stream errors.
601            if got_error {
602                // Check if it's a prompt-too-long error in the stream.
603                if error_text.contains("prompt is too long")
604                    || error_text.contains("Prompt is too long")
605                {
606                    let freed = compact::microcompact(&mut self.state.messages, 1);
607                    if freed > 0 {
608                        sink.on_compact(freed);
609                        continue;
610                    }
611                }
612
613                // Check for max-output-tokens hit (partial response).
614                if content_blocks
615                    .iter()
616                    .any(|b| matches!(b, ContentBlock::Text { .. }))
617                    && error_text.contains("max_tokens")
618                    && max_output_recovery_count < MAX_OUTPUT_TOKENS_RECOVERY_LIMIT
619                {
620                    max_output_recovery_count += 1;
621                    info!(
622                        "Max output tokens recovery attempt {}/{}",
623                        max_output_recovery_count, MAX_OUTPUT_TOKENS_RECOVERY_LIMIT
624                    );
625                    let recovery_msg = compact::max_output_recovery_message();
626                    self.state.push_message(recovery_msg);
627                    continue;
628                }
629            }
630
631            // Step 6b: Handle max_tokens stop reason (escalate and continue).
632            if matches!(stop_reason, Some(StopReason::MaxTokens))
633                && !got_error
634                && content_blocks
635                    .iter()
636                    .any(|b| matches!(b, ContentBlock::Text { .. }))
637                && max_output_recovery_count < MAX_OUTPUT_TOKENS_RECOVERY_LIMIT
638            {
639                max_output_recovery_count += 1;
640                info!(
641                    "Max tokens stop reason — recovery attempt {}/{}",
642                    max_output_recovery_count, MAX_OUTPUT_TOKENS_RECOVERY_LIMIT
643                );
644                let recovery_msg = compact::max_output_recovery_message();
645                self.state.push_message(recovery_msg);
646                continue;
647            }
648
649            // Step 7: Extract tool calls from the response.
650            let tool_calls = extract_tool_calls(&content_blocks);
651
652            if tool_calls.is_empty() {
653                // No tools requested — turn is complete.
654                info!("Turn complete (no tool calls)");
655                sink.on_turn_complete(turn + 1);
656                self.state.is_query_active = false;
657
658                // Fire background memory extraction (fire-and-forget).
659                // Only runs if feature enabled and memory directory exists.
660                if self.state.config.features.extract_memories
661                    && crate::memory::ensure_memory_dir().is_some()
662                {
663                    let extraction_messages = self.state.messages.clone();
664                    let extraction_state = self.extraction_state.clone();
665                    let extraction_llm = self.llm.clone();
666                    let extraction_model = model.clone();
667                    tokio::spawn(async move {
668                        crate::memory::extraction::extract_memories_background(
669                            extraction_messages,
670                            extraction_state,
671                            extraction_llm,
672                            extraction_model,
673                        )
674                        .await;
675                    });
676                }
677
678                return Ok(());
679            }
680
681            // Step 8: Execute tool calls with pre/post hooks.
682            info!("Executing {} tool call(s)", tool_calls.len());
683            let cwd = PathBuf::from(&self.state.cwd);
684            let tool_ctx = ToolContext {
685                cwd: cwd.clone(),
686                cancel: self.cancel.clone(),
687                permission_checker: self.permissions.clone(),
688                verbose: self.config.verbose,
689                plan_mode: self.state.plan_mode,
690                file_cache: Some(self.file_cache.clone()),
691                denial_tracker: Some(self.denial_tracker.clone()),
692                task_manager: Some(self.state.task_manager.clone()),
693                session_allows: Some(self.session_allows.clone()),
694                permission_prompter: self.permission_prompter.clone(),
695                sandbox: Some(std::sync::Arc::new(
696                    crate::sandbox::SandboxExecutor::from_session_config(&self.state.config, &cwd),
697                )),
698            };
699
700            // Collect streaming tool results first.
701            let streaming_ids: std::collections::HashSet<String> = streaming_tool_handles
702                .iter()
703                .map(|(id, _, _)| id.clone())
704                .collect();
705
706            let mut streaming_results = Vec::new();
707            for (id, name, handle) in streaming_tool_handles.drain(..) {
708                match handle.await {
709                    Ok(result) => streaming_results.push(crate::tools::executor::ToolCallResult {
710                        tool_use_id: id,
711                        tool_name: name,
712                        result,
713                    }),
714                    Err(e) => streaming_results.push(crate::tools::executor::ToolCallResult {
715                        tool_use_id: id,
716                        tool_name: name,
717                        result: crate::tools::ToolResult::error(format!("Task failed: {e}")),
718                    }),
719                }
720            }
721
722            // Fire pre-tool-use hooks.
723            for call in &tool_calls {
724                self.hooks
725                    .run_hooks(&HookEvent::PreToolUse, Some(&call.name), &call.input)
726                    .await;
727            }
728
729            // Execute remaining tools (ones not started during streaming).
730            let remaining_calls: Vec<_> = tool_calls
731                .iter()
732                .filter(|c| !streaming_ids.contains(&c.id))
733                .cloned()
734                .collect();
735
736            let mut results = streaming_results;
737            if !remaining_calls.is_empty() {
738                let batch_results = execute_tool_calls(
739                    &remaining_calls,
740                    self.tools.all(),
741                    &tool_ctx,
742                    &self.permissions,
743                )
744                .await;
745                results.extend(batch_results);
746            }
747
748            // Step 9: Inject tool results + fire post-tool-use hooks.
749            for result in &results {
750                // Handle plan mode state transitions.
751                if !result.result.is_error {
752                    match result.tool_name.as_str() {
753                        "EnterPlanMode" => {
754                            self.state.plan_mode = true;
755                            info!("Plan mode enabled");
756                        }
757                        "ExitPlanMode" => {
758                            self.state.plan_mode = false;
759                            info!("Plan mode disabled");
760                        }
761                        _ => {}
762                    }
763                }
764
765                sink.on_tool_result(&result.tool_name, &result.result);
766
767                // Fire post-tool-use hooks.
768                self.hooks
769                    .run_hooks(
770                        &HookEvent::PostToolUse,
771                        Some(&result.tool_name),
772                        &serde_json::json!({
773                            "tool": result.tool_name,
774                            "is_error": result.result.is_error,
775                        }),
776                    )
777                    .await;
778
779                let msg = tool_result_message(
780                    &result.tool_use_id,
781                    &result.result.content,
782                    result.result.is_error,
783                );
784                self.state.push_message(msg);
785            }
786
787            // Continue the loop — the model will see the tool results.
788        }
789
790        warn!("Max turns ({max_turns}) reached");
791        sink.on_warning(&format!("Agent stopped after {max_turns} turns"));
792        self.state.is_query_active = false;
793        Ok(())
794    }
795
796    /// Cancel the current operation.
797    pub fn cancel(&self) {
798        self.cancel.cancel();
799    }
800
801    /// Get a cloneable cancel token for use in background tasks.
802    pub fn cancel_token(&self) -> tokio_util::sync::CancellationToken {
803        self.cancel.clone()
804    }
805}
806
807/// Get a fallback model (smaller/cheaper) for retry on overload.
808fn get_fallback_model(current: &str) -> String {
809    let lower = current.to_lowercase();
810    if lower.contains("opus") {
811        // Opus → Sonnet
812        current.replace("opus", "sonnet")
813    } else if (lower.contains("gpt-5.4") || lower.contains("gpt-4.1"))
814        && !lower.contains("mini")
815        && !lower.contains("nano")
816    {
817        format!("{current}-mini")
818    } else if lower.contains("large") {
819        current.replace("large", "small")
820    } else {
821        // Already a small model, keep it.
822        current.to_string()
823    }
824}
825
826/// Build the system prompt from tool definitions, app state, and memory.
827pub fn build_system_prompt(tools: &ToolRegistry, state: &AppState) -> String {
828    let mut prompt = String::new();
829
830    prompt.push_str(
831        "You are an AI coding agent. You help users with software engineering tasks \
832         by reading, writing, and searching code. Use the tools available to you to \
833         accomplish tasks.\n\n",
834    );
835
836    // Environment context.
837    let shell = std::env::var("SHELL").unwrap_or_else(|_| "bash".to_string());
838    let is_git = std::path::Path::new(&state.cwd).join(".git").exists();
839    prompt.push_str(&format!(
840        "# Environment\n\
841         - Working directory: {}\n\
842         - Platform: {}\n\
843         - Shell: {shell}\n\
844         - Git repository: {}\n\n",
845        state.cwd,
846        std::env::consts::OS,
847        if is_git { "yes" } else { "no" },
848    ));
849
850    // Inject memory context (project + user + on-demand relevant).
851    let mut memory = crate::memory::MemoryContext::load(Some(std::path::Path::new(&state.cwd)));
852
853    // On-demand: surface relevant memories based on recent conversation.
854    let recent_text: String = state
855        .messages
856        .iter()
857        .rev()
858        .take(5)
859        .filter_map(|m| match m {
860            crate::llm::message::Message::User(u) => Some(
861                u.content
862                    .iter()
863                    .filter_map(|b| b.as_text())
864                    .collect::<Vec<_>>()
865                    .join(" "),
866            ),
867            _ => None,
868        })
869        .collect::<Vec<_>>()
870        .join(" ");
871
872    if !recent_text.is_empty() {
873        memory.load_relevant(&recent_text);
874    }
875
876    let memory_section = memory.to_system_prompt_section();
877    if !memory_section.is_empty() {
878        prompt.push_str(&memory_section);
879    }
880
881    // Tool documentation.
882    prompt.push_str("# Available Tools\n\n");
883    for tool in tools.all() {
884        if tool.is_enabled() {
885            prompt.push_str(&format!("## {}\n{}\n\n", tool.name(), tool.prompt()));
886        }
887    }
888
889    // Available skills.
890    let skills = crate::skills::SkillRegistry::load_all(Some(std::path::Path::new(&state.cwd)));
891    let invocable = skills.user_invocable();
892    if !invocable.is_empty() {
893        prompt.push_str("# Available Skills\n\n");
894        for skill in invocable {
895            let desc = skill.metadata.description.as_deref().unwrap_or("");
896            let when = skill.metadata.when_to_use.as_deref().unwrap_or("");
897            prompt.push_str(&format!("- `/{}`", skill.name));
898            if !desc.is_empty() {
899                prompt.push_str(&format!(": {desc}"));
900            }
901            if !when.is_empty() {
902                prompt.push_str(&format!(" (use when: {when})"));
903            }
904            prompt.push('\n');
905        }
906        prompt.push('\n');
907    }
908
909    // Guidelines and safety framework.
910    prompt.push_str(
911        "# Using tools\n\n\
912         Use dedicated tools instead of shell commands when available:\n\
913         - File search: Glob (not find or ls)\n\
914         - Content search: Grep (not grep or rg)\n\
915         - Read files: FileRead (not cat/head/tail)\n\
916         - Edit files: FileEdit (not sed/awk)\n\
917         - Write files: FileWrite (not echo/cat with redirect)\n\
918         - Reserve Bash for system commands and operations that require shell execution.\n\
919         - Break complex tasks into steps. Use multiple tool calls in parallel when independent.\n\
920         - Use the Agent tool for complex multi-step research or tasks that benefit from isolation.\n\n\
921         # Working with code\n\n\
922         - Read files before editing them. Understand existing code before suggesting changes.\n\
923         - Prefer editing existing files over creating new ones to avoid file bloat.\n\
924         - Only make changes that were requested. Don't add features, refactor, add comments, \
925           or make \"improvements\" beyond the ask.\n\
926         - Don't add error handling for scenarios that can't happen. Don't design for \
927           hypothetical future requirements.\n\
928         - When referencing code, include file_path:line_number.\n\
929         - Be careful not to introduce security vulnerabilities (command injection, XSS, SQL injection, \
930           OWASP top 10). If you notice insecure code you wrote, fix it immediately.\n\
931         - Don't add docstrings, comments, or type annotations to code you didn't change.\n\
932         - Three similar lines of code is better than a premature abstraction.\n\n\
933         # Git safety protocol\n\n\
934         - NEVER update the git config.\n\
935         - NEVER run destructive git commands (push --force, reset --hard, checkout ., restore ., \
936           clean -f, branch -D) unless the user explicitly requests them.\n\
937         - NEVER skip hooks (--no-verify, --no-gpg-sign) unless the user explicitly requests it.\n\
938         - NEVER force push to main/master. Warn the user if they request it.\n\
939         - Always create NEW commits rather than amending, unless the user explicitly requests amend. \
940           After hook failure, the commit did NOT happen — amend would modify the PREVIOUS commit.\n\
941         - When staging files, prefer adding specific files by name rather than git add -A or git add ., \
942           which can accidentally include sensitive files.\n\
943         - NEVER commit changes unless the user explicitly asks.\n\n\
944         # Committing changes\n\n\
945         When the user asks to commit:\n\
946         1. Run git status and git diff to see all changes.\n\
947         2. Run git log --oneline -5 to match the repository's commit message style.\n\
948         3. Draft a concise (1-2 sentence) commit message focusing on \"why\" not \"what\".\n\
949         4. Do not commit files that likely contain secrets (.env, credentials.json).\n\
950         5. Stage specific files, create the commit.\n\
951         6. If pre-commit hook fails, fix the issue and create a NEW commit.\n\
952         7. When creating commits, include a co-author attribution line at the end of the message.\n\n\
953         # Creating pull requests\n\n\
954         When the user asks to create a PR:\n\
955         1. Run git status, git diff, and git log to understand all changes on the branch.\n\
956         2. Analyze ALL commits (not just the latest) that will be in the PR.\n\
957         3. Draft a short title (under 70 chars) and detailed body with summary and test plan.\n\
958         4. Push to remote with -u flag if needed, then create PR using gh pr create.\n\
959         5. Return the PR URL when done.\n\n\
960         # Executing actions safely\n\n\
961         Consider the reversibility and blast radius of every action:\n\
962         - Freely take local, reversible actions (editing files, running tests).\n\
963         - For hard-to-reverse or shared-state actions, confirm with the user first:\n\
964           - Destructive: deleting files/branches, dropping tables, rm -rf, overwriting uncommitted changes.\n\
965           - Hard to reverse: force-pushing, git reset --hard, amending published commits.\n\
966           - Visible to others: pushing code, creating/commenting on PRs/issues, sending messages.\n\
967         - When you encounter an obstacle, do not use destructive actions as a shortcut. \
968           Identify root causes and fix underlying issues.\n\
969         - If you discover unexpected state (unfamiliar files, branches, config), investigate \
970           before deleting or overwriting — it may be the user's in-progress work.\n\n\
971         # Response style\n\n\
972         - Be concise. Lead with the answer or action, not the reasoning.\n\
973         - Skip filler, preamble, and unnecessary transitions.\n\
974         - Don't restate what the user said.\n\
975         - If you can say it in one sentence, don't use three.\n\
976         - Focus output on: decisions that need input, status updates, and errors that change the plan.\n\
977         - When referencing GitHub issues or PRs, use owner/repo#123 format.\n\
978         - Only use emojis if the user explicitly requests it.\n\n\
979         # Memory\n\n\
980         You can save information across sessions by writing memory files.\n\
981         - Save to: ~/.config/agent-code/memory/ (one .md file per topic)\n\
982         - Each file needs YAML frontmatter: name, description, type (user/feedback/project/reference)\n\
983         - After writing a file, update MEMORY.md with a one-line pointer\n\
984         - Memory types: user (role, preferences), feedback (corrections, confirmations), \
985           project (decisions, deadlines), reference (external resources)\n\
986         - Do NOT store: code patterns, git history, debugging solutions, anything derivable from code\n\
987         - Memory is a hint — always verify against current state before acting on it\n",
988    );
989
990    // Detailed tool usage examples and workflow patterns.
991    prompt.push_str(
992        "# Tool usage patterns\n\n\
993         Common patterns for effective tool use:\n\n\
994         **Read before edit**: Always read a file before editing it. This ensures you \
995         understand the current state and can make targeted changes.\n\
996         ```\n\
997         1. FileRead file_path → understand structure\n\
998         2. FileEdit old_string, new_string → targeted change\n\
999         ```\n\n\
1000         **Search then act**: Use Glob to find files, Grep to find content, then read/edit.\n\
1001         ```\n\
1002         1. Glob **/*.rs → find Rust files\n\
1003         2. Grep pattern path → find specific code\n\
1004         3. FileRead → read the match\n\
1005         4. FileEdit → make the change\n\
1006         ```\n\n\
1007         **Parallel tool calls**: When you need to read multiple independent files or run \
1008         independent searches, make all the tool calls in one response. Don't serialize \
1009         independent operations.\n\n\
1010         **Test after change**: After editing code, run tests to verify the change works.\n\
1011         ```\n\
1012         1. FileEdit → make change\n\
1013         2. Bash cargo test / pytest / npm test → verify\n\
1014         3. If tests fail, read the error, fix, re-test\n\
1015         ```\n\n\
1016         # Error recovery\n\n\
1017         When something goes wrong:\n\
1018         - **Tool not found**: Use ToolSearch to find the right tool name.\n\
1019         - **Permission denied**: Explain why the action is needed, ask the user to approve.\n\
1020         - **File not found**: Use Glob to find the correct path. Check for typos.\n\
1021         - **Edit failed (not unique)**: Provide more surrounding context in old_string, \
1022           or use replace_all=true if renaming.\n\
1023         - **Command failed**: Read the full error message. Don't retry the same command. \
1024           Diagnose the root cause first.\n\
1025         - **Context too large**: The system will auto-compact. If you need specific \
1026           information from before compaction, re-read the relevant files.\n\
1027         - **Rate limited**: The system will auto-retry with backoff. Just wait.\n\n\
1028         # Common workflows\n\n\
1029         **Bug fix**: Read the failing test → read the source code it tests → \
1030         identify the bug → fix it → run the test → confirm it passes.\n\n\
1031         **New feature**: Read existing patterns in the codebase → create or edit files → \
1032         add tests → run tests → update docs if needed.\n\n\
1033         **Code review**: Read the diff → identify issues (bugs, security, style) → \
1034         report findings with file:line references.\n\n\
1035         **Refactor**: Search for all usages of the symbol → plan the changes → \
1036         edit each file → run tests to verify nothing broke.\n\n",
1037    );
1038
1039    // MCP server instructions (dynamic, per-server).
1040    if !state.config.mcp_servers.is_empty() {
1041        prompt.push_str("# MCP Servers\n\n");
1042        prompt.push_str(
1043            "Connected MCP servers provide additional tools. MCP tools are prefixed \
1044             with `mcp__{server}__{tool}`. Use them like any other tool.\n\n",
1045        );
1046        for (name, entry) in &state.config.mcp_servers {
1047            let transport = if entry.command.is_some() {
1048                "stdio"
1049            } else if entry.url.is_some() {
1050                "sse"
1051            } else {
1052                "unknown"
1053            };
1054            prompt.push_str(&format!("- **{name}** ({transport})\n"));
1055        }
1056        prompt.push('\n');
1057    }
1058
1059    // Deferred tools listing.
1060    let deferred = tools.deferred_names();
1061    if !deferred.is_empty() {
1062        prompt.push_str("# Deferred Tools\n\n");
1063        prompt.push_str(
1064            "These tools are available but not loaded by default. \
1065             Use ToolSearch to load them when needed:\n",
1066        );
1067        for name in &deferred {
1068            prompt.push_str(&format!("- {name}\n"));
1069        }
1070        prompt.push('\n');
1071    }
1072
1073    // Task management guidance.
1074    prompt.push_str(
1075        "# Task management\n\n\
1076         - Use TaskCreate to break complex work into trackable steps.\n\
1077         - Mark tasks as in_progress when starting, completed when done.\n\
1078         - Use the Agent tool to spawn subagents for parallel independent work.\n\
1079         - Use EnterPlanMode/ExitPlanMode for read-only exploration before making changes.\n\
1080         - Use EnterWorktree/ExitWorktree for isolated changes in git worktrees.\n\n\
1081         # Output formatting\n\n\
1082         - All text output is displayed to the user. Use GitHub-flavored markdown.\n\
1083         - Use fenced code blocks with language hints for code: ```rust, ```python, etc.\n\
1084         - Use inline `code` for file names, function names, and short code references.\n\
1085         - Use tables for structured comparisons.\n\
1086         - Use bullet lists for multiple items.\n\
1087         - Keep paragraphs short (2-3 sentences).\n\
1088         - Never output raw HTML or complex formatting — stick to standard markdown.\n",
1089    );
1090
1091    prompt
1092}
1093
1094#[cfg(test)]
1095mod tests {
1096    use super::*;
1097
1098    /// Verify that cancelling via the shared handle cancels the current
1099    /// turn's token (regression: the signal handler previously held a
1100    /// stale clone that couldn't cancel subsequent turns).
1101    #[test]
1102    fn cancel_shared_propagates_to_current_token() {
1103        let root = CancellationToken::new();
1104        let shared = Arc::new(std::sync::Mutex::new(root.clone()));
1105
1106        // Simulate turn reset: create a new token and update the shared handle.
1107        let turn1 = CancellationToken::new();
1108        *shared.lock().unwrap() = turn1.clone();
1109
1110        // Cancelling via the shared handle should cancel turn1.
1111        shared.lock().unwrap().cancel();
1112        assert!(turn1.is_cancelled());
1113
1114        // New turn: replace the token. The old cancellation shouldn't affect it.
1115        let turn2 = CancellationToken::new();
1116        *shared.lock().unwrap() = turn2.clone();
1117        assert!(!turn2.is_cancelled());
1118
1119        // Cancelling via shared should cancel turn2.
1120        shared.lock().unwrap().cancel();
1121        assert!(turn2.is_cancelled());
1122    }
1123
1124    /// Verify that the streaming loop breaks on cancellation by simulating
1125    /// the select pattern used in run_turn_with_sink.
1126    #[tokio::test]
1127    async fn stream_loop_responds_to_cancellation() {
1128        let cancel = CancellationToken::new();
1129        let (tx, mut rx) = tokio::sync::mpsc::channel::<StreamEvent>(10);
1130
1131        // Simulate a slow stream: send one event, then cancel before more arrive.
1132        tx.send(StreamEvent::TextDelta("hello".into()))
1133            .await
1134            .unwrap();
1135
1136        let cancel2 = cancel.clone();
1137        tokio::spawn(async move {
1138            // Small delay, then cancel.
1139            tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1140            cancel2.cancel();
1141        });
1142
1143        let mut events_received = 0u32;
1144        let mut cancelled = false;
1145
1146        loop {
1147            tokio::select! {
1148                event = rx.recv() => {
1149                    match event {
1150                        Some(_) => events_received += 1,
1151                        None => break,
1152                    }
1153                }
1154                _ = cancel.cancelled() => {
1155                    cancelled = true;
1156                    break;
1157                }
1158            }
1159        }
1160
1161        assert!(cancelled, "Loop should have been cancelled");
1162        assert_eq!(
1163            events_received, 1,
1164            "Should have received exactly one event before cancel"
1165        );
1166    }
1167
1168    // ------------------------------------------------------------------
1169    // End-to-end regression tests for #103.
1170    //
1171    // These tests build a real QueryEngine with a mock Provider and
1172    // exercise run_turn_with_sink directly, verifying that cancellation
1173    // actually interrupts the streaming loop (not just the select!
1174    // pattern in isolation).
1175    // ------------------------------------------------------------------
1176
1177    use crate::llm::provider::{Provider, ProviderError, ProviderRequest};
1178
1179    /// A provider whose stream yields one TextDelta and then hangs forever.
1180    /// Simulates the real bug: a slow LLM response the user wants to interrupt.
1181    struct HangingProvider;
1182
1183    #[async_trait::async_trait]
1184    impl Provider for HangingProvider {
1185        fn name(&self) -> &str {
1186            "hanging-mock"
1187        }
1188
1189        async fn stream(
1190            &self,
1191            _request: &ProviderRequest,
1192        ) -> Result<tokio::sync::mpsc::Receiver<StreamEvent>, ProviderError> {
1193            let (tx, rx) = tokio::sync::mpsc::channel(4);
1194            tokio::spawn(async move {
1195                let _ = tx.send(StreamEvent::TextDelta("thinking...".into())).await;
1196                // Hang forever without closing the channel or sending Done.
1197                let _tx_holder = tx;
1198                std::future::pending::<()>().await;
1199            });
1200            Ok(rx)
1201        }
1202    }
1203
1204    /// A provider whose spawned streaming task honors `ProviderRequest::cancel`.
1205    /// When it exits cleanly via cancellation it flips `exit_flag` to true.
1206    /// Used to prove that the cancel token reaches the provider's own task
1207    /// (the anthropic/openai/azure SSE loops), not just the query engine's
1208    /// outer select.
1209    struct CancelAwareHangingProvider {
1210        exit_flag: Arc<std::sync::atomic::AtomicBool>,
1211    }
1212
1213    #[async_trait::async_trait]
1214    impl Provider for CancelAwareHangingProvider {
1215        fn name(&self) -> &str {
1216            "cancel-aware-mock"
1217        }
1218
1219        async fn stream(
1220            &self,
1221            request: &ProviderRequest,
1222        ) -> Result<tokio::sync::mpsc::Receiver<StreamEvent>, ProviderError> {
1223            let (tx, rx) = tokio::sync::mpsc::channel(4);
1224            let cancel = request.cancel.clone();
1225            let exit_flag = self.exit_flag.clone();
1226            tokio::spawn(async move {
1227                let _ = tx.send(StreamEvent::TextDelta("thinking...".into())).await;
1228                // Mirror the real provider pattern: race the "next SSE chunk"
1229                // future against the cancel token. A real provider's
1230                // `byte_stream.next().await` would sit where `pending` does.
1231                tokio::select! {
1232                    biased;
1233                    _ = cancel.cancelled() => {
1234                        exit_flag.store(true, std::sync::atomic::Ordering::SeqCst);
1235                    }
1236                    _ = std::future::pending::<()>() => unreachable!(),
1237                }
1238            });
1239            Ok(rx)
1240        }
1241    }
1242
1243    /// A provider that completes a turn normally: emits text and a Done event.
1244    struct CompletingProvider;
1245
1246    #[async_trait::async_trait]
1247    impl Provider for CompletingProvider {
1248        fn name(&self) -> &str {
1249            "completing-mock"
1250        }
1251
1252        async fn stream(
1253            &self,
1254            _request: &ProviderRequest,
1255        ) -> Result<tokio::sync::mpsc::Receiver<StreamEvent>, ProviderError> {
1256            let (tx, rx) = tokio::sync::mpsc::channel(8);
1257            tokio::spawn(async move {
1258                let _ = tx.send(StreamEvent::TextDelta("hello".into())).await;
1259                let _ = tx
1260                    .send(StreamEvent::ContentBlockComplete(ContentBlock::Text {
1261                        text: "hello".into(),
1262                    }))
1263                    .await;
1264                let _ = tx
1265                    .send(StreamEvent::Done {
1266                        usage: Usage::default(),
1267                        stop_reason: Some(StopReason::EndTurn),
1268                    })
1269                    .await;
1270                // Channel closes when tx drops.
1271            });
1272            Ok(rx)
1273        }
1274    }
1275
1276    fn build_engine(llm: Arc<dyn Provider>) -> QueryEngine {
1277        use crate::config::Config;
1278        use crate::permissions::PermissionChecker;
1279        use crate::state::AppState;
1280        use crate::tools::registry::ToolRegistry;
1281
1282        let config = Config::default();
1283        let permissions = PermissionChecker::from_config(&config.permissions);
1284        let state = AppState::new(config);
1285
1286        QueryEngine::new(
1287            llm,
1288            ToolRegistry::default_tools(),
1289            permissions,
1290            state,
1291            QueryEngineConfig {
1292                max_turns: Some(1),
1293                verbose: false,
1294                unattended: true,
1295            },
1296        )
1297    }
1298
1299    /// Schedule a cancellation after `delay_ms` via the shared handle
1300    /// (same path the signal handler uses).
1301    fn schedule_cancel(engine: &QueryEngine, delay_ms: u64) {
1302        let shared = engine.cancel_shared.clone();
1303        tokio::spawn(async move {
1304            tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
1305            shared.lock().unwrap().cancel();
1306        });
1307    }
1308
1309    /// Builds a mock provider whose stream yields one TextDelta and then hangs.
1310    /// Verifies the turn returns promptly once cancel fires.
1311    #[tokio::test]
1312    async fn run_turn_with_sink_interrupts_on_cancel() {
1313        let mut engine = build_engine(Arc::new(HangingProvider));
1314        schedule_cancel(&engine, 100);
1315
1316        let result = tokio::time::timeout(
1317            std::time::Duration::from_secs(5),
1318            engine.run_turn_with_sink("test input", &NullSink),
1319        )
1320        .await;
1321
1322        assert!(
1323            result.is_ok(),
1324            "run_turn_with_sink should return promptly on cancel, not hang"
1325        );
1326        assert!(
1327            result.unwrap().is_ok(),
1328            "cancelled turn should return Ok(()), not an error"
1329        );
1330        assert!(
1331            !engine.state().is_query_active,
1332            "is_query_active should be reset after cancel"
1333        );
1334    }
1335
1336    /// Regression test for the original #103 bug: the signal handler held
1337    /// a stale clone of the cancellation token, so Ctrl+C only worked on
1338    /// the *first* turn. This test cancels turn 1, then runs turn 2 and
1339    /// verifies it is ALSO cancellable via the same shared handle.
1340    #[tokio::test]
1341    async fn cancel_works_across_multiple_turns() {
1342        let mut engine = build_engine(Arc::new(HangingProvider));
1343
1344        // Turn 1: cancel mid-stream.
1345        schedule_cancel(&engine, 80);
1346        let r1 = tokio::time::timeout(
1347            std::time::Duration::from_secs(5),
1348            engine.run_turn_with_sink("turn 1", &NullSink),
1349        )
1350        .await;
1351        assert!(r1.is_ok(), "turn 1 should cancel promptly");
1352        assert!(!engine.state().is_query_active);
1353
1354        // Turn 2: cancel again via the same shared handle.
1355        // With the pre-fix stale-token bug, the handle would be pointing
1356        // at turn 1's already-used token and turn 2 would hang forever.
1357        schedule_cancel(&engine, 80);
1358        let r2 = tokio::time::timeout(
1359            std::time::Duration::from_secs(5),
1360            engine.run_turn_with_sink("turn 2", &NullSink),
1361        )
1362        .await;
1363        assert!(
1364            r2.is_ok(),
1365            "turn 2 should also cancel promptly — regression would hang here"
1366        );
1367        assert!(!engine.state().is_query_active);
1368
1369        // Turn 3: one more for good measure.
1370        schedule_cancel(&engine, 80);
1371        let r3 = tokio::time::timeout(
1372            std::time::Duration::from_secs(5),
1373            engine.run_turn_with_sink("turn 3", &NullSink),
1374        )
1375        .await;
1376        assert!(r3.is_ok(), "turn 3 should still be cancellable");
1377        assert!(!engine.state().is_query_active);
1378    }
1379
1380    /// Verifies that a previously-cancelled token does not poison subsequent
1381    /// turns. A fresh run_turn_with_sink on the same engine should complete
1382    /// normally even after a prior cancel.
1383    #[tokio::test]
1384    async fn cancel_does_not_poison_next_turn() {
1385        // Turn 1: hangs and gets cancelled.
1386        let mut engine = build_engine(Arc::new(HangingProvider));
1387        schedule_cancel(&engine, 80);
1388        let _ = tokio::time::timeout(
1389            std::time::Duration::from_secs(5),
1390            engine.run_turn_with_sink("turn 1", &NullSink),
1391        )
1392        .await
1393        .expect("turn 1 should cancel");
1394
1395        // Swap the provider to one that completes normally by rebuilding
1396        // the engine (we can't swap llm on an existing engine, so this
1397        // simulates the isolated "fresh turn" behavior). The key property
1398        // being tested is that the per-turn cancel reset correctly
1399        // initializes a non-cancelled token.
1400        let mut engine2 = build_engine(Arc::new(CompletingProvider));
1401
1402        // Pre-cancel engine2 to simulate a leftover cancelled state, then
1403        // verify run_turn_with_sink still works because it resets the token.
1404        engine2.cancel_shared.lock().unwrap().cancel();
1405
1406        let result = tokio::time::timeout(
1407            std::time::Duration::from_secs(5),
1408            engine2.run_turn_with_sink("hello", &NullSink),
1409        )
1410        .await;
1411
1412        assert!(result.is_ok(), "completing turn should not hang");
1413        assert!(
1414            result.unwrap().is_ok(),
1415            "turn should succeed — the stale cancel flag must be cleared on turn start"
1416        );
1417        // Message history should contain the user + assistant messages.
1418        assert!(
1419            engine2.state().messages.len() >= 2,
1420            "normal turn should push both user and assistant messages"
1421        );
1422    }
1423
1424    /// Verifies that cancelling BEFORE any event arrives still interrupts
1425    /// the turn cleanly (edge case: cancellation races with the first recv).
1426    #[tokio::test]
1427    async fn cancel_before_first_event_interrupts_cleanly() {
1428        let mut engine = build_engine(Arc::new(HangingProvider));
1429        // Very short delay so cancel likely fires before or during the
1430        // first event. The test is tolerant of either ordering.
1431        schedule_cancel(&engine, 1);
1432
1433        let result = tokio::time::timeout(
1434            std::time::Duration::from_secs(5),
1435            engine.run_turn_with_sink("immediate", &NullSink),
1436        )
1437        .await;
1438
1439        assert!(result.is_ok(), "early cancel should not hang");
1440        assert!(result.unwrap().is_ok());
1441        assert!(!engine.state().is_query_active);
1442    }
1443
1444    /// Verifies the sink receives cancellation feedback via on_warning.
1445    #[tokio::test]
1446    async fn cancelled_turn_emits_warning_to_sink() {
1447        use std::sync::Mutex;
1448
1449        /// Captures sink events for assertion.
1450        struct CapturingSink {
1451            warnings: Mutex<Vec<String>>,
1452        }
1453
1454        impl StreamSink for CapturingSink {
1455            fn on_text(&self, _: &str) {}
1456            fn on_tool_start(&self, _: &str, _: &serde_json::Value) {}
1457            fn on_tool_result(&self, _: &str, _: &crate::tools::ToolResult) {}
1458            fn on_error(&self, _: &str) {}
1459            fn on_warning(&self, msg: &str) {
1460                self.warnings.lock().unwrap().push(msg.to_string());
1461            }
1462        }
1463
1464        let sink = CapturingSink {
1465            warnings: Mutex::new(Vec::new()),
1466        };
1467
1468        let mut engine = build_engine(Arc::new(HangingProvider));
1469        schedule_cancel(&engine, 100);
1470
1471        let _ = tokio::time::timeout(
1472            std::time::Duration::from_secs(5),
1473            engine.run_turn_with_sink("test", &sink),
1474        )
1475        .await
1476        .expect("should not hang");
1477
1478        let warnings = sink.warnings.lock().unwrap();
1479        assert!(
1480            warnings.iter().any(|w| w.contains("Cancelled")),
1481            "expected 'Cancelled' warning in sink, got: {:?}",
1482            *warnings
1483        );
1484    }
1485
1486    /// Regression test for the #125-followup: the cancellation token must
1487    /// reach the *provider's own spawned streaming task*, not just the query
1488    /// engine's outer select loop. This is what was broken — the existing
1489    /// `HangingProvider` test above passed even while Escape-interrupt was
1490    /// completely dead in production, because that provider ignores the
1491    /// token and the query loop exits cleanly on its own. This test fails
1492    /// if `ProviderRequest::cancel` is dropped on the floor anywhere between
1493    /// `query::mod.rs` and the provider's spawn.
1494    #[tokio::test]
1495    async fn provider_stream_task_observes_cancellation() {
1496        use std::sync::atomic::{AtomicBool, Ordering};
1497
1498        let exit_flag = Arc::new(AtomicBool::new(false));
1499        let provider = Arc::new(CancelAwareHangingProvider {
1500            exit_flag: exit_flag.clone(),
1501        });
1502        let mut engine = build_engine(provider);
1503        schedule_cancel(&engine, 50);
1504
1505        let result = tokio::time::timeout(
1506            std::time::Duration::from_secs(2),
1507            engine.run_turn_with_sink("test input", &NullSink),
1508        )
1509        .await;
1510        assert!(result.is_ok(), "engine should exit promptly on cancel");
1511
1512        // Give the provider's spawned task a moment to observe the cancel
1513        // after the engine returns. In release builds this is effectively
1514        // instantaneous; the sleep just makes the test tolerant of scheduler
1515        // variance on slow CI runners.
1516        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
1517
1518        assert!(
1519            exit_flag.load(Ordering::SeqCst),
1520            "provider's streaming task should have observed cancel via \
1521             ProviderRequest::cancel and exited; if this flag is false, \
1522             the token is being dropped somewhere in query::mod.rs or the \
1523             provider is ignoring it"
1524        );
1525    }
1526}