1pub mod source;
16
17use std::path::PathBuf;
18use std::sync::Arc;
19
20use tokio_util::sync::CancellationToken;
21use tracing::{debug, info, warn};
22use uuid::Uuid;
23
24use crate::hooks::{HookEvent, HookRegistry};
25use crate::llm::message::*;
26use crate::llm::provider::{Provider, ProviderError, ProviderRequest};
27use crate::llm::stream::StreamEvent;
28use crate::permissions::PermissionChecker;
29use crate::services::compact::{self, CompactTracking, MAX_OUTPUT_TOKENS_RECOVERY_LIMIT};
30use crate::services::tokens;
31use crate::state::AppState;
32use crate::tools::ToolContext;
33use crate::tools::executor::{execute_tool_calls, extract_tool_calls};
34use crate::tools::registry::ToolRegistry;
35
36pub struct QueryEngineConfig {
38 pub max_turns: Option<usize>,
39 pub verbose: bool,
40 pub unattended: bool,
42}
43
44pub struct QueryEngine {
50 llm: Arc<dyn Provider>,
51 tools: ToolRegistry,
52 file_cache: Arc<tokio::sync::Mutex<crate::services::file_cache::FileCache>>,
53 permissions: Arc<PermissionChecker>,
54 state: AppState,
55 config: QueryEngineConfig,
56 cancel_shared: Arc<std::sync::Mutex<CancellationToken>>,
58 cancel: CancellationToken,
60 hooks: HookRegistry,
61 cache_tracker: crate::services::cache_tracking::CacheTracker,
62 denial_tracker: Arc<tokio::sync::Mutex<crate::permissions::tracking::DenialTracker>>,
63 extraction_state: Arc<tokio::sync::Mutex<crate::memory::extraction::ExtractionState>>,
64 session_allows: Arc<tokio::sync::Mutex<std::collections::HashSet<String>>>,
65 permission_prompter: Option<Arc<dyn crate::tools::PermissionPrompter>>,
66 cached_system_prompt: Option<(u64, String)>, }
69
70pub trait StreamSink: Send + Sync {
72 fn on_text(&self, text: &str);
73 fn on_tool_start(&self, tool_name: &str, input: &serde_json::Value);
74 fn on_tool_result(&self, tool_name: &str, result: &crate::tools::ToolResult);
75 fn on_thinking(&self, _text: &str) {}
76 fn on_turn_complete(&self, _turn: usize) {}
77 fn on_error(&self, error: &str);
78 fn on_usage(&self, _usage: &Usage) {}
79 fn on_compact(&self, _freed_tokens: u64) {}
80 fn on_warning(&self, _msg: &str) {}
81}
82
83pub struct NullSink;
85impl StreamSink for NullSink {
86 fn on_text(&self, _: &str) {}
87 fn on_tool_start(&self, _: &str, _: &serde_json::Value) {}
88 fn on_tool_result(&self, _: &str, _: &crate::tools::ToolResult) {}
89 fn on_error(&self, _: &str) {}
90}
91
92impl QueryEngine {
93 pub fn new(
94 llm: Arc<dyn Provider>,
95 tools: ToolRegistry,
96 permissions: PermissionChecker,
97 state: AppState,
98 config: QueryEngineConfig,
99 ) -> Self {
100 let cancel = CancellationToken::new();
101 let cancel_shared = Arc::new(std::sync::Mutex::new(cancel.clone()));
102 Self {
103 llm,
104 tools,
105 file_cache: Arc::new(tokio::sync::Mutex::new(
106 crate::services::file_cache::FileCache::new(),
107 )),
108 permissions: Arc::new(permissions),
109 state,
110 config,
111 cancel,
112 cancel_shared,
113 hooks: HookRegistry::new(),
114 cache_tracker: crate::services::cache_tracking::CacheTracker::new(),
115 denial_tracker: Arc::new(tokio::sync::Mutex::new(
116 crate::permissions::tracking::DenialTracker::new(100),
117 )),
118 extraction_state: Arc::new(tokio::sync::Mutex::new(
119 crate::memory::extraction::ExtractionState::new(),
120 )),
121 session_allows: Arc::new(tokio::sync::Mutex::new(std::collections::HashSet::new())),
122 permission_prompter: None,
123 cached_system_prompt: None,
124 }
125 }
126
127 pub fn load_hooks(&mut self, hook_defs: &[crate::hooks::HookDefinition]) {
129 for def in hook_defs {
130 self.hooks.register(def.clone());
131 }
132 if !hook_defs.is_empty() {
133 tracing::info!("Loaded {} hooks from config", hook_defs.len());
134 }
135 }
136
137 pub fn state(&self) -> &AppState {
139 &self.state
140 }
141
142 pub fn state_mut(&mut self) -> &mut AppState {
144 &mut self.state
145 }
146
147 pub fn install_signal_handler(&self) {
151 let shared = self.cancel_shared.clone();
152 tokio::spawn(async move {
153 let mut pending = false;
154 loop {
155 if tokio::signal::ctrl_c().await.is_ok() {
156 let token = shared.lock().unwrap().clone();
157 if token.is_cancelled() && pending {
158 std::process::exit(130);
160 }
161 token.cancel();
162 pending = true;
163 }
164 }
165 });
166 }
167
168 pub async fn run_turn(&mut self, user_input: &str) -> crate::error::Result<()> {
170 self.run_turn_with_sink(user_input, &NullSink).await
171 }
172
173 pub async fn run_turn_with_sink(
175 &mut self,
176 user_input: &str,
177 sink: &dyn StreamSink,
178 ) -> crate::error::Result<()> {
179 self.cancel = CancellationToken::new();
182 *self.cancel_shared.lock().unwrap() = self.cancel.clone();
183
184 let user_msg = user_message(user_input);
186 self.state.push_message(user_msg);
187
188 let max_turns = self.config.max_turns.unwrap_or(50);
189 let mut compact_tracking = CompactTracking::default();
190 let mut retry_state = crate::llm::retry::RetryState::default();
191 let retry_config = crate::llm::retry::RetryConfig::default();
192 let mut max_output_recovery_count = 0u32;
193
194 for turn in 0..max_turns {
196 self.state.turn_count = turn + 1;
197 self.state.is_query_active = true;
198
199 let budget_config = crate::services::budget::BudgetConfig::default();
201 match crate::services::budget::check_budget(
202 self.state.total_cost_usd,
203 self.state.total_usage.total(),
204 &budget_config,
205 ) {
206 crate::services::budget::BudgetDecision::Stop { message } => {
207 sink.on_warning(&message);
208 self.state.is_query_active = false;
209 return Ok(());
210 }
211 crate::services::budget::BudgetDecision::ContinueWithWarning {
212 message, ..
213 } => {
214 sink.on_warning(&message);
215 }
216 crate::services::budget::BudgetDecision::Continue => {}
217 }
218
219 crate::llm::normalize::ensure_tool_result_pairing(&mut self.state.messages);
221 crate::llm::normalize::strip_empty_blocks(&mut self.state.messages);
222 crate::llm::normalize::remove_empty_messages(&mut self.state.messages);
223 crate::llm::normalize::cap_document_blocks(&mut self.state.messages, 500_000);
224 crate::llm::normalize::merge_consecutive_user_messages(&mut self.state.messages);
225
226 debug!("Agent turn {}/{}", turn + 1, max_turns);
227
228 let mut model = self.state.config.api.model.clone();
229
230 if compact::should_auto_compact(self.state.history(), &model, &compact_tracking) {
232 let token_count = tokens::estimate_context_tokens(self.state.history());
233 let threshold = compact::auto_compact_threshold(&model);
234 info!("Auto-compact triggered: {token_count} tokens >= {threshold} threshold");
235
236 let freed = compact::microcompact(&mut self.state.messages, 5);
238 if freed > 0 {
239 sink.on_compact(freed);
240 info!("Microcompact freed ~{freed} tokens");
241 }
242
243 let post_mc_tokens = tokens::estimate_context_tokens(self.state.history());
245 if post_mc_tokens >= threshold {
246 info!("Microcompact insufficient, attempting LLM compaction");
248 match compact::compact_with_llm(
249 &mut self.state.messages,
250 &*self.llm,
251 &model,
252 self.cancel.clone(),
253 )
254 .await
255 {
256 Some(removed) => {
257 info!("LLM compaction removed {removed} messages");
258 compact_tracking.was_compacted = true;
259 compact_tracking.consecutive_failures = 0;
260 }
261 None => {
262 compact_tracking.consecutive_failures += 1;
263 warn!(
264 "LLM compaction failed (attempt {})",
265 compact_tracking.consecutive_failures
266 );
267 let effective = compact::effective_context_window(&model);
269 if let Some(collapse) =
270 crate::services::context_collapse::collapse_to_budget(
271 self.state.history(),
272 effective,
273 )
274 {
275 info!(
276 "Context collapse snipped {} messages, freed ~{} tokens",
277 collapse.snipped_count, collapse.tokens_freed
278 );
279 self.state.messages = collapse.api_messages;
280 sink.on_compact(collapse.tokens_freed);
281 } else {
282 let freed2 = compact::microcompact(&mut self.state.messages, 2);
284 if freed2 > 0 {
285 sink.on_compact(freed2);
286 }
287 }
288 }
289 }
290 }
291 }
292
293 if compact_tracking.was_compacted && self.state.config.features.compaction_reminders {
295 let reminder = user_message(
296 "<system-reminder>Context was automatically compacted. \
297 Earlier messages were summarized. If you need details from \
298 before compaction, ask the user or re-read the relevant files.</system-reminder>",
299 );
300 self.state.push_message(reminder);
301 compact_tracking.was_compacted = false; }
303
304 let warning = compact::token_warning_state(self.state.history(), &model);
306 if warning.is_blocking {
307 sink.on_warning("Context window nearly full. Consider starting a new session.");
308 } else if warning.is_above_warning {
309 sink.on_warning(&format!("Context {}% remaining", warning.percent_left));
310 }
311
312 let prompt_hash = {
315 use std::hash::{Hash, Hasher};
316 let mut h = std::collections::hash_map::DefaultHasher::new();
317 self.state.config.api.model.hash(&mut h);
318 self.state.cwd.hash(&mut h);
319 self.state.config.mcp_servers.len().hash(&mut h);
320 self.tools.all().len().hash(&mut h);
321 h.finish()
322 };
323 let system_prompt = if let Some((cached_hash, ref cached)) = self.cached_system_prompt
324 && cached_hash == prompt_hash
325 {
326 cached.clone()
327 } else {
328 let prompt = build_system_prompt(&self.tools, &self.state);
329 self.cached_system_prompt = Some((prompt_hash, prompt.clone()));
330 prompt
331 };
332 let tool_schemas = self.tools.core_schemas();
334
335 let base_tokens = self.state.config.api.max_output_tokens.unwrap_or(16384);
337 let effective_tokens = if max_output_recovery_count > 0 {
338 base_tokens.max(65536) } else {
340 base_tokens
341 };
342
343 let request = ProviderRequest {
344 messages: self.state.history().to_vec(),
345 system_prompt: system_prompt.clone(),
346 tools: tool_schemas.clone(),
347 model: model.clone(),
348 max_tokens: effective_tokens,
349 temperature: None,
350 enable_caching: self.state.config.features.prompt_caching,
351 tool_choice: Default::default(),
352 metadata: None,
353 cancel: self.cancel.clone(),
354 };
355
356 let mut rx = match self.llm.stream(&request).await {
357 Ok(rx) => {
358 retry_state.reset();
359 rx
360 }
361 Err(e) => {
362 let retryable = match &e {
363 ProviderError::RateLimited { retry_after_ms } => {
364 crate::llm::retry::RetryableError::RateLimited {
365 retry_after: *retry_after_ms,
366 }
367 }
368 ProviderError::Overloaded => crate::llm::retry::RetryableError::Overloaded,
369 ProviderError::Network(_) => {
370 crate::llm::retry::RetryableError::StreamInterrupted
371 }
372 other => crate::llm::retry::RetryableError::NonRetryable(other.to_string()),
373 };
374
375 match retry_state.next_action(&retryable, &retry_config) {
376 crate::llm::retry::RetryAction::Retry { after } => {
377 warn!("Retrying in {}ms", after.as_millis());
378 tokio::time::sleep(after).await;
379 continue;
380 }
381 crate::llm::retry::RetryAction::FallbackModel => {
382 let fallback = get_fallback_model(&model);
384 sink.on_warning(&format!("Falling back from {model} to {fallback}"));
385 model = fallback;
386 continue;
387 }
388 crate::llm::retry::RetryAction::Abort(reason) => {
389 if self.config.unattended
392 && self.state.config.features.unattended_retry
393 && matches!(
394 &e,
395 ProviderError::Overloaded | ProviderError::RateLimited { .. }
396 )
397 {
398 warn!("Unattended retry: waiting 30s for capacity");
399 tokio::time::sleep(std::time::Duration::from_secs(30)).await;
400 continue;
401 }
402 if let ProviderError::RequestTooLarge(body) = &e {
405 let gap = compact::parse_prompt_too_long_gap(body);
406
407 let effective = compact::effective_context_window(&model);
409 if let Some(collapse) =
410 crate::services::context_collapse::collapse_to_budget(
411 self.state.history(),
412 effective,
413 )
414 {
415 info!(
416 "Reactive collapse: snipped {} messages, freed ~{} tokens",
417 collapse.snipped_count, collapse.tokens_freed
418 );
419 self.state.messages = collapse.api_messages;
420 sink.on_compact(collapse.tokens_freed);
421 continue;
422 }
423
424 let freed = compact::microcompact(&mut self.state.messages, 1);
426 if freed > 0 {
427 sink.on_compact(freed);
428 info!(
429 "Reactive microcompact freed ~{freed} tokens (gap: {gap:?})"
430 );
431 continue;
432 }
433 }
434 sink.on_error(&reason);
435 self.state.is_query_active = false;
436 return Err(crate::error::Error::Other(e.to_string()));
437 }
438 }
439 }
440 };
441
442 let mut content_blocks = Vec::new();
445 let mut usage = Usage::default();
446 let mut stop_reason: Option<StopReason> = None;
447 let mut got_error = false;
448 let mut error_text = String::new();
449
450 let mut streaming_tool_handles: Vec<(
452 String,
453 String,
454 tokio::task::JoinHandle<crate::tools::ToolResult>,
455 )> = Vec::new();
456
457 let mut cancelled = false;
458 loop {
459 tokio::select! {
460 event = rx.recv() => {
461 match event {
462 Some(StreamEvent::TextDelta(text)) => {
463 sink.on_text(&text);
464 }
465 Some(StreamEvent::ContentBlockComplete(block)) => {
466 if let ContentBlock::ToolUse {
467 ref id,
468 ref name,
469 ref input,
470 } = block
471 {
472 sink.on_tool_start(name, input);
473
474 if let Some(tool) = self.tools.get(name)
476 && tool.is_read_only()
477 && tool.is_concurrency_safe()
478 {
479 let tool = tool.clone();
480 let input = input.clone();
481 let cwd = std::path::PathBuf::from(&self.state.cwd);
482 let cancel = self.cancel.clone();
483 let perm = self.permissions.clone();
484 let tool_id = id.clone();
485 let tool_name = name.clone();
486
487 let handle = tokio::spawn(async move {
488 match tool
489 .call(
490 input,
491 &ToolContext {
492 cwd,
493 cancel,
494 permission_checker: perm.clone(),
495 verbose: false,
496 plan_mode: false,
497 file_cache: None,
498 denial_tracker: None,
499 task_manager: None,
500 session_allows: None,
501 permission_prompter: None,
502 sandbox: None,
503 },
504 )
505 .await
506 {
507 Ok(r) => r,
508 Err(e) => crate::tools::ToolResult::error(e.to_string()),
509 }
510 });
511
512 streaming_tool_handles.push((tool_id, tool_name, handle));
513 }
514 }
515 if let ContentBlock::Thinking { ref thinking, .. } = block {
516 sink.on_thinking(thinking);
517 }
518 content_blocks.push(block);
519 }
520 Some(StreamEvent::Done {
521 usage: u,
522 stop_reason: sr,
523 }) => {
524 usage = u;
525 stop_reason = sr;
526 sink.on_usage(&usage);
527 }
528 Some(StreamEvent::Error(msg)) => {
529 got_error = true;
530 error_text = msg.clone();
531 sink.on_error(&msg);
532 }
533 Some(_) => {}
534 None => break,
535 }
536 }
537 _ = self.cancel.cancelled() => {
538 warn!("Turn cancelled by user");
539 cancelled = true;
540 for (_, _, handle) in streaming_tool_handles.drain(..) {
542 handle.abort();
543 }
544 break;
545 }
546 }
547 }
548
549 if cancelled {
550 sink.on_warning("Cancelled");
551 self.state.is_query_active = false;
552 return Ok(());
553 }
554
555 let assistant_msg = Message::Assistant(AssistantMessage {
557 uuid: Uuid::new_v4(),
558 timestamp: chrono::Utc::now().to_rfc3339(),
559 content: content_blocks.clone(),
560 model: Some(model.clone()),
561 usage: Some(usage.clone()),
562 stop_reason: stop_reason.clone(),
563 request_id: None,
564 });
565 self.state.push_message(assistant_msg);
566 self.state.record_usage(&usage, &model);
567
568 if self.state.config.features.token_budget && usage.total() > 0 {
570 let turn_total = usage.input_tokens + usage.output_tokens;
571 if turn_total > 100_000 {
572 sink.on_warning(&format!(
573 "High token usage this turn: {} tokens ({}in + {}out)",
574 turn_total, usage.input_tokens, usage.output_tokens
575 ));
576 }
577 }
578
579 let _cache_event = self.cache_tracker.record(&usage);
581 {
582 let mut span = crate::services::telemetry::api_call_span(
583 &model,
584 turn + 1,
585 &self.state.session_id,
586 );
587 crate::services::telemetry::record_usage(&mut span, &usage);
588 span.finish();
589 tracing::debug!(
590 "API call: {}ms, {}in/{}out tokens",
591 span.duration_ms().unwrap_or(0),
592 usage.input_tokens,
593 usage.output_tokens,
594 );
595 }
596
597 if got_error {
599 if error_text.contains("prompt is too long")
601 || error_text.contains("Prompt is too long")
602 {
603 let freed = compact::microcompact(&mut self.state.messages, 1);
604 if freed > 0 {
605 sink.on_compact(freed);
606 continue;
607 }
608 }
609
610 if content_blocks
612 .iter()
613 .any(|b| matches!(b, ContentBlock::Text { .. }))
614 && error_text.contains("max_tokens")
615 && max_output_recovery_count < MAX_OUTPUT_TOKENS_RECOVERY_LIMIT
616 {
617 max_output_recovery_count += 1;
618 info!(
619 "Max output tokens recovery attempt {}/{}",
620 max_output_recovery_count, MAX_OUTPUT_TOKENS_RECOVERY_LIMIT
621 );
622 let recovery_msg = compact::max_output_recovery_message();
623 self.state.push_message(recovery_msg);
624 continue;
625 }
626 }
627
628 if matches!(stop_reason, Some(StopReason::MaxTokens))
630 && !got_error
631 && content_blocks
632 .iter()
633 .any(|b| matches!(b, ContentBlock::Text { .. }))
634 && max_output_recovery_count < MAX_OUTPUT_TOKENS_RECOVERY_LIMIT
635 {
636 max_output_recovery_count += 1;
637 info!(
638 "Max tokens stop reason — recovery attempt {}/{}",
639 max_output_recovery_count, MAX_OUTPUT_TOKENS_RECOVERY_LIMIT
640 );
641 let recovery_msg = compact::max_output_recovery_message();
642 self.state.push_message(recovery_msg);
643 continue;
644 }
645
646 let tool_calls = extract_tool_calls(&content_blocks);
648
649 if tool_calls.is_empty() {
650 info!("Turn complete (no tool calls)");
652 sink.on_turn_complete(turn + 1);
653 self.state.is_query_active = false;
654
655 if self.state.config.features.extract_memories
658 && crate::memory::ensure_memory_dir().is_some()
659 {
660 let extraction_messages = self.state.messages.clone();
661 let extraction_state = self.extraction_state.clone();
662 let extraction_llm = self.llm.clone();
663 let extraction_model = model.clone();
664 tokio::spawn(async move {
665 crate::memory::extraction::extract_memories_background(
666 extraction_messages,
667 extraction_state,
668 extraction_llm,
669 extraction_model,
670 )
671 .await;
672 });
673 }
674
675 return Ok(());
676 }
677
678 info!("Executing {} tool call(s)", tool_calls.len());
680 let cwd = PathBuf::from(&self.state.cwd);
681 let tool_ctx = ToolContext {
682 cwd: cwd.clone(),
683 cancel: self.cancel.clone(),
684 permission_checker: self.permissions.clone(),
685 verbose: self.config.verbose,
686 plan_mode: self.state.plan_mode,
687 file_cache: Some(self.file_cache.clone()),
688 denial_tracker: Some(self.denial_tracker.clone()),
689 task_manager: Some(self.state.task_manager.clone()),
690 session_allows: Some(self.session_allows.clone()),
691 permission_prompter: self.permission_prompter.clone(),
692 sandbox: Some(std::sync::Arc::new(
693 crate::sandbox::SandboxExecutor::from_session_config(&self.state.config, &cwd),
694 )),
695 };
696
697 let streaming_ids: std::collections::HashSet<String> = streaming_tool_handles
699 .iter()
700 .map(|(id, _, _)| id.clone())
701 .collect();
702
703 let mut streaming_results = Vec::new();
704 for (id, name, handle) in streaming_tool_handles.drain(..) {
705 match handle.await {
706 Ok(result) => streaming_results.push(crate::tools::executor::ToolCallResult {
707 tool_use_id: id,
708 tool_name: name,
709 result,
710 }),
711 Err(e) => streaming_results.push(crate::tools::executor::ToolCallResult {
712 tool_use_id: id,
713 tool_name: name,
714 result: crate::tools::ToolResult::error(format!("Task failed: {e}")),
715 }),
716 }
717 }
718
719 for call in &tool_calls {
721 self.hooks
722 .run_hooks(&HookEvent::PreToolUse, Some(&call.name), &call.input)
723 .await;
724 }
725
726 let remaining_calls: Vec<_> = tool_calls
728 .iter()
729 .filter(|c| !streaming_ids.contains(&c.id))
730 .cloned()
731 .collect();
732
733 let mut results = streaming_results;
734 if !remaining_calls.is_empty() {
735 let batch_results = execute_tool_calls(
736 &remaining_calls,
737 self.tools.all(),
738 &tool_ctx,
739 &self.permissions,
740 )
741 .await;
742 results.extend(batch_results);
743 }
744
745 for result in &results {
747 if !result.result.is_error {
749 match result.tool_name.as_str() {
750 "EnterPlanMode" => {
751 self.state.plan_mode = true;
752 info!("Plan mode enabled");
753 }
754 "ExitPlanMode" => {
755 self.state.plan_mode = false;
756 info!("Plan mode disabled");
757 }
758 _ => {}
759 }
760 }
761
762 sink.on_tool_result(&result.tool_name, &result.result);
763
764 self.hooks
766 .run_hooks(
767 &HookEvent::PostToolUse,
768 Some(&result.tool_name),
769 &serde_json::json!({
770 "tool": result.tool_name,
771 "is_error": result.result.is_error,
772 }),
773 )
774 .await;
775
776 let msg = tool_result_message(
777 &result.tool_use_id,
778 &result.result.content,
779 result.result.is_error,
780 );
781 self.state.push_message(msg);
782 }
783
784 }
786
787 warn!("Max turns ({max_turns}) reached");
788 sink.on_warning(&format!("Agent stopped after {max_turns} turns"));
789 self.state.is_query_active = false;
790 Ok(())
791 }
792
793 pub fn cancel(&self) {
795 self.cancel.cancel();
796 }
797
798 pub fn cancel_token(&self) -> tokio_util::sync::CancellationToken {
800 self.cancel.clone()
801 }
802}
803
804fn get_fallback_model(current: &str) -> String {
806 let lower = current.to_lowercase();
807 if lower.contains("opus") {
808 current.replace("opus", "sonnet")
810 } else if (lower.contains("gpt-5.4") || lower.contains("gpt-4.1"))
811 && !lower.contains("mini")
812 && !lower.contains("nano")
813 {
814 format!("{current}-mini")
815 } else if lower.contains("large") {
816 current.replace("large", "small")
817 } else {
818 current.to_string()
820 }
821}
822
823pub fn build_system_prompt(tools: &ToolRegistry, state: &AppState) -> String {
825 let mut prompt = String::new();
826
827 prompt.push_str(
828 "You are an AI coding agent. You help users with software engineering tasks \
829 by reading, writing, and searching code. Use the tools available to you to \
830 accomplish tasks.\n\n",
831 );
832
833 let shell = std::env::var("SHELL").unwrap_or_else(|_| "bash".to_string());
835 let is_git = std::path::Path::new(&state.cwd).join(".git").exists();
836 prompt.push_str(&format!(
837 "# Environment\n\
838 - Working directory: {}\n\
839 - Platform: {}\n\
840 - Shell: {shell}\n\
841 - Git repository: {}\n\n",
842 state.cwd,
843 std::env::consts::OS,
844 if is_git { "yes" } else { "no" },
845 ));
846
847 let mut memory = crate::memory::MemoryContext::load(Some(std::path::Path::new(&state.cwd)));
849
850 let recent_text: String = state
852 .messages
853 .iter()
854 .rev()
855 .take(5)
856 .filter_map(|m| match m {
857 crate::llm::message::Message::User(u) => Some(
858 u.content
859 .iter()
860 .filter_map(|b| b.as_text())
861 .collect::<Vec<_>>()
862 .join(" "),
863 ),
864 _ => None,
865 })
866 .collect::<Vec<_>>()
867 .join(" ");
868
869 if !recent_text.is_empty() {
870 memory.load_relevant(&recent_text);
871 }
872
873 let memory_section = memory.to_system_prompt_section();
874 if !memory_section.is_empty() {
875 prompt.push_str(&memory_section);
876 }
877
878 prompt.push_str("# Available Tools\n\n");
880 for tool in tools.all() {
881 if tool.is_enabled() {
882 prompt.push_str(&format!("## {}\n{}\n\n", tool.name(), tool.prompt()));
883 }
884 }
885
886 let skills = crate::skills::SkillRegistry::load_all(Some(std::path::Path::new(&state.cwd)));
888 let invocable = skills.user_invocable();
889 if !invocable.is_empty() {
890 prompt.push_str("# Available Skills\n\n");
891 for skill in invocable {
892 let desc = skill.metadata.description.as_deref().unwrap_or("");
893 let when = skill.metadata.when_to_use.as_deref().unwrap_or("");
894 prompt.push_str(&format!("- `/{}`", skill.name));
895 if !desc.is_empty() {
896 prompt.push_str(&format!(": {desc}"));
897 }
898 if !when.is_empty() {
899 prompt.push_str(&format!(" (use when: {when})"));
900 }
901 prompt.push('\n');
902 }
903 prompt.push('\n');
904 }
905
906 prompt.push_str(
908 "# Using tools\n\n\
909 Use dedicated tools instead of shell commands when available:\n\
910 - File search: Glob (not find or ls)\n\
911 - Content search: Grep (not grep or rg)\n\
912 - Read files: FileRead (not cat/head/tail)\n\
913 - Edit files: FileEdit (not sed/awk)\n\
914 - Write files: FileWrite (not echo/cat with redirect)\n\
915 - Reserve Bash for system commands and operations that require shell execution.\n\
916 - Break complex tasks into steps. Use multiple tool calls in parallel when independent.\n\
917 - Use the Agent tool for complex multi-step research or tasks that benefit from isolation.\n\n\
918 # Working with code\n\n\
919 - Read files before editing them. Understand existing code before suggesting changes.\n\
920 - Prefer editing existing files over creating new ones to avoid file bloat.\n\
921 - Only make changes that were requested. Don't add features, refactor, add comments, \
922 or make \"improvements\" beyond the ask.\n\
923 - Don't add error handling for scenarios that can't happen. Don't design for \
924 hypothetical future requirements.\n\
925 - When referencing code, include file_path:line_number.\n\
926 - Be careful not to introduce security vulnerabilities (command injection, XSS, SQL injection, \
927 OWASP top 10). If you notice insecure code you wrote, fix it immediately.\n\
928 - Don't add docstrings, comments, or type annotations to code you didn't change.\n\
929 - Three similar lines of code is better than a premature abstraction.\n\n\
930 # Git safety protocol\n\n\
931 - NEVER update the git config.\n\
932 - NEVER run destructive git commands (push --force, reset --hard, checkout ., restore ., \
933 clean -f, branch -D) unless the user explicitly requests them.\n\
934 - NEVER skip hooks (--no-verify, --no-gpg-sign) unless the user explicitly requests it.\n\
935 - NEVER force push to main/master. Warn the user if they request it.\n\
936 - Always create NEW commits rather than amending, unless the user explicitly requests amend. \
937 After hook failure, the commit did NOT happen — amend would modify the PREVIOUS commit.\n\
938 - When staging files, prefer adding specific files by name rather than git add -A or git add ., \
939 which can accidentally include sensitive files.\n\
940 - NEVER commit changes unless the user explicitly asks.\n\n\
941 # Committing changes\n\n\
942 When the user asks to commit:\n\
943 1. Run git status and git diff to see all changes.\n\
944 2. Run git log --oneline -5 to match the repository's commit message style.\n\
945 3. Draft a concise (1-2 sentence) commit message focusing on \"why\" not \"what\".\n\
946 4. Do not commit files that likely contain secrets (.env, credentials.json).\n\
947 5. Stage specific files, create the commit.\n\
948 6. If pre-commit hook fails, fix the issue and create a NEW commit.\n\
949 7. When creating commits, include a co-author attribution line at the end of the message.\n\n\
950 # Creating pull requests\n\n\
951 When the user asks to create a PR:\n\
952 1. Run git status, git diff, and git log to understand all changes on the branch.\n\
953 2. Analyze ALL commits (not just the latest) that will be in the PR.\n\
954 3. Draft a short title (under 70 chars) and detailed body with summary and test plan.\n\
955 4. Push to remote with -u flag if needed, then create PR using gh pr create.\n\
956 5. Return the PR URL when done.\n\n\
957 # Executing actions safely\n\n\
958 Consider the reversibility and blast radius of every action:\n\
959 - Freely take local, reversible actions (editing files, running tests).\n\
960 - For hard-to-reverse or shared-state actions, confirm with the user first:\n\
961 - Destructive: deleting files/branches, dropping tables, rm -rf, overwriting uncommitted changes.\n\
962 - Hard to reverse: force-pushing, git reset --hard, amending published commits.\n\
963 - Visible to others: pushing code, creating/commenting on PRs/issues, sending messages.\n\
964 - When you encounter an obstacle, do not use destructive actions as a shortcut. \
965 Identify root causes and fix underlying issues.\n\
966 - If you discover unexpected state (unfamiliar files, branches, config), investigate \
967 before deleting or overwriting — it may be the user's in-progress work.\n\n\
968 # Response style\n\n\
969 - Be concise. Lead with the answer or action, not the reasoning.\n\
970 - Skip filler, preamble, and unnecessary transitions.\n\
971 - Don't restate what the user said.\n\
972 - If you can say it in one sentence, don't use three.\n\
973 - Focus output on: decisions that need input, status updates, and errors that change the plan.\n\
974 - When referencing GitHub issues or PRs, use owner/repo#123 format.\n\
975 - Only use emojis if the user explicitly requests it.\n\n\
976 # Memory\n\n\
977 You can save information across sessions by writing memory files.\n\
978 - Save to: ~/.config/agent-code/memory/ (one .md file per topic)\n\
979 - Each file needs YAML frontmatter: name, description, type (user/feedback/project/reference)\n\
980 - After writing a file, update MEMORY.md with a one-line pointer\n\
981 - Memory types: user (role, preferences), feedback (corrections, confirmations), \
982 project (decisions, deadlines), reference (external resources)\n\
983 - Do NOT store: code patterns, git history, debugging solutions, anything derivable from code\n\
984 - Memory is a hint — always verify against current state before acting on it\n",
985 );
986
987 prompt.push_str(
989 "# Tool usage patterns\n\n\
990 Common patterns for effective tool use:\n\n\
991 **Read before edit**: Always read a file before editing it. This ensures you \
992 understand the current state and can make targeted changes.\n\
993 ```\n\
994 1. FileRead file_path → understand structure\n\
995 2. FileEdit old_string, new_string → targeted change\n\
996 ```\n\n\
997 **Search then act**: Use Glob to find files, Grep to find content, then read/edit.\n\
998 ```\n\
999 1. Glob **/*.rs → find Rust files\n\
1000 2. Grep pattern path → find specific code\n\
1001 3. FileRead → read the match\n\
1002 4. FileEdit → make the change\n\
1003 ```\n\n\
1004 **Parallel tool calls**: When you need to read multiple independent files or run \
1005 independent searches, make all the tool calls in one response. Don't serialize \
1006 independent operations.\n\n\
1007 **Test after change**: After editing code, run tests to verify the change works.\n\
1008 ```\n\
1009 1. FileEdit → make change\n\
1010 2. Bash cargo test / pytest / npm test → verify\n\
1011 3. If tests fail, read the error, fix, re-test\n\
1012 ```\n\n\
1013 # Error recovery\n\n\
1014 When something goes wrong:\n\
1015 - **Tool not found**: Use ToolSearch to find the right tool name.\n\
1016 - **Permission denied**: Explain why the action is needed, ask the user to approve.\n\
1017 - **File not found**: Use Glob to find the correct path. Check for typos.\n\
1018 - **Edit failed (not unique)**: Provide more surrounding context in old_string, \
1019 or use replace_all=true if renaming.\n\
1020 - **Command failed**: Read the full error message. Don't retry the same command. \
1021 Diagnose the root cause first.\n\
1022 - **Context too large**: The system will auto-compact. If you need specific \
1023 information from before compaction, re-read the relevant files.\n\
1024 - **Rate limited**: The system will auto-retry with backoff. Just wait.\n\n\
1025 # Common workflows\n\n\
1026 **Bug fix**: Read the failing test → read the source code it tests → \
1027 identify the bug → fix it → run the test → confirm it passes.\n\n\
1028 **New feature**: Read existing patterns in the codebase → create or edit files → \
1029 add tests → run tests → update docs if needed.\n\n\
1030 **Code review**: Read the diff → identify issues (bugs, security, style) → \
1031 report findings with file:line references.\n\n\
1032 **Refactor**: Search for all usages of the symbol → plan the changes → \
1033 edit each file → run tests to verify nothing broke.\n\n",
1034 );
1035
1036 if !state.config.mcp_servers.is_empty() {
1038 prompt.push_str("# MCP Servers\n\n");
1039 prompt.push_str(
1040 "Connected MCP servers provide additional tools. MCP tools are prefixed \
1041 with `mcp__{server}__{tool}`. Use them like any other tool.\n\n",
1042 );
1043 for (name, entry) in &state.config.mcp_servers {
1044 let transport = if entry.command.is_some() {
1045 "stdio"
1046 } else if entry.url.is_some() {
1047 "sse"
1048 } else {
1049 "unknown"
1050 };
1051 prompt.push_str(&format!("- **{name}** ({transport})\n"));
1052 }
1053 prompt.push('\n');
1054 }
1055
1056 let deferred = tools.deferred_names();
1058 if !deferred.is_empty() {
1059 prompt.push_str("# Deferred Tools\n\n");
1060 prompt.push_str(
1061 "These tools are available but not loaded by default. \
1062 Use ToolSearch to load them when needed:\n",
1063 );
1064 for name in &deferred {
1065 prompt.push_str(&format!("- {name}\n"));
1066 }
1067 prompt.push('\n');
1068 }
1069
1070 prompt.push_str(
1072 "# Task management\n\n\
1073 - Use TaskCreate to break complex work into trackable steps.\n\
1074 - Mark tasks as in_progress when starting, completed when done.\n\
1075 - Use the Agent tool to spawn subagents for parallel independent work.\n\
1076 - Use EnterPlanMode/ExitPlanMode for read-only exploration before making changes.\n\
1077 - Use EnterWorktree/ExitWorktree for isolated changes in git worktrees.\n\n\
1078 # Output formatting\n\n\
1079 - All text output is displayed to the user. Use GitHub-flavored markdown.\n\
1080 - Use fenced code blocks with language hints for code: ```rust, ```python, etc.\n\
1081 - Use inline `code` for file names, function names, and short code references.\n\
1082 - Use tables for structured comparisons.\n\
1083 - Use bullet lists for multiple items.\n\
1084 - Keep paragraphs short (2-3 sentences).\n\
1085 - Never output raw HTML or complex formatting — stick to standard markdown.\n",
1086 );
1087
1088 prompt
1089}
1090
1091#[cfg(test)]
1092mod tests {
1093 use super::*;
1094
1095 #[test]
1099 fn cancel_shared_propagates_to_current_token() {
1100 let root = CancellationToken::new();
1101 let shared = Arc::new(std::sync::Mutex::new(root.clone()));
1102
1103 let turn1 = CancellationToken::new();
1105 *shared.lock().unwrap() = turn1.clone();
1106
1107 shared.lock().unwrap().cancel();
1109 assert!(turn1.is_cancelled());
1110
1111 let turn2 = CancellationToken::new();
1113 *shared.lock().unwrap() = turn2.clone();
1114 assert!(!turn2.is_cancelled());
1115
1116 shared.lock().unwrap().cancel();
1118 assert!(turn2.is_cancelled());
1119 }
1120
1121 #[tokio::test]
1124 async fn stream_loop_responds_to_cancellation() {
1125 let cancel = CancellationToken::new();
1126 let (tx, mut rx) = tokio::sync::mpsc::channel::<StreamEvent>(10);
1127
1128 tx.send(StreamEvent::TextDelta("hello".into()))
1130 .await
1131 .unwrap();
1132
1133 let cancel2 = cancel.clone();
1134 tokio::spawn(async move {
1135 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1137 cancel2.cancel();
1138 });
1139
1140 let mut events_received = 0u32;
1141 let mut cancelled = false;
1142
1143 loop {
1144 tokio::select! {
1145 event = rx.recv() => {
1146 match event {
1147 Some(_) => events_received += 1,
1148 None => break,
1149 }
1150 }
1151 _ = cancel.cancelled() => {
1152 cancelled = true;
1153 break;
1154 }
1155 }
1156 }
1157
1158 assert!(cancelled, "Loop should have been cancelled");
1159 assert_eq!(
1160 events_received, 1,
1161 "Should have received exactly one event before cancel"
1162 );
1163 }
1164
1165 use crate::llm::provider::{Provider, ProviderError, ProviderRequest};
1175
1176 struct HangingProvider;
1179
1180 #[async_trait::async_trait]
1181 impl Provider for HangingProvider {
1182 fn name(&self) -> &str {
1183 "hanging-mock"
1184 }
1185
1186 async fn stream(
1187 &self,
1188 _request: &ProviderRequest,
1189 ) -> Result<tokio::sync::mpsc::Receiver<StreamEvent>, ProviderError> {
1190 let (tx, rx) = tokio::sync::mpsc::channel(4);
1191 tokio::spawn(async move {
1192 let _ = tx.send(StreamEvent::TextDelta("thinking...".into())).await;
1193 let _tx_holder = tx;
1195 std::future::pending::<()>().await;
1196 });
1197 Ok(rx)
1198 }
1199 }
1200
1201 struct CancelAwareHangingProvider {
1207 exit_flag: Arc<std::sync::atomic::AtomicBool>,
1208 }
1209
1210 #[async_trait::async_trait]
1211 impl Provider for CancelAwareHangingProvider {
1212 fn name(&self) -> &str {
1213 "cancel-aware-mock"
1214 }
1215
1216 async fn stream(
1217 &self,
1218 request: &ProviderRequest,
1219 ) -> Result<tokio::sync::mpsc::Receiver<StreamEvent>, ProviderError> {
1220 let (tx, rx) = tokio::sync::mpsc::channel(4);
1221 let cancel = request.cancel.clone();
1222 let exit_flag = self.exit_flag.clone();
1223 tokio::spawn(async move {
1224 let _ = tx.send(StreamEvent::TextDelta("thinking...".into())).await;
1225 tokio::select! {
1229 biased;
1230 _ = cancel.cancelled() => {
1231 exit_flag.store(true, std::sync::atomic::Ordering::SeqCst);
1232 }
1233 _ = std::future::pending::<()>() => unreachable!(),
1234 }
1235 });
1236 Ok(rx)
1237 }
1238 }
1239
1240 struct CompletingProvider;
1242
1243 #[async_trait::async_trait]
1244 impl Provider for CompletingProvider {
1245 fn name(&self) -> &str {
1246 "completing-mock"
1247 }
1248
1249 async fn stream(
1250 &self,
1251 _request: &ProviderRequest,
1252 ) -> Result<tokio::sync::mpsc::Receiver<StreamEvent>, ProviderError> {
1253 let (tx, rx) = tokio::sync::mpsc::channel(8);
1254 tokio::spawn(async move {
1255 let _ = tx.send(StreamEvent::TextDelta("hello".into())).await;
1256 let _ = tx
1257 .send(StreamEvent::ContentBlockComplete(ContentBlock::Text {
1258 text: "hello".into(),
1259 }))
1260 .await;
1261 let _ = tx
1262 .send(StreamEvent::Done {
1263 usage: Usage::default(),
1264 stop_reason: Some(StopReason::EndTurn),
1265 })
1266 .await;
1267 });
1269 Ok(rx)
1270 }
1271 }
1272
1273 fn build_engine(llm: Arc<dyn Provider>) -> QueryEngine {
1274 use crate::config::Config;
1275 use crate::permissions::PermissionChecker;
1276 use crate::state::AppState;
1277 use crate::tools::registry::ToolRegistry;
1278
1279 let config = Config::default();
1280 let permissions = PermissionChecker::from_config(&config.permissions);
1281 let state = AppState::new(config);
1282
1283 QueryEngine::new(
1284 llm,
1285 ToolRegistry::default_tools(),
1286 permissions,
1287 state,
1288 QueryEngineConfig {
1289 max_turns: Some(1),
1290 verbose: false,
1291 unattended: true,
1292 },
1293 )
1294 }
1295
1296 fn schedule_cancel(engine: &QueryEngine, delay_ms: u64) {
1299 let shared = engine.cancel_shared.clone();
1300 tokio::spawn(async move {
1301 tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
1302 shared.lock().unwrap().cancel();
1303 });
1304 }
1305
1306 #[tokio::test]
1309 async fn run_turn_with_sink_interrupts_on_cancel() {
1310 let mut engine = build_engine(Arc::new(HangingProvider));
1311 schedule_cancel(&engine, 100);
1312
1313 let result = tokio::time::timeout(
1314 std::time::Duration::from_secs(5),
1315 engine.run_turn_with_sink("test input", &NullSink),
1316 )
1317 .await;
1318
1319 assert!(
1320 result.is_ok(),
1321 "run_turn_with_sink should return promptly on cancel, not hang"
1322 );
1323 assert!(
1324 result.unwrap().is_ok(),
1325 "cancelled turn should return Ok(()), not an error"
1326 );
1327 assert!(
1328 !engine.state().is_query_active,
1329 "is_query_active should be reset after cancel"
1330 );
1331 }
1332
1333 #[tokio::test]
1338 async fn cancel_works_across_multiple_turns() {
1339 let mut engine = build_engine(Arc::new(HangingProvider));
1340
1341 schedule_cancel(&engine, 80);
1343 let r1 = tokio::time::timeout(
1344 std::time::Duration::from_secs(5),
1345 engine.run_turn_with_sink("turn 1", &NullSink),
1346 )
1347 .await;
1348 assert!(r1.is_ok(), "turn 1 should cancel promptly");
1349 assert!(!engine.state().is_query_active);
1350
1351 schedule_cancel(&engine, 80);
1355 let r2 = tokio::time::timeout(
1356 std::time::Duration::from_secs(5),
1357 engine.run_turn_with_sink("turn 2", &NullSink),
1358 )
1359 .await;
1360 assert!(
1361 r2.is_ok(),
1362 "turn 2 should also cancel promptly — regression would hang here"
1363 );
1364 assert!(!engine.state().is_query_active);
1365
1366 schedule_cancel(&engine, 80);
1368 let r3 = tokio::time::timeout(
1369 std::time::Duration::from_secs(5),
1370 engine.run_turn_with_sink("turn 3", &NullSink),
1371 )
1372 .await;
1373 assert!(r3.is_ok(), "turn 3 should still be cancellable");
1374 assert!(!engine.state().is_query_active);
1375 }
1376
1377 #[tokio::test]
1381 async fn cancel_does_not_poison_next_turn() {
1382 let mut engine = build_engine(Arc::new(HangingProvider));
1384 schedule_cancel(&engine, 80);
1385 let _ = tokio::time::timeout(
1386 std::time::Duration::from_secs(5),
1387 engine.run_turn_with_sink("turn 1", &NullSink),
1388 )
1389 .await
1390 .expect("turn 1 should cancel");
1391
1392 let mut engine2 = build_engine(Arc::new(CompletingProvider));
1398
1399 engine2.cancel_shared.lock().unwrap().cancel();
1402
1403 let result = tokio::time::timeout(
1404 std::time::Duration::from_secs(5),
1405 engine2.run_turn_with_sink("hello", &NullSink),
1406 )
1407 .await;
1408
1409 assert!(result.is_ok(), "completing turn should not hang");
1410 assert!(
1411 result.unwrap().is_ok(),
1412 "turn should succeed — the stale cancel flag must be cleared on turn start"
1413 );
1414 assert!(
1416 engine2.state().messages.len() >= 2,
1417 "normal turn should push both user and assistant messages"
1418 );
1419 }
1420
1421 #[tokio::test]
1424 async fn cancel_before_first_event_interrupts_cleanly() {
1425 let mut engine = build_engine(Arc::new(HangingProvider));
1426 schedule_cancel(&engine, 1);
1429
1430 let result = tokio::time::timeout(
1431 std::time::Duration::from_secs(5),
1432 engine.run_turn_with_sink("immediate", &NullSink),
1433 )
1434 .await;
1435
1436 assert!(result.is_ok(), "early cancel should not hang");
1437 assert!(result.unwrap().is_ok());
1438 assert!(!engine.state().is_query_active);
1439 }
1440
1441 #[tokio::test]
1443 async fn cancelled_turn_emits_warning_to_sink() {
1444 use std::sync::Mutex;
1445
1446 struct CapturingSink {
1448 warnings: Mutex<Vec<String>>,
1449 }
1450
1451 impl StreamSink for CapturingSink {
1452 fn on_text(&self, _: &str) {}
1453 fn on_tool_start(&self, _: &str, _: &serde_json::Value) {}
1454 fn on_tool_result(&self, _: &str, _: &crate::tools::ToolResult) {}
1455 fn on_error(&self, _: &str) {}
1456 fn on_warning(&self, msg: &str) {
1457 self.warnings.lock().unwrap().push(msg.to_string());
1458 }
1459 }
1460
1461 let sink = CapturingSink {
1462 warnings: Mutex::new(Vec::new()),
1463 };
1464
1465 let mut engine = build_engine(Arc::new(HangingProvider));
1466 schedule_cancel(&engine, 100);
1467
1468 let _ = tokio::time::timeout(
1469 std::time::Duration::from_secs(5),
1470 engine.run_turn_with_sink("test", &sink),
1471 )
1472 .await
1473 .expect("should not hang");
1474
1475 let warnings = sink.warnings.lock().unwrap();
1476 assert!(
1477 warnings.iter().any(|w| w.contains("Cancelled")),
1478 "expected 'Cancelled' warning in sink, got: {:?}",
1479 *warnings
1480 );
1481 }
1482
1483 #[tokio::test]
1492 async fn provider_stream_task_observes_cancellation() {
1493 use std::sync::atomic::{AtomicBool, Ordering};
1494
1495 let exit_flag = Arc::new(AtomicBool::new(false));
1496 let provider = Arc::new(CancelAwareHangingProvider {
1497 exit_flag: exit_flag.clone(),
1498 });
1499 let mut engine = build_engine(provider);
1500 schedule_cancel(&engine, 50);
1501
1502 let result = tokio::time::timeout(
1503 std::time::Duration::from_secs(2),
1504 engine.run_turn_with_sink("test input", &NullSink),
1505 )
1506 .await;
1507 assert!(result.is_ok(), "engine should exit promptly on cancel");
1508
1509 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
1514
1515 assert!(
1516 exit_flag.load(Ordering::SeqCst),
1517 "provider's streaming task should have observed cancel via \
1518 ProviderRequest::cancel and exited; if this flag is false, \
1519 the token is being dropped somewhere in query::mod.rs or the \
1520 provider is ignoring it"
1521 );
1522 }
1523}