1use std::collections::HashMap;
8use std::path::PathBuf;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use std::time::Instant;
12
13use futures::stream::FuturesUnordered;
14use futures::{Stream, StreamExt as FuturesStreamExt};
15use serde_json::json;
16use tokio::sync::mpsc;
17use tokio_stream::wrappers::UnboundedReceiverStream;
18use tracing::{debug, error, warn};
19use uuid::Uuid;
20
21use crate::client::{
22 ApiContentBlock, ApiMessage, ApiUsage, CacheControl, ContentDelta, CreateMessageRequest,
23 ImageSource, MessageResponse, StreamEvent as ClientStreamEvent, SystemBlock, ThinkingParam,
24 ToolDefinition,
25};
26use crate::compact;
27use crate::error::{AgentError, Result};
28use crate::hooks::HookRegistry;
29use crate::options::{Options, PermissionMode, ThinkingConfig};
30use crate::permissions::{PermissionEvaluator, PermissionVerdict};
31use crate::provider::LlmProvider;
32use crate::providers::AnthropicProvider;
33use crate::sanitize;
34use crate::session::Session;
35use crate::tools::definitions::get_tool_definitions;
36use crate::tools::executor::{ToolExecutor, ToolResult};
37use crate::types::messages::*;
38
39const DEFAULT_MODEL: &str = "claude-haiku-4-5";
41const DEFAULT_MAX_TOKENS: u32 = 16384;
43
44pub struct Query {
48 receiver: UnboundedReceiverStream<Result<Message>>,
49 session_id: Option<String>,
50 cancel_token: tokio_util::sync::CancellationToken,
51}
52
53impl Query {
54 pub async fn interrupt(&self) -> Result<()> {
56 self.cancel_token.cancel();
57 Ok(())
58 }
59
60 pub fn session_id(&self) -> Option<&str> {
62 self.session_id.as_deref()
63 }
64
65 pub async fn set_permission_mode(&self, _mode: PermissionMode) -> Result<()> {
67 Ok(())
69 }
70
71 pub async fn set_model(&self, _model: &str) -> Result<()> {
73 Ok(())
75 }
76
77 pub fn close(&self) {
79 self.cancel_token.cancel();
80 }
81}
82
83impl Stream for Query {
84 type Item = Result<Message>;
85
86 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
87 Pin::new(&mut self.receiver).poll_next(cx)
88 }
89}
90
91pub fn query(prompt: &str, options: Options) -> Query {
126 let (tx, rx) = mpsc::unbounded_channel();
127 let cancel_token = tokio_util::sync::CancellationToken::new();
128 let cancel = cancel_token.clone();
129
130 let prompt = prompt.to_string();
131
132 tokio::spawn(async move {
133 let result = run_agent_loop(prompt, options, tx.clone(), cancel).await;
134 if let Err(e) = result {
135 let _ = tx.send(Err(e));
136 }
137 });
138
139 Query {
140 receiver: UnboundedReceiverStream::new(rx),
141 session_id: None,
142 cancel_token,
143 }
144}
145
146async fn run_agent_loop(
156 prompt: String,
157 mut options: Options,
158 tx: mpsc::UnboundedSender<Result<Message>>,
159 cancel: tokio_util::sync::CancellationToken,
160) -> Result<()> {
161 let start_time = Instant::now();
162 let mut api_time_ms: u64 = 0;
163
164 let cwd = options.cwd.clone().unwrap_or_else(|| {
166 std::env::current_dir()
167 .unwrap_or_else(|_| PathBuf::from("."))
168 .to_string_lossy()
169 .to_string()
170 });
171
172 let session = if let Some(ref resume_id) = options.resume {
174 Session::with_id(resume_id, &cwd)
175 } else if options.continue_session {
176 match crate::session::find_most_recent_session(Some(&cwd)).await? {
178 Some(info) => Session::with_id(&info.session_id, &cwd),
179 None => Session::new(&cwd),
180 }
181 } else {
182 match &options.session_id {
183 Some(id) => Session::with_id(id, &cwd),
184 None => Session::new(&cwd),
185 }
186 };
187
188 let session_id = session.id.clone();
189 let model = options
190 .model
191 .clone()
192 .unwrap_or_else(|| DEFAULT_MODEL.to_string());
193
194 let tool_names: Vec<String> = if options.output_format.is_some() {
197 Vec::new()
198 } else if options.allowed_tools.is_empty() {
199 vec![
201 "Read".into(),
202 "Write".into(),
203 "Edit".into(),
204 "Bash".into(),
205 "Glob".into(),
206 "Grep".into(),
207 ]
208 } else {
209 options.allowed_tools.clone()
210 };
211
212 let raw_defs: Vec<_> = get_tool_definitions(&tool_names);
213
214 let mut all_defs: Vec<ToolDefinition> = raw_defs
216 .into_iter()
217 .map(|td| ToolDefinition {
218 name: td.name.to_string(),
219 description: td.description.to_string(),
220 input_schema: td.input_schema,
221 cache_control: None,
222 })
223 .collect();
224
225 for ctd in &options.custom_tool_definitions {
227 all_defs.push(ToolDefinition {
228 name: ctd.name.clone(),
229 description: ctd.description.clone(),
230 input_schema: ctd.input_schema.clone(),
231 cache_control: None,
232 });
233 }
234
235 if let Some(last) = all_defs.last_mut() {
237 last.cache_control = Some(CacheControl::ephemeral());
238 }
239
240 let tool_defs = all_defs;
241
242 let init_msg = Message::System(SystemMessage {
244 subtype: SystemSubtype::Init,
245 uuid: Uuid::new_v4(),
246 session_id: session_id.clone(),
247 agents: if options.agents.is_empty() {
248 None
249 } else {
250 Some(options.agents.keys().cloned().collect())
251 },
252 claude_code_version: Some(env!("CARGO_PKG_VERSION").to_string()),
253 cwd: Some(cwd.clone()),
254 tools: Some(tool_names.clone()),
255 mcp_servers: if options.mcp_servers.is_empty() {
256 None
257 } else {
258 Some(
259 options
260 .mcp_servers
261 .keys()
262 .map(|name| McpServerStatus {
263 name: name.clone(),
264 status: "connected".to_string(),
265 })
266 .collect(),
267 )
268 },
269 model: Some(model.clone()),
270 permission_mode: Some(options.permission_mode.to_string()),
271 compact_metadata: None,
272 });
273
274 if options.persist_session {
276 let _ = session
277 .append_message(&serde_json::to_value(&init_msg).unwrap_or_default())
278 .await;
279 }
280 if tx.send(Ok(init_msg)).is_err() {
281 return Ok(());
282 }
283
284 let provider: Box<dyn LlmProvider> = match options.provider.take() {
286 Some(p) => p,
287 None => Box::new(AnthropicProvider::from_env()?),
288 };
289
290 let additional_dirs: Vec<PathBuf> = options
292 .additional_directories
293 .iter()
294 .map(PathBuf::from)
295 .collect();
296 let env_blocklist = std::mem::take(&mut options.env_blocklist);
297 let env_inject = std::mem::take(&mut options.env);
298 #[cfg(unix)]
299 let pre_exec_fn = options.pre_exec_fn.take();
300 let mut tool_executor = if additional_dirs.is_empty() {
301 ToolExecutor::new(PathBuf::from(&cwd))
302 } else {
303 ToolExecutor::with_allowed_dirs(PathBuf::from(&cwd), additional_dirs)
304 }
305 .with_env_blocklist(env_blocklist)
306 .with_env_inject(env_inject);
307 #[cfg(unix)]
308 if let Some(f) = pre_exec_fn {
309 tool_executor = tool_executor.with_pre_exec(f);
310 }
311
312 let mut hook_registry = HookRegistry::from_map(std::mem::take(&mut options.hooks));
314 if !options.hook_dirs.is_empty() {
315 let dirs: Vec<&std::path::Path> = options.hook_dirs.iter().map(|p| p.as_path()).collect();
316 match crate::hooks::HookDiscovery::discover(&dirs) {
317 Ok(discovered) => hook_registry.merge(discovered),
318 Err(e) => tracing::warn!("Failed to discover hooks from dirs: {}", e),
319 }
320 }
321
322 let mut followup_rx = options.followup_rx.take();
324
325 let permission_eval = PermissionEvaluator::new(&options);
327
328 let system_prompt: Option<Vec<SystemBlock>> = {
330 let text = match &options.system_prompt {
331 Some(crate::options::SystemPrompt::Custom(s)) => s.clone(),
332 Some(crate::options::SystemPrompt::Preset { append, .. }) => {
333 let base = "You are Claude, an AI assistant. You have access to tools to help accomplish tasks.";
334 match append {
335 Some(extra) => format!("{}\n\n{}", base, extra),
336 None => base.to_string(),
337 }
338 }
339 None => "You are Claude, an AI assistant. You have access to tools to help accomplish tasks.".to_string(),
340 };
341 Some(vec![SystemBlock {
342 kind: "text".to_string(),
343 text,
344 cache_control: Some(CacheControl::ephemeral()),
345 }])
346 };
347
348 let mut conversation: Vec<ApiMessage> = Vec::new();
350
351 if options.resume.is_some() || options.continue_session {
353 let prev_messages = session.load_messages().await?;
354 for msg_value in prev_messages {
355 if let Some(api_msg) = value_to_api_message(&msg_value) {
356 conversation.push(api_msg);
357 }
358 }
359 }
360
361 {
363 let mut content_blocks: Vec<ApiContentBlock> = Vec::new();
364
365 for att in &options.attachments {
367 let is_image = matches!(
368 att.mime_type.as_str(),
369 "image/png" | "image/jpeg" | "image/gif" | "image/webp"
370 );
371 if is_image {
372 content_blocks.push(ApiContentBlock::Image {
373 source: ImageSource {
374 kind: "base64".to_string(),
375 media_type: att.mime_type.clone(),
376 data: att.base64_data.clone(),
377 },
378 });
379 }
380 }
381
382 content_blocks.push(ApiContentBlock::Text {
384 text: prompt.clone(),
385 cache_control: None,
386 });
387
388 conversation.push(ApiMessage {
389 role: "user".to_string(),
390 content: content_blocks,
391 });
392 }
393
394 if options.persist_session {
396 let user_msg = json!({
397 "type": "user",
398 "uuid": Uuid::new_v4().to_string(),
399 "session_id": &session_id,
400 "content": [{"type": "text", "text": &prompt}]
401 });
402 let _ = session.append_message(&user_msg).await;
403 }
404
405 let mut num_turns: u32 = 0;
407 let mut total_usage = Usage::default();
408 let mut total_cost: f64 = 0.0;
409 let mut model_usage: HashMap<String, ModelUsage> = HashMap::new();
410 let mut permission_denials: Vec<PermissionDenial> = Vec::new();
411
412 loop {
413 if cancel.is_cancelled() {
415 return Err(AgentError::Cancelled);
416 }
417
418 if let Some(max_turns) = options.max_turns {
420 if num_turns >= max_turns {
421 let result_msg = build_result_message(
422 ResultSubtype::ErrorMaxTurns,
423 &session_id,
424 None,
425 start_time,
426 api_time_ms,
427 num_turns,
428 total_cost,
429 &total_usage,
430 &model_usage,
431 &permission_denials,
432 );
433 let _ = tx.send(Ok(result_msg));
434 return Ok(());
435 }
436 }
437
438 if let Some(max_budget) = options.max_budget_usd {
440 if total_cost >= max_budget {
441 let result_msg = build_result_message(
442 ResultSubtype::ErrorMaxBudgetUsd,
443 &session_id,
444 None,
445 start_time,
446 api_time_ms,
447 num_turns,
448 total_cost,
449 &total_usage,
450 &model_usage,
451 &permission_denials,
452 );
453 let _ = tx.send(Ok(result_msg));
454 return Ok(());
455 }
456 }
457
458 if let Some(ref mut followup_rx) = followup_rx {
462 let mut followups: Vec<String> = Vec::new();
463 while let Ok(msg) = followup_rx.try_recv() {
464 followups.push(msg);
465 }
466 if !followups.is_empty() {
467 let combined = followups.join("\n\n");
468 debug!(
469 count = followups.len(),
470 "Injecting followup messages into agent loop"
471 );
472
473 conversation.push(ApiMessage {
474 role: "user".to_string(),
475 content: vec![ApiContentBlock::Text {
476 text: combined.clone(),
477 cache_control: None,
478 }],
479 });
480
481 let followup_msg = Message::User(UserMessage {
483 uuid: Some(Uuid::new_v4()),
484 session_id: session_id.clone(),
485 content: vec![ContentBlock::Text { text: combined }],
486 parent_tool_use_id: None,
487 is_synthetic: false,
488 tool_use_result: None,
489 });
490
491 if options.persist_session {
492 let _ = session
493 .append_message(&serde_json::to_value(&followup_msg).unwrap_or_default())
494 .await;
495 }
496 if tx.send(Ok(followup_msg)).is_err() {
497 return Ok(());
498 }
499 }
500 }
501
502 apply_cache_breakpoint(&mut conversation);
506
507 let thinking_param = options.thinking.as_ref().map(|tc| match tc {
509 ThinkingConfig::Adaptive => ThinkingParam {
510 kind: "enabled".into(),
511 budget_tokens: Some(10240),
512 },
513 ThinkingConfig::Disabled => ThinkingParam {
514 kind: "disabled".into(),
515 budget_tokens: None,
516 },
517 ThinkingConfig::Enabled { budget_tokens } => ThinkingParam {
518 kind: "enabled".into(),
519 budget_tokens: Some(*budget_tokens),
520 },
521 });
522
523 let base_max_tokens = options.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS);
525 let max_tokens = if let Some(ref tp) = thinking_param {
526 if let Some(budget) = tp.budget_tokens {
527 base_max_tokens.max(budget as u32 + 8192)
528 } else {
529 base_max_tokens
530 }
531 } else {
532 base_max_tokens
533 };
534
535 let use_streaming = options.include_partial_messages;
537 let request = CreateMessageRequest {
538 model: model.clone(),
539 max_tokens,
540 messages: conversation.clone(),
541 system: system_prompt.clone(),
542 tools: if tool_defs.is_empty() {
543 None
544 } else {
545 Some(tool_defs.clone())
546 },
547 stream: use_streaming,
548 metadata: None,
549 thinking: thinking_param,
550 };
551
552 let api_start = Instant::now();
554 let response = if use_streaming {
555 match provider.create_message_stream(&request).await {
557 Ok(mut event_stream) => {
558 match accumulate_stream(&mut event_stream, &tx, &session_id).await {
559 Ok(resp) => resp,
560 Err(e) => {
561 error!("Stream accumulation failed: {}", e);
562 let result_msg = build_error_result_message(
563 &session_id,
564 &format!("Stream error: {}", e),
565 start_time,
566 api_time_ms,
567 num_turns,
568 total_cost,
569 &total_usage,
570 &model_usage,
571 &permission_denials,
572 );
573 let _ = tx.send(Ok(result_msg));
574 return Ok(());
575 }
576 }
577 }
578 Err(e) => {
579 error!("API stream call failed: {}", e);
580 let result_msg = build_error_result_message(
581 &session_id,
582 &format!("API error: {}", e),
583 start_time,
584 api_time_ms,
585 num_turns,
586 total_cost,
587 &total_usage,
588 &model_usage,
589 &permission_denials,
590 );
591 let _ = tx.send(Ok(result_msg));
592 return Ok(());
593 }
594 }
595 } else {
596 match provider.create_message(&request).await {
598 Ok(resp) => resp,
599 Err(e) => {
600 error!("API call failed: {}", e);
601 let result_msg = build_error_result_message(
602 &session_id,
603 &format!("API error: {}", e),
604 start_time,
605 api_time_ms,
606 num_turns,
607 total_cost,
608 &total_usage,
609 &model_usage,
610 &permission_denials,
611 );
612 let _ = tx.send(Ok(result_msg));
613 return Ok(());
614 }
615 }
616 };
617 api_time_ms += api_start.elapsed().as_millis() as u64;
618
619 total_usage.input_tokens += response.usage.input_tokens;
621 total_usage.output_tokens += response.usage.output_tokens;
622 total_usage.cache_creation_input_tokens +=
623 response.usage.cache_creation_input_tokens.unwrap_or(0);
624 total_usage.cache_read_input_tokens += response.usage.cache_read_input_tokens.unwrap_or(0);
625
626 let rates = provider.cost_rates(&model);
628 let turn_cost = rates.compute_with_cache(
629 response.usage.input_tokens,
630 response.usage.output_tokens,
631 response.usage.cache_read_input_tokens.unwrap_or(0),
632 response.usage.cache_creation_input_tokens.unwrap_or(0),
633 );
634 total_cost += turn_cost;
635
636 let model_entry = model_usage.entry(model.clone()).or_default();
638 model_entry.input_tokens += response.usage.input_tokens;
639 model_entry.output_tokens += response.usage.output_tokens;
640 model_entry.cost_usd += turn_cost;
641
642 let content_blocks: Vec<ContentBlock> = response
644 .content
645 .iter()
646 .map(api_block_to_content_block)
647 .collect();
648
649 let assistant_msg = Message::Assistant(AssistantMessage {
651 uuid: Uuid::new_v4(),
652 session_id: session_id.clone(),
653 content: content_blocks.clone(),
654 model: response.model.clone(),
655 stop_reason: response.stop_reason.clone(),
656 parent_tool_use_id: None,
657 usage: Some(Usage {
658 input_tokens: response.usage.input_tokens,
659 output_tokens: response.usage.output_tokens,
660 cache_creation_input_tokens: response
661 .usage
662 .cache_creation_input_tokens
663 .unwrap_or(0),
664 cache_read_input_tokens: response.usage.cache_read_input_tokens.unwrap_or(0),
665 }),
666 error: None,
667 });
668
669 if options.persist_session {
670 let _ = session
671 .append_message(&serde_json::to_value(&assistant_msg).unwrap_or_default())
672 .await;
673 }
674 if tx.send(Ok(assistant_msg)).is_err() {
675 return Ok(());
676 }
677
678 conversation.push(ApiMessage {
680 role: "assistant".to_string(),
681 content: response.content.clone(),
682 });
683
684 let tool_uses: Vec<_> = response
686 .content
687 .iter()
688 .filter_map(|block| match block {
689 ApiContentBlock::ToolUse { id, name, input } => {
690 Some((id.clone(), name.clone(), input.clone()))
691 }
692 _ => None,
693 })
694 .collect();
695
696 if tool_uses.is_empty() {
698 let final_text = response
700 .content
701 .iter()
702 .filter_map(|block| match block {
703 ApiContentBlock::Text { text, .. } => Some(text.as_str()),
704 _ => None,
705 })
706 .collect::<Vec<_>>()
707 .join("");
708
709 let result_msg = build_result_message(
710 ResultSubtype::Success,
711 &session_id,
712 Some(final_text),
713 start_time,
714 api_time_ms,
715 num_turns,
716 total_cost,
717 &total_usage,
718 &model_usage,
719 &permission_denials,
720 );
721
722 if options.persist_session {
723 let _ = session
724 .append_message(&serde_json::to_value(&result_msg).unwrap_or_default())
725 .await;
726 }
727 let _ = tx.send(Ok(result_msg));
728 return Ok(());
729 }
730
731 num_turns += 1;
733 let mut tool_results: Vec<ApiContentBlock> = Vec::new();
734
735 let known_tool_names: std::collections::HashSet<&str> =
738 tool_defs.iter().map(|td| td.name.as_str()).collect();
739
740 let mut valid_tool_uses: Vec<&(String, String, serde_json::Value)> = Vec::new();
741 for tu in &tool_uses {
742 let (tool_use_id, tool_name, _tool_input) = tu;
743 if known_tool_names.contains(tool_name.as_str()) {
744 valid_tool_uses.push(tu);
745 } else {
746 warn!(tool = %tool_name, "model invoked unknown tool, returning error");
747 let available: Vec<&str> = tool_defs.iter().map(|td| td.name.as_str()).collect();
748 let error_msg = format!(
749 "Error: '{}' is not a valid tool. You MUST use one of the following tools: {}",
750 tool_name,
751 available.join(", ")
752 );
753 let api_block = ApiContentBlock::ToolResult {
754 tool_use_id: tool_use_id.clone(),
755 content: json!(error_msg),
756 is_error: Some(true),
757 cache_control: None,
758 name: Some(tool_name.clone()),
759 };
760
761 let result_msg = Message::User(UserMessage {
763 uuid: Some(Uuid::new_v4()),
764 session_id: session_id.clone(),
765 content: vec![api_block_to_content_block(&api_block)],
766 parent_tool_use_id: None,
767 is_synthetic: true,
768 tool_use_result: None,
769 });
770 if options.persist_session {
771 let _ = session
772 .append_message(&serde_json::to_value(&result_msg).unwrap_or_default())
773 .await;
774 }
775 if tx.send(Ok(result_msg)).is_err() {
776 return Ok(());
777 }
778
779 tool_results.push(api_block);
780 }
781 }
782
783 struct PermittedTool {
785 tool_use_id: String,
786 tool_name: String,
787 actual_input: serde_json::Value,
788 }
789 let mut permitted_tools: Vec<PermittedTool> = Vec::new();
790
791 for (tool_use_id, tool_name, tool_input) in valid_tool_uses.iter().map(|t| &**t) {
792 let verdict = permission_eval
793 .evaluate(tool_name, tool_input, tool_use_id, &session_id, &cwd)
794 .await?;
795
796 let actual_input = match &verdict {
797 PermissionVerdict::AllowWithUpdatedInput(new_input) => new_input.clone(),
798 _ => tool_input.clone(),
799 };
800
801 match verdict {
802 PermissionVerdict::Allow | PermissionVerdict::AllowWithUpdatedInput(_) => {
803 permitted_tools.push(PermittedTool {
804 tool_use_id: tool_use_id.clone(),
805 tool_name: tool_name.clone(),
806 actual_input,
807 });
808 }
809 PermissionVerdict::Deny { reason } => {
810 debug!(tool = %tool_name, reason = %reason, "Tool denied");
811 permission_denials.push(PermissionDenial {
812 tool_name: tool_name.clone(),
813 tool_use_id: tool_use_id.clone(),
814 tool_input: tool_input.clone(),
815 });
816
817 let api_block = ApiContentBlock::ToolResult {
818 tool_use_id: tool_use_id.clone(),
819 content: json!(format!("Permission denied: {}", reason)),
820 is_error: Some(true),
821 cache_control: None,
822 name: Some(tool_name.clone()),
823 };
824
825 let denial_msg = Message::User(UserMessage {
827 uuid: Some(Uuid::new_v4()),
828 session_id: session_id.clone(),
829 content: vec![api_block_to_content_block(&api_block)],
830 parent_tool_use_id: None,
831 is_synthetic: true,
832 tool_use_result: None,
833 });
834 if options.persist_session {
835 let _ = session
836 .append_message(&serde_json::to_value(&denial_msg).unwrap_or_default())
837 .await;
838 }
839 if tx.send(Ok(denial_msg)).is_err() {
840 return Ok(());
841 }
842
843 tool_results.push(api_block);
844 }
845 }
846 }
847
848 let mut futs: FuturesUnordered<_> = permitted_tools
850 .iter()
851 .map(|pt| {
852 let handler = &options.external_tool_handler;
853 let executor = &tool_executor;
854 let name = &pt.tool_name;
855 let input = &pt.actual_input;
856 let id = &pt.tool_use_id;
857 async move {
858 debug!(tool = %name, "Executing tool");
859
860 let tool_result = if let Some(ref handler) = handler {
861 let ext_result = handler(name.clone(), input.clone()).await;
862 if let Some(tr) = ext_result {
863 tr
864 } else {
865 match executor.execute(name, input.clone()).await {
866 Ok(tr) => tr,
867 Err(e) => ToolResult {
868 content: format!("{}", e),
869 is_error: true,
870 raw_content: None,
871 },
872 }
873 }
874 } else {
875 match executor.execute(name, input.clone()).await {
876 Ok(tr) => tr,
877 Err(e) => ToolResult {
878 content: format!("{}", e),
879 is_error: true,
880 raw_content: None,
881 },
882 }
883 };
884 (id.as_str(), name.as_str(), input, tool_result)
885 }
886 })
887 .collect();
888
889 while let Some((tool_use_id, tool_name, actual_input, mut tool_result)) = futs.next().await
890 {
891 let max_result_bytes = options
893 .max_tool_result_bytes
894 .unwrap_or(sanitize::DEFAULT_MAX_TOOL_RESULT_BYTES);
895 tool_result.content =
896 sanitize::sanitize_tool_result(&tool_result.content, max_result_bytes);
897
898 hook_registry
900 .run_post_tool_use(
901 tool_name,
902 actual_input,
903 &serde_json::to_value(&tool_result.content).unwrap_or_default(),
904 tool_use_id,
905 &session_id,
906 &cwd,
907 )
908 .await;
909
910 let result_content = tool_result
911 .raw_content
912 .unwrap_or_else(|| json!(tool_result.content));
913
914 let api_block = ApiContentBlock::ToolResult {
915 tool_use_id: tool_use_id.to_string(),
916 content: result_content,
917 is_error: if tool_result.is_error {
918 Some(true)
919 } else {
920 None
921 },
922 cache_control: None,
923 name: Some(tool_name.to_string()),
924 };
925
926 let result_msg = Message::User(UserMessage {
928 uuid: Some(Uuid::new_v4()),
929 session_id: session_id.clone(),
930 content: vec![api_block_to_content_block(&api_block)],
931 parent_tool_use_id: None,
932 is_synthetic: true,
933 tool_use_result: None,
934 });
935 if options.persist_session {
936 let _ = session
937 .append_message(&serde_json::to_value(&result_msg).unwrap_or_default())
938 .await;
939 }
940 if tx.send(Ok(result_msg)).is_err() {
941 return Ok(());
942 }
943
944 tool_results.push(api_block);
945 }
946
947 conversation.push(ApiMessage {
949 role: "user".to_string(),
950 content: tool_results,
951 });
952
953 if let Some(context_budget) = options.context_budget {
955 let prune_pct = options
956 .prune_threshold_pct
957 .unwrap_or(compact::DEFAULT_PRUNE_THRESHOLD_PCT);
958 if compact::should_prune(response.usage.input_tokens, context_budget, prune_pct) {
959 let max_chars = options
960 .prune_tool_result_max_chars
961 .unwrap_or(compact::DEFAULT_PRUNE_TOOL_RESULT_MAX_CHARS);
962 let min_keep = options.min_keep_messages.unwrap_or(4);
963 let removed = compact::prune_tool_results(&mut conversation, max_chars, min_keep);
964 if removed > 0 {
965 debug!(
966 chars_removed = removed,
967 input_tokens = response.usage.input_tokens,
968 "Pruned oversized tool results to free context space"
969 );
970 }
971 }
972 }
973
974 if let Some(context_budget) = options.context_budget {
976 if compact::should_compact(response.usage.input_tokens, context_budget) {
977 let min_keep = options.min_keep_messages.unwrap_or(4);
978 let split_point = compact::find_split_point(&conversation, min_keep);
979 if split_point > 0 {
980 debug!(
981 input_tokens = response.usage.input_tokens,
982 context_budget,
983 split_point,
984 "Context budget exceeded, compacting conversation"
985 );
986
987 let compaction_model = options
988 .compaction_model
989 .as_deref()
990 .unwrap_or(compact::DEFAULT_COMPACTION_MODEL);
991
992 if let Some(ref handler) = options.pre_compact_handler {
994 let msgs_to_compact = conversation[..split_point].to_vec();
995 handler(msgs_to_compact).await;
996 }
997
998 let summary_prompt =
999 compact::build_summary_prompt(&conversation[..split_point]);
1000
1001 let summary_max_tokens = options.summary_max_tokens.unwrap_or(4096);
1002 let compact_provider: &dyn LlmProvider = match &options.compaction_provider {
1003 Some(cp) => cp.as_ref(),
1004 None => provider.as_ref(),
1005 };
1006 let fallback_provider: Option<&dyn LlmProvider> =
1007 if options.compaction_provider.is_some() {
1008 Some(provider.as_ref())
1009 } else {
1010 None
1011 };
1012 match compact::call_summarizer(
1013 compact_provider,
1014 &summary_prompt,
1015 compaction_model,
1016 fallback_provider,
1017 &model,
1018 summary_max_tokens,
1019 )
1020 .await
1021 {
1022 Ok(summary) => {
1023 let pre_tokens = response.usage.input_tokens;
1024 let messages_compacted = split_point;
1025
1026 compact::splice_conversation(&mut conversation, split_point, &summary);
1027
1028 let compact_msg = Message::System(SystemMessage {
1030 subtype: SystemSubtype::CompactBoundary,
1031 uuid: Uuid::new_v4(),
1032 session_id: session_id.clone(),
1033 agents: None,
1034 claude_code_version: None,
1035 cwd: None,
1036 tools: None,
1037 mcp_servers: None,
1038 model: None,
1039 permission_mode: None,
1040 compact_metadata: Some(CompactMetadata {
1041 trigger: CompactTrigger::Auto,
1042 pre_tokens,
1043 }),
1044 });
1045
1046 if options.persist_session {
1047 let _ = session
1048 .append_message(
1049 &serde_json::to_value(&compact_msg).unwrap_or_default(),
1050 )
1051 .await;
1052 }
1053 let _ = tx.send(Ok(compact_msg));
1054
1055 debug!(
1056 pre_tokens,
1057 messages_compacted,
1058 summary_len = summary.len(),
1059 "Conversation compacted"
1060 );
1061 }
1062 Err(e) => {
1063 warn!("Compaction failed, continuing without compaction: {}", e);
1064 }
1065 }
1066 }
1067 }
1068 }
1069 }
1070}
1071
1072async fn accumulate_stream(
1075 event_stream: &mut std::pin::Pin<
1076 Box<dyn futures::Stream<Item = Result<ClientStreamEvent>> + Send>,
1077 >,
1078 tx: &mpsc::UnboundedSender<Result<Message>>,
1079 session_id: &str,
1080) -> Result<MessageResponse> {
1081 use crate::client::StreamEvent as SE;
1082
1083 let mut message_id = String::new();
1085 let mut model = String::new();
1086 let mut role = String::from("assistant");
1087 let mut content_blocks: Vec<ApiContentBlock> = Vec::new();
1088 let mut stop_reason: Option<String> = None;
1089 let mut usage = ApiUsage::default();
1090
1091 let mut block_texts: Vec<String> = Vec::new();
1094 let mut block_types: Vec<String> = Vec::new(); let mut block_tool_ids: Vec<String> = Vec::new();
1096 let mut block_tool_names: Vec<String> = Vec::new();
1097
1098 while let Some(event_result) = FuturesStreamExt::next(event_stream).await {
1099 let event = event_result?;
1100 match event {
1101 SE::MessageStart { message } => {
1102 message_id = message.id;
1103 model = message.model;
1104 role = message.role;
1105 usage = message.usage;
1106 }
1107 SE::ContentBlockStart {
1108 index,
1109 content_block,
1110 } => {
1111 while block_texts.len() <= index {
1113 block_texts.push(String::new());
1114 block_types.push(String::new());
1115 block_tool_ids.push(String::new());
1116 block_tool_names.push(String::new());
1117 }
1118 match &content_block {
1119 ApiContentBlock::Text { .. } => {
1120 block_types[index] = "text".to_string();
1121 }
1122 ApiContentBlock::ToolUse { id, name, input } => {
1123 block_types[index] = "tool_use".to_string();
1124 block_tool_ids[index] = id.clone();
1125 block_tool_names[index] = name.clone();
1126 let input_str = input.to_string();
1130 if input_str != "{}" {
1131 block_texts[index] = input_str;
1132 }
1133 }
1134 ApiContentBlock::Thinking { .. } => {
1135 block_types[index] = "thinking".to_string();
1136 }
1137 _ => {}
1138 }
1139 }
1140 SE::ContentBlockDelta { index, delta } => {
1141 while block_texts.len() <= index {
1142 block_texts.push(String::new());
1143 block_types.push(String::new());
1144 block_tool_ids.push(String::new());
1145 block_tool_names.push(String::new());
1146 }
1147 match &delta {
1148 ContentDelta::TextDelta { text } => {
1149 block_texts[index].push_str(text);
1150 let stream_event = Message::StreamEvent(StreamEventMessage {
1152 event: serde_json::json!({
1153 "type": "content_block_delta",
1154 "index": index,
1155 "delta": { "type": "text_delta", "text": text }
1156 }),
1157 parent_tool_use_id: None,
1158 uuid: Uuid::new_v4(),
1159 session_id: session_id.to_string(),
1160 });
1161 if tx.send(Ok(stream_event)).is_err() {
1162 return Err(AgentError::Cancelled);
1163 }
1164 }
1165 ContentDelta::InputJsonDelta { partial_json } => {
1166 block_texts[index].push_str(partial_json);
1167 }
1168 ContentDelta::ThinkingDelta { thinking } => {
1169 block_texts[index].push_str(thinking);
1170 }
1171 }
1172 }
1173 SE::ContentBlockStop { index } => {
1174 if index < block_types.len() {
1175 let block = match block_types[index].as_str() {
1176 "text" => ApiContentBlock::Text {
1177 text: std::mem::take(&mut block_texts[index]),
1178 cache_control: None,
1179 },
1180 "tool_use" => {
1181 let input: serde_json::Value =
1182 serde_json::from_str(&block_texts[index])
1183 .unwrap_or(serde_json::Value::Object(Default::default()));
1184 ApiContentBlock::ToolUse {
1185 id: std::mem::take(&mut block_tool_ids[index]),
1186 name: std::mem::take(&mut block_tool_names[index]),
1187 input,
1188 }
1189 }
1190 "thinking" => ApiContentBlock::Thinking {
1191 thinking: std::mem::take(&mut block_texts[index]),
1192 },
1193 _ => continue,
1194 };
1195 while content_blocks.len() <= index {
1197 content_blocks.push(ApiContentBlock::Text {
1198 text: String::new(),
1199 cache_control: None,
1200 });
1201 }
1202 content_blocks[index] = block;
1203 }
1204 }
1205 SE::MessageDelta {
1206 delta,
1207 usage: delta_usage,
1208 } => {
1209 stop_reason = delta.stop_reason;
1210 usage.output_tokens = delta_usage.output_tokens;
1212 }
1213 SE::MessageStop => {
1214 break;
1215 }
1216 SE::Error { error } => {
1217 return Err(AgentError::Api(error.message));
1218 }
1219 SE::Ping => {}
1220 }
1221 }
1222
1223 Ok(MessageResponse {
1224 id: message_id,
1225 role,
1226 content: content_blocks,
1227 model,
1228 stop_reason,
1229 usage,
1230 })
1231}
1232
1233fn apply_cache_breakpoint(conversation: &mut [ApiMessage]) {
1238 for msg in conversation.iter_mut() {
1240 for block in msg.content.iter_mut() {
1241 match block {
1242 ApiContentBlock::Text { cache_control, .. }
1243 | ApiContentBlock::ToolResult { cache_control, .. } => {
1244 *cache_control = None;
1245 }
1246 ApiContentBlock::Image { .. }
1247 | ApiContentBlock::ToolUse { .. }
1248 | ApiContentBlock::Thinking { .. } => {}
1249 }
1250 }
1251 }
1252
1253 if let Some(last_user) = conversation.iter_mut().rev().find(|m| m.role == "user") {
1255 if let Some(last_block) = last_user.content.last_mut() {
1256 match last_block {
1257 ApiContentBlock::Text { cache_control, .. }
1258 | ApiContentBlock::ToolResult { cache_control, .. } => {
1259 *cache_control = Some(CacheControl::ephemeral());
1260 }
1261 ApiContentBlock::Image { .. }
1262 | ApiContentBlock::ToolUse { .. }
1263 | ApiContentBlock::Thinking { .. } => {}
1264 }
1265 }
1266 }
1267}
1268
1269fn api_block_to_content_block(block: &ApiContentBlock) -> ContentBlock {
1271 match block {
1272 ApiContentBlock::Text { text, .. } => ContentBlock::Text { text: text.clone() },
1273 ApiContentBlock::Image { .. } => ContentBlock::Text {
1274 text: "[image]".to_string(),
1275 },
1276 ApiContentBlock::ToolUse { id, name, input } => ContentBlock::ToolUse {
1277 id: id.clone(),
1278 name: name.clone(),
1279 input: input.clone(),
1280 },
1281 ApiContentBlock::ToolResult {
1282 tool_use_id,
1283 content,
1284 is_error,
1285 ..
1286 } => ContentBlock::ToolResult {
1287 tool_use_id: tool_use_id.clone(),
1288 content: content.clone(),
1289 is_error: *is_error,
1290 },
1291 ApiContentBlock::Thinking { thinking } => ContentBlock::Thinking {
1292 thinking: thinking.clone(),
1293 },
1294 }
1295}
1296
1297fn value_to_api_message(value: &serde_json::Value) -> Option<ApiMessage> {
1299 let msg_type = value.get("type")?.as_str()?;
1300
1301 match msg_type {
1302 "assistant" => {
1303 let content = value.get("content")?;
1304 let blocks = parse_content_blocks(content)?;
1305 Some(ApiMessage {
1306 role: "assistant".to_string(),
1307 content: blocks,
1308 })
1309 }
1310 "user" => {
1311 let content = value.get("content")?;
1312 let blocks = parse_content_blocks(content)?;
1313 Some(ApiMessage {
1314 role: "user".to_string(),
1315 content: blocks,
1316 })
1317 }
1318 _ => None,
1319 }
1320}
1321
1322fn parse_content_blocks(content: &serde_json::Value) -> Option<Vec<ApiContentBlock>> {
1324 if let Some(text) = content.as_str() {
1325 return Some(vec![ApiContentBlock::Text {
1326 text: text.to_string(),
1327 cache_control: None,
1328 }]);
1329 }
1330
1331 if let Some(blocks) = content.as_array() {
1332 let parsed: Vec<ApiContentBlock> = blocks
1333 .iter()
1334 .filter_map(|b| serde_json::from_value(b.clone()).ok())
1335 .collect();
1336 if !parsed.is_empty() {
1337 return Some(parsed);
1338 }
1339 }
1340
1341 None
1342}
1343
1344#[allow(clippy::too_many_arguments)]
1346fn build_result_message(
1347 subtype: ResultSubtype,
1348 session_id: &str,
1349 result_text: Option<String>,
1350 start_time: Instant,
1351 api_time_ms: u64,
1352 num_turns: u32,
1353 total_cost: f64,
1354 usage: &Usage,
1355 model_usage: &HashMap<String, ModelUsage>,
1356 permission_denials: &[PermissionDenial],
1357) -> Message {
1358 Message::Result(ResultMessage {
1359 subtype,
1360 uuid: Uuid::new_v4(),
1361 session_id: session_id.to_string(),
1362 duration_ms: start_time.elapsed().as_millis() as u64,
1363 duration_api_ms: api_time_ms,
1364 is_error: result_text.is_none(),
1365 num_turns,
1366 result: result_text,
1367 stop_reason: Some("end_turn".to_string()),
1368 total_cost_usd: total_cost,
1369 usage: Some(usage.clone()),
1370 model_usage: model_usage.clone(),
1371 permission_denials: permission_denials.to_vec(),
1372 structured_output: None,
1373 errors: Vec::new(),
1374 })
1375}
1376
1377#[allow(clippy::too_many_arguments)]
1379fn build_error_result_message(
1380 session_id: &str,
1381 error_msg: &str,
1382 start_time: Instant,
1383 api_time_ms: u64,
1384 num_turns: u32,
1385 total_cost: f64,
1386 usage: &Usage,
1387 model_usage: &HashMap<String, ModelUsage>,
1388 permission_denials: &[PermissionDenial],
1389) -> Message {
1390 Message::Result(ResultMessage {
1391 subtype: ResultSubtype::ErrorDuringExecution,
1392 uuid: Uuid::new_v4(),
1393 session_id: session_id.to_string(),
1394 duration_ms: start_time.elapsed().as_millis() as u64,
1395 duration_api_ms: api_time_ms,
1396 is_error: true,
1397 num_turns,
1398 result: None,
1399 stop_reason: None,
1400 total_cost_usd: total_cost,
1401 usage: Some(usage.clone()),
1402 model_usage: model_usage.clone(),
1403 permission_denials: permission_denials.to_vec(),
1404 structured_output: None,
1405 errors: vec![error_msg.to_string()],
1406 })
1407}
1408
1409#[cfg(test)]
1410mod tests {
1411 use super::*;
1412 use std::sync::atomic::{AtomicUsize, Ordering};
1413 use std::sync::Arc;
1414 use std::time::Duration;
1415
1416 async fn run_concurrent_tools(
1419 tools: Vec<(String, String, serde_json::Value)>,
1420 handler: impl Fn(
1421 String,
1422 serde_json::Value,
1423 ) -> Pin<Box<dyn futures::Future<Output = Option<ToolResult>> + Send>>,
1424 ) -> Vec<(String, String, usize)> {
1425 let order = Arc::new(AtomicUsize::new(0));
1426 let handler = Arc::new(handler);
1427
1428 struct PermittedTool {
1429 tool_use_id: String,
1430 tool_name: String,
1431 actual_input: serde_json::Value,
1432 }
1433
1434 let permitted: Vec<PermittedTool> = tools
1435 .into_iter()
1436 .map(|(id, name, input)| PermittedTool {
1437 tool_use_id: id,
1438 tool_name: name,
1439 actual_input: input,
1440 })
1441 .collect();
1442
1443 let mut futs: FuturesUnordered<_> = permitted
1444 .iter()
1445 .map(|pt| {
1446 let handler = handler.clone();
1447 let order = order.clone();
1448 let name = pt.tool_name.clone();
1449 let input = pt.actual_input.clone();
1450 let id = pt.tool_use_id.clone();
1451 async move {
1452 let result = handler(name, input).await;
1453 let seq = order.fetch_add(1, Ordering::SeqCst);
1454 (id, result, seq)
1455 }
1456 })
1457 .collect();
1458
1459 let mut results = Vec::new();
1460 while let Some((id, result, seq)) = futs.next().await {
1461 let content = result
1462 .map(|r| r.content)
1463 .unwrap_or_else(|| "no handler".into());
1464 results.push((id, content, seq));
1465 }
1466 results
1467 }
1468
1469 #[tokio::test]
1470 async fn concurrent_tools_all_complete() {
1471 let results = run_concurrent_tools(
1472 vec![
1473 ("t1".into(), "Read".into(), json!({"path": "a.txt"})),
1474 ("t2".into(), "Read".into(), json!({"path": "b.txt"})),
1475 ("t3".into(), "Read".into(), json!({"path": "c.txt"})),
1476 ],
1477 |name, input| {
1478 Box::pin(async move {
1479 let path = input["path"].as_str().unwrap_or("?");
1480 Some(ToolResult {
1481 content: format!("{}: {}", name, path),
1482 is_error: false,
1483 raw_content: None,
1484 })
1485 })
1486 },
1487 )
1488 .await;
1489
1490 assert_eq!(results.len(), 3);
1491 let ids: Vec<&str> = results.iter().map(|(id, _, _)| id.as_str()).collect();
1492 assert!(ids.contains(&"t1"));
1493 assert!(ids.contains(&"t2"));
1494 assert!(ids.contains(&"t3"));
1495 }
1496
1497 #[tokio::test]
1498 async fn slow_tool_does_not_block_fast_tools() {
1499 let start = Instant::now();
1500
1501 let results = run_concurrent_tools(
1502 vec![
1503 ("slow".into(), "Bash".into(), json!({})),
1504 ("fast1".into(), "Read".into(), json!({})),
1505 ("fast2".into(), "Read".into(), json!({})),
1506 ],
1507 |name, _input| {
1508 Box::pin(async move {
1509 if name == "Bash" {
1510 tokio::time::sleep(Duration::from_millis(200)).await;
1511 Some(ToolResult {
1512 content: "slow done".into(),
1513 is_error: false,
1514 raw_content: None,
1515 })
1516 } else {
1517 Some(ToolResult {
1519 content: "fast done".into(),
1520 is_error: false,
1521 raw_content: None,
1522 })
1523 }
1524 })
1525 },
1526 )
1527 .await;
1528
1529 let elapsed = start.elapsed();
1530
1531 assert_eq!(results.len(), 3);
1533
1534 let slow = results.iter().find(|(id, _, _)| id == "slow").unwrap();
1536 let fast1 = results.iter().find(|(id, _, _)| id == "fast1").unwrap();
1537 let fast2 = results.iter().find(|(id, _, _)| id == "fast2").unwrap();
1538
1539 assert!(fast1.2 < slow.2, "fast1 should complete before slow");
1540 assert!(fast2.2 < slow.2, "fast2 should complete before slow");
1541
1542 assert!(
1544 elapsed < Duration::from_millis(400),
1545 "elapsed {:?} should be under 400ms (concurrent execution)",
1546 elapsed
1547 );
1548 }
1549
1550 #[tokio::test]
1551 async fn results_streamed_individually_as_they_complete() {
1552 let (tx, mut rx) = mpsc::unbounded_channel::<(String, String)>();
1555
1556 let tools = vec![
1557 ("t_slow".into(), "Slow".into(), json!({})),
1558 ("t_fast".into(), "Fast".into(), json!({})),
1559 ];
1560
1561 struct PT {
1562 tool_use_id: String,
1563 tool_name: String,
1564 }
1565
1566 let permitted: Vec<PT> = tools
1567 .into_iter()
1568 .map(|(id, name, _)| PT {
1569 tool_use_id: id,
1570 tool_name: name,
1571 })
1572 .collect();
1573
1574 let mut futs: FuturesUnordered<_> = permitted
1575 .iter()
1576 .map(|pt| {
1577 let name = pt.tool_name.clone();
1578 let id = pt.tool_use_id.clone();
1579 async move {
1580 if name == "Slow" {
1581 tokio::time::sleep(Duration::from_millis(100)).await;
1582 }
1583 let result = ToolResult {
1584 content: format!("{} result", name),
1585 is_error: false,
1586 raw_content: None,
1587 };
1588 (id, result)
1589 }
1590 })
1591 .collect();
1592
1593 while let Some((id, result)) = futs.next().await {
1595 tx.send((id, result.content)).unwrap();
1596 }
1597 drop(tx);
1598
1599 let mut streamed = Vec::new();
1601 while let Some(item) = rx.recv().await {
1602 streamed.push(item);
1603 }
1604
1605 assert_eq!(streamed.len(), 2);
1606 assert_eq!(streamed[0].0, "t_fast");
1608 assert_eq!(streamed[0].1, "Fast result");
1609 assert_eq!(streamed[1].0, "t_slow");
1610 assert_eq!(streamed[1].1, "Slow result");
1611 }
1612
1613 #[tokio::test]
1614 async fn error_tool_does_not_prevent_other_tools() {
1615 let results = run_concurrent_tools(
1616 vec![
1617 ("t_ok".into(), "Read".into(), json!({})),
1618 ("t_err".into(), "Fail".into(), json!({})),
1619 ],
1620 |name, _input| {
1621 Box::pin(async move {
1622 if name == "Fail" {
1623 Some(ToolResult {
1624 content: "something went wrong".into(),
1625 is_error: true,
1626 raw_content: None,
1627 })
1628 } else {
1629 Some(ToolResult {
1630 content: "ok".into(),
1631 is_error: false,
1632 raw_content: None,
1633 })
1634 }
1635 })
1636 },
1637 )
1638 .await;
1639
1640 assert_eq!(results.len(), 2);
1641 let ok = results.iter().find(|(id, _, _)| id == "t_ok").unwrap();
1642 let err = results.iter().find(|(id, _, _)| id == "t_err").unwrap();
1643 assert_eq!(ok.1, "ok");
1644 assert_eq!(err.1, "something went wrong");
1645 }
1646
1647 #[tokio::test]
1648 async fn external_handler_none_falls_through_correctly() {
1649 let results = run_concurrent_tools(
1652 vec![
1653 ("t_custom".into(), "MyTool".into(), json!({"x": 1})),
1654 ("t_builtin".into(), "Read".into(), json!({"path": "/tmp"})),
1655 ],
1656 |name, _input| {
1657 Box::pin(async move {
1658 if name == "MyTool" {
1659 Some(ToolResult {
1660 content: "custom handled".into(),
1661 is_error: false,
1662 raw_content: None,
1663 })
1664 } else {
1665 None
1667 }
1668 })
1669 },
1670 )
1671 .await;
1672
1673 assert_eq!(results.len(), 2);
1674 let custom = results.iter().find(|(id, _, _)| id == "t_custom").unwrap();
1675 let builtin = results.iter().find(|(id, _, _)| id == "t_builtin").unwrap();
1676 assert_eq!(custom.1, "custom handled");
1677 assert_eq!(builtin.1, "no handler"); }
1679
1680 #[tokio::test]
1681 async fn single_tool_works_same_as_before() {
1682 let results = run_concurrent_tools(
1683 vec![("t1".into(), "Read".into(), json!({"path": "file.txt"}))],
1684 |_name, _input| {
1685 Box::pin(async move {
1686 Some(ToolResult {
1687 content: "file contents".into(),
1688 is_error: false,
1689 raw_content: None,
1690 })
1691 })
1692 },
1693 )
1694 .await;
1695
1696 assert_eq!(results.len(), 1);
1697 assert_eq!(results[0].0, "t1");
1698 assert_eq!(results[0].1, "file contents");
1699 assert_eq!(results[0].2, 0); }
1701
1702 #[tokio::test]
1703 async fn empty_tool_list_produces_no_results() {
1704 let results =
1705 run_concurrent_tools(vec![], |_name, _input| Box::pin(async move { None })).await;
1706
1707 assert_eq!(results.len(), 0);
1708 }
1709
1710 #[tokio::test]
1711 async fn tool_use_ids_preserved_through_concurrent_execution() {
1712 let results = run_concurrent_tools(
1713 vec![
1714 ("toolu_abc123".into(), "Read".into(), json!({})),
1715 ("toolu_def456".into(), "Write".into(), json!({})),
1716 ("toolu_ghi789".into(), "Bash".into(), json!({})),
1717 ],
1718 |name, _input| {
1719 Box::pin(async move {
1720 match name.as_str() {
1722 "Read" => tokio::time::sleep(Duration::from_millis(30)).await,
1723 "Write" => tokio::time::sleep(Duration::from_millis(10)).await,
1724 _ => tokio::time::sleep(Duration::from_millis(50)).await,
1725 }
1726 Some(ToolResult {
1727 content: format!("{} result", name),
1728 is_error: false,
1729 raw_content: None,
1730 })
1731 })
1732 },
1733 )
1734 .await;
1735
1736 assert_eq!(results.len(), 3);
1737
1738 for (id, content, _) in &results {
1740 match id.as_str() {
1741 "toolu_abc123" => assert_eq!(content, "Read result"),
1742 "toolu_def456" => assert_eq!(content, "Write result"),
1743 "toolu_ghi789" => assert_eq!(content, "Bash result"),
1744 other => panic!("unexpected tool_use_id: {}", other),
1745 }
1746 }
1747 }
1748
1749 #[tokio::test]
1750 async fn concurrent_execution_timing_is_parallel() {
1751 let tools: Vec<_> = (0..5)
1753 .map(|i| (format!("t{}", i), "Tool".into(), json!({})))
1754 .collect();
1755
1756 let start = Instant::now();
1757
1758 let results = run_concurrent_tools(tools, |_name, _input| {
1759 Box::pin(async move {
1760 tokio::time::sleep(Duration::from_millis(50)).await;
1761 Some(ToolResult {
1762 content: "done".into(),
1763 is_error: false,
1764 raw_content: None,
1765 })
1766 })
1767 })
1768 .await;
1769
1770 let elapsed = start.elapsed();
1771
1772 assert_eq!(results.len(), 5);
1773 assert!(
1775 elapsed < Duration::from_millis(200),
1776 "5 x 50ms tools took {:?} — should be ~50ms if concurrent",
1777 elapsed
1778 );
1779 }
1780
1781 #[tokio::test]
1782 async fn api_block_to_content_block_preserves_tool_result_fields() {
1783 let block = ApiContentBlock::ToolResult {
1784 tool_use_id: "toolu_abc".into(),
1785 content: json!("result text"),
1786 is_error: Some(true),
1787 cache_control: None,
1788 name: None,
1789 };
1790
1791 let content = api_block_to_content_block(&block);
1792 match content {
1793 ContentBlock::ToolResult {
1794 tool_use_id,
1795 content,
1796 is_error,
1797 } => {
1798 assert_eq!(tool_use_id, "toolu_abc");
1799 assert_eq!(content, json!("result text"));
1800 assert_eq!(is_error, Some(true));
1801 }
1802 _ => panic!("expected ToolResult content block"),
1803 }
1804 }
1805
1806 #[tokio::test]
1807 async fn streamed_messages_each_contain_single_tool_result() {
1808 let (tx, mut rx) = mpsc::unbounded_channel::<Result<Message>>();
1810 let session_id = "test-session".to_string();
1811
1812 let tool_ids = vec!["t1", "t2", "t3"];
1814 for id in &tool_ids {
1815 let api_block = ApiContentBlock::ToolResult {
1816 tool_use_id: id.to_string(),
1817 content: json!(format!("result for {}", id)),
1818 is_error: None,
1819 cache_control: None,
1820 name: None,
1821 };
1822
1823 let result_msg = Message::User(UserMessage {
1824 uuid: Some(Uuid::new_v4()),
1825 session_id: session_id.clone(),
1826 content: vec![api_block_to_content_block(&api_block)],
1827 parent_tool_use_id: None,
1828 is_synthetic: true,
1829 tool_use_result: None,
1830 });
1831 tx.send(Ok(result_msg)).unwrap();
1832 }
1833 drop(tx);
1834
1835 let mut messages = Vec::new();
1836 while let Some(Ok(msg)) = rx.recv().await {
1837 messages.push(msg);
1838 }
1839
1840 assert_eq!(messages.len(), 3, "should have 3 individual messages");
1841
1842 for (i, msg) in messages.iter().enumerate() {
1843 if let Message::User(user) = msg {
1844 assert_eq!(
1845 user.content.len(),
1846 1,
1847 "each message should have exactly 1 content block"
1848 );
1849 assert!(user.is_synthetic);
1850 if let ContentBlock::ToolResult { tool_use_id, .. } = &user.content[0] {
1851 assert_eq!(tool_use_id, tool_ids[i]);
1852 } else {
1853 panic!("expected ToolResult block");
1854 }
1855 } else {
1856 panic!("expected User message");
1857 }
1858 }
1859 }
1860
1861 #[tokio::test]
1862 async fn accumulate_stream_emits_text_deltas_and_builds_response() {
1863 use crate::client::{
1864 ApiContentBlock, ApiUsage, ContentDelta, MessageResponse, StreamEvent as SE,
1865 };
1866
1867 let events: Vec<Result<SE>> = vec![
1869 Ok(SE::MessageStart {
1870 message: MessageResponse {
1871 id: "msg_123".into(),
1872 role: "assistant".into(),
1873 content: vec![],
1874 model: "claude-test".into(),
1875 stop_reason: None,
1876 usage: ApiUsage {
1877 input_tokens: 100,
1878 output_tokens: 0,
1879 cache_creation_input_tokens: None,
1880 cache_read_input_tokens: None,
1881 },
1882 },
1883 }),
1884 Ok(SE::ContentBlockStart {
1885 index: 0,
1886 content_block: ApiContentBlock::Text {
1887 text: String::new(),
1888 cache_control: None,
1889 },
1890 }),
1891 Ok(SE::ContentBlockDelta {
1892 index: 0,
1893 delta: ContentDelta::TextDelta {
1894 text: "Hello".into(),
1895 },
1896 }),
1897 Ok(SE::ContentBlockDelta {
1898 index: 0,
1899 delta: ContentDelta::TextDelta {
1900 text: " world".into(),
1901 },
1902 }),
1903 Ok(SE::ContentBlockDelta {
1904 index: 0,
1905 delta: ContentDelta::TextDelta { text: "!".into() },
1906 }),
1907 Ok(SE::ContentBlockStop { index: 0 }),
1908 Ok(SE::MessageDelta {
1909 delta: crate::client::MessageDelta {
1910 stop_reason: Some("end_turn".into()),
1911 },
1912 usage: ApiUsage {
1913 input_tokens: 0,
1914 output_tokens: 15,
1915 cache_creation_input_tokens: None,
1916 cache_read_input_tokens: None,
1917 },
1918 }),
1919 Ok(SE::MessageStop),
1920 ];
1921
1922 let stream = futures::stream::iter(events);
1923 let mut boxed_stream: std::pin::Pin<Box<dyn futures::Stream<Item = Result<SE>> + Send>> =
1924 Box::pin(stream);
1925
1926 let (tx, mut rx) = mpsc::unbounded_channel();
1927
1928 let response = accumulate_stream(&mut boxed_stream, &tx, "test-session")
1929 .await
1930 .expect("accumulate_stream should succeed");
1931
1932 assert_eq!(response.id, "msg_123");
1934 assert_eq!(response.model, "claude-test");
1935 assert_eq!(response.stop_reason, Some("end_turn".into()));
1936 assert_eq!(response.usage.output_tokens, 15);
1937 assert_eq!(response.content.len(), 1);
1938 if let ApiContentBlock::Text { text, .. } = &response.content[0] {
1939 assert_eq!(text, "Hello world!");
1940 } else {
1941 panic!("expected Text content block");
1942 }
1943
1944 let mut stream_events = Vec::new();
1946 while let Ok(msg) = rx.try_recv() {
1947 stream_events.push(msg.unwrap());
1948 }
1949 assert_eq!(stream_events.len(), 3);
1950
1951 let expected_texts = ["Hello", " world", "!"];
1953 for (i, msg) in stream_events.iter().enumerate() {
1954 if let Message::StreamEvent(se) = msg {
1955 let delta = se.event.get("delta").unwrap();
1956 let text = delta.get("text").unwrap().as_str().unwrap();
1957 assert_eq!(text, expected_texts[i]);
1958 assert_eq!(se.session_id, "test-session");
1959 } else {
1960 panic!("expected StreamEvent message at index {}", i);
1961 }
1962 }
1963 }
1964
1965 #[tokio::test]
1966 async fn accumulate_stream_handles_tool_use() {
1967 use crate::client::{
1968 ApiContentBlock, ApiUsage, ContentDelta, MessageResponse, StreamEvent as SE,
1969 };
1970
1971 let events: Vec<Result<SE>> = vec![
1972 Ok(SE::MessageStart {
1973 message: MessageResponse {
1974 id: "msg_456".into(),
1975 role: "assistant".into(),
1976 content: vec![],
1977 model: "claude-test".into(),
1978 stop_reason: None,
1979 usage: ApiUsage::default(),
1980 },
1981 }),
1982 Ok(SE::ContentBlockStart {
1984 index: 0,
1985 content_block: ApiContentBlock::Text {
1986 text: String::new(),
1987 cache_control: None,
1988 },
1989 }),
1990 Ok(SE::ContentBlockDelta {
1991 index: 0,
1992 delta: ContentDelta::TextDelta {
1993 text: "Let me check.".into(),
1994 },
1995 }),
1996 Ok(SE::ContentBlockStop { index: 0 }),
1997 Ok(SE::ContentBlockStart {
1999 index: 1,
2000 content_block: ApiContentBlock::ToolUse {
2001 id: "toolu_abc".into(),
2002 name: "Read".into(),
2003 input: serde_json::json!({}),
2004 },
2005 }),
2006 Ok(SE::ContentBlockDelta {
2007 index: 1,
2008 delta: ContentDelta::InputJsonDelta {
2009 partial_json: r#"{"path":"/tmp/f.txt"}"#.into(),
2010 },
2011 }),
2012 Ok(SE::ContentBlockStop { index: 1 }),
2013 Ok(SE::MessageDelta {
2014 delta: crate::client::MessageDelta {
2015 stop_reason: Some("tool_use".into()),
2016 },
2017 usage: ApiUsage {
2018 input_tokens: 0,
2019 output_tokens: 20,
2020 ..Default::default()
2021 },
2022 }),
2023 Ok(SE::MessageStop),
2024 ];
2025
2026 let stream = futures::stream::iter(events);
2027 let mut boxed_stream: std::pin::Pin<Box<dyn futures::Stream<Item = Result<SE>> + Send>> =
2028 Box::pin(stream);
2029
2030 let (tx, _rx) = mpsc::unbounded_channel();
2031 let response = accumulate_stream(&mut boxed_stream, &tx, "test-session")
2032 .await
2033 .expect("should succeed");
2034
2035 assert_eq!(response.content.len(), 2);
2036 if let ApiContentBlock::Text { text, .. } = &response.content[0] {
2037 assert_eq!(text, "Let me check.");
2038 } else {
2039 panic!("expected Text block at index 0");
2040 }
2041 if let ApiContentBlock::ToolUse { id, name, input } = &response.content[1] {
2042 assert_eq!(id, "toolu_abc");
2043 assert_eq!(name, "Read");
2044 assert_eq!(input["path"], "/tmp/f.txt");
2045 } else {
2046 panic!("expected ToolUse block at index 1");
2047 }
2048 assert_eq!(response.stop_reason, Some("tool_use".into()));
2049 }
2050
2051 #[tokio::test]
2055 async fn accumulate_stream_preserves_openai_tool_input() {
2056 use crate::client::{ApiContentBlock, ApiUsage, StreamEvent as SE};
2057
2058 let events: Vec<Result<SE>> = vec![
2059 Ok(SE::MessageStart {
2060 message: MessageResponse {
2061 id: "msg_oai".into(),
2062 role: "assistant".into(),
2063 content: vec![],
2064 model: "qwen3:8b".into(),
2065 stop_reason: None,
2066 usage: ApiUsage::default(),
2067 },
2068 }),
2069 Ok(SE::ContentBlockStart {
2071 index: 0,
2072 content_block: ApiContentBlock::ToolUse {
2073 id: "call_123".into(),
2074 name: "Bash".into(),
2075 input: serde_json::json!({"command": "ls -la", "timeout": 5000}),
2076 },
2077 }),
2078 Ok(SE::ContentBlockStop { index: 0 }),
2080 Ok(SE::MessageDelta {
2081 delta: crate::client::MessageDelta {
2082 stop_reason: Some("tool_use".into()),
2083 },
2084 usage: ApiUsage {
2085 input_tokens: 0,
2086 output_tokens: 10,
2087 ..Default::default()
2088 },
2089 }),
2090 Ok(SE::MessageStop),
2091 ];
2092
2093 let stream = futures::stream::iter(events);
2094 let mut boxed_stream: std::pin::Pin<Box<dyn futures::Stream<Item = Result<SE>> + Send>> =
2095 Box::pin(stream);
2096
2097 let (tx, _rx) = mpsc::unbounded_channel();
2098 let response = accumulate_stream(&mut boxed_stream, &tx, "test-session")
2099 .await
2100 .expect("should succeed");
2101
2102 assert_eq!(response.content.len(), 1);
2103 if let ApiContentBlock::ToolUse { id, name, input } = &response.content[0] {
2104 assert_eq!(id, "call_123");
2105 assert_eq!(name, "Bash");
2106 assert_eq!(input["command"], "ls -la");
2107 assert_eq!(input["timeout"], 5000);
2108 } else {
2109 panic!("expected ToolUse block");
2110 }
2111 }
2112}