Skip to main content

deepseek/agent/
loop_runner.rs

1//! Claude-Code-shaped streaming agent loop.
2//!
3//! `run(...)` returns an async stream of [`SdkMessage`]. The loop:
4//!
5//! 1. Yields `System{Init}` carrying the session id.
6//! 2. POSTs to the configured chat-completions endpoint.
7//! 3. On `finish_reason == "tool_calls"`: yields one `Assistant` message with
8//!    text + tool_use blocks, runs each tool through the permission gate
9//!    (read-only tools in parallel, mutating tools sequentially), yields one
10//!    `User` message containing all tool_result blocks, and continues.
11//! 4. On any other finish reason: yields the final `Assistant` text and a
12//!    `Result{Success}` carrying usage, cost, and turn count.
13//! 5. Enforces `max_turns` and `max_budget_usd`; transport errors yield
14//!    `Result{ErrorDuringExecution}`.
15
16use std::sync::Arc;
17
18use async_stream::stream;
19use futures::future::join_all;
20use futures::stream::Stream;
21use serde_json::{json, Value};
22
23use crate::client::HttpClient;
24use crate::types::{
25    tool_result_msg, ChatContent, ChatMessage, ChatRequest, FunctionSchema, ToolSchema, UsageInfo,
26};
27
28use super::messages::{ContentBlock, ResultSubtype, SdkMessage, SystemSubtype};
29use super::options::{CompactionConfig, RunOptions};
30use super::permissions::{PermissionDecision, PermissionMode};
31use super::pricing::{map_stop_reason, turn_cost_usd};
32use super::tool::Tool;
33
34/// Run the agent loop and stream `SdkMessage`s in turn order.
35///
36/// `tools` is wrapped in an `Arc` so callers can reuse the same registry
37/// across multiple runs.
38pub fn run<H>(
39    http: H,
40    api_key: String,
41    tools: Arc<Vec<Box<dyn Tool>>>,
42    user_prompt: String,
43    opts: RunOptions,
44) -> impl Stream<Item = SdkMessage>
45where
46    H: HttpClient + Send + Sync + 'static,
47{
48    stream! {
49        // Resolve session and emit init.
50        let session_id = opts
51            .session_id
52            .clone()
53            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
54        yield SdkMessage::System {
55            subtype: SystemSubtype::Init,
56            session_id: session_id.clone(),
57            data: json!({
58                "model": opts.model,
59                "permission_mode": opts.permission_mode,
60                "max_turns": opts.max_turns,
61                "max_budget_usd": opts.max_budget_usd,
62            }),
63        };
64
65        // Hide disallowed/non-allowed tools from the model entirely.
66        let visible_tools: Vec<&Box<dyn Tool>> = tools
67            .iter()
68            .filter(|t| {
69                let n = t.name();
70                if opts.disallowed_tools.iter().any(|d| d == n) {
71                    return false;
72                }
73                if let Some(allow) = &opts.allowed_tools {
74                    return allow.iter().any(|a| a == n);
75                }
76                true
77            })
78            .collect();
79
80        let tool_schemas: Vec<ToolSchema> = visible_tools
81            .iter()
82            .map(|t| {
83                let def = t.definition();
84                ToolSchema {
85                    r#type: "function".into(),
86                    function: FunctionSchema {
87                        name: def.name,
88                        description: def.description,
89                        parameters: def.parameters,
90                    },
91                }
92            })
93            .collect();
94
95        // Conversation history.
96        let mut messages: Vec<ChatMessage> = Vec::new();
97        if !opts.system_prompt.is_empty() {
98            messages.push(ChatMessage {
99                role: "system".into(),
100                content: ChatContent::Text(opts.system_prompt.clone()),
101                reasoning_content: None,
102                tool_calls: None,
103                tool_call_id: None,
104                name: None,
105            });
106        }
107        messages.push(ChatMessage {
108            role: "user".into(),
109            content: ChatContent::Text(user_prompt),
110            reasoning_content: None,
111            tool_calls: None,
112            tool_call_id: None,
113            name: None,
114        });
115
116        let url = format!("{}/chat/completions", opts.base_url);
117        let mut num_turns: u32 = 0;
118        let mut total_prompt_tokens: u32 = 0;
119        let mut total_completion_tokens: u32 = 0;
120        let mut total_cache_hit_tokens: u32 = 0;
121        let mut total_cache_miss_tokens: u32 = 0;
122        let mut any_cache_stats_seen = false;
123        let mut total_cost: Option<f64> =
124            super::pricing::model_pricing(&opts.model).map(|_| 0.0);
125        let mut last_stop_reason: Option<String> = None;
126        let mut last_turn_prompt_tokens: u32 = 0;
127
128        loop {
129            let request = ChatRequest {
130                model: opts.model.clone(),
131                messages: messages.clone(),
132                tools: if tool_schemas.is_empty() { None } else { Some(tool_schemas.clone()) },
133                tool_choice: if tool_schemas.is_empty() {
134                    None
135                } else {
136                    Some(json!("auto"))
137                },
138                temperature: Some(opts.effort.temperature()),
139                max_tokens: Some(opts.effort.max_tokens()),
140                stream: Some(false),
141                reasoning_effort: Some(match opts.effort {
142                    crate::types::EffortLevel::Max => "max".into(),
143                    crate::types::EffortLevel::High => "high".into(),
144                    crate::types::EffortLevel::Medium => "medium".into(),
145                    crate::types::EffortLevel::Low => "low".into(),
146                }),
147                thinking: Some(json!({"type": "enabled"})),
148            };
149
150            let resp = match http.post_json(&url, &api_key, &request).await {
151                Ok(r) => r,
152                Err(e) => {
153                    tracing::warn!(error = %e, "agent loop transport error");
154                    yield SdkMessage::Result {
155                        subtype: ResultSubtype::ErrorDuringExecution,
156                        result: None,
157                        total_cost_usd: total_cost,
158                        usage: usage_info(total_prompt_tokens, total_completion_tokens, total_cache_hit_tokens, total_cache_miss_tokens, any_cache_stats_seen),
159                        num_turns,
160                        session_id,
161                        stop_reason: last_stop_reason,
162                    };
163                    return;
164                }
165            };
166
167            // Accumulate usage / cost from this turn.
168            if let Some(u) = &resp.usage {
169                last_turn_prompt_tokens = u.prompt_tokens;
170                total_prompt_tokens = total_prompt_tokens.saturating_add(u.prompt_tokens);
171                total_completion_tokens = total_completion_tokens.saturating_add(u.completion_tokens);
172                if let Some(h) = u.prompt_cache_hit_tokens {
173                    total_cache_hit_tokens = total_cache_hit_tokens.saturating_add(h);
174                    any_cache_stats_seen = true;
175                }
176                if let Some(m) = u.prompt_cache_miss_tokens {
177                    total_cache_miss_tokens = total_cache_miss_tokens.saturating_add(m);
178                    any_cache_stats_seen = true;
179                }
180                if let (Some(running), Some(turn)) = (
181                    total_cost.as_mut(),
182                    turn_cost_usd(&opts.model, u),
183                ) {
184                    *running += turn;
185                }
186            }
187
188            let Some(choice) = resp.choices.into_iter().next() else {
189                yield SdkMessage::Result {
190                    subtype: ResultSubtype::ErrorDuringExecution,
191                    result: None,
192                    total_cost_usd: total_cost,
193                    usage: usage_info(total_prompt_tokens, total_completion_tokens, total_cache_hit_tokens, total_cache_miss_tokens, any_cache_stats_seen),
194                    num_turns,
195                    session_id,
196                    stop_reason: last_stop_reason,
197                };
198                return;
199            };
200
201            let finish_reason = choice.finish_reason.as_deref().unwrap_or("stop");
202            last_stop_reason = map_stop_reason(finish_reason);
203            let assistant_msg = choice.message;
204
205            if finish_reason == "tool_calls" {
206                let tool_calls = assistant_msg.tool_calls.clone().unwrap_or_default();
207
208                // Build the Assistant SdkMessage (text + tool_use blocks).
209                let mut content_blocks: Vec<ContentBlock> = Vec::new();
210                let text = assistant_msg.content.as_str();
211                if !text.is_empty() {
212                    content_blocks.push(ContentBlock::Text { text: text.to_string() });
213                }
214                let parsed_calls: Vec<(String, String, Value)> = tool_calls
215                    .iter()
216                    .map(|c| {
217                        let args: Value =
218                            serde_json::from_str(&c.function.arguments).unwrap_or(json!({}));
219                        (c.id.clone(), c.function.name.clone(), args)
220                    })
221                    .collect();
222                for (id, name, input) in &parsed_calls {
223                    content_blocks.push(ContentBlock::ToolUse {
224                        id: id.clone(),
225                        name: name.clone(),
226                        input: input.clone(),
227                    });
228                }
229                yield SdkMessage::Assistant {
230                    content: content_blocks,
231                    stop_reason: last_stop_reason.clone(),
232                };
233
234                // Persist assistant turn in history (carries tool_calls).
235                messages.push(assistant_msg);
236
237                // Permission gate.
238                let mut decisions: Vec<(String, String, Value, PermissionDecision, bool)> =
239                    Vec::with_capacity(parsed_calls.len());
240                for (id, name, args) in parsed_calls {
241                    let tool_ref = visible_tools.iter().find(|t| t.name() == name);
242                    let read_only = tool_ref.map(|t| t.read_only_hint()).unwrap_or(false);
243
244                    let mode_decision = opts.permission_mode.evaluate(&name, read_only);
245                    let final_decision = match (mode_decision, &opts.pre_tool_hook) {
246                        (PermissionDecision::Allow, _) => PermissionDecision::Allow,
247                        (PermissionDecision::Deny(r), _) => PermissionDecision::Deny(r),
248                        (PermissionDecision::Ask, Some(hook)) => {
249                            match hook.check(&name, &args).await {
250                                PermissionDecision::Ask => PermissionDecision::Deny(format!(
251                                    "Tool `{name}` requires approval and the hook returned Ask"
252                                )),
253                                d => d,
254                            }
255                        }
256                        (PermissionDecision::Ask, None) => {
257                            if matches!(opts.permission_mode, PermissionMode::BypassPermissions) {
258                                PermissionDecision::Allow
259                            } else {
260                                PermissionDecision::Deny(format!(
261                                    "Tool `{name}` not pre-approved and no permission hook configured"
262                                ))
263                            }
264                        }
265                    };
266
267                    decisions.push((id, name, args, final_decision, read_only));
268                }
269
270                // Partition allowed calls into read-only (parallel) and mutating (sequential).
271                let mut tool_results: Vec<(String, Result<String, String>)> = Vec::new();
272                let mut parallel_idxs: Vec<usize> = Vec::new();
273                let mut sequential_idxs: Vec<usize> = Vec::new();
274                for (i, (_, _, _, d, ro)) in decisions.iter().enumerate() {
275                    if matches!(d, PermissionDecision::Allow) {
276                        if *ro {
277                            parallel_idxs.push(i);
278                        } else {
279                            sequential_idxs.push(i);
280                        }
281                    }
282                }
283
284                // Run parallel set.
285                if !parallel_idxs.is_empty() {
286                    let futs = parallel_idxs.iter().map(|&i| {
287                        let (id, name, args, _, _) = &decisions[i];
288                        let id = id.clone();
289                        let name = name.clone();
290                        let args = args.clone();
291                        let tools = Arc::clone(&tools);
292                        async move {
293                            let res = match tools.iter().find(|t| t.name() == name) {
294                                Some(t) => t.call_json(args).await,
295                                None => Err(format!("Unknown tool: {name}")),
296                            };
297                            (id, res)
298                        }
299                    });
300                    let outs = join_all(futs).await;
301                    for (id, res) in outs {
302                        tool_results.push((id, res));
303                    }
304                }
305
306                // Run sequential set.
307                for i in sequential_idxs {
308                    let (id, name, args, _, _) = &decisions[i];
309                    let res = match tools.iter().find(|t| t.name() == *name) {
310                        Some(t) => t.call_json(args.clone()).await,
311                        None => Err(format!("Unknown tool: {name}")),
312                    };
313                    tool_results.push((id.clone(), res));
314                }
315
316                // Append denials as synthetic error tool_results.
317                for (id, _name, _args, d, _) in &decisions {
318                    if let PermissionDecision::Deny(reason) = d {
319                        tool_results.push((id.clone(), Err(reason.clone())));
320                    }
321                }
322
323                // Re-order results to match the original tool-call order so the
324                // model sees them in the same sequence it requested.
325                let id_order: Vec<String> = decisions.iter().map(|d| d.0.clone()).collect();
326                tool_results.sort_by_key(|(id, _)| {
327                    id_order.iter().position(|x| x == id).unwrap_or(usize::MAX)
328                });
329
330                // Append tool_result messages to history; build user SdkMessage.
331                let mut user_blocks: Vec<ContentBlock> = Vec::with_capacity(tool_results.len());
332                for (call_id, res) in &tool_results {
333                    let (content_str, is_error) = match res {
334                        Ok(s) => (s.clone(), false),
335                        Err(e) => (e.clone(), true),
336                    };
337                    messages.push(tool_result_msg(call_id, &content_str));
338                    user_blocks.push(ContentBlock::ToolResult {
339                        tool_use_id: call_id.clone(),
340                        content: content_str,
341                        is_error,
342                    });
343                }
344                yield SdkMessage::User { content: user_blocks };
345
346                num_turns = num_turns.saturating_add(1);
347
348                if let Some(limit) = opts.max_turns {
349                    if num_turns >= limit {
350                        yield SdkMessage::Result {
351                            subtype: ResultSubtype::ErrorMaxTurns,
352                            result: None,
353                            total_cost_usd: total_cost,
354                            usage: usage_info(total_prompt_tokens, total_completion_tokens, total_cache_hit_tokens, total_cache_miss_tokens, any_cache_stats_seen),
355                            num_turns,
356                            session_id,
357                            stop_reason: last_stop_reason,
358                        };
359                        return;
360                    }
361                }
362                if let (Some(budget), Some(cost)) = (opts.max_budget_usd, total_cost) {
363                    if cost >= budget {
364                        yield SdkMessage::Result {
365                            subtype: ResultSubtype::ErrorMaxBudgetUsd,
366                            result: None,
367                            total_cost_usd: total_cost,
368                            usage: usage_info(total_prompt_tokens, total_completion_tokens, total_cache_hit_tokens, total_cache_miss_tokens, any_cache_stats_seen),
369                            num_turns,
370                            session_id,
371                            stop_reason: last_stop_reason,
372                        };
373                        return;
374                    }
375                }
376
377                // Optional history compaction. Triggered only when the
378                // previous turn's prompt_tokens crossed the configured
379                // threshold; failure is non-fatal and falls through to a
380                // full-history retry on the next iteration.
381                if let Some(cfg) = opts.compaction.as_ref() {
382                    if last_turn_prompt_tokens >= cfg.threshold_prompt_tokens {
383                        match compact_history(&http, &api_key, &opts, cfg, &mut messages).await {
384                            Ok(outcome) => {
385                                if let Some(u) = &outcome.usage {
386                                    total_prompt_tokens =
387                                        total_prompt_tokens.saturating_add(u.prompt_tokens);
388                                    total_completion_tokens = total_completion_tokens
389                                        .saturating_add(u.completion_tokens);
390                                    if let Some(h) = u.prompt_cache_hit_tokens {
391                                        total_cache_hit_tokens =
392                                            total_cache_hit_tokens.saturating_add(h);
393                                        any_cache_stats_seen = true;
394                                    }
395                                    if let Some(m) = u.prompt_cache_miss_tokens {
396                                        total_cache_miss_tokens =
397                                            total_cache_miss_tokens.saturating_add(m);
398                                        any_cache_stats_seen = true;
399                                    }
400                                    if let (Some(running), Some(turn)) = (
401                                        total_cost.as_mut(),
402                                        turn_cost_usd(&cfg.compactor_model, u),
403                                    ) {
404                                        *running += turn;
405                                    }
406                                }
407                                if outcome.rewrote {
408                                    yield SdkMessage::System {
409                                        subtype: SystemSubtype::Compact,
410                                        session_id: session_id.clone(),
411                                        data: json!({
412                                            "message_count_after": messages.len(),
413                                        }),
414                                    };
415                                }
416                            }
417                            Err(e) => {
418                                tracing::warn!(
419                                    error = %e,
420                                    "history compaction failed; continuing with full history"
421                                );
422                            }
423                        }
424                    }
425                }
426            } else {
427                // Final assistant turn — text only.
428                let text = assistant_msg.content.as_str().to_string();
429                yield SdkMessage::Assistant {
430                    content: vec![ContentBlock::Text { text: text.clone() }],
431                    stop_reason: last_stop_reason.clone(),
432                };
433                yield SdkMessage::Result {
434                    subtype: ResultSubtype::Success,
435                    result: Some(text),
436                    total_cost_usd: total_cost,
437                    usage: usage_info(total_prompt_tokens, total_completion_tokens, total_cache_hit_tokens, total_cache_miss_tokens, any_cache_stats_seen),
438                    num_turns,
439                    session_id,
440                    stop_reason: last_stop_reason,
441                };
442                return;
443            }
444        }
445    }
446}
447
448/// Result of a [`compact_history`] call.
449struct CompactionOutcome {
450    /// Usage reported by the compactor API call. `None` if the helper
451    /// short-circuited before making the call (e.g. not enough history to
452    /// compact).
453    usage: Option<UsageInfo>,
454    /// True iff `messages` was actually rewritten. False when the helper
455    /// short-circuited or the model returned an empty summary.
456    rewrote: bool,
457}
458
459/// Truncate `s` to at most `max` bytes on a UTF-8 char boundary, appending
460/// an ellipsis when truncation occurred.
461fn truncate_for_transcript(s: &str, max: usize) -> String {
462    if s.len() <= max {
463        s.to_string()
464    } else {
465        let mut end = max;
466        while end > 0 && !s.is_char_boundary(end) {
467            end -= 1;
468        }
469        format!("{}…", &s[..end])
470    }
471}
472
473/// Compact the middle of `messages` into a synthetic summary system message,
474/// preserving the system prompt, the initial user message, and the most
475/// recent `cfg.keep_recent_turns` complete turns.
476///
477/// Returns `Ok(CompactionOutcome { rewrote: false, .. })` when there isn't
478/// enough history to compact or when the compactor returned an empty
479/// summary. Transport errors propagate as `Err` and are treated as
480/// non-fatal by the agent loop.
481async fn compact_history<H>(
482    http: &H,
483    api_key: &str,
484    opts: &RunOptions,
485    cfg: &CompactionConfig,
486    messages: &mut Vec<ChatMessage>,
487) -> crate::error::Result<CompactionOutcome>
488where
489    H: HttpClient + Send + Sync,
490{
491    // head_end: past system (if present) + initial user message.
492    let head_end = match messages.first().map(|m| m.role.as_str()) {
493        Some("system") => {
494            if matches!(messages.get(1).map(|m| m.role.as_str()), Some("user")) {
495                2
496            } else {
497                1
498            }
499        }
500        Some("user") => 1,
501        _ => {
502            return Ok(CompactionOutcome {
503                usage: None,
504                rewrote: false,
505            })
506        }
507    };
508
509    // tail_start: index of the (keep_recent_turns)-th-from-end assistant
510    // message. A "turn" begins at an assistant message; tool messages that
511    // follow it belong to the same turn and are kept atomically.
512    let assistant_idxs: Vec<usize> = messages
513        .iter()
514        .enumerate()
515        .filter(|(_, m)| m.role == "assistant")
516        .map(|(i, _)| i)
517        .collect();
518    if (assistant_idxs.len() as u32) <= cfg.keep_recent_turns {
519        return Ok(CompactionOutcome {
520            usage: None,
521            rewrote: false,
522        });
523    }
524    let tail_start = assistant_idxs[assistant_idxs.len() - cfg.keep_recent_turns as usize];
525    if tail_start <= head_end {
526        return Ok(CompactionOutcome {
527            usage: None,
528            rewrote: false,
529        });
530    }
531
532    // Serialize the middle slice into a compact transcript.
533    let mut transcript = String::new();
534    for msg in &messages[head_end..tail_start] {
535        let content_text = msg.content.as_str();
536        match msg.role.as_str() {
537            "assistant" => {
538                if !content_text.trim().is_empty() {
539                    transcript.push_str(&format!(
540                        "[assistant] {}\n",
541                        truncate_for_transcript(content_text.trim(), 400)
542                    ));
543                }
544                if let Some(calls) = &msg.tool_calls {
545                    for c in calls {
546                        transcript.push_str(&format!(
547                            "  [tool_call name={} args={}]\n",
548                            c.function.name,
549                            truncate_for_transcript(&c.function.arguments, 400)
550                        ));
551                    }
552                }
553            }
554            "tool" => {
555                let id = msg.tool_call_id.as_deref().unwrap_or("?");
556                transcript.push_str(&format!(
557                    "  [tool_result id={}] {}\n",
558                    id,
559                    truncate_for_transcript(content_text, 500)
560                ));
561            }
562            other => {
563                transcript.push_str(&format!(
564                    "[{}] {}\n",
565                    other,
566                    truncate_for_transcript(content_text, 400)
567                ));
568            }
569        }
570    }
571
572    let system_prompt = "You are a conversation-history compactor. Produce a concise structured summary of the conversation segment provided. Preserve: files read or written (with paths), tool calls made (by name and key arguments), test results, decisions reached, and open questions. Drop: verbose tool output, intermediate reasoning, formatting noise. Output prose only — no markdown headers, no lists longer than 5 items. Stay under the model's max_tokens budget.";
573
574    let request = ChatRequest {
575        model: cfg.compactor_model.clone(),
576        messages: vec![
577            crate::types::system_msg(system_prompt),
578            crate::types::user_msg(&format!(
579                "Conversation segment to summarize:\n\n{transcript}"
580            )),
581        ],
582        tools: None,
583        tool_choice: None,
584        temperature: Some(0.2),
585        max_tokens: Some(cfg.max_summary_tokens),
586        stream: Some(false),
587        reasoning_effort: None,
588        thinking: None,
589    };
590
591    let url = format!("{}/chat/completions", opts.base_url);
592    let resp = http.post_json(&url, api_key, &request).await?;
593    let usage = resp.usage.clone();
594
595    let Some(choice) = resp.choices.into_iter().next() else {
596        return Ok(CompactionOutcome {
597            usage,
598            rewrote: false,
599        });
600    };
601    let summary = choice.message.content.as_str().trim().to_string();
602    if summary.is_empty() {
603        return Ok(CompactionOutcome {
604            usage,
605            rewrote: false,
606        });
607    }
608
609    let replacement = ChatMessage {
610        role: "system".into(),
611        content: ChatContent::Text(format!(
612            "[Compacted summary of earlier conversation]\n\n{summary}"
613        )),
614        reasoning_content: None,
615        tool_calls: None,
616        tool_call_id: None,
617        name: None,
618    };
619    messages.splice(head_end..tail_start, std::iter::once(replacement));
620
621    Ok(CompactionOutcome {
622        usage,
623        rewrote: true,
624    })
625}
626
627fn usage_info(
628    prompt: u32,
629    completion: u32,
630    cache_hit: u32,
631    cache_miss: u32,
632    cache_stats_seen: bool,
633) -> Option<UsageInfo> {
634    if prompt == 0 && completion == 0 {
635        None
636    } else {
637        Some(UsageInfo {
638            prompt_tokens: prompt,
639            completion_tokens: completion,
640            total_tokens: prompt.saturating_add(completion),
641            prompt_cache_hit_tokens: cache_stats_seen.then_some(cache_hit),
642            prompt_cache_miss_tokens: cache_stats_seen.then_some(cache_miss),
643        })
644    }
645}
646
647#[cfg(test)]
648mod tests {
649    use super::*;
650
651    use std::sync::Mutex;
652
653    use async_trait::async_trait;
654    use futures::StreamExt;
655    use serde_json::json;
656
657    use crate::agent::permissions::PermissionMode;
658    use crate::agent::tool::ToolDefinition;
659    use crate::client::HttpClient;
660    use crate::error::Result as DResult;
661    use crate::types::{
662        ChatContent, ChatMessage, ChatRequest, ChatResponse, Choice, FunctionCall, ToolCall,
663        UsageInfo,
664    };
665
666    /// Returns a queued sequence of [`ChatResponse`] values, panicking if the
667    /// loop calls the API more times than expected. `seen_requests` is shared
668    /// across clones so tests can inspect it after the loop has consumed the
669    /// mock. The queue can also contain `Err` to drive transport-failure
670    /// paths.
671    #[derive(Clone)]
672    struct MockHttp {
673        queue: Arc<Mutex<Vec<DResult<ChatResponse>>>>,
674        seen_requests: Arc<Mutex<Vec<ChatRequest>>>,
675    }
676
677    impl MockHttp {
678        fn new(queue: Vec<ChatResponse>) -> Self {
679            Self {
680                queue: Arc::new(Mutex::new(queue.into_iter().map(Ok).collect())),
681                seen_requests: Arc::new(Mutex::new(Vec::new())),
682            }
683        }
684
685        fn new_with_results(queue: Vec<DResult<ChatResponse>>) -> Self {
686            Self {
687                queue: Arc::new(Mutex::new(queue)),
688                seen_requests: Arc::new(Mutex::new(Vec::new())),
689            }
690        }
691    }
692
693    #[async_trait]
694    impl HttpClient for MockHttp {
695        async fn post_json(
696            &self,
697            _url: &str,
698            _bearer: &str,
699            body: &ChatRequest,
700        ) -> DResult<ChatResponse> {
701            self.seen_requests.lock().unwrap().push(body.clone());
702            let mut q = self.queue.lock().unwrap();
703            assert!(!q.is_empty(), "MockHttp: queue exhausted");
704            q.remove(0)
705        }
706    }
707
708    fn assistant_text(text: &str) -> ChatResponse {
709        ChatResponse {
710            id: "test".into(),
711            choices: vec![Choice {
712                index: 0,
713                message: ChatMessage {
714                    role: "assistant".into(),
715                    content: ChatContent::Text(text.into()),
716                    reasoning_content: None,
717                    tool_calls: None,
718                    tool_call_id: None,
719                    name: None,
720                },
721                finish_reason: Some("stop".into()),
722            }],
723            usage: Some(UsageInfo {
724                prompt_tokens: 10,
725                completion_tokens: 5,
726                total_tokens: 15,
727                ..Default::default()
728            }),
729        }
730    }
731
732    fn assistant_tool_call(id: &str, name: &str, args: serde_json::Value) -> ChatResponse {
733        ChatResponse {
734            id: "test".into(),
735            choices: vec![Choice {
736                index: 0,
737                message: ChatMessage {
738                    role: "assistant".into(),
739                    content: ChatContent::Null,
740                    reasoning_content: None,
741                    tool_calls: Some(vec![ToolCall {
742                        id: id.into(),
743                        r#type: "function".into(),
744                        function: FunctionCall {
745                            name: name.into(),
746                            arguments: args.to_string(),
747                        },
748                    }]),
749                    tool_call_id: None,
750                    name: None,
751                },
752                finish_reason: Some("tool_calls".into()),
753            }],
754            usage: Some(UsageInfo {
755                prompt_tokens: 8,
756                completion_tokens: 4,
757                total_tokens: 12,
758                ..Default::default()
759            }),
760        }
761    }
762
763    /// Minimal tool used by the loop tests — just echoes its args back.
764    struct EchoTool {
765        name: &'static str,
766        read_only: bool,
767    }
768
769    #[async_trait]
770    impl Tool for EchoTool {
771        fn name(&self) -> &str {
772            self.name
773        }
774        fn read_only_hint(&self) -> bool {
775            self.read_only
776        }
777        fn definition(&self) -> ToolDefinition {
778            ToolDefinition {
779                name: self.name.to_string(),
780                description: "echo".into(),
781                parameters: json!({"type":"object"}),
782            }
783        }
784        async fn call_json(&self, args: serde_json::Value) -> std::result::Result<String, String> {
785            Ok(format!("echoed {}", args))
786        }
787    }
788
789    fn tools(items: Vec<(&'static str, bool)>) -> Arc<Vec<Box<dyn Tool>>> {
790        Arc::new(
791            items
792                .into_iter()
793                .map(|(n, ro)| {
794                    Box::new(EchoTool {
795                        name: n,
796                        read_only: ro,
797                    }) as Box<dyn Tool>
798                })
799                .collect(),
800        )
801    }
802
803    async fn collect(
804        http: MockHttp,
805        toolset: Arc<Vec<Box<dyn Tool>>>,
806        prompt: &str,
807        opts: RunOptions,
808    ) -> Vec<SdkMessage> {
809        run(http, "test-key".into(), toolset, prompt.into(), opts)
810            .collect()
811            .await
812    }
813
814    #[tokio::test]
815    async fn text_only_emits_assistant_then_success() {
816        let http = MockHttp::new(vec![assistant_text("hello world")]);
817        let msgs = collect(http, tools(vec![]), "hi", RunOptions::default()).await;
818
819        assert!(matches!(msgs[0], SdkMessage::System { .. }));
820        assert!(matches!(&msgs[1], SdkMessage::Assistant { .. }));
821        match &msgs[2] {
822            SdkMessage::Result {
823                subtype,
824                result: Some(t),
825                num_turns,
826                ..
827            } => {
828                assert_eq!(*subtype, ResultSubtype::Success);
829                assert_eq!(t, "hello world");
830                assert_eq!(*num_turns, 0);
831            }
832            other => panic!("expected Result, got {other:?}"),
833        }
834    }
835
836    #[tokio::test]
837    async fn tool_call_then_text_completes_successfully() {
838        let http = MockHttp::new(vec![
839            assistant_tool_call("c1", "echo_ro", json!({"x": 1})),
840            assistant_text("done"),
841        ]);
842        let msgs = collect(
843            http,
844            tools(vec![("echo_ro", true)]),
845            "hi",
846            RunOptions::default().permission_mode(PermissionMode::BypassPermissions),
847        )
848        .await;
849
850        // System, Assistant(tool_use), User(tool_result), Assistant(text), Result.
851        assert_eq!(msgs.len(), 5, "msgs={msgs:?}");
852        match &msgs[1] {
853            SdkMessage::Assistant { content, .. } => {
854                assert!(matches!(content[0], ContentBlock::ToolUse { .. }));
855            }
856            _ => panic!(),
857        }
858        match &msgs[2] {
859            SdkMessage::User { content } => match &content[0] {
860                ContentBlock::ToolResult {
861                    tool_use_id,
862                    is_error,
863                    ..
864                } => {
865                    assert_eq!(tool_use_id, "c1");
866                    assert!(!is_error);
867                }
868                _ => panic!(),
869            },
870            _ => panic!(),
871        }
872        match &msgs[4] {
873            SdkMessage::Result {
874                subtype, num_turns, ..
875            } => {
876                assert_eq!(*subtype, ResultSubtype::Success);
877                assert_eq!(*num_turns, 1);
878            }
879            _ => panic!(),
880        }
881    }
882
883    #[tokio::test]
884    async fn max_turns_stops_with_error_subtype() {
885        let http = MockHttp::new(vec![
886            assistant_tool_call("c1", "echo_ro", json!({})),
887            assistant_tool_call("c2", "echo_ro", json!({})),
888        ]);
889        let msgs = collect(
890            http,
891            tools(vec![("echo_ro", true)]),
892            "loop",
893            RunOptions::default()
894                .max_turns(1)
895                .permission_mode(PermissionMode::BypassPermissions),
896        )
897        .await;
898        let last = msgs.last().unwrap();
899        match last {
900            SdkMessage::Result {
901                subtype, num_turns, ..
902            } => {
903                assert_eq!(*subtype, ResultSubtype::ErrorMaxTurns);
904                assert_eq!(*num_turns, 1);
905            }
906            _ => panic!("expected Result"),
907        }
908    }
909
910    #[tokio::test]
911    async fn plan_mode_denies_mutating_tool() {
912        // Loop sees a single tool call, plan-mode denies it, then the final
913        // assistant turn says "ok".
914        let http = MockHttp::new(vec![
915            assistant_tool_call("c1", "echo_mut", json!({})),
916            assistant_text("ok"),
917        ]);
918        let msgs = collect(
919            http,
920            tools(vec![("echo_mut", false)]),
921            "do",
922            RunOptions::default().permission_mode(PermissionMode::Plan),
923        )
924        .await;
925        // Find the User(tool_result) message and assert is_error=true.
926        let denied = msgs
927            .iter()
928            .find_map(|m| match m {
929                SdkMessage::User { content } => Some(content.clone()),
930                _ => None,
931            })
932            .expect("expected a User tool_result message");
933        match &denied[0] {
934            ContentBlock::ToolResult {
935                is_error, content, ..
936            } => {
937                assert!(*is_error);
938                assert!(content.contains("Plan mode"), "msg={content}");
939            }
940            _ => panic!(),
941        }
942    }
943
944    #[tokio::test]
945    async fn legacy_builder_prompt_round_trips_text() {
946        // Validates the back-compat `AgentBuilder` → `DeepSeekAgent::prompt`
947        // surface that `crates/research` depends on.
948        use crate::agent::AgentBuilder;
949        let http = MockHttp::new(vec![assistant_text("hello back")]);
950        let agent = AgentBuilder::new(http, "test-key", "deepseek-chat")
951            .preamble("you are a test")
952            .build();
953        let out = agent.prompt("hi".into()).await.expect("prompt ok");
954        assert_eq!(out, "hello back");
955    }
956
957    #[tokio::test]
958    async fn disallowed_tool_is_hidden_from_request() {
959        let http = MockHttp::new(vec![assistant_text("nothing to do")]);
960        let mock = http.clone();
961        let _ = collect(
962            http,
963            tools(vec![("echo_ro", true), ("echo_mut", false)]),
964            "hi",
965            RunOptions::default().disallowed_tools(["echo_mut"]),
966        )
967        .await;
968        let req = &mock.seen_requests.lock().unwrap()[0];
969        let names: Vec<String> = req
970            .tools
971            .as_ref()
972            .map(|s| s.iter().map(|t| t.function.name.clone()).collect())
973            .unwrap_or_default();
974        assert_eq!(names, vec!["echo_ro".to_string()]);
975    }
976
977    /// Build a tool_call response with a custom `prompt_tokens` value so a
978    /// test can drive the compaction trigger threshold.
979    fn assistant_tool_call_with_prompt(
980        id: &str,
981        name: &str,
982        args: serde_json::Value,
983        prompt_tokens: u32,
984    ) -> ChatResponse {
985        let mut r = assistant_tool_call(id, name, args);
986        if let Some(u) = r.usage.as_mut() {
987            u.prompt_tokens = prompt_tokens;
988            u.total_tokens = prompt_tokens.saturating_add(u.completion_tokens);
989        }
990        r
991    }
992
993    fn compaction_cfg() -> CompactionConfig {
994        CompactionConfig {
995            threshold_prompt_tokens: 100,
996            keep_recent_turns: 1,
997            compactor_model: "deepseek-chat".into(),
998            max_summary_tokens: 64,
999        }
1000    }
1001
1002    #[tokio::test]
1003    async fn compaction_triggers_when_prompt_tokens_exceed_threshold() {
1004        // Two tool-call turns each report prompt_tokens above the threshold.
1005        // After the second turn, compaction fires (assistant count > 1),
1006        // the compactor mock returns a summary, then the third main turn
1007        // closes the loop with text.
1008        let queue = vec![
1009            assistant_tool_call_with_prompt("c1", "echo_ro", json!({}), 200),
1010            assistant_tool_call_with_prompt("c2", "echo_ro", json!({}), 200),
1011            assistant_text("summary of earlier turns"),
1012            assistant_text("done"),
1013        ];
1014        let http = MockHttp::new(queue);
1015        let mock = http.clone();
1016        let msgs = collect(
1017            http,
1018            tools(vec![("echo_ro", true)]),
1019            "hi",
1020            RunOptions::default()
1021                .permission_mode(PermissionMode::BypassPermissions)
1022                .compaction(compaction_cfg()),
1023        )
1024        .await;
1025
1026        let seen = mock.seen_requests.lock().unwrap();
1027        assert_eq!(seen.len(), 4, "expected 2 main + 1 compactor + 1 main");
1028
1029        // Third request is the compactor call — different model, no tools,
1030        // no thinking, low max_tokens.
1031        let compactor_req = &seen[2];
1032        assert_eq!(compactor_req.model, "deepseek-chat");
1033        assert!(compactor_req.tools.is_none());
1034        assert!(compactor_req.thinking.is_none());
1035        assert_eq!(compactor_req.max_tokens, Some(64));
1036
1037        // Fourth request (post-compaction main turn) should carry fewer
1038        // messages than the un-compacted history would have produced.
1039        // History before compaction after turn 2 was 5 messages
1040        // (user, asst1, tool1, asst2, tool2). After compaction it should
1041        // be 4 (user, summary_system, asst2, tool2).
1042        let post_compact_req = &seen[3];
1043        assert_eq!(
1044            post_compact_req.messages.len(),
1045            4,
1046            "post-compaction history should be [user, summary, last_assistant, last_tool_result]"
1047        );
1048        assert_eq!(post_compact_req.messages[1].role, "system");
1049        assert!(post_compact_req.messages[1]
1050            .content
1051            .as_str()
1052            .contains("Compacted summary"));
1053
1054        // A System{Compact} event was yielded.
1055        assert!(
1056            msgs.iter().any(|m| matches!(
1057                m,
1058                SdkMessage::System {
1059                    subtype: SystemSubtype::Compact,
1060                    ..
1061                }
1062            )),
1063            "expected a SystemSubtype::Compact event in the stream"
1064        );
1065    }
1066
1067    #[tokio::test]
1068    async fn compaction_preserves_tool_call_pairs() {
1069        // Same shape as the trigger test; assert that every assistant
1070        // message with tool_calls in the post-compaction main request is
1071        // immediately followed by tool-role messages whose tool_call_ids
1072        // match the assistant's tool_calls — the API invariant.
1073        let queue = vec![
1074            assistant_tool_call_with_prompt("c1", "echo_ro", json!({}), 200),
1075            assistant_tool_call_with_prompt("c2", "echo_ro", json!({}), 200),
1076            assistant_text("summary"),
1077            assistant_text("done"),
1078        ];
1079        let http = MockHttp::new(queue);
1080        let mock = http.clone();
1081        let _ = collect(
1082            http,
1083            tools(vec![("echo_ro", true)]),
1084            "hi",
1085            RunOptions::default()
1086                .permission_mode(PermissionMode::BypassPermissions)
1087                .compaction(compaction_cfg()),
1088        )
1089        .await;
1090
1091        let seen = mock.seen_requests.lock().unwrap();
1092        let post_compact = &seen[3];
1093        let msgs = &post_compact.messages;
1094        for (i, m) in msgs.iter().enumerate() {
1095            if m.role == "assistant" {
1096                if let Some(calls) = &m.tool_calls {
1097                    for (offset, call) in calls.iter().enumerate() {
1098                        let follower = msgs.get(i + 1 + offset).unwrap_or_else(|| {
1099                            panic!("assistant tool_call at idx {i} has no follower")
1100                        });
1101                        assert_eq!(follower.role, "tool");
1102                        assert_eq!(
1103                            follower.tool_call_id.as_deref(),
1104                            Some(call.id.as_str()),
1105                            "tool_result id must match assistant's tool_call id"
1106                        );
1107                    }
1108                }
1109            }
1110        }
1111    }
1112
1113    #[tokio::test]
1114    async fn compaction_failure_falls_through() {
1115        // Compactor returns a transport error. The main loop must log a
1116        // warning and continue with the un-compacted history; the run
1117        // still terminates successfully.
1118        let queue: Vec<DResult<ChatResponse>> = vec![
1119            Ok(assistant_tool_call_with_prompt(
1120                "c1",
1121                "echo_ro",
1122                json!({}),
1123                200,
1124            )),
1125            Ok(assistant_tool_call_with_prompt(
1126                "c2",
1127                "echo_ro",
1128                json!({}),
1129                200,
1130            )),
1131            Err(crate::error::DeepSeekError::Api {
1132                status: 500,
1133                body: "boom".into(),
1134            }),
1135            Ok(assistant_text("done")),
1136        ];
1137        let http = MockHttp::new_with_results(queue);
1138        let mock = http.clone();
1139        let msgs = collect(
1140            http,
1141            tools(vec![("echo_ro", true)]),
1142            "hi",
1143            RunOptions::default()
1144                .permission_mode(PermissionMode::BypassPermissions)
1145                .compaction(compaction_cfg()),
1146        )
1147        .await;
1148
1149        // No System{Compact} event was emitted.
1150        assert!(
1151            !msgs.iter().any(|m| matches!(
1152                m,
1153                SdkMessage::System {
1154                    subtype: SystemSubtype::Compact,
1155                    ..
1156                }
1157            )),
1158            "compaction failure must not emit System::Compact"
1159        );
1160
1161        // Run still terminated successfully on the un-compacted history.
1162        let last = msgs.last().unwrap();
1163        assert!(matches!(
1164            last,
1165            SdkMessage::Result {
1166                subtype: ResultSubtype::Success,
1167                ..
1168            }
1169        ));
1170
1171        // The post-failure main request retained the full message history
1172        // (no rewrite happened): user + 2 asst + 2 tool = 5 messages.
1173        let seen = mock.seen_requests.lock().unwrap();
1174        let post_failure = &seen[3];
1175        assert_eq!(
1176            post_failure.messages.len(),
1177            5,
1178            "history must remain un-compacted after a compactor failure"
1179        );
1180    }
1181
1182    #[tokio::test]
1183    async fn compaction_disabled_by_default() {
1184        // Without RunOptions::compaction(...), even with high prompt_tokens
1185        // and many turns, no extra compactor request is observed.
1186        let queue = vec![
1187            assistant_tool_call_with_prompt("c1", "echo_ro", json!({}), 200),
1188            assistant_tool_call_with_prompt("c2", "echo_ro", json!({}), 200),
1189            assistant_text("done"),
1190        ];
1191        let http = MockHttp::new(queue);
1192        let mock = http.clone();
1193        let msgs = collect(
1194            http,
1195            tools(vec![("echo_ro", true)]),
1196            "hi",
1197            RunOptions::default().permission_mode(PermissionMode::BypassPermissions),
1198        )
1199        .await;
1200
1201        // Exactly 3 requests — no compactor call sneaked in.
1202        assert_eq!(mock.seen_requests.lock().unwrap().len(), 3);
1203        assert!(!msgs.iter().any(|m| matches!(
1204            m,
1205            SdkMessage::System {
1206                subtype: SystemSubtype::Compact,
1207                ..
1208            }
1209        )));
1210    }
1211}