Skip to main content

agent_sdk/
query.rs

1//! The query function and agent loop implementation.
2//!
3//! This module contains the core `query()` function that creates an async stream
4//! of messages, driving Claude through the agentic loop of prompt → response →
5//! tool calls → tool results → repeat.
6
7use std::collections::HashMap;
8use std::path::PathBuf;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use std::time::Instant;
12
13use futures::stream::FuturesUnordered;
14use futures::{Stream, StreamExt as FuturesStreamExt};
15use serde_json::json;
16use tokio::sync::mpsc;
17use tokio_stream::wrappers::UnboundedReceiverStream;
18use tracing::{debug, error, warn};
19use uuid::Uuid;
20
21use crate::client::{
22    ApiContentBlock, ApiMessage, ApiUsage, CacheControl, ContentDelta, CreateMessageRequest,
23    ImageSource, MessageResponse, StreamEvent as ClientStreamEvent, SystemBlock, ThinkingParam,
24    ToolDefinition,
25};
26use crate::compact;
27use crate::error::{AgentError, Result};
28use crate::hooks::HookRegistry;
29use crate::options::{Options, PermissionMode, ThinkingConfig};
30use crate::permissions::{PermissionEvaluator, PermissionVerdict};
31use crate::provider::LlmProvider;
32use crate::providers::AnthropicProvider;
33use crate::sanitize;
34use crate::session::Session;
35use crate::tools::definitions::get_tool_definitions;
36use crate::tools::executor::{ToolExecutor, ToolResult};
37use crate::types::messages::*;
38
39/// Default model to use when none is specified.
40const DEFAULT_MODEL: &str = "claude-haiku-4-5";
41/// Default max tokens for API responses.
42const DEFAULT_MAX_TOKENS: u32 = 16384;
43
44/// A handle to a running query that streams messages.
45///
46/// Implements `Stream<Item = Result<Message>>` for async iteration.
47pub struct Query {
48    receiver: UnboundedReceiverStream<Result<Message>>,
49    session_id: Option<String>,
50    cancel_token: tokio_util::sync::CancellationToken,
51}
52
53impl Query {
54    /// Interrupt the current query.
55    pub async fn interrupt(&self) -> Result<()> {
56        self.cancel_token.cancel();
57        Ok(())
58    }
59
60    /// Get the session ID (available after the init message).
61    pub fn session_id(&self) -> Option<&str> {
62        self.session_id.as_deref()
63    }
64
65    /// Change the permission mode mid-session.
66    pub async fn set_permission_mode(&self, _mode: PermissionMode) -> Result<()> {
67        // TODO: Send control message to the running agent loop
68        Ok(())
69    }
70
71    /// Change the model mid-session.
72    pub async fn set_model(&self, _model: &str) -> Result<()> {
73        // TODO: Send control message to the running agent loop
74        Ok(())
75    }
76
77    /// Close the query and terminate the underlying process.
78    pub fn close(&self) {
79        self.cancel_token.cancel();
80    }
81}
82
83impl Stream for Query {
84    type Item = Result<Message>;
85
86    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
87        Pin::new(&mut self.receiver).poll_next(cx)
88    }
89}
90
91/// Create a query that streams messages from Claude.
92///
93/// This is the primary function for interacting with the Claude Agent SDK.
94/// Returns a [`Query`] stream that yields [`Message`] items as the agent loop
95/// progresses.
96///
97/// # Arguments
98///
99/// * `prompt` - The input prompt string
100/// * `options` - Configuration options for the query
101///
102/// # Example
103///
104/// ```rust,no_run
105/// use agent_sdk::{query, Options, Message};
106/// use tokio_stream::StreamExt;
107///
108/// # async fn example() -> anyhow::Result<()> {
109/// let mut stream = query(
110///     "What files are in this directory?",
111///     Options::builder()
112///         .allowed_tools(vec!["Bash".into(), "Glob".into()])
113///         .build(),
114/// );
115///
116/// while let Some(message) = stream.next().await {
117///     let message = message?;
118///     if let Message::Result(result) = &message {
119///         println!("{}", result.result.as_deref().unwrap_or(""));
120///     }
121/// }
122/// # Ok(())
123/// # }
124/// ```
125pub fn query(prompt: &str, options: Options) -> Query {
126    let (tx, rx) = mpsc::unbounded_channel();
127    let cancel_token = tokio_util::sync::CancellationToken::new();
128    let cancel = cancel_token.clone();
129
130    let prompt = prompt.to_string();
131
132    tokio::spawn(async move {
133        let result = run_agent_loop(prompt, options, tx.clone(), cancel).await;
134        if let Err(e) = result {
135            let _ = tx.send(Err(e));
136        }
137    });
138
139    Query {
140        receiver: UnboundedReceiverStream::new(rx),
141        session_id: None,
142        cancel_token,
143    }
144}
145
146/// The main agent loop.
147///
148/// This implements the core cycle:
149/// 1. Receive prompt
150/// 2. Send to Claude
151/// 3. Process response (text + tool calls)
152/// 4. Execute tools
153/// 5. Feed results back
154/// 6. Repeat until done or limits hit
155async fn run_agent_loop(
156    prompt: String,
157    mut options: Options,
158    tx: mpsc::UnboundedSender<Result<Message>>,
159    cancel: tokio_util::sync::CancellationToken,
160) -> Result<()> {
161    let start_time = Instant::now();
162    let mut api_time_ms: u64 = 0;
163
164    // Resolve working directory
165    let cwd = options.cwd.clone().unwrap_or_else(|| {
166        std::env::current_dir()
167            .unwrap_or_else(|_| PathBuf::from("."))
168            .to_string_lossy()
169            .to_string()
170    });
171
172    // Create or resume session
173    let session = if let Some(ref resume_id) = options.resume {
174        Session::with_id(resume_id, &cwd)
175    } else if options.continue_session {
176        // Find most recent session
177        match crate::session::find_most_recent_session(Some(&cwd)).await? {
178            Some(info) => Session::with_id(&info.session_id, &cwd),
179            None => Session::new(&cwd),
180        }
181    } else {
182        match &options.session_id {
183            Some(id) => Session::with_id(id, &cwd),
184            None => Session::new(&cwd),
185        }
186    };
187
188    let session_id = session.id.clone();
189    let model = options
190        .model
191        .clone()
192        .unwrap_or_else(|| DEFAULT_MODEL.to_string());
193
194    // Build tool definitions (skip tools entirely when output_format is set —
195    // structured-output queries should not use tools).
196    let tool_names: Vec<String> = if options.output_format.is_some() {
197        Vec::new()
198    } else if options.allowed_tools.is_empty() {
199        // Default set of tools
200        vec![
201            "Read".into(),
202            "Write".into(),
203            "Edit".into(),
204            "Bash".into(),
205            "Glob".into(),
206            "Grep".into(),
207        ]
208    } else {
209        options.allowed_tools.clone()
210    };
211
212    let raw_defs: Vec<_> = get_tool_definitions(&tool_names);
213
214    // Combine built-in + custom tool definitions
215    let mut all_defs: Vec<ToolDefinition> = raw_defs
216        .into_iter()
217        .map(|td| ToolDefinition {
218            name: td.name.to_string(),
219            description: td.description.to_string(),
220            input_schema: td.input_schema,
221            cache_control: None,
222        })
223        .collect();
224
225    // Append custom tool definitions
226    for ctd in &options.custom_tool_definitions {
227        all_defs.push(ToolDefinition {
228            name: ctd.name.clone(),
229            description: ctd.description.clone(),
230            input_schema: ctd.input_schema.clone(),
231            cache_control: None,
232        });
233    }
234
235    // Mark the last tool with cache_control so the tools block is cached
236    if let Some(last) = all_defs.last_mut() {
237        last.cache_control = Some(CacheControl::ephemeral());
238    }
239
240    let tool_defs = all_defs;
241
242    // Emit init system message
243    let init_msg = Message::System(SystemMessage {
244        subtype: SystemSubtype::Init,
245        uuid: Uuid::new_v4(),
246        session_id: session_id.clone(),
247        agents: if options.agents.is_empty() {
248            None
249        } else {
250            Some(options.agents.keys().cloned().collect())
251        },
252        claude_code_version: Some(env!("CARGO_PKG_VERSION").to_string()),
253        cwd: Some(cwd.clone()),
254        tools: Some(tool_names.clone()),
255        mcp_servers: if options.mcp_servers.is_empty() {
256            None
257        } else {
258            Some(
259                options
260                    .mcp_servers
261                    .keys()
262                    .map(|name| McpServerStatus {
263                        name: name.clone(),
264                        status: "connected".to_string(),
265                    })
266                    .collect(),
267            )
268        },
269        model: Some(model.clone()),
270        permission_mode: Some(options.permission_mode.to_string()),
271        compact_metadata: None,
272    });
273
274    // Persist and emit init message
275    if options.persist_session {
276        let _ = session
277            .append_message(&serde_json::to_value(&init_msg).unwrap_or_default())
278            .await;
279    }
280    if tx.send(Ok(init_msg)).is_err() {
281        return Ok(());
282    }
283
284    // Initialize LLM provider
285    let provider: Box<dyn LlmProvider> = match options.provider.take() {
286        Some(p) => p,
287        None => Box::new(AnthropicProvider::from_env()?),
288    };
289
290    // Initialize tool executor with optional path boundary
291    let additional_dirs: Vec<PathBuf> = options
292        .additional_directories
293        .iter()
294        .map(PathBuf::from)
295        .collect();
296    let env_blocklist = std::mem::take(&mut options.env_blocklist);
297    let env_inject = std::mem::take(&mut options.env);
298    #[cfg(unix)]
299    let pre_exec_fn = options.pre_exec_fn.take();
300    let mut tool_executor = if additional_dirs.is_empty() {
301        ToolExecutor::new(PathBuf::from(&cwd))
302    } else {
303        ToolExecutor::with_allowed_dirs(PathBuf::from(&cwd), additional_dirs)
304    }
305    .with_env_blocklist(env_blocklist)
306    .with_env_inject(env_inject);
307    #[cfg(unix)]
308    if let Some(f) = pre_exec_fn {
309        tool_executor = tool_executor.with_pre_exec(f);
310    }
311
312    // Build hook registry from options, merging file-discovered hooks
313    let mut hook_registry = HookRegistry::from_map(std::mem::take(&mut options.hooks));
314    if !options.hook_dirs.is_empty() {
315        let dirs: Vec<&std::path::Path> = options.hook_dirs.iter().map(|p| p.as_path()).collect();
316        match crate::hooks::HookDiscovery::discover(&dirs) {
317            Ok(discovered) => hook_registry.merge(discovered),
318            Err(e) => tracing::warn!("Failed to discover hooks from dirs: {}", e),
319        }
320    }
321
322    // Take followup_rx out of options before borrowing options immutably
323    let mut followup_rx = options.followup_rx.take();
324
325    // Initialize permission evaluator
326    let permission_eval = PermissionEvaluator::new(&options);
327
328    // Build the system prompt as SystemBlock(s) with prompt caching
329    let system_prompt: Option<Vec<SystemBlock>> = {
330        let text = match &options.system_prompt {
331            Some(crate::options::SystemPrompt::Custom(s)) => s.clone(),
332            Some(crate::options::SystemPrompt::Preset { append, .. }) => {
333                let base = "You are Claude, an AI assistant. You have access to tools to help accomplish tasks.";
334                match append {
335                    Some(extra) => format!("{}\n\n{}", base, extra),
336                    None => base.to_string(),
337                }
338            }
339            None => "You are Claude, an AI assistant. You have access to tools to help accomplish tasks.".to_string(),
340        };
341        Some(vec![SystemBlock {
342            kind: "text".to_string(),
343            text,
344            cache_control: Some(CacheControl::ephemeral()),
345        }])
346    };
347
348    // Build initial conversation from prompt
349    let mut conversation: Vec<ApiMessage> = Vec::new();
350
351    // Load previous messages if resuming
352    if options.resume.is_some() || options.continue_session {
353        let prev_messages = session.load_messages().await?;
354        for msg_value in prev_messages {
355            if let Some(api_msg) = value_to_api_message(&msg_value) {
356                conversation.push(api_msg);
357            }
358        }
359    }
360
361    // Add the user prompt (with optional image attachments)
362    {
363        let mut content_blocks: Vec<ApiContentBlock> = Vec::new();
364
365        // Add image attachments as Image content blocks
366        for att in &options.attachments {
367            let is_image = matches!(
368                att.mime_type.as_str(),
369                "image/png" | "image/jpeg" | "image/gif" | "image/webp"
370            );
371            if is_image {
372                content_blocks.push(ApiContentBlock::Image {
373                    source: ImageSource {
374                        kind: "base64".to_string(),
375                        media_type: att.mime_type.clone(),
376                        data: att.base64_data.clone(),
377                    },
378                });
379            }
380        }
381
382        // Add the text prompt
383        content_blocks.push(ApiContentBlock::Text {
384            text: prompt.clone(),
385            cache_control: None,
386        });
387
388        conversation.push(ApiMessage {
389            role: "user".to_string(),
390            content: content_blocks,
391        });
392    }
393
394    // Persist user message
395    if options.persist_session {
396        let user_msg = json!({
397            "type": "user",
398            "uuid": Uuid::new_v4().to_string(),
399            "session_id": &session_id,
400            "content": [{"type": "text", "text": &prompt}]
401        });
402        let _ = session.append_message(&user_msg).await;
403    }
404
405    // Agent loop
406    let mut num_turns: u32 = 0;
407    let mut total_usage = Usage::default();
408    let mut total_cost: f64 = 0.0;
409    let mut model_usage: HashMap<String, ModelUsage> = HashMap::new();
410    let mut permission_denials: Vec<PermissionDenial> = Vec::new();
411
412    loop {
413        // Check cancellation
414        if cancel.is_cancelled() {
415            return Err(AgentError::Cancelled);
416        }
417
418        // Check turn limit
419        if let Some(max_turns) = options.max_turns {
420            if num_turns >= max_turns {
421                let result_msg = build_result_message(
422                    ResultSubtype::ErrorMaxTurns,
423                    &session_id,
424                    None,
425                    start_time,
426                    api_time_ms,
427                    num_turns,
428                    total_cost,
429                    &total_usage,
430                    &model_usage,
431                    &permission_denials,
432                );
433                let _ = tx.send(Ok(result_msg));
434                return Ok(());
435            }
436        }
437
438        // Check budget limit
439        if let Some(max_budget) = options.max_budget_usd {
440            if total_cost >= max_budget {
441                let result_msg = build_result_message(
442                    ResultSubtype::ErrorMaxBudgetUsd,
443                    &session_id,
444                    None,
445                    start_time,
446                    api_time_ms,
447                    num_turns,
448                    total_cost,
449                    &total_usage,
450                    &model_usage,
451                    &permission_denials,
452                );
453                let _ = tx.send(Ok(result_msg));
454                return Ok(());
455            }
456        }
457
458        // Drain any followup messages that arrived while we were processing.
459        // These are batched into a single user message appended to the conversation
460        // so the model sees them on the next API call.
461        if let Some(ref mut followup_rx) = followup_rx {
462            let mut followups: Vec<String> = Vec::new();
463            while let Ok(msg) = followup_rx.try_recv() {
464                followups.push(msg);
465            }
466            if !followups.is_empty() {
467                let combined = followups.join("\n\n");
468                debug!(
469                    count = followups.len(),
470                    "Injecting followup messages into agent loop"
471                );
472
473                conversation.push(ApiMessage {
474                    role: "user".to_string(),
475                    content: vec![ApiContentBlock::Text {
476                        text: combined.clone(),
477                        cache_control: None,
478                    }],
479                });
480
481                // Emit a user message so downstream consumers know about the injection
482                let followup_msg = Message::User(UserMessage {
483                    uuid: Some(Uuid::new_v4()),
484                    session_id: session_id.clone(),
485                    content: vec![ContentBlock::Text { text: combined }],
486                    parent_tool_use_id: None,
487                    is_synthetic: false,
488                    tool_use_result: None,
489                });
490
491                if options.persist_session {
492                    let _ = session
493                        .append_message(&serde_json::to_value(&followup_msg).unwrap_or_default())
494                        .await;
495                }
496                if tx.send(Ok(followup_msg)).is_err() {
497                    return Ok(());
498                }
499            }
500        }
501
502        // Set a cache breakpoint on the last content block of the last user
503        // message. This keeps the total breakpoints at 3 (system + tools + last
504        // user turn), well within the API limit of 4.
505        apply_cache_breakpoint(&mut conversation);
506
507        // Build thinking param from options
508        let thinking_param = options.thinking.as_ref().map(|tc| match tc {
509            ThinkingConfig::Adaptive => ThinkingParam {
510                kind: "enabled".into(),
511                budget_tokens: Some(10240),
512            },
513            ThinkingConfig::Disabled => ThinkingParam {
514                kind: "disabled".into(),
515                budget_tokens: None,
516            },
517            ThinkingConfig::Enabled { budget_tokens } => ThinkingParam {
518                kind: "enabled".into(),
519                budget_tokens: Some(*budget_tokens),
520            },
521        });
522
523        // Increase max_tokens when thinking is enabled
524        let base_max_tokens = options.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS);
525        let max_tokens = if let Some(ref tp) = thinking_param {
526            if let Some(budget) = tp.budget_tokens {
527                base_max_tokens.max(budget as u32 + 8192)
528            } else {
529                base_max_tokens
530            }
531        } else {
532            base_max_tokens
533        };
534
535        // Build the API request
536        let use_streaming = options.include_partial_messages;
537        let request = CreateMessageRequest {
538            model: model.clone(),
539            max_tokens,
540            messages: conversation.clone(),
541            system: system_prompt.clone(),
542            tools: if tool_defs.is_empty() {
543                None
544            } else {
545                Some(tool_defs.clone())
546            },
547            stream: use_streaming,
548            metadata: None,
549            thinking: thinking_param,
550        };
551
552        // Call LLM provider
553        let api_start = Instant::now();
554        let response = if use_streaming {
555            // Streaming mode: consume SSE events, emit text deltas, accumulate full response
556            match provider.create_message_stream(&request).await {
557                Ok(mut event_stream) => {
558                    match accumulate_stream(&mut event_stream, &tx, &session_id).await {
559                        Ok(resp) => resp,
560                        Err(e) => {
561                            error!("Stream accumulation failed: {}", e);
562                            let result_msg = build_error_result_message(
563                                &session_id,
564                                &format!("Stream error: {}", e),
565                                start_time,
566                                api_time_ms,
567                                num_turns,
568                                total_cost,
569                                &total_usage,
570                                &model_usage,
571                                &permission_denials,
572                            );
573                            let _ = tx.send(Ok(result_msg));
574                            return Ok(());
575                        }
576                    }
577                }
578                Err(e) => {
579                    error!("API stream call failed: {}", e);
580                    let result_msg = build_error_result_message(
581                        &session_id,
582                        &format!("API error: {}", e),
583                        start_time,
584                        api_time_ms,
585                        num_turns,
586                        total_cost,
587                        &total_usage,
588                        &model_usage,
589                        &permission_denials,
590                    );
591                    let _ = tx.send(Ok(result_msg));
592                    return Ok(());
593                }
594            }
595        } else {
596            // Non-streaming mode: single request/response
597            match provider.create_message(&request).await {
598                Ok(resp) => resp,
599                Err(e) => {
600                    error!("API call failed: {}", e);
601                    let result_msg = build_error_result_message(
602                        &session_id,
603                        &format!("API error: {}", e),
604                        start_time,
605                        api_time_ms,
606                        num_turns,
607                        total_cost,
608                        &total_usage,
609                        &model_usage,
610                        &permission_denials,
611                    );
612                    let _ = tx.send(Ok(result_msg));
613                    return Ok(());
614                }
615            }
616        };
617        api_time_ms += api_start.elapsed().as_millis() as u64;
618
619        // Update usage
620        total_usage.input_tokens += response.usage.input_tokens;
621        total_usage.output_tokens += response.usage.output_tokens;
622        total_usage.cache_creation_input_tokens +=
623            response.usage.cache_creation_input_tokens.unwrap_or(0);
624        total_usage.cache_read_input_tokens += response.usage.cache_read_input_tokens.unwrap_or(0);
625
626        // Estimate cost using provider-specific rates (with cache-aware pricing)
627        let rates = provider.cost_rates(&model);
628        let turn_cost = rates.compute_with_cache(
629            response.usage.input_tokens,
630            response.usage.output_tokens,
631            response.usage.cache_read_input_tokens.unwrap_or(0),
632            response.usage.cache_creation_input_tokens.unwrap_or(0),
633        );
634        total_cost += turn_cost;
635
636        // Update model usage
637        let model_entry = model_usage.entry(model.clone()).or_default();
638        model_entry.input_tokens += response.usage.input_tokens;
639        model_entry.output_tokens += response.usage.output_tokens;
640        model_entry.cost_usd += turn_cost;
641
642        // Convert response to our message types
643        let content_blocks: Vec<ContentBlock> = response
644            .content
645            .iter()
646            .map(api_block_to_content_block)
647            .collect();
648
649        // Emit assistant message
650        let assistant_msg = Message::Assistant(AssistantMessage {
651            uuid: Uuid::new_v4(),
652            session_id: session_id.clone(),
653            content: content_blocks.clone(),
654            model: response.model.clone(),
655            stop_reason: response.stop_reason.clone(),
656            parent_tool_use_id: None,
657            usage: Some(Usage {
658                input_tokens: response.usage.input_tokens,
659                output_tokens: response.usage.output_tokens,
660                cache_creation_input_tokens: response
661                    .usage
662                    .cache_creation_input_tokens
663                    .unwrap_or(0),
664                cache_read_input_tokens: response.usage.cache_read_input_tokens.unwrap_or(0),
665            }),
666            error: None,
667        });
668
669        if options.persist_session {
670            let _ = session
671                .append_message(&serde_json::to_value(&assistant_msg).unwrap_or_default())
672                .await;
673        }
674        if tx.send(Ok(assistant_msg)).is_err() {
675            return Ok(());
676        }
677
678        // Add assistant response to conversation
679        conversation.push(ApiMessage {
680            role: "assistant".to_string(),
681            content: response.content.clone(),
682        });
683
684        // Check if there are tool calls
685        let tool_uses: Vec<_> = response
686            .content
687            .iter()
688            .filter_map(|block| match block {
689                ApiContentBlock::ToolUse { id, name, input } => {
690                    Some((id.clone(), name.clone(), input.clone()))
691                }
692                _ => None,
693            })
694            .collect();
695
696        // If no tool calls, we're done
697        if tool_uses.is_empty() {
698            // Extract final text
699            let final_text = response
700                .content
701                .iter()
702                .filter_map(|block| match block {
703                    ApiContentBlock::Text { text, .. } => Some(text.as_str()),
704                    _ => None,
705                })
706                .collect::<Vec<_>>()
707                .join("");
708
709            let result_msg = build_result_message(
710                ResultSubtype::Success,
711                &session_id,
712                Some(final_text),
713                start_time,
714                api_time_ms,
715                num_turns,
716                total_cost,
717                &total_usage,
718                &model_usage,
719                &permission_denials,
720            );
721
722            if options.persist_session {
723                let _ = session
724                    .append_message(&serde_json::to_value(&result_msg).unwrap_or_default())
725                    .await;
726            }
727            let _ = tx.send(Ok(result_msg));
728            return Ok(());
729        }
730
731        // Execute tool calls
732        num_turns += 1;
733        let mut tool_results: Vec<ApiContentBlock> = Vec::new();
734
735        // Phase 0: Reject hallucinated tool names immediately with a helpful error.
736        // Collect known tool names from the definitions we sent to the model.
737        let known_tool_names: std::collections::HashSet<&str> =
738            tool_defs.iter().map(|td| td.name.as_str()).collect();
739
740        let mut valid_tool_uses: Vec<&(String, String, serde_json::Value)> = Vec::new();
741        for tu in &tool_uses {
742            let (tool_use_id, tool_name, _tool_input) = tu;
743            if known_tool_names.contains(tool_name.as_str()) {
744                valid_tool_uses.push(tu);
745            } else {
746                warn!(tool = %tool_name, "model invoked unknown tool, returning error");
747                let available: Vec<&str> = tool_defs.iter().map(|td| td.name.as_str()).collect();
748                let error_msg = format!(
749                    "Error: '{}' is not a valid tool. You MUST use one of the following tools: {}",
750                    tool_name,
751                    available.join(", ")
752                );
753                let api_block = ApiContentBlock::ToolResult {
754                    tool_use_id: tool_use_id.clone(),
755                    content: json!(error_msg),
756                    is_error: Some(true),
757                    cache_control: None,
758                    name: Some(tool_name.clone()),
759                };
760
761                // Stream the error to the frontend
762                let result_msg = Message::User(UserMessage {
763                    uuid: Some(Uuid::new_v4()),
764                    session_id: session_id.clone(),
765                    content: vec![api_block_to_content_block(&api_block)],
766                    parent_tool_use_id: None,
767                    is_synthetic: true,
768                    tool_use_result: None,
769                });
770                if options.persist_session {
771                    let _ = session
772                        .append_message(&serde_json::to_value(&result_msg).unwrap_or_default())
773                        .await;
774                }
775                if tx.send(Ok(result_msg)).is_err() {
776                    return Ok(());
777                }
778
779                tool_results.push(api_block);
780            }
781        }
782
783        // Phase 1: Evaluate permissions sequentially (may involve user interaction)
784        struct PermittedTool {
785            tool_use_id: String,
786            tool_name: String,
787            actual_input: serde_json::Value,
788        }
789        let mut permitted_tools: Vec<PermittedTool> = Vec::new();
790
791        for (tool_use_id, tool_name, tool_input) in valid_tool_uses.iter().map(|t| &**t) {
792            let verdict = permission_eval
793                .evaluate(tool_name, tool_input, tool_use_id, &session_id, &cwd)
794                .await?;
795
796            let actual_input = match &verdict {
797                PermissionVerdict::AllowWithUpdatedInput(new_input) => new_input.clone(),
798                _ => tool_input.clone(),
799            };
800
801            match verdict {
802                PermissionVerdict::Allow | PermissionVerdict::AllowWithUpdatedInput(_) => {
803                    permitted_tools.push(PermittedTool {
804                        tool_use_id: tool_use_id.clone(),
805                        tool_name: tool_name.clone(),
806                        actual_input,
807                    });
808                }
809                PermissionVerdict::Deny { reason } => {
810                    debug!(tool = %tool_name, reason = %reason, "Tool denied");
811                    permission_denials.push(PermissionDenial {
812                        tool_name: tool_name.clone(),
813                        tool_use_id: tool_use_id.clone(),
814                        tool_input: tool_input.clone(),
815                    });
816
817                    let api_block = ApiContentBlock::ToolResult {
818                        tool_use_id: tool_use_id.clone(),
819                        content: json!(format!("Permission denied: {}", reason)),
820                        is_error: Some(true),
821                        cache_control: None,
822                        name: Some(tool_name.clone()),
823                    };
824
825                    // Stream denial result to frontend immediately
826                    let denial_msg = Message::User(UserMessage {
827                        uuid: Some(Uuid::new_v4()),
828                        session_id: session_id.clone(),
829                        content: vec![api_block_to_content_block(&api_block)],
830                        parent_tool_use_id: None,
831                        is_synthetic: true,
832                        tool_use_result: None,
833                    });
834                    if options.persist_session {
835                        let _ = session
836                            .append_message(&serde_json::to_value(&denial_msg).unwrap_or_default())
837                            .await;
838                    }
839                    if tx.send(Ok(denial_msg)).is_err() {
840                        return Ok(());
841                    }
842
843                    tool_results.push(api_block);
844                }
845            }
846        }
847
848        // Phase 2: Execute permitted tools concurrently, stream results as they complete
849        let mut futs: FuturesUnordered<_> = permitted_tools
850            .iter()
851            .map(|pt| {
852                let handler = &options.external_tool_handler;
853                let executor = &tool_executor;
854                let name = &pt.tool_name;
855                let input = &pt.actual_input;
856                let id = &pt.tool_use_id;
857                async move {
858                    debug!(tool = %name, "Executing tool");
859
860                    let tool_result = if let Some(ref handler) = handler {
861                        let ext_result = handler(name.clone(), input.clone()).await;
862                        if let Some(tr) = ext_result {
863                            tr
864                        } else {
865                            match executor.execute(name, input.clone()).await {
866                                Ok(tr) => tr,
867                                Err(e) => ToolResult {
868                                    content: format!("{}", e),
869                                    is_error: true,
870                                    raw_content: None,
871                                },
872                            }
873                        }
874                    } else {
875                        match executor.execute(name, input.clone()).await {
876                            Ok(tr) => tr,
877                            Err(e) => ToolResult {
878                                content: format!("{}", e),
879                                is_error: true,
880                                raw_content: None,
881                            },
882                        }
883                    };
884                    (id.as_str(), name.as_str(), input, tool_result)
885                }
886            })
887            .collect();
888
889        while let Some((tool_use_id, tool_name, actual_input, mut tool_result)) = futs.next().await
890        {
891            // Sanitize tool result: strip blobs, enforce byte limit.
892            let max_result_bytes = options
893                .max_tool_result_bytes
894                .unwrap_or(sanitize::DEFAULT_MAX_TOOL_RESULT_BYTES);
895            tool_result.content =
896                sanitize::sanitize_tool_result(&tool_result.content, max_result_bytes);
897
898            // Run PostToolUse hooks
899            hook_registry
900                .run_post_tool_use(
901                    tool_name,
902                    actual_input,
903                    &serde_json::to_value(&tool_result.content).unwrap_or_default(),
904                    tool_use_id,
905                    &session_id,
906                    &cwd,
907                )
908                .await;
909
910            let result_content = tool_result
911                .raw_content
912                .unwrap_or_else(|| json!(tool_result.content));
913
914            let api_block = ApiContentBlock::ToolResult {
915                tool_use_id: tool_use_id.to_string(),
916                content: result_content,
917                is_error: if tool_result.is_error {
918                    Some(true)
919                } else {
920                    None
921                },
922                cache_control: None,
923                name: Some(tool_name.to_string()),
924            };
925
926            // Stream this individual result to the frontend immediately
927            let result_msg = Message::User(UserMessage {
928                uuid: Some(Uuid::new_v4()),
929                session_id: session_id.clone(),
930                content: vec![api_block_to_content_block(&api_block)],
931                parent_tool_use_id: None,
932                is_synthetic: true,
933                tool_use_result: None,
934            });
935            if options.persist_session {
936                let _ = session
937                    .append_message(&serde_json::to_value(&result_msg).unwrap_or_default())
938                    .await;
939            }
940            if tx.send(Ok(result_msg)).is_err() {
941                return Ok(());
942            }
943
944            tool_results.push(api_block);
945        }
946
947        // Add all tool results to conversation for the next API call
948        conversation.push(ApiMessage {
949            role: "user".to_string(),
950            content: tool_results,
951        });
952
953        // --- Lightweight pruning (between turns, before full compaction) ---
954        if let Some(context_budget) = options.context_budget {
955            let prune_pct = options
956                .prune_threshold_pct
957                .unwrap_or(compact::DEFAULT_PRUNE_THRESHOLD_PCT);
958            if compact::should_prune(response.usage.input_tokens, context_budget, prune_pct) {
959                let max_chars = options
960                    .prune_tool_result_max_chars
961                    .unwrap_or(compact::DEFAULT_PRUNE_TOOL_RESULT_MAX_CHARS);
962                let min_keep = options.min_keep_messages.unwrap_or(4);
963                let removed = compact::prune_tool_results(&mut conversation, max_chars, min_keep);
964                if removed > 0 {
965                    debug!(
966                        chars_removed = removed,
967                        input_tokens = response.usage.input_tokens,
968                        "Pruned oversized tool results to free context space"
969                    );
970                }
971            }
972        }
973
974        // --- Compaction check (between turns) ---
975        if let Some(context_budget) = options.context_budget {
976            if compact::should_compact(response.usage.input_tokens, context_budget) {
977                let min_keep = options.min_keep_messages.unwrap_or(4);
978                let split_point = compact::find_split_point(&conversation, min_keep);
979                if split_point > 0 {
980                    debug!(
981                        input_tokens = response.usage.input_tokens,
982                        context_budget,
983                        split_point,
984                        "Context budget exceeded, compacting conversation"
985                    );
986
987                    let compaction_model = options
988                        .compaction_model
989                        .as_deref()
990                        .unwrap_or(compact::DEFAULT_COMPACTION_MODEL);
991
992                    // Fire pre-compact handler so the host can persist key facts
993                    if let Some(ref handler) = options.pre_compact_handler {
994                        let msgs_to_compact = conversation[..split_point].to_vec();
995                        handler(msgs_to_compact).await;
996                    }
997
998                    let summary_prompt =
999                        compact::build_summary_prompt(&conversation[..split_point]);
1000
1001                    let summary_max_tokens = options.summary_max_tokens.unwrap_or(4096);
1002                    let compact_provider: &dyn LlmProvider = match &options.compaction_provider {
1003                        Some(cp) => cp.as_ref(),
1004                        None => provider.as_ref(),
1005                    };
1006                    let fallback_provider: Option<&dyn LlmProvider> =
1007                        if options.compaction_provider.is_some() {
1008                            Some(provider.as_ref())
1009                        } else {
1010                            None
1011                        };
1012                    match compact::call_summarizer(
1013                        compact_provider,
1014                        &summary_prompt,
1015                        compaction_model,
1016                        fallback_provider,
1017                        &model,
1018                        summary_max_tokens,
1019                    )
1020                    .await
1021                    {
1022                        Ok(summary) => {
1023                            let pre_tokens = response.usage.input_tokens;
1024                            let messages_compacted = split_point;
1025
1026                            compact::splice_conversation(&mut conversation, split_point, &summary);
1027
1028                            // Emit CompactBoundary system message
1029                            let compact_msg = Message::System(SystemMessage {
1030                                subtype: SystemSubtype::CompactBoundary,
1031                                uuid: Uuid::new_v4(),
1032                                session_id: session_id.clone(),
1033                                agents: None,
1034                                claude_code_version: None,
1035                                cwd: None,
1036                                tools: None,
1037                                mcp_servers: None,
1038                                model: None,
1039                                permission_mode: None,
1040                                compact_metadata: Some(CompactMetadata {
1041                                    trigger: CompactTrigger::Auto,
1042                                    pre_tokens,
1043                                }),
1044                            });
1045
1046                            if options.persist_session {
1047                                let _ = session
1048                                    .append_message(
1049                                        &serde_json::to_value(&compact_msg).unwrap_or_default(),
1050                                    )
1051                                    .await;
1052                            }
1053                            let _ = tx.send(Ok(compact_msg));
1054
1055                            debug!(
1056                                pre_tokens,
1057                                messages_compacted,
1058                                summary_len = summary.len(),
1059                                "Conversation compacted"
1060                            );
1061                        }
1062                        Err(e) => {
1063                            warn!("Compaction failed, continuing without compaction: {}", e);
1064                        }
1065                    }
1066                }
1067            }
1068        }
1069    }
1070}
1071
1072/// Consume a streaming response, emitting `Message::StreamEvent` for each text
1073/// delta, and accumulate the full `MessageResponse` for the agent loop.
1074async fn accumulate_stream(
1075    event_stream: &mut std::pin::Pin<
1076        Box<dyn futures::Stream<Item = Result<ClientStreamEvent>> + Send>,
1077    >,
1078    tx: &mpsc::UnboundedSender<Result<Message>>,
1079    session_id: &str,
1080) -> Result<MessageResponse> {
1081    use crate::client::StreamEvent as SE;
1082
1083    // Accumulated state
1084    let mut message_id = String::new();
1085    let mut model = String::new();
1086    let mut role = String::from("assistant");
1087    let mut content_blocks: Vec<ApiContentBlock> = Vec::new();
1088    let mut stop_reason: Option<String> = None;
1089    let mut usage = ApiUsage::default();
1090
1091    // Track in-progress content blocks by index
1092    // For text blocks: accumulate text. For tool_use: accumulate JSON string.
1093    let mut block_texts: Vec<String> = Vec::new();
1094    let mut block_types: Vec<String> = Vec::new(); // "text", "tool_use", "thinking"
1095    let mut block_tool_ids: Vec<String> = Vec::new();
1096    let mut block_tool_names: Vec<String> = Vec::new();
1097
1098    while let Some(event_result) = FuturesStreamExt::next(event_stream).await {
1099        let event = event_result?;
1100        match event {
1101            SE::MessageStart { message } => {
1102                message_id = message.id;
1103                model = message.model;
1104                role = message.role;
1105                usage = message.usage;
1106            }
1107            SE::ContentBlockStart {
1108                index,
1109                content_block,
1110            } => {
1111                // Ensure vectors are large enough
1112                while block_texts.len() <= index {
1113                    block_texts.push(String::new());
1114                    block_types.push(String::new());
1115                    block_tool_ids.push(String::new());
1116                    block_tool_names.push(String::new());
1117                }
1118                match &content_block {
1119                    ApiContentBlock::Text { .. } => {
1120                        block_types[index] = "text".to_string();
1121                    }
1122                    ApiContentBlock::ToolUse { id, name, input } => {
1123                        block_types[index] = "tool_use".to_string();
1124                        block_tool_ids[index] = id.clone();
1125                        block_tool_names[index] = name.clone();
1126                        // OpenAI/Ollama streaming delivers the complete input
1127                        // in ContentBlockStart (not via InputJsonDelta like
1128                        // Anthropic). Store it so ContentBlockStop can parse it.
1129                        let input_str = input.to_string();
1130                        if input_str != "{}" {
1131                            block_texts[index] = input_str;
1132                        }
1133                    }
1134                    ApiContentBlock::Thinking { .. } => {
1135                        block_types[index] = "thinking".to_string();
1136                    }
1137                    _ => {}
1138                }
1139            }
1140            SE::ContentBlockDelta { index, delta } => {
1141                while block_texts.len() <= index {
1142                    block_texts.push(String::new());
1143                    block_types.push(String::new());
1144                    block_tool_ids.push(String::new());
1145                    block_tool_names.push(String::new());
1146                }
1147                match &delta {
1148                    ContentDelta::TextDelta { text } => {
1149                        block_texts[index].push_str(text);
1150                        // Emit streaming event so downstream consumers get per-token updates
1151                        let stream_event = Message::StreamEvent(StreamEventMessage {
1152                            event: serde_json::json!({
1153                                "type": "content_block_delta",
1154                                "index": index,
1155                                "delta": { "type": "text_delta", "text": text }
1156                            }),
1157                            parent_tool_use_id: None,
1158                            uuid: Uuid::new_v4(),
1159                            session_id: session_id.to_string(),
1160                        });
1161                        if tx.send(Ok(stream_event)).is_err() {
1162                            return Err(AgentError::Cancelled);
1163                        }
1164                    }
1165                    ContentDelta::InputJsonDelta { partial_json } => {
1166                        block_texts[index].push_str(partial_json);
1167                    }
1168                    ContentDelta::ThinkingDelta { thinking } => {
1169                        block_texts[index].push_str(thinking);
1170                    }
1171                }
1172            }
1173            SE::ContentBlockStop { index } => {
1174                if index < block_types.len() {
1175                    let block = match block_types[index].as_str() {
1176                        "text" => ApiContentBlock::Text {
1177                            text: std::mem::take(&mut block_texts[index]),
1178                            cache_control: None,
1179                        },
1180                        "tool_use" => {
1181                            let input: serde_json::Value =
1182                                serde_json::from_str(&block_texts[index])
1183                                    .unwrap_or(serde_json::Value::Object(Default::default()));
1184                            ApiContentBlock::ToolUse {
1185                                id: std::mem::take(&mut block_tool_ids[index]),
1186                                name: std::mem::take(&mut block_tool_names[index]),
1187                                input,
1188                            }
1189                        }
1190                        "thinking" => ApiContentBlock::Thinking {
1191                            thinking: std::mem::take(&mut block_texts[index]),
1192                        },
1193                        _ => continue,
1194                    };
1195                    // Place blocks at the correct index
1196                    while content_blocks.len() <= index {
1197                        content_blocks.push(ApiContentBlock::Text {
1198                            text: String::new(),
1199                            cache_control: None,
1200                        });
1201                    }
1202                    content_blocks[index] = block;
1203                }
1204            }
1205            SE::MessageDelta {
1206                delta,
1207                usage: delta_usage,
1208            } => {
1209                stop_reason = delta.stop_reason;
1210                // MessageDelta carries output_tokens for the whole message
1211                usage.output_tokens = delta_usage.output_tokens;
1212            }
1213            SE::MessageStop => {
1214                break;
1215            }
1216            SE::Error { error } => {
1217                return Err(AgentError::Api(error.message));
1218            }
1219            SE::Ping => {}
1220        }
1221    }
1222
1223    Ok(MessageResponse {
1224        id: message_id,
1225        role,
1226        content: content_blocks,
1227        model,
1228        stop_reason,
1229        usage,
1230    })
1231}
1232
1233/// Apply a single cache breakpoint to the last content block of the last user
1234/// message in the conversation. Clears any previous breakpoints from messages
1235/// so we stay within the API limit of 4 cache_control blocks (system + tools +
1236/// this one = 3 total).
1237fn apply_cache_breakpoint(conversation: &mut [ApiMessage]) {
1238    // First, clear all existing cache_control from messages
1239    for msg in conversation.iter_mut() {
1240        for block in msg.content.iter_mut() {
1241            match block {
1242                ApiContentBlock::Text { cache_control, .. }
1243                | ApiContentBlock::ToolResult { cache_control, .. } => {
1244                    *cache_control = None;
1245                }
1246                ApiContentBlock::Image { .. }
1247                | ApiContentBlock::ToolUse { .. }
1248                | ApiContentBlock::Thinking { .. } => {}
1249            }
1250        }
1251    }
1252
1253    // Set cache_control on the last content block of the last user message
1254    if let Some(last_user) = conversation.iter_mut().rev().find(|m| m.role == "user") {
1255        if let Some(last_block) = last_user.content.last_mut() {
1256            match last_block {
1257                ApiContentBlock::Text { cache_control, .. }
1258                | ApiContentBlock::ToolResult { cache_control, .. } => {
1259                    *cache_control = Some(CacheControl::ephemeral());
1260                }
1261                ApiContentBlock::Image { .. }
1262                | ApiContentBlock::ToolUse { .. }
1263                | ApiContentBlock::Thinking { .. } => {}
1264            }
1265        }
1266    }
1267}
1268
1269/// Convert an API content block to our ContentBlock type.
1270fn api_block_to_content_block(block: &ApiContentBlock) -> ContentBlock {
1271    match block {
1272        ApiContentBlock::Text { text, .. } => ContentBlock::Text { text: text.clone() },
1273        ApiContentBlock::Image { .. } => ContentBlock::Text {
1274            text: "[image]".to_string(),
1275        },
1276        ApiContentBlock::ToolUse { id, name, input } => ContentBlock::ToolUse {
1277            id: id.clone(),
1278            name: name.clone(),
1279            input: input.clone(),
1280        },
1281        ApiContentBlock::ToolResult {
1282            tool_use_id,
1283            content,
1284            is_error,
1285            ..
1286        } => ContentBlock::ToolResult {
1287            tool_use_id: tool_use_id.clone(),
1288            content: content.clone(),
1289            is_error: *is_error,
1290        },
1291        ApiContentBlock::Thinking { thinking } => ContentBlock::Thinking {
1292            thinking: thinking.clone(),
1293        },
1294    }
1295}
1296
1297/// Try to convert a stored JSON value to an API message.
1298fn value_to_api_message(value: &serde_json::Value) -> Option<ApiMessage> {
1299    let msg_type = value.get("type")?.as_str()?;
1300
1301    match msg_type {
1302        "assistant" => {
1303            let content = value.get("content")?;
1304            let blocks = parse_content_blocks(content)?;
1305            Some(ApiMessage {
1306                role: "assistant".to_string(),
1307                content: blocks,
1308            })
1309        }
1310        "user" => {
1311            let content = value.get("content")?;
1312            let blocks = parse_content_blocks(content)?;
1313            Some(ApiMessage {
1314                role: "user".to_string(),
1315                content: blocks,
1316            })
1317        }
1318        _ => None,
1319    }
1320}
1321
1322/// Parse content blocks from a JSON value.
1323fn parse_content_blocks(content: &serde_json::Value) -> Option<Vec<ApiContentBlock>> {
1324    if let Some(text) = content.as_str() {
1325        return Some(vec![ApiContentBlock::Text {
1326            text: text.to_string(),
1327            cache_control: None,
1328        }]);
1329    }
1330
1331    if let Some(blocks) = content.as_array() {
1332        let parsed: Vec<ApiContentBlock> = blocks
1333            .iter()
1334            .filter_map(|b| serde_json::from_value(b.clone()).ok())
1335            .collect();
1336        if !parsed.is_empty() {
1337            return Some(parsed);
1338        }
1339    }
1340
1341    None
1342}
1343
1344/// Build a ResultMessage.
1345#[allow(clippy::too_many_arguments)]
1346fn build_result_message(
1347    subtype: ResultSubtype,
1348    session_id: &str,
1349    result_text: Option<String>,
1350    start_time: Instant,
1351    api_time_ms: u64,
1352    num_turns: u32,
1353    total_cost: f64,
1354    usage: &Usage,
1355    model_usage: &HashMap<String, ModelUsage>,
1356    permission_denials: &[PermissionDenial],
1357) -> Message {
1358    Message::Result(ResultMessage {
1359        subtype,
1360        uuid: Uuid::new_v4(),
1361        session_id: session_id.to_string(),
1362        duration_ms: start_time.elapsed().as_millis() as u64,
1363        duration_api_ms: api_time_ms,
1364        is_error: result_text.is_none(),
1365        num_turns,
1366        result: result_text,
1367        stop_reason: Some("end_turn".to_string()),
1368        total_cost_usd: total_cost,
1369        usage: Some(usage.clone()),
1370        model_usage: model_usage.clone(),
1371        permission_denials: permission_denials.to_vec(),
1372        structured_output: None,
1373        errors: Vec::new(),
1374    })
1375}
1376
1377/// Build an error ResultMessage.
1378#[allow(clippy::too_many_arguments)]
1379fn build_error_result_message(
1380    session_id: &str,
1381    error_msg: &str,
1382    start_time: Instant,
1383    api_time_ms: u64,
1384    num_turns: u32,
1385    total_cost: f64,
1386    usage: &Usage,
1387    model_usage: &HashMap<String, ModelUsage>,
1388    permission_denials: &[PermissionDenial],
1389) -> Message {
1390    Message::Result(ResultMessage {
1391        subtype: ResultSubtype::ErrorDuringExecution,
1392        uuid: Uuid::new_v4(),
1393        session_id: session_id.to_string(),
1394        duration_ms: start_time.elapsed().as_millis() as u64,
1395        duration_api_ms: api_time_ms,
1396        is_error: true,
1397        num_turns,
1398        result: None,
1399        stop_reason: None,
1400        total_cost_usd: total_cost,
1401        usage: Some(usage.clone()),
1402        model_usage: model_usage.clone(),
1403        permission_denials: permission_denials.to_vec(),
1404        structured_output: None,
1405        errors: vec![error_msg.to_string()],
1406    })
1407}
1408
1409#[cfg(test)]
1410mod tests {
1411    use super::*;
1412    use std::sync::atomic::{AtomicUsize, Ordering};
1413    use std::sync::Arc;
1414    use std::time::Duration;
1415
1416    /// Helper: execute tools concurrently using the same FuturesUnordered pattern
1417    /// as the production code, collecting (tool_use_id, content, completion_order).
1418    async fn run_concurrent_tools(
1419        tools: Vec<(String, String, serde_json::Value)>,
1420        handler: impl Fn(
1421            String,
1422            serde_json::Value,
1423        ) -> Pin<Box<dyn futures::Future<Output = Option<ToolResult>> + Send>>,
1424    ) -> Vec<(String, String, usize)> {
1425        let order = Arc::new(AtomicUsize::new(0));
1426        let handler = Arc::new(handler);
1427
1428        struct PermittedTool {
1429            tool_use_id: String,
1430            tool_name: String,
1431            actual_input: serde_json::Value,
1432        }
1433
1434        let permitted: Vec<PermittedTool> = tools
1435            .into_iter()
1436            .map(|(id, name, input)| PermittedTool {
1437                tool_use_id: id,
1438                tool_name: name,
1439                actual_input: input,
1440            })
1441            .collect();
1442
1443        let mut futs: FuturesUnordered<_> = permitted
1444            .iter()
1445            .map(|pt| {
1446                let handler = handler.clone();
1447                let order = order.clone();
1448                let name = pt.tool_name.clone();
1449                let input = pt.actual_input.clone();
1450                let id = pt.tool_use_id.clone();
1451                async move {
1452                    let result = handler(name, input).await;
1453                    let seq = order.fetch_add(1, Ordering::SeqCst);
1454                    (id, result, seq)
1455                }
1456            })
1457            .collect();
1458
1459        let mut results = Vec::new();
1460        while let Some((id, result, seq)) = futs.next().await {
1461            let content = result
1462                .map(|r| r.content)
1463                .unwrap_or_else(|| "no handler".into());
1464            results.push((id, content, seq));
1465        }
1466        results
1467    }
1468
1469    #[tokio::test]
1470    async fn concurrent_tools_all_complete() {
1471        let results = run_concurrent_tools(
1472            vec![
1473                ("t1".into(), "Read".into(), json!({"path": "a.txt"})),
1474                ("t2".into(), "Read".into(), json!({"path": "b.txt"})),
1475                ("t3".into(), "Read".into(), json!({"path": "c.txt"})),
1476            ],
1477            |name, input| {
1478                Box::pin(async move {
1479                    let path = input["path"].as_str().unwrap_or("?");
1480                    Some(ToolResult {
1481                        content: format!("{}: {}", name, path),
1482                        is_error: false,
1483                        raw_content: None,
1484                    })
1485                })
1486            },
1487        )
1488        .await;
1489
1490        assert_eq!(results.len(), 3);
1491        let ids: Vec<&str> = results.iter().map(|(id, _, _)| id.as_str()).collect();
1492        assert!(ids.contains(&"t1"));
1493        assert!(ids.contains(&"t2"));
1494        assert!(ids.contains(&"t3"));
1495    }
1496
1497    #[tokio::test]
1498    async fn slow_tool_does_not_block_fast_tools() {
1499        let start = Instant::now();
1500
1501        let results = run_concurrent_tools(
1502            vec![
1503                ("slow".into(), "Bash".into(), json!({})),
1504                ("fast1".into(), "Read".into(), json!({})),
1505                ("fast2".into(), "Read".into(), json!({})),
1506            ],
1507            |name, _input| {
1508                Box::pin(async move {
1509                    if name == "Bash" {
1510                        tokio::time::sleep(Duration::from_millis(200)).await;
1511                        Some(ToolResult {
1512                            content: "slow done".into(),
1513                            is_error: false,
1514                            raw_content: None,
1515                        })
1516                    } else {
1517                        // Fast tools complete immediately
1518                        Some(ToolResult {
1519                            content: "fast done".into(),
1520                            is_error: false,
1521                            raw_content: None,
1522                        })
1523                    }
1524                })
1525            },
1526        )
1527        .await;
1528
1529        let elapsed = start.elapsed();
1530
1531        // All three should complete
1532        assert_eq!(results.len(), 3);
1533
1534        // Fast tools should complete before the slow tool (lower order index)
1535        let slow = results.iter().find(|(id, _, _)| id == "slow").unwrap();
1536        let fast1 = results.iter().find(|(id, _, _)| id == "fast1").unwrap();
1537        let fast2 = results.iter().find(|(id, _, _)| id == "fast2").unwrap();
1538
1539        assert!(fast1.2 < slow.2, "fast1 should complete before slow");
1540        assert!(fast2.2 < slow.2, "fast2 should complete before slow");
1541
1542        // Total time should be ~200ms (concurrent), not ~400ms+ (sequential)
1543        assert!(
1544            elapsed < Duration::from_millis(400),
1545            "elapsed {:?} should be under 400ms (concurrent execution)",
1546            elapsed
1547        );
1548    }
1549
1550    #[tokio::test]
1551    async fn results_streamed_individually_as_they_complete() {
1552        // Simulate the streaming pattern from the production code:
1553        // each tool result is sent to the channel as it completes.
1554        let (tx, mut rx) = mpsc::unbounded_channel::<(String, String)>();
1555
1556        let tools = vec![
1557            ("t_slow".into(), "Slow".into(), json!({})),
1558            ("t_fast".into(), "Fast".into(), json!({})),
1559        ];
1560
1561        struct PT {
1562            tool_use_id: String,
1563            tool_name: String,
1564        }
1565
1566        let permitted: Vec<PT> = tools
1567            .into_iter()
1568            .map(|(id, name, _)| PT {
1569                tool_use_id: id,
1570                tool_name: name,
1571            })
1572            .collect();
1573
1574        let mut futs: FuturesUnordered<_> = permitted
1575            .iter()
1576            .map(|pt| {
1577                let name = pt.tool_name.clone();
1578                let id = pt.tool_use_id.clone();
1579                async move {
1580                    if name == "Slow" {
1581                        tokio::time::sleep(Duration::from_millis(100)).await;
1582                    }
1583                    let result = ToolResult {
1584                        content: format!("{} result", name),
1585                        is_error: false,
1586                        raw_content: None,
1587                    };
1588                    (id, result)
1589                }
1590            })
1591            .collect();
1592
1593        // Process results as they complete (like production code)
1594        while let Some((id, result)) = futs.next().await {
1595            tx.send((id, result.content)).unwrap();
1596        }
1597        drop(tx);
1598
1599        // Collect what was streamed
1600        let mut streamed = Vec::new();
1601        while let Some(item) = rx.recv().await {
1602            streamed.push(item);
1603        }
1604
1605        assert_eq!(streamed.len(), 2);
1606        // Fast should arrive first
1607        assert_eq!(streamed[0].0, "t_fast");
1608        assert_eq!(streamed[0].1, "Fast result");
1609        assert_eq!(streamed[1].0, "t_slow");
1610        assert_eq!(streamed[1].1, "Slow result");
1611    }
1612
1613    #[tokio::test]
1614    async fn error_tool_does_not_prevent_other_tools() {
1615        let results = run_concurrent_tools(
1616            vec![
1617                ("t_ok".into(), "Read".into(), json!({})),
1618                ("t_err".into(), "Fail".into(), json!({})),
1619            ],
1620            |name, _input| {
1621                Box::pin(async move {
1622                    if name == "Fail" {
1623                        Some(ToolResult {
1624                            content: "something went wrong".into(),
1625                            is_error: true,
1626                            raw_content: None,
1627                        })
1628                    } else {
1629                        Some(ToolResult {
1630                            content: "ok".into(),
1631                            is_error: false,
1632                            raw_content: None,
1633                        })
1634                    }
1635                })
1636            },
1637        )
1638        .await;
1639
1640        assert_eq!(results.len(), 2);
1641        let ok = results.iter().find(|(id, _, _)| id == "t_ok").unwrap();
1642        let err = results.iter().find(|(id, _, _)| id == "t_err").unwrap();
1643        assert_eq!(ok.1, "ok");
1644        assert_eq!(err.1, "something went wrong");
1645    }
1646
1647    #[tokio::test]
1648    async fn external_handler_none_falls_through_correctly() {
1649        // When handler returns None for a tool, the production code falls through
1650        // to the built-in executor. Test that the pattern works.
1651        let results = run_concurrent_tools(
1652            vec![
1653                ("t_custom".into(), "MyTool".into(), json!({"x": 1})),
1654                ("t_builtin".into(), "Read".into(), json!({"path": "/tmp"})),
1655            ],
1656            |name, _input| {
1657                Box::pin(async move {
1658                    if name == "MyTool" {
1659                        Some(ToolResult {
1660                            content: "custom handled".into(),
1661                            is_error: false,
1662                            raw_content: None,
1663                        })
1664                    } else {
1665                        // Returns None => would fall through to built-in executor
1666                        None
1667                    }
1668                })
1669            },
1670        )
1671        .await;
1672
1673        assert_eq!(results.len(), 2);
1674        let custom = results.iter().find(|(id, _, _)| id == "t_custom").unwrap();
1675        let builtin = results.iter().find(|(id, _, _)| id == "t_builtin").unwrap();
1676        assert_eq!(custom.1, "custom handled");
1677        assert_eq!(builtin.1, "no handler"); // our test helper treats None as "no handler"
1678    }
1679
1680    #[tokio::test]
1681    async fn single_tool_works_same_as_before() {
1682        let results = run_concurrent_tools(
1683            vec![("t1".into(), "Read".into(), json!({"path": "file.txt"}))],
1684            |_name, _input| {
1685                Box::pin(async move {
1686                    Some(ToolResult {
1687                        content: "file contents".into(),
1688                        is_error: false,
1689                        raw_content: None,
1690                    })
1691                })
1692            },
1693        )
1694        .await;
1695
1696        assert_eq!(results.len(), 1);
1697        assert_eq!(results[0].0, "t1");
1698        assert_eq!(results[0].1, "file contents");
1699        assert_eq!(results[0].2, 0); // first (and only) completion
1700    }
1701
1702    #[tokio::test]
1703    async fn empty_tool_list_produces_no_results() {
1704        let results =
1705            run_concurrent_tools(vec![], |_name, _input| Box::pin(async move { None })).await;
1706
1707        assert_eq!(results.len(), 0);
1708    }
1709
1710    #[tokio::test]
1711    async fn tool_use_ids_preserved_through_concurrent_execution() {
1712        let results = run_concurrent_tools(
1713            vec![
1714                ("toolu_abc123".into(), "Read".into(), json!({})),
1715                ("toolu_def456".into(), "Write".into(), json!({})),
1716                ("toolu_ghi789".into(), "Bash".into(), json!({})),
1717            ],
1718            |name, _input| {
1719                Box::pin(async move {
1720                    // Add varying delays to shuffle completion order
1721                    match name.as_str() {
1722                        "Read" => tokio::time::sleep(Duration::from_millis(30)).await,
1723                        "Write" => tokio::time::sleep(Duration::from_millis(10)).await,
1724                        _ => tokio::time::sleep(Duration::from_millis(50)).await,
1725                    }
1726                    Some(ToolResult {
1727                        content: format!("{} result", name),
1728                        is_error: false,
1729                        raw_content: None,
1730                    })
1731                })
1732            },
1733        )
1734        .await;
1735
1736        assert_eq!(results.len(), 3);
1737
1738        // Regardless of completion order, IDs must match their tools
1739        for (id, content, _) in &results {
1740            match id.as_str() {
1741                "toolu_abc123" => assert_eq!(content, "Read result"),
1742                "toolu_def456" => assert_eq!(content, "Write result"),
1743                "toolu_ghi789" => assert_eq!(content, "Bash result"),
1744                other => panic!("unexpected tool_use_id: {}", other),
1745            }
1746        }
1747    }
1748
1749    #[tokio::test]
1750    async fn concurrent_execution_timing_is_parallel() {
1751        // 5 tools each taking 50ms should complete in ~50ms total, not 250ms
1752        let tools: Vec<_> = (0..5)
1753            .map(|i| (format!("t{}", i), "Tool".into(), json!({})))
1754            .collect();
1755
1756        let start = Instant::now();
1757
1758        let results = run_concurrent_tools(tools, |_name, _input| {
1759            Box::pin(async move {
1760                tokio::time::sleep(Duration::from_millis(50)).await;
1761                Some(ToolResult {
1762                    content: "done".into(),
1763                    is_error: false,
1764                    raw_content: None,
1765                })
1766            })
1767        })
1768        .await;
1769
1770        let elapsed = start.elapsed();
1771
1772        assert_eq!(results.len(), 5);
1773        // Should complete in roughly 50ms, definitely under 200ms
1774        assert!(
1775            elapsed < Duration::from_millis(200),
1776            "5 x 50ms tools took {:?} — should be ~50ms if concurrent",
1777            elapsed
1778        );
1779    }
1780
1781    #[tokio::test]
1782    async fn api_block_to_content_block_preserves_tool_result_fields() {
1783        let block = ApiContentBlock::ToolResult {
1784            tool_use_id: "toolu_abc".into(),
1785            content: json!("result text"),
1786            is_error: Some(true),
1787            cache_control: None,
1788            name: None,
1789        };
1790
1791        let content = api_block_to_content_block(&block);
1792        match content {
1793            ContentBlock::ToolResult {
1794                tool_use_id,
1795                content,
1796                is_error,
1797            } => {
1798                assert_eq!(tool_use_id, "toolu_abc");
1799                assert_eq!(content, json!("result text"));
1800                assert_eq!(is_error, Some(true));
1801            }
1802            _ => panic!("expected ToolResult content block"),
1803        }
1804    }
1805
1806    #[tokio::test]
1807    async fn streamed_messages_each_contain_single_tool_result() {
1808        // Verify that the streaming pattern produces one User message per tool result
1809        let (tx, mut rx) = mpsc::unbounded_channel::<Result<Message>>();
1810        let session_id = "test-session".to_string();
1811
1812        // Simulate what the production code does
1813        let tool_ids = vec!["t1", "t2", "t3"];
1814        for id in &tool_ids {
1815            let api_block = ApiContentBlock::ToolResult {
1816                tool_use_id: id.to_string(),
1817                content: json!(format!("result for {}", id)),
1818                is_error: None,
1819                cache_control: None,
1820                name: None,
1821            };
1822
1823            let result_msg = Message::User(UserMessage {
1824                uuid: Some(Uuid::new_v4()),
1825                session_id: session_id.clone(),
1826                content: vec![api_block_to_content_block(&api_block)],
1827                parent_tool_use_id: None,
1828                is_synthetic: true,
1829                tool_use_result: None,
1830            });
1831            tx.send(Ok(result_msg)).unwrap();
1832        }
1833        drop(tx);
1834
1835        let mut messages = Vec::new();
1836        while let Some(Ok(msg)) = rx.recv().await {
1837            messages.push(msg);
1838        }
1839
1840        assert_eq!(messages.len(), 3, "should have 3 individual messages");
1841
1842        for (i, msg) in messages.iter().enumerate() {
1843            if let Message::User(user) = msg {
1844                assert_eq!(
1845                    user.content.len(),
1846                    1,
1847                    "each message should have exactly 1 content block"
1848                );
1849                assert!(user.is_synthetic);
1850                if let ContentBlock::ToolResult { tool_use_id, .. } = &user.content[0] {
1851                    assert_eq!(tool_use_id, tool_ids[i]);
1852                } else {
1853                    panic!("expected ToolResult block");
1854                }
1855            } else {
1856                panic!("expected User message");
1857            }
1858        }
1859    }
1860
1861    #[tokio::test]
1862    async fn accumulate_stream_emits_text_deltas_and_builds_response() {
1863        use crate::client::{
1864            ApiContentBlock, ApiUsage, ContentDelta, MessageResponse, StreamEvent as SE,
1865        };
1866
1867        // Build a fake stream of SSE events
1868        let events: Vec<Result<SE>> = vec![
1869            Ok(SE::MessageStart {
1870                message: MessageResponse {
1871                    id: "msg_123".into(),
1872                    role: "assistant".into(),
1873                    content: vec![],
1874                    model: "claude-test".into(),
1875                    stop_reason: None,
1876                    usage: ApiUsage {
1877                        input_tokens: 100,
1878                        output_tokens: 0,
1879                        cache_creation_input_tokens: None,
1880                        cache_read_input_tokens: None,
1881                    },
1882                },
1883            }),
1884            Ok(SE::ContentBlockStart {
1885                index: 0,
1886                content_block: ApiContentBlock::Text {
1887                    text: String::new(),
1888                    cache_control: None,
1889                },
1890            }),
1891            Ok(SE::ContentBlockDelta {
1892                index: 0,
1893                delta: ContentDelta::TextDelta {
1894                    text: "Hello".into(),
1895                },
1896            }),
1897            Ok(SE::ContentBlockDelta {
1898                index: 0,
1899                delta: ContentDelta::TextDelta {
1900                    text: " world".into(),
1901                },
1902            }),
1903            Ok(SE::ContentBlockDelta {
1904                index: 0,
1905                delta: ContentDelta::TextDelta { text: "!".into() },
1906            }),
1907            Ok(SE::ContentBlockStop { index: 0 }),
1908            Ok(SE::MessageDelta {
1909                delta: crate::client::MessageDelta {
1910                    stop_reason: Some("end_turn".into()),
1911                },
1912                usage: ApiUsage {
1913                    input_tokens: 0,
1914                    output_tokens: 15,
1915                    cache_creation_input_tokens: None,
1916                    cache_read_input_tokens: None,
1917                },
1918            }),
1919            Ok(SE::MessageStop),
1920        ];
1921
1922        let stream = futures::stream::iter(events);
1923        let mut boxed_stream: std::pin::Pin<Box<dyn futures::Stream<Item = Result<SE>> + Send>> =
1924            Box::pin(stream);
1925
1926        let (tx, mut rx) = mpsc::unbounded_channel();
1927
1928        let response = accumulate_stream(&mut boxed_stream, &tx, "test-session")
1929            .await
1930            .expect("accumulate_stream should succeed");
1931
1932        // Verify accumulated response
1933        assert_eq!(response.id, "msg_123");
1934        assert_eq!(response.model, "claude-test");
1935        assert_eq!(response.stop_reason, Some("end_turn".into()));
1936        assert_eq!(response.usage.output_tokens, 15);
1937        assert_eq!(response.content.len(), 1);
1938        if let ApiContentBlock::Text { text, .. } = &response.content[0] {
1939            assert_eq!(text, "Hello world!");
1940        } else {
1941            panic!("expected Text content block");
1942        }
1943
1944        // Verify 3 StreamEvent messages were emitted (one per text delta)
1945        let mut stream_events = Vec::new();
1946        while let Ok(msg) = rx.try_recv() {
1947            stream_events.push(msg.unwrap());
1948        }
1949        assert_eq!(stream_events.len(), 3);
1950
1951        // Verify each is a StreamEvent with the correct text
1952        let expected_texts = ["Hello", " world", "!"];
1953        for (i, msg) in stream_events.iter().enumerate() {
1954            if let Message::StreamEvent(se) = msg {
1955                let delta = se.event.get("delta").unwrap();
1956                let text = delta.get("text").unwrap().as_str().unwrap();
1957                assert_eq!(text, expected_texts[i]);
1958                assert_eq!(se.session_id, "test-session");
1959            } else {
1960                panic!("expected StreamEvent message at index {}", i);
1961            }
1962        }
1963    }
1964
1965    #[tokio::test]
1966    async fn accumulate_stream_handles_tool_use() {
1967        use crate::client::{
1968            ApiContentBlock, ApiUsage, ContentDelta, MessageResponse, StreamEvent as SE,
1969        };
1970
1971        let events: Vec<Result<SE>> = vec![
1972            Ok(SE::MessageStart {
1973                message: MessageResponse {
1974                    id: "msg_456".into(),
1975                    role: "assistant".into(),
1976                    content: vec![],
1977                    model: "claude-test".into(),
1978                    stop_reason: None,
1979                    usage: ApiUsage::default(),
1980                },
1981            }),
1982            // Text block
1983            Ok(SE::ContentBlockStart {
1984                index: 0,
1985                content_block: ApiContentBlock::Text {
1986                    text: String::new(),
1987                    cache_control: None,
1988                },
1989            }),
1990            Ok(SE::ContentBlockDelta {
1991                index: 0,
1992                delta: ContentDelta::TextDelta {
1993                    text: "Let me check.".into(),
1994                },
1995            }),
1996            Ok(SE::ContentBlockStop { index: 0 }),
1997            // Tool use block
1998            Ok(SE::ContentBlockStart {
1999                index: 1,
2000                content_block: ApiContentBlock::ToolUse {
2001                    id: "toolu_abc".into(),
2002                    name: "Read".into(),
2003                    input: serde_json::json!({}),
2004                },
2005            }),
2006            Ok(SE::ContentBlockDelta {
2007                index: 1,
2008                delta: ContentDelta::InputJsonDelta {
2009                    partial_json: r#"{"path":"/tmp/f.txt"}"#.into(),
2010                },
2011            }),
2012            Ok(SE::ContentBlockStop { index: 1 }),
2013            Ok(SE::MessageDelta {
2014                delta: crate::client::MessageDelta {
2015                    stop_reason: Some("tool_use".into()),
2016                },
2017                usage: ApiUsage {
2018                    input_tokens: 0,
2019                    output_tokens: 20,
2020                    ..Default::default()
2021                },
2022            }),
2023            Ok(SE::MessageStop),
2024        ];
2025
2026        let stream = futures::stream::iter(events);
2027        let mut boxed_stream: std::pin::Pin<Box<dyn futures::Stream<Item = Result<SE>> + Send>> =
2028            Box::pin(stream);
2029
2030        let (tx, _rx) = mpsc::unbounded_channel();
2031        let response = accumulate_stream(&mut boxed_stream, &tx, "test-session")
2032            .await
2033            .expect("should succeed");
2034
2035        assert_eq!(response.content.len(), 2);
2036        if let ApiContentBlock::Text { text, .. } = &response.content[0] {
2037            assert_eq!(text, "Let me check.");
2038        } else {
2039            panic!("expected Text block at index 0");
2040        }
2041        if let ApiContentBlock::ToolUse { id, name, input } = &response.content[1] {
2042            assert_eq!(id, "toolu_abc");
2043            assert_eq!(name, "Read");
2044            assert_eq!(input["path"], "/tmp/f.txt");
2045        } else {
2046            panic!("expected ToolUse block at index 1");
2047        }
2048        assert_eq!(response.stop_reason, Some("tool_use".into()));
2049    }
2050
2051    /// OpenAI/Ollama streaming delivers the complete tool input inside
2052    /// `ContentBlockStart` (no `InputJsonDelta` follows). Verify that
2053    /// `accumulate_stream` preserves that input instead of defaulting to `{}`.
2054    #[tokio::test]
2055    async fn accumulate_stream_preserves_openai_tool_input() {
2056        use crate::client::{ApiContentBlock, ApiUsage, StreamEvent as SE};
2057
2058        let events: Vec<Result<SE>> = vec![
2059            Ok(SE::MessageStart {
2060                message: MessageResponse {
2061                    id: "msg_oai".into(),
2062                    role: "assistant".into(),
2063                    content: vec![],
2064                    model: "qwen3:8b".into(),
2065                    stop_reason: None,
2066                    usage: ApiUsage::default(),
2067                },
2068            }),
2069            // Tool use with full input in ContentBlockStart (OpenAI/Ollama pattern)
2070            Ok(SE::ContentBlockStart {
2071                index: 0,
2072                content_block: ApiContentBlock::ToolUse {
2073                    id: "call_123".into(),
2074                    name: "Bash".into(),
2075                    input: serde_json::json!({"command": "ls -la", "timeout": 5000}),
2076                },
2077            }),
2078            // No InputJsonDelta — OpenAI/Ollama doesn't send one
2079            Ok(SE::ContentBlockStop { index: 0 }),
2080            Ok(SE::MessageDelta {
2081                delta: crate::client::MessageDelta {
2082                    stop_reason: Some("tool_use".into()),
2083                },
2084                usage: ApiUsage {
2085                    input_tokens: 0,
2086                    output_tokens: 10,
2087                    ..Default::default()
2088                },
2089            }),
2090            Ok(SE::MessageStop),
2091        ];
2092
2093        let stream = futures::stream::iter(events);
2094        let mut boxed_stream: std::pin::Pin<Box<dyn futures::Stream<Item = Result<SE>> + Send>> =
2095            Box::pin(stream);
2096
2097        let (tx, _rx) = mpsc::unbounded_channel();
2098        let response = accumulate_stream(&mut boxed_stream, &tx, "test-session")
2099            .await
2100            .expect("should succeed");
2101
2102        assert_eq!(response.content.len(), 1);
2103        if let ApiContentBlock::ToolUse { id, name, input } = &response.content[0] {
2104            assert_eq!(id, "call_123");
2105            assert_eq!(name, "Bash");
2106            assert_eq!(input["command"], "ls -la");
2107            assert_eq!(input["timeout"], 5000);
2108        } else {
2109            panic!("expected ToolUse block");
2110        }
2111    }
2112}