1use std::sync::Arc;
17
18use async_stream::stream;
19use futures::future::join_all;
20use futures::stream::Stream;
21use serde_json::{json, Value};
22
23use crate::client::HttpClient;
24use crate::types::{
25 tool_result_msg, ChatContent, ChatMessage, ChatRequest, FunctionSchema, ToolSchema, UsageInfo,
26};
27
28use super::messages::{ContentBlock, ResultSubtype, SdkMessage, SystemSubtype};
29use super::options::{CompactionConfig, RunOptions};
30use super::permissions::{PermissionDecision, PermissionMode};
31use super::pricing::{map_stop_reason, turn_cost_usd};
32use super::tool::Tool;
33
34pub fn run<H>(
39 http: H,
40 api_key: String,
41 tools: Arc<Vec<Box<dyn Tool>>>,
42 user_prompt: String,
43 opts: RunOptions,
44) -> impl Stream<Item = SdkMessage>
45where
46 H: HttpClient + Send + Sync + 'static,
47{
48 stream! {
49 let session_id = opts
51 .session_id
52 .clone()
53 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
54 yield SdkMessage::System {
55 subtype: SystemSubtype::Init,
56 session_id: session_id.clone(),
57 data: json!({
58 "model": opts.model,
59 "permission_mode": opts.permission_mode,
60 "max_turns": opts.max_turns,
61 "max_budget_usd": opts.max_budget_usd,
62 }),
63 };
64
65 let visible_tools: Vec<&Box<dyn Tool>> = tools
67 .iter()
68 .filter(|t| {
69 let n = t.name();
70 if opts.disallowed_tools.iter().any(|d| d == n) {
71 return false;
72 }
73 if let Some(allow) = &opts.allowed_tools {
74 return allow.iter().any(|a| a == n);
75 }
76 true
77 })
78 .collect();
79
80 let tool_schemas: Vec<ToolSchema> = visible_tools
81 .iter()
82 .map(|t| {
83 let def = t.definition();
84 ToolSchema {
85 r#type: "function".into(),
86 function: FunctionSchema {
87 name: def.name,
88 description: def.description,
89 parameters: def.parameters,
90 },
91 }
92 })
93 .collect();
94
95 let mut messages: Vec<ChatMessage> = Vec::new();
97 if !opts.system_prompt.is_empty() {
98 messages.push(ChatMessage {
99 role: "system".into(),
100 content: ChatContent::Text(opts.system_prompt.clone()),
101 reasoning_content: None,
102 tool_calls: None,
103 tool_call_id: None,
104 name: None,
105 });
106 }
107 messages.push(ChatMessage {
108 role: "user".into(),
109 content: ChatContent::Text(user_prompt),
110 reasoning_content: None,
111 tool_calls: None,
112 tool_call_id: None,
113 name: None,
114 });
115
116 let url = format!("{}/chat/completions", opts.base_url);
117 let mut num_turns: u32 = 0;
118 let mut total_prompt_tokens: u32 = 0;
119 let mut total_completion_tokens: u32 = 0;
120 let mut total_cache_hit_tokens: u32 = 0;
121 let mut total_cache_miss_tokens: u32 = 0;
122 let mut any_cache_stats_seen = false;
123 let mut total_cost: Option<f64> =
124 super::pricing::model_pricing(&opts.model).map(|_| 0.0);
125 let mut last_stop_reason: Option<String> = None;
126 let mut last_turn_prompt_tokens: u32 = 0;
127
128 loop {
129 let request = ChatRequest {
130 model: opts.model.clone(),
131 messages: messages.clone(),
132 tools: if tool_schemas.is_empty() { None } else { Some(tool_schemas.clone()) },
133 tool_choice: if tool_schemas.is_empty() {
134 None
135 } else {
136 Some(json!("auto"))
137 },
138 temperature: Some(opts.effort.temperature()),
139 max_tokens: Some(opts.effort.max_tokens()),
140 stream: Some(false),
141 reasoning_effort: Some(match opts.effort {
142 crate::types::EffortLevel::Max => "max".into(),
143 crate::types::EffortLevel::High => "high".into(),
144 crate::types::EffortLevel::Medium => "medium".into(),
145 crate::types::EffortLevel::Low => "low".into(),
146 }),
147 thinking: Some(json!({"type": "enabled"})),
148 };
149
150 let resp = match http.post_json(&url, &api_key, &request).await {
151 Ok(r) => r,
152 Err(e) => {
153 tracing::warn!(error = %e, "agent loop transport error");
154 yield SdkMessage::Result {
155 subtype: ResultSubtype::ErrorDuringExecution,
156 result: None,
157 total_cost_usd: total_cost,
158 usage: usage_info(total_prompt_tokens, total_completion_tokens, total_cache_hit_tokens, total_cache_miss_tokens, any_cache_stats_seen),
159 num_turns,
160 session_id,
161 stop_reason: last_stop_reason,
162 };
163 return;
164 }
165 };
166
167 if let Some(u) = &resp.usage {
169 last_turn_prompt_tokens = u.prompt_tokens;
170 total_prompt_tokens = total_prompt_tokens.saturating_add(u.prompt_tokens);
171 total_completion_tokens = total_completion_tokens.saturating_add(u.completion_tokens);
172 if let Some(h) = u.prompt_cache_hit_tokens {
173 total_cache_hit_tokens = total_cache_hit_tokens.saturating_add(h);
174 any_cache_stats_seen = true;
175 }
176 if let Some(m) = u.prompt_cache_miss_tokens {
177 total_cache_miss_tokens = total_cache_miss_tokens.saturating_add(m);
178 any_cache_stats_seen = true;
179 }
180 if let (Some(running), Some(turn)) = (
181 total_cost.as_mut(),
182 turn_cost_usd(&opts.model, u),
183 ) {
184 *running += turn;
185 }
186 }
187
188 let Some(choice) = resp.choices.into_iter().next() else {
189 yield SdkMessage::Result {
190 subtype: ResultSubtype::ErrorDuringExecution,
191 result: None,
192 total_cost_usd: total_cost,
193 usage: usage_info(total_prompt_tokens, total_completion_tokens, total_cache_hit_tokens, total_cache_miss_tokens, any_cache_stats_seen),
194 num_turns,
195 session_id,
196 stop_reason: last_stop_reason,
197 };
198 return;
199 };
200
201 let finish_reason = choice.finish_reason.as_deref().unwrap_or("stop");
202 last_stop_reason = map_stop_reason(finish_reason);
203 let assistant_msg = choice.message;
204
205 if finish_reason == "tool_calls" {
206 let tool_calls = assistant_msg.tool_calls.clone().unwrap_or_default();
207
208 let mut content_blocks: Vec<ContentBlock> = Vec::new();
210 let text = assistant_msg.content.as_str();
211 if !text.is_empty() {
212 content_blocks.push(ContentBlock::Text { text: text.to_string() });
213 }
214 let parsed_calls: Vec<(String, String, Value)> = tool_calls
215 .iter()
216 .map(|c| {
217 let args: Value =
218 serde_json::from_str(&c.function.arguments).unwrap_or(json!({}));
219 (c.id.clone(), c.function.name.clone(), args)
220 })
221 .collect();
222 for (id, name, input) in &parsed_calls {
223 content_blocks.push(ContentBlock::ToolUse {
224 id: id.clone(),
225 name: name.clone(),
226 input: input.clone(),
227 });
228 }
229 yield SdkMessage::Assistant {
230 content: content_blocks,
231 stop_reason: last_stop_reason.clone(),
232 };
233
234 messages.push(assistant_msg);
236
237 let mut decisions: Vec<(String, String, Value, PermissionDecision, bool)> =
239 Vec::with_capacity(parsed_calls.len());
240 for (id, name, args) in parsed_calls {
241 let tool_ref = visible_tools.iter().find(|t| t.name() == name);
242 let read_only = tool_ref.map(|t| t.read_only_hint()).unwrap_or(false);
243
244 let mode_decision = opts.permission_mode.evaluate(&name, read_only);
245 let final_decision = match (mode_decision, &opts.pre_tool_hook) {
246 (PermissionDecision::Allow, _) => PermissionDecision::Allow,
247 (PermissionDecision::Deny(r), _) => PermissionDecision::Deny(r),
248 (PermissionDecision::Ask, Some(hook)) => {
249 match hook.check(&name, &args).await {
250 PermissionDecision::Ask => PermissionDecision::Deny(format!(
251 "Tool `{name}` requires approval and the hook returned Ask"
252 )),
253 d => d,
254 }
255 }
256 (PermissionDecision::Ask, None) => {
257 if matches!(opts.permission_mode, PermissionMode::BypassPermissions) {
258 PermissionDecision::Allow
259 } else {
260 PermissionDecision::Deny(format!(
261 "Tool `{name}` not pre-approved and no permission hook configured"
262 ))
263 }
264 }
265 };
266
267 decisions.push((id, name, args, final_decision, read_only));
268 }
269
270 let mut tool_results: Vec<(String, Result<String, String>)> = Vec::new();
272 let mut parallel_idxs: Vec<usize> = Vec::new();
273 let mut sequential_idxs: Vec<usize> = Vec::new();
274 for (i, (_, _, _, d, ro)) in decisions.iter().enumerate() {
275 if matches!(d, PermissionDecision::Allow) {
276 if *ro {
277 parallel_idxs.push(i);
278 } else {
279 sequential_idxs.push(i);
280 }
281 }
282 }
283
284 if !parallel_idxs.is_empty() {
286 let futs = parallel_idxs.iter().map(|&i| {
287 let (id, name, args, _, _) = &decisions[i];
288 let id = id.clone();
289 let name = name.clone();
290 let args = args.clone();
291 let tools = Arc::clone(&tools);
292 async move {
293 let res = match tools.iter().find(|t| t.name() == name) {
294 Some(t) => t.call_json(args).await,
295 None => Err(format!("Unknown tool: {name}")),
296 };
297 (id, res)
298 }
299 });
300 let outs = join_all(futs).await;
301 for (id, res) in outs {
302 tool_results.push((id, res));
303 }
304 }
305
306 for i in sequential_idxs {
308 let (id, name, args, _, _) = &decisions[i];
309 let res = match tools.iter().find(|t| t.name() == *name) {
310 Some(t) => t.call_json(args.clone()).await,
311 None => Err(format!("Unknown tool: {name}")),
312 };
313 tool_results.push((id.clone(), res));
314 }
315
316 for (id, _name, _args, d, _) in &decisions {
318 if let PermissionDecision::Deny(reason) = d {
319 tool_results.push((id.clone(), Err(reason.clone())));
320 }
321 }
322
323 let id_order: Vec<String> = decisions.iter().map(|d| d.0.clone()).collect();
326 tool_results.sort_by_key(|(id, _)| {
327 id_order.iter().position(|x| x == id).unwrap_or(usize::MAX)
328 });
329
330 let mut user_blocks: Vec<ContentBlock> = Vec::with_capacity(tool_results.len());
332 for (call_id, res) in &tool_results {
333 let (content_str, is_error) = match res {
334 Ok(s) => (s.clone(), false),
335 Err(e) => (e.clone(), true),
336 };
337 messages.push(tool_result_msg(call_id, &content_str));
338 user_blocks.push(ContentBlock::ToolResult {
339 tool_use_id: call_id.clone(),
340 content: content_str,
341 is_error,
342 });
343 }
344 yield SdkMessage::User { content: user_blocks };
345
346 num_turns = num_turns.saturating_add(1);
347
348 if let Some(limit) = opts.max_turns {
349 if num_turns >= limit {
350 yield SdkMessage::Result {
351 subtype: ResultSubtype::ErrorMaxTurns,
352 result: None,
353 total_cost_usd: total_cost,
354 usage: usage_info(total_prompt_tokens, total_completion_tokens, total_cache_hit_tokens, total_cache_miss_tokens, any_cache_stats_seen),
355 num_turns,
356 session_id,
357 stop_reason: last_stop_reason,
358 };
359 return;
360 }
361 }
362 if let (Some(budget), Some(cost)) = (opts.max_budget_usd, total_cost) {
363 if cost >= budget {
364 yield SdkMessage::Result {
365 subtype: ResultSubtype::ErrorMaxBudgetUsd,
366 result: None,
367 total_cost_usd: total_cost,
368 usage: usage_info(total_prompt_tokens, total_completion_tokens, total_cache_hit_tokens, total_cache_miss_tokens, any_cache_stats_seen),
369 num_turns,
370 session_id,
371 stop_reason: last_stop_reason,
372 };
373 return;
374 }
375 }
376
377 if let Some(cfg) = opts.compaction.as_ref() {
382 if last_turn_prompt_tokens >= cfg.threshold_prompt_tokens {
383 match compact_history(&http, &api_key, &opts, cfg, &mut messages).await {
384 Ok(outcome) => {
385 if let Some(u) = &outcome.usage {
386 total_prompt_tokens =
387 total_prompt_tokens.saturating_add(u.prompt_tokens);
388 total_completion_tokens = total_completion_tokens
389 .saturating_add(u.completion_tokens);
390 if let Some(h) = u.prompt_cache_hit_tokens {
391 total_cache_hit_tokens =
392 total_cache_hit_tokens.saturating_add(h);
393 any_cache_stats_seen = true;
394 }
395 if let Some(m) = u.prompt_cache_miss_tokens {
396 total_cache_miss_tokens =
397 total_cache_miss_tokens.saturating_add(m);
398 any_cache_stats_seen = true;
399 }
400 if let (Some(running), Some(turn)) = (
401 total_cost.as_mut(),
402 turn_cost_usd(&cfg.compactor_model, u),
403 ) {
404 *running += turn;
405 }
406 }
407 if outcome.rewrote {
408 yield SdkMessage::System {
409 subtype: SystemSubtype::Compact,
410 session_id: session_id.clone(),
411 data: json!({
412 "message_count_after": messages.len(),
413 }),
414 };
415 }
416 }
417 Err(e) => {
418 tracing::warn!(
419 error = %e,
420 "history compaction failed; continuing with full history"
421 );
422 }
423 }
424 }
425 }
426 } else {
427 let text = assistant_msg.content.as_str().to_string();
429 yield SdkMessage::Assistant {
430 content: vec![ContentBlock::Text { text: text.clone() }],
431 stop_reason: last_stop_reason.clone(),
432 };
433 yield SdkMessage::Result {
434 subtype: ResultSubtype::Success,
435 result: Some(text),
436 total_cost_usd: total_cost,
437 usage: usage_info(total_prompt_tokens, total_completion_tokens, total_cache_hit_tokens, total_cache_miss_tokens, any_cache_stats_seen),
438 num_turns,
439 session_id,
440 stop_reason: last_stop_reason,
441 };
442 return;
443 }
444 }
445 }
446}
447
448struct CompactionOutcome {
450 usage: Option<UsageInfo>,
454 rewrote: bool,
457}
458
459fn truncate_for_transcript(s: &str, max: usize) -> String {
462 if s.len() <= max {
463 s.to_string()
464 } else {
465 let mut end = max;
466 while end > 0 && !s.is_char_boundary(end) {
467 end -= 1;
468 }
469 format!("{}…", &s[..end])
470 }
471}
472
473async fn compact_history<H>(
482 http: &H,
483 api_key: &str,
484 opts: &RunOptions,
485 cfg: &CompactionConfig,
486 messages: &mut Vec<ChatMessage>,
487) -> crate::error::Result<CompactionOutcome>
488where
489 H: HttpClient + Send + Sync,
490{
491 let head_end = match messages.first().map(|m| m.role.as_str()) {
493 Some("system") => {
494 if matches!(messages.get(1).map(|m| m.role.as_str()), Some("user")) {
495 2
496 } else {
497 1
498 }
499 }
500 Some("user") => 1,
501 _ => {
502 return Ok(CompactionOutcome {
503 usage: None,
504 rewrote: false,
505 })
506 }
507 };
508
509 let assistant_idxs: Vec<usize> = messages
513 .iter()
514 .enumerate()
515 .filter(|(_, m)| m.role == "assistant")
516 .map(|(i, _)| i)
517 .collect();
518 if (assistant_idxs.len() as u32) <= cfg.keep_recent_turns {
519 return Ok(CompactionOutcome {
520 usage: None,
521 rewrote: false,
522 });
523 }
524 let tail_start = assistant_idxs[assistant_idxs.len() - cfg.keep_recent_turns as usize];
525 if tail_start <= head_end {
526 return Ok(CompactionOutcome {
527 usage: None,
528 rewrote: false,
529 });
530 }
531
532 let mut transcript = String::new();
534 for msg in &messages[head_end..tail_start] {
535 let content_text = msg.content.as_str();
536 match msg.role.as_str() {
537 "assistant" => {
538 if !content_text.trim().is_empty() {
539 transcript.push_str(&format!(
540 "[assistant] {}\n",
541 truncate_for_transcript(content_text.trim(), 400)
542 ));
543 }
544 if let Some(calls) = &msg.tool_calls {
545 for c in calls {
546 transcript.push_str(&format!(
547 " [tool_call name={} args={}]\n",
548 c.function.name,
549 truncate_for_transcript(&c.function.arguments, 400)
550 ));
551 }
552 }
553 }
554 "tool" => {
555 let id = msg.tool_call_id.as_deref().unwrap_or("?");
556 transcript.push_str(&format!(
557 " [tool_result id={}] {}\n",
558 id,
559 truncate_for_transcript(content_text, 500)
560 ));
561 }
562 other => {
563 transcript.push_str(&format!(
564 "[{}] {}\n",
565 other,
566 truncate_for_transcript(content_text, 400)
567 ));
568 }
569 }
570 }
571
572 let system_prompt = "You are a conversation-history compactor. Produce a concise structured summary of the conversation segment provided. Preserve: files read or written (with paths), tool calls made (by name and key arguments), test results, decisions reached, and open questions. Drop: verbose tool output, intermediate reasoning, formatting noise. Output prose only — no markdown headers, no lists longer than 5 items. Stay under the model's max_tokens budget.";
573
574 let request = ChatRequest {
575 model: cfg.compactor_model.clone(),
576 messages: vec![
577 crate::types::system_msg(system_prompt),
578 crate::types::user_msg(&format!(
579 "Conversation segment to summarize:\n\n{transcript}"
580 )),
581 ],
582 tools: None,
583 tool_choice: None,
584 temperature: Some(0.2),
585 max_tokens: Some(cfg.max_summary_tokens),
586 stream: Some(false),
587 reasoning_effort: None,
588 thinking: None,
589 };
590
591 let url = format!("{}/chat/completions", opts.base_url);
592 let resp = http.post_json(&url, api_key, &request).await?;
593 let usage = resp.usage.clone();
594
595 let Some(choice) = resp.choices.into_iter().next() else {
596 return Ok(CompactionOutcome {
597 usage,
598 rewrote: false,
599 });
600 };
601 let summary = choice.message.content.as_str().trim().to_string();
602 if summary.is_empty() {
603 return Ok(CompactionOutcome {
604 usage,
605 rewrote: false,
606 });
607 }
608
609 let replacement = ChatMessage {
610 role: "system".into(),
611 content: ChatContent::Text(format!(
612 "[Compacted summary of earlier conversation]\n\n{summary}"
613 )),
614 reasoning_content: None,
615 tool_calls: None,
616 tool_call_id: None,
617 name: None,
618 };
619 messages.splice(head_end..tail_start, std::iter::once(replacement));
620
621 Ok(CompactionOutcome {
622 usage,
623 rewrote: true,
624 })
625}
626
627fn usage_info(
628 prompt: u32,
629 completion: u32,
630 cache_hit: u32,
631 cache_miss: u32,
632 cache_stats_seen: bool,
633) -> Option<UsageInfo> {
634 if prompt == 0 && completion == 0 {
635 None
636 } else {
637 Some(UsageInfo {
638 prompt_tokens: prompt,
639 completion_tokens: completion,
640 total_tokens: prompt.saturating_add(completion),
641 prompt_cache_hit_tokens: cache_stats_seen.then_some(cache_hit),
642 prompt_cache_miss_tokens: cache_stats_seen.then_some(cache_miss),
643 })
644 }
645}
646
647#[cfg(test)]
648mod tests {
649 use super::*;
650
651 use std::sync::Mutex;
652
653 use async_trait::async_trait;
654 use futures::StreamExt;
655 use serde_json::json;
656
657 use crate::agent::permissions::PermissionMode;
658 use crate::agent::tool::ToolDefinition;
659 use crate::client::HttpClient;
660 use crate::error::Result as DResult;
661 use crate::types::{
662 ChatContent, ChatMessage, ChatRequest, ChatResponse, Choice, FunctionCall, ToolCall,
663 UsageInfo,
664 };
665
666 #[derive(Clone)]
672 struct MockHttp {
673 queue: Arc<Mutex<Vec<DResult<ChatResponse>>>>,
674 seen_requests: Arc<Mutex<Vec<ChatRequest>>>,
675 }
676
677 impl MockHttp {
678 fn new(queue: Vec<ChatResponse>) -> Self {
679 Self {
680 queue: Arc::new(Mutex::new(queue.into_iter().map(Ok).collect())),
681 seen_requests: Arc::new(Mutex::new(Vec::new())),
682 }
683 }
684
685 fn new_with_results(queue: Vec<DResult<ChatResponse>>) -> Self {
686 Self {
687 queue: Arc::new(Mutex::new(queue)),
688 seen_requests: Arc::new(Mutex::new(Vec::new())),
689 }
690 }
691 }
692
693 #[async_trait]
694 impl HttpClient for MockHttp {
695 async fn post_json(
696 &self,
697 _url: &str,
698 _bearer: &str,
699 body: &ChatRequest,
700 ) -> DResult<ChatResponse> {
701 self.seen_requests.lock().unwrap().push(body.clone());
702 let mut q = self.queue.lock().unwrap();
703 assert!(!q.is_empty(), "MockHttp: queue exhausted");
704 q.remove(0)
705 }
706 }
707
708 fn assistant_text(text: &str) -> ChatResponse {
709 ChatResponse {
710 id: "test".into(),
711 choices: vec![Choice {
712 index: 0,
713 message: ChatMessage {
714 role: "assistant".into(),
715 content: ChatContent::Text(text.into()),
716 reasoning_content: None,
717 tool_calls: None,
718 tool_call_id: None,
719 name: None,
720 },
721 finish_reason: Some("stop".into()),
722 }],
723 usage: Some(UsageInfo {
724 prompt_tokens: 10,
725 completion_tokens: 5,
726 total_tokens: 15,
727 ..Default::default()
728 }),
729 }
730 }
731
732 fn assistant_tool_call(id: &str, name: &str, args: serde_json::Value) -> ChatResponse {
733 ChatResponse {
734 id: "test".into(),
735 choices: vec![Choice {
736 index: 0,
737 message: ChatMessage {
738 role: "assistant".into(),
739 content: ChatContent::Null,
740 reasoning_content: None,
741 tool_calls: Some(vec![ToolCall {
742 id: id.into(),
743 r#type: "function".into(),
744 function: FunctionCall {
745 name: name.into(),
746 arguments: args.to_string(),
747 },
748 }]),
749 tool_call_id: None,
750 name: None,
751 },
752 finish_reason: Some("tool_calls".into()),
753 }],
754 usage: Some(UsageInfo {
755 prompt_tokens: 8,
756 completion_tokens: 4,
757 total_tokens: 12,
758 ..Default::default()
759 }),
760 }
761 }
762
763 struct EchoTool {
765 name: &'static str,
766 read_only: bool,
767 }
768
769 #[async_trait]
770 impl Tool for EchoTool {
771 fn name(&self) -> &str {
772 self.name
773 }
774 fn read_only_hint(&self) -> bool {
775 self.read_only
776 }
777 fn definition(&self) -> ToolDefinition {
778 ToolDefinition {
779 name: self.name.to_string(),
780 description: "echo".into(),
781 parameters: json!({"type":"object"}),
782 }
783 }
784 async fn call_json(&self, args: serde_json::Value) -> std::result::Result<String, String> {
785 Ok(format!("echoed {}", args))
786 }
787 }
788
789 fn tools(items: Vec<(&'static str, bool)>) -> Arc<Vec<Box<dyn Tool>>> {
790 Arc::new(
791 items
792 .into_iter()
793 .map(|(n, ro)| {
794 Box::new(EchoTool {
795 name: n,
796 read_only: ro,
797 }) as Box<dyn Tool>
798 })
799 .collect(),
800 )
801 }
802
803 async fn collect(
804 http: MockHttp,
805 toolset: Arc<Vec<Box<dyn Tool>>>,
806 prompt: &str,
807 opts: RunOptions,
808 ) -> Vec<SdkMessage> {
809 run(http, "test-key".into(), toolset, prompt.into(), opts)
810 .collect()
811 .await
812 }
813
814 #[tokio::test]
815 async fn text_only_emits_assistant_then_success() {
816 let http = MockHttp::new(vec![assistant_text("hello world")]);
817 let msgs = collect(http, tools(vec![]), "hi", RunOptions::default()).await;
818
819 assert!(matches!(msgs[0], SdkMessage::System { .. }));
820 assert!(matches!(&msgs[1], SdkMessage::Assistant { .. }));
821 match &msgs[2] {
822 SdkMessage::Result {
823 subtype,
824 result: Some(t),
825 num_turns,
826 ..
827 } => {
828 assert_eq!(*subtype, ResultSubtype::Success);
829 assert_eq!(t, "hello world");
830 assert_eq!(*num_turns, 0);
831 }
832 other => panic!("expected Result, got {other:?}"),
833 }
834 }
835
836 #[tokio::test]
837 async fn tool_call_then_text_completes_successfully() {
838 let http = MockHttp::new(vec![
839 assistant_tool_call("c1", "echo_ro", json!({"x": 1})),
840 assistant_text("done"),
841 ]);
842 let msgs = collect(
843 http,
844 tools(vec![("echo_ro", true)]),
845 "hi",
846 RunOptions::default().permission_mode(PermissionMode::BypassPermissions),
847 )
848 .await;
849
850 assert_eq!(msgs.len(), 5, "msgs={msgs:?}");
852 match &msgs[1] {
853 SdkMessage::Assistant { content, .. } => {
854 assert!(matches!(content[0], ContentBlock::ToolUse { .. }));
855 }
856 _ => panic!(),
857 }
858 match &msgs[2] {
859 SdkMessage::User { content } => match &content[0] {
860 ContentBlock::ToolResult {
861 tool_use_id,
862 is_error,
863 ..
864 } => {
865 assert_eq!(tool_use_id, "c1");
866 assert!(!is_error);
867 }
868 _ => panic!(),
869 },
870 _ => panic!(),
871 }
872 match &msgs[4] {
873 SdkMessage::Result {
874 subtype, num_turns, ..
875 } => {
876 assert_eq!(*subtype, ResultSubtype::Success);
877 assert_eq!(*num_turns, 1);
878 }
879 _ => panic!(),
880 }
881 }
882
883 #[tokio::test]
884 async fn max_turns_stops_with_error_subtype() {
885 let http = MockHttp::new(vec![
886 assistant_tool_call("c1", "echo_ro", json!({})),
887 assistant_tool_call("c2", "echo_ro", json!({})),
888 ]);
889 let msgs = collect(
890 http,
891 tools(vec![("echo_ro", true)]),
892 "loop",
893 RunOptions::default()
894 .max_turns(1)
895 .permission_mode(PermissionMode::BypassPermissions),
896 )
897 .await;
898 let last = msgs.last().unwrap();
899 match last {
900 SdkMessage::Result {
901 subtype, num_turns, ..
902 } => {
903 assert_eq!(*subtype, ResultSubtype::ErrorMaxTurns);
904 assert_eq!(*num_turns, 1);
905 }
906 _ => panic!("expected Result"),
907 }
908 }
909
910 #[tokio::test]
911 async fn plan_mode_denies_mutating_tool() {
912 let http = MockHttp::new(vec![
915 assistant_tool_call("c1", "echo_mut", json!({})),
916 assistant_text("ok"),
917 ]);
918 let msgs = collect(
919 http,
920 tools(vec![("echo_mut", false)]),
921 "do",
922 RunOptions::default().permission_mode(PermissionMode::Plan),
923 )
924 .await;
925 let denied = msgs
927 .iter()
928 .find_map(|m| match m {
929 SdkMessage::User { content } => Some(content.clone()),
930 _ => None,
931 })
932 .expect("expected a User tool_result message");
933 match &denied[0] {
934 ContentBlock::ToolResult {
935 is_error, content, ..
936 } => {
937 assert!(*is_error);
938 assert!(content.contains("Plan mode"), "msg={content}");
939 }
940 _ => panic!(),
941 }
942 }
943
944 #[tokio::test]
945 async fn legacy_builder_prompt_round_trips_text() {
946 use crate::agent::AgentBuilder;
949 let http = MockHttp::new(vec![assistant_text("hello back")]);
950 let agent = AgentBuilder::new(http, "test-key", "deepseek-chat")
951 .preamble("you are a test")
952 .build();
953 let out = agent.prompt("hi".into()).await.expect("prompt ok");
954 assert_eq!(out, "hello back");
955 }
956
957 #[tokio::test]
958 async fn disallowed_tool_is_hidden_from_request() {
959 let http = MockHttp::new(vec![assistant_text("nothing to do")]);
960 let mock = http.clone();
961 let _ = collect(
962 http,
963 tools(vec![("echo_ro", true), ("echo_mut", false)]),
964 "hi",
965 RunOptions::default().disallowed_tools(["echo_mut"]),
966 )
967 .await;
968 let req = &mock.seen_requests.lock().unwrap()[0];
969 let names: Vec<String> = req
970 .tools
971 .as_ref()
972 .map(|s| s.iter().map(|t| t.function.name.clone()).collect())
973 .unwrap_or_default();
974 assert_eq!(names, vec!["echo_ro".to_string()]);
975 }
976
977 fn assistant_tool_call_with_prompt(
980 id: &str,
981 name: &str,
982 args: serde_json::Value,
983 prompt_tokens: u32,
984 ) -> ChatResponse {
985 let mut r = assistant_tool_call(id, name, args);
986 if let Some(u) = r.usage.as_mut() {
987 u.prompt_tokens = prompt_tokens;
988 u.total_tokens = prompt_tokens.saturating_add(u.completion_tokens);
989 }
990 r
991 }
992
993 fn compaction_cfg() -> CompactionConfig {
994 CompactionConfig {
995 threshold_prompt_tokens: 100,
996 keep_recent_turns: 1,
997 compactor_model: "deepseek-chat".into(),
998 max_summary_tokens: 64,
999 }
1000 }
1001
1002 #[tokio::test]
1003 async fn compaction_triggers_when_prompt_tokens_exceed_threshold() {
1004 let queue = vec![
1009 assistant_tool_call_with_prompt("c1", "echo_ro", json!({}), 200),
1010 assistant_tool_call_with_prompt("c2", "echo_ro", json!({}), 200),
1011 assistant_text("summary of earlier turns"),
1012 assistant_text("done"),
1013 ];
1014 let http = MockHttp::new(queue);
1015 let mock = http.clone();
1016 let msgs = collect(
1017 http,
1018 tools(vec![("echo_ro", true)]),
1019 "hi",
1020 RunOptions::default()
1021 .permission_mode(PermissionMode::BypassPermissions)
1022 .compaction(compaction_cfg()),
1023 )
1024 .await;
1025
1026 let seen = mock.seen_requests.lock().unwrap();
1027 assert_eq!(seen.len(), 4, "expected 2 main + 1 compactor + 1 main");
1028
1029 let compactor_req = &seen[2];
1032 assert_eq!(compactor_req.model, "deepseek-chat");
1033 assert!(compactor_req.tools.is_none());
1034 assert!(compactor_req.thinking.is_none());
1035 assert_eq!(compactor_req.max_tokens, Some(64));
1036
1037 let post_compact_req = &seen[3];
1043 assert_eq!(
1044 post_compact_req.messages.len(),
1045 4,
1046 "post-compaction history should be [user, summary, last_assistant, last_tool_result]"
1047 );
1048 assert_eq!(post_compact_req.messages[1].role, "system");
1049 assert!(post_compact_req.messages[1]
1050 .content
1051 .as_str()
1052 .contains("Compacted summary"));
1053
1054 assert!(
1056 msgs.iter().any(|m| matches!(
1057 m,
1058 SdkMessage::System {
1059 subtype: SystemSubtype::Compact,
1060 ..
1061 }
1062 )),
1063 "expected a SystemSubtype::Compact event in the stream"
1064 );
1065 }
1066
1067 #[tokio::test]
1068 async fn compaction_preserves_tool_call_pairs() {
1069 let queue = vec![
1074 assistant_tool_call_with_prompt("c1", "echo_ro", json!({}), 200),
1075 assistant_tool_call_with_prompt("c2", "echo_ro", json!({}), 200),
1076 assistant_text("summary"),
1077 assistant_text("done"),
1078 ];
1079 let http = MockHttp::new(queue);
1080 let mock = http.clone();
1081 let _ = collect(
1082 http,
1083 tools(vec![("echo_ro", true)]),
1084 "hi",
1085 RunOptions::default()
1086 .permission_mode(PermissionMode::BypassPermissions)
1087 .compaction(compaction_cfg()),
1088 )
1089 .await;
1090
1091 let seen = mock.seen_requests.lock().unwrap();
1092 let post_compact = &seen[3];
1093 let msgs = &post_compact.messages;
1094 for (i, m) in msgs.iter().enumerate() {
1095 if m.role == "assistant" {
1096 if let Some(calls) = &m.tool_calls {
1097 for (offset, call) in calls.iter().enumerate() {
1098 let follower = msgs.get(i + 1 + offset).unwrap_or_else(|| {
1099 panic!("assistant tool_call at idx {i} has no follower")
1100 });
1101 assert_eq!(follower.role, "tool");
1102 assert_eq!(
1103 follower.tool_call_id.as_deref(),
1104 Some(call.id.as_str()),
1105 "tool_result id must match assistant's tool_call id"
1106 );
1107 }
1108 }
1109 }
1110 }
1111 }
1112
1113 #[tokio::test]
1114 async fn compaction_failure_falls_through() {
1115 let queue: Vec<DResult<ChatResponse>> = vec![
1119 Ok(assistant_tool_call_with_prompt(
1120 "c1",
1121 "echo_ro",
1122 json!({}),
1123 200,
1124 )),
1125 Ok(assistant_tool_call_with_prompt(
1126 "c2",
1127 "echo_ro",
1128 json!({}),
1129 200,
1130 )),
1131 Err(crate::error::DeepSeekError::Api {
1132 status: 500,
1133 body: "boom".into(),
1134 }),
1135 Ok(assistant_text("done")),
1136 ];
1137 let http = MockHttp::new_with_results(queue);
1138 let mock = http.clone();
1139 let msgs = collect(
1140 http,
1141 tools(vec![("echo_ro", true)]),
1142 "hi",
1143 RunOptions::default()
1144 .permission_mode(PermissionMode::BypassPermissions)
1145 .compaction(compaction_cfg()),
1146 )
1147 .await;
1148
1149 assert!(
1151 !msgs.iter().any(|m| matches!(
1152 m,
1153 SdkMessage::System {
1154 subtype: SystemSubtype::Compact,
1155 ..
1156 }
1157 )),
1158 "compaction failure must not emit System::Compact"
1159 );
1160
1161 let last = msgs.last().unwrap();
1163 assert!(matches!(
1164 last,
1165 SdkMessage::Result {
1166 subtype: ResultSubtype::Success,
1167 ..
1168 }
1169 ));
1170
1171 let seen = mock.seen_requests.lock().unwrap();
1174 let post_failure = &seen[3];
1175 assert_eq!(
1176 post_failure.messages.len(),
1177 5,
1178 "history must remain un-compacted after a compactor failure"
1179 );
1180 }
1181
1182 #[tokio::test]
1183 async fn compaction_disabled_by_default() {
1184 let queue = vec![
1187 assistant_tool_call_with_prompt("c1", "echo_ro", json!({}), 200),
1188 assistant_tool_call_with_prompt("c2", "echo_ro", json!({}), 200),
1189 assistant_text("done"),
1190 ];
1191 let http = MockHttp::new(queue);
1192 let mock = http.clone();
1193 let msgs = collect(
1194 http,
1195 tools(vec![("echo_ro", true)]),
1196 "hi",
1197 RunOptions::default().permission_mode(PermissionMode::BypassPermissions),
1198 )
1199 .await;
1200
1201 assert_eq!(mock.seen_requests.lock().unwrap().len(), 3);
1203 assert!(!msgs.iter().any(|m| matches!(
1204 m,
1205 SdkMessage::System {
1206 subtype: SystemSubtype::Compact,
1207 ..
1208 }
1209 )));
1210 }
1211}