Skip to main content

deepseek/agent/
loop_runner.rs

1//! Claude-Code-shaped streaming agent loop.
2//!
3//! `run(...)` returns an async stream of [`SdkMessage`]. The loop:
4//!
5//! 1. Yields `System{Init}` carrying the session id.
6//! 2. POSTs to the configured chat-completions endpoint.
7//! 3. On `finish_reason == "tool_calls"`: yields one `Assistant` message with
8//!    text + tool_use blocks, runs each tool through the permission gate
9//!    (read-only tools in parallel, mutating tools sequentially), yields one
10//!    `User` message containing all tool_result blocks, and continues.
11//! 4. On any other finish reason: yields the final `Assistant` text and a
12//!    `Result{Success}` carrying usage, cost, and turn count.
13//! 5. Enforces `max_turns` and `max_budget_usd`; transport errors yield
14//!    `Result{ErrorDuringExecution}`.
15
16use std::sync::Arc;
17
18use async_stream::stream;
19use futures::future::join_all;
20use futures::stream::Stream;
21use serde_json::{json, Value};
22
23use crate::client::HttpClient;
24use crate::types::{
25    tool_result_msg, ChatContent, ChatMessage, ChatRequest, FunctionSchema, ToolSchema, UsageInfo,
26};
27
28use super::messages::{ContentBlock, ResultSubtype, SdkMessage, SystemSubtype};
29use super::options::RunOptions;
30use super::permissions::{PermissionDecision, PermissionMode};
31use super::pricing::{map_stop_reason, turn_cost_usd};
32use super::tool::Tool;
33
34/// Run the agent loop and stream `SdkMessage`s in turn order.
35///
36/// `tools` is wrapped in an `Arc` so callers can reuse the same registry
37/// across multiple runs.
38pub fn run<H>(
39    http: H,
40    api_key: String,
41    tools: Arc<Vec<Box<dyn Tool>>>,
42    user_prompt: String,
43    opts: RunOptions,
44) -> impl Stream<Item = SdkMessage>
45where
46    H: HttpClient + Send + Sync + 'static,
47{
48    stream! {
49        // Resolve session and emit init.
50        let session_id = opts
51            .session_id
52            .clone()
53            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
54        yield SdkMessage::System {
55            subtype: SystemSubtype::Init,
56            session_id: session_id.clone(),
57            data: json!({
58                "model": opts.model,
59                "permission_mode": opts.permission_mode,
60                "max_turns": opts.max_turns,
61                "max_budget_usd": opts.max_budget_usd,
62            }),
63        };
64
65        // Hide disallowed/non-allowed tools from the model entirely.
66        let visible_tools: Vec<&Box<dyn Tool>> = tools
67            .iter()
68            .filter(|t| {
69                let n = t.name();
70                if opts.disallowed_tools.iter().any(|d| d == n) {
71                    return false;
72                }
73                if let Some(allow) = &opts.allowed_tools {
74                    return allow.iter().any(|a| a == n);
75                }
76                true
77            })
78            .collect();
79
80        let tool_schemas: Vec<ToolSchema> = visible_tools
81            .iter()
82            .map(|t| {
83                let def = t.definition();
84                ToolSchema {
85                    r#type: "function".into(),
86                    function: FunctionSchema {
87                        name: def.name,
88                        description: def.description,
89                        parameters: def.parameters,
90                    },
91                }
92            })
93            .collect();
94
95        // Conversation history.
96        let mut messages: Vec<ChatMessage> = Vec::new();
97        if !opts.system_prompt.is_empty() {
98            messages.push(ChatMessage {
99                role: "system".into(),
100                content: ChatContent::Text(opts.system_prompt.clone()),
101                reasoning_content: None,
102                tool_calls: None,
103                tool_call_id: None,
104                name: None,
105            });
106        }
107        messages.push(ChatMessage {
108            role: "user".into(),
109            content: ChatContent::Text(user_prompt),
110            reasoning_content: None,
111            tool_calls: None,
112            tool_call_id: None,
113            name: None,
114        });
115
116        let url = format!("{}/chat/completions", opts.base_url);
117        let mut num_turns: u32 = 0;
118        let mut total_prompt_tokens: u32 = 0;
119        let mut total_completion_tokens: u32 = 0;
120        let mut total_cost: Option<f64> = turn_cost_usd(&opts.model, 0, 0).map(|_| 0.0);
121        let mut last_stop_reason: Option<String> = None;
122
123        loop {
124            let request = ChatRequest {
125                model: opts.model.clone(),
126                messages: messages.clone(),
127                tools: if tool_schemas.is_empty() { None } else { Some(tool_schemas.clone()) },
128                tool_choice: if tool_schemas.is_empty() {
129                    None
130                } else {
131                    Some(json!("auto"))
132                },
133                temperature: Some(opts.effort.temperature()),
134                max_tokens: Some(opts.effort.max_tokens()),
135                stream: Some(false),
136                reasoning_effort: Some(match opts.effort {
137                    crate::types::EffortLevel::Max => "max".into(),
138                    crate::types::EffortLevel::High => "high".into(),
139                    crate::types::EffortLevel::Medium => "medium".into(),
140                    crate::types::EffortLevel::Low => "low".into(),
141                }),
142                thinking: Some(json!({"type": "enabled"})),
143            };
144
145            let resp = match http.post_json(&url, &api_key, &request).await {
146                Ok(r) => r,
147                Err(e) => {
148                    tracing::warn!(error = %e, "agent loop transport error");
149                    yield SdkMessage::Result {
150                        subtype: ResultSubtype::ErrorDuringExecution,
151                        result: None,
152                        total_cost_usd: total_cost,
153                        usage: usage_info(total_prompt_tokens, total_completion_tokens),
154                        num_turns,
155                        session_id,
156                        stop_reason: last_stop_reason,
157                    };
158                    return;
159                }
160            };
161
162            // Accumulate usage / cost from this turn.
163            if let Some(u) = &resp.usage {
164                total_prompt_tokens = total_prompt_tokens.saturating_add(u.prompt_tokens);
165                total_completion_tokens = total_completion_tokens.saturating_add(u.completion_tokens);
166                if let (Some(running), Some(turn)) = (
167                    total_cost.as_mut(),
168                    turn_cost_usd(&opts.model, u.prompt_tokens, u.completion_tokens),
169                ) {
170                    *running += turn;
171                }
172            }
173
174            let Some(choice) = resp.choices.into_iter().next() else {
175                yield SdkMessage::Result {
176                    subtype: ResultSubtype::ErrorDuringExecution,
177                    result: None,
178                    total_cost_usd: total_cost,
179                    usage: usage_info(total_prompt_tokens, total_completion_tokens),
180                    num_turns,
181                    session_id,
182                    stop_reason: last_stop_reason,
183                };
184                return;
185            };
186
187            let finish_reason = choice.finish_reason.as_deref().unwrap_or("stop");
188            last_stop_reason = map_stop_reason(finish_reason);
189            let assistant_msg = choice.message;
190
191            if finish_reason == "tool_calls" {
192                let tool_calls = assistant_msg.tool_calls.clone().unwrap_or_default();
193
194                // Build the Assistant SdkMessage (text + tool_use blocks).
195                let mut content_blocks: Vec<ContentBlock> = Vec::new();
196                let text = assistant_msg.content.as_str();
197                if !text.is_empty() {
198                    content_blocks.push(ContentBlock::Text { text: text.to_string() });
199                }
200                let parsed_calls: Vec<(String, String, Value)> = tool_calls
201                    .iter()
202                    .map(|c| {
203                        let args: Value =
204                            serde_json::from_str(&c.function.arguments).unwrap_or(json!({}));
205                        (c.id.clone(), c.function.name.clone(), args)
206                    })
207                    .collect();
208                for (id, name, input) in &parsed_calls {
209                    content_blocks.push(ContentBlock::ToolUse {
210                        id: id.clone(),
211                        name: name.clone(),
212                        input: input.clone(),
213                    });
214                }
215                yield SdkMessage::Assistant {
216                    content: content_blocks,
217                    stop_reason: last_stop_reason.clone(),
218                };
219
220                // Persist assistant turn in history (carries tool_calls).
221                messages.push(assistant_msg);
222
223                // Permission gate.
224                let mut decisions: Vec<(String, String, Value, PermissionDecision, bool)> =
225                    Vec::with_capacity(parsed_calls.len());
226                for (id, name, args) in parsed_calls {
227                    let tool_ref = visible_tools.iter().find(|t| t.name() == name);
228                    let read_only = tool_ref.map(|t| t.read_only_hint()).unwrap_or(false);
229
230                    let mode_decision = opts.permission_mode.evaluate(&name, read_only);
231                    let final_decision = match (mode_decision, &opts.pre_tool_hook) {
232                        (PermissionDecision::Allow, _) => PermissionDecision::Allow,
233                        (PermissionDecision::Deny(r), _) => PermissionDecision::Deny(r),
234                        (PermissionDecision::Ask, Some(hook)) => {
235                            match hook.check(&name, &args).await {
236                                PermissionDecision::Ask => PermissionDecision::Deny(format!(
237                                    "Tool `{name}` requires approval and the hook returned Ask"
238                                )),
239                                d => d,
240                            }
241                        }
242                        (PermissionDecision::Ask, None) => {
243                            if matches!(opts.permission_mode, PermissionMode::BypassPermissions) {
244                                PermissionDecision::Allow
245                            } else {
246                                PermissionDecision::Deny(format!(
247                                    "Tool `{name}` not pre-approved and no permission hook configured"
248                                ))
249                            }
250                        }
251                    };
252
253                    decisions.push((id, name, args, final_decision, read_only));
254                }
255
256                // Partition allowed calls into read-only (parallel) and mutating (sequential).
257                let mut tool_results: Vec<(String, Result<String, String>)> = Vec::new();
258                let mut parallel_idxs: Vec<usize> = Vec::new();
259                let mut sequential_idxs: Vec<usize> = Vec::new();
260                for (i, (_, _, _, d, ro)) in decisions.iter().enumerate() {
261                    if matches!(d, PermissionDecision::Allow) {
262                        if *ro {
263                            parallel_idxs.push(i);
264                        } else {
265                            sequential_idxs.push(i);
266                        }
267                    }
268                }
269
270                // Run parallel set.
271                if !parallel_idxs.is_empty() {
272                    let futs = parallel_idxs.iter().map(|&i| {
273                        let (id, name, args, _, _) = &decisions[i];
274                        let id = id.clone();
275                        let name = name.clone();
276                        let args = args.clone();
277                        let tools = Arc::clone(&tools);
278                        async move {
279                            let res = match tools.iter().find(|t| t.name() == name) {
280                                Some(t) => t.call_json(args).await,
281                                None => Err(format!("Unknown tool: {name}")),
282                            };
283                            (id, res)
284                        }
285                    });
286                    let outs = join_all(futs).await;
287                    for (id, res) in outs {
288                        tool_results.push((id, res));
289                    }
290                }
291
292                // Run sequential set.
293                for i in sequential_idxs {
294                    let (id, name, args, _, _) = &decisions[i];
295                    let res = match tools.iter().find(|t| t.name() == *name) {
296                        Some(t) => t.call_json(args.clone()).await,
297                        None => Err(format!("Unknown tool: {name}")),
298                    };
299                    tool_results.push((id.clone(), res));
300                }
301
302                // Append denials as synthetic error tool_results.
303                for (id, _name, _args, d, _) in &decisions {
304                    if let PermissionDecision::Deny(reason) = d {
305                        tool_results.push((id.clone(), Err(reason.clone())));
306                    }
307                }
308
309                // Re-order results to match the original tool-call order so the
310                // model sees them in the same sequence it requested.
311                let id_order: Vec<String> = decisions.iter().map(|d| d.0.clone()).collect();
312                tool_results.sort_by_key(|(id, _)| {
313                    id_order.iter().position(|x| x == id).unwrap_or(usize::MAX)
314                });
315
316                // Append tool_result messages to history; build user SdkMessage.
317                let mut user_blocks: Vec<ContentBlock> = Vec::with_capacity(tool_results.len());
318                for (call_id, res) in &tool_results {
319                    let (content_str, is_error) = match res {
320                        Ok(s) => (s.clone(), false),
321                        Err(e) => (e.clone(), true),
322                    };
323                    messages.push(tool_result_msg(call_id, &content_str));
324                    user_blocks.push(ContentBlock::ToolResult {
325                        tool_use_id: call_id.clone(),
326                        content: content_str,
327                        is_error,
328                    });
329                }
330                yield SdkMessage::User { content: user_blocks };
331
332                num_turns = num_turns.saturating_add(1);
333
334                if let Some(limit) = opts.max_turns {
335                    if num_turns >= limit {
336                        yield SdkMessage::Result {
337                            subtype: ResultSubtype::ErrorMaxTurns,
338                            result: None,
339                            total_cost_usd: total_cost,
340                            usage: usage_info(total_prompt_tokens, total_completion_tokens),
341                            num_turns,
342                            session_id,
343                            stop_reason: last_stop_reason,
344                        };
345                        return;
346                    }
347                }
348                if let (Some(budget), Some(cost)) = (opts.max_budget_usd, total_cost) {
349                    if cost >= budget {
350                        yield SdkMessage::Result {
351                            subtype: ResultSubtype::ErrorMaxBudgetUsd,
352                            result: None,
353                            total_cost_usd: total_cost,
354                            usage: usage_info(total_prompt_tokens, total_completion_tokens),
355                            num_turns,
356                            session_id,
357                            stop_reason: last_stop_reason,
358                        };
359                        return;
360                    }
361                }
362            } else {
363                // Final assistant turn — text only.
364                let text = assistant_msg.content.as_str().to_string();
365                yield SdkMessage::Assistant {
366                    content: vec![ContentBlock::Text { text: text.clone() }],
367                    stop_reason: last_stop_reason.clone(),
368                };
369                yield SdkMessage::Result {
370                    subtype: ResultSubtype::Success,
371                    result: Some(text),
372                    total_cost_usd: total_cost,
373                    usage: usage_info(total_prompt_tokens, total_completion_tokens),
374                    num_turns,
375                    session_id,
376                    stop_reason: last_stop_reason,
377                };
378                return;
379            }
380        }
381    }
382}
383
384fn usage_info(prompt: u32, completion: u32) -> Option<UsageInfo> {
385    if prompt == 0 && completion == 0 {
386        None
387    } else {
388        Some(UsageInfo {
389            prompt_tokens: prompt,
390            completion_tokens: completion,
391            total_tokens: prompt.saturating_add(completion),
392        })
393    }
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399
400    use std::sync::Mutex;
401
402    use async_trait::async_trait;
403    use futures::StreamExt;
404    use serde_json::json;
405
406    use crate::agent::permissions::PermissionMode;
407    use crate::agent::tool::ToolDefinition;
408    use crate::client::HttpClient;
409    use crate::error::Result as DResult;
410    use crate::types::{
411        ChatContent, ChatMessage, ChatRequest, ChatResponse, Choice, FunctionCall, ToolCall,
412        UsageInfo,
413    };
414
415    /// Returns a queued sequence of [`ChatResponse`] values, panicking if the
416    /// loop calls the API more times than expected. `seen_requests` is shared
417    /// across clones so tests can inspect it after the loop has consumed the
418    /// mock.
419    #[derive(Clone)]
420    struct MockHttp {
421        queue: Arc<Mutex<Vec<ChatResponse>>>,
422        seen_requests: Arc<Mutex<Vec<ChatRequest>>>,
423    }
424
425    impl MockHttp {
426        fn new(queue: Vec<ChatResponse>) -> Self {
427            Self {
428                queue: Arc::new(Mutex::new(queue)),
429                seen_requests: Arc::new(Mutex::new(Vec::new())),
430            }
431        }
432    }
433
434    #[async_trait]
435    impl HttpClient for MockHttp {
436        async fn post_json(
437            &self,
438            _url: &str,
439            _bearer: &str,
440            body: &ChatRequest,
441        ) -> DResult<ChatResponse> {
442            self.seen_requests.lock().unwrap().push(body.clone());
443            let mut q = self.queue.lock().unwrap();
444            assert!(!q.is_empty(), "MockHttp: queue exhausted");
445            Ok(q.remove(0))
446        }
447    }
448
449    fn assistant_text(text: &str) -> ChatResponse {
450        ChatResponse {
451            id: "test".into(),
452            choices: vec![Choice {
453                index: 0,
454                message: ChatMessage {
455                    role: "assistant".into(),
456                    content: ChatContent::Text(text.into()),
457                    reasoning_content: None,
458                    tool_calls: None,
459                    tool_call_id: None,
460                    name: None,
461                },
462                finish_reason: Some("stop".into()),
463            }],
464            usage: Some(UsageInfo {
465                prompt_tokens: 10,
466                completion_tokens: 5,
467                total_tokens: 15,
468            }),
469        }
470    }
471
472    fn assistant_tool_call(id: &str, name: &str, args: serde_json::Value) -> ChatResponse {
473        ChatResponse {
474            id: "test".into(),
475            choices: vec![Choice {
476                index: 0,
477                message: ChatMessage {
478                    role: "assistant".into(),
479                    content: ChatContent::Null,
480                    reasoning_content: None,
481                    tool_calls: Some(vec![ToolCall {
482                        id: id.into(),
483                        r#type: "function".into(),
484                        function: FunctionCall {
485                            name: name.into(),
486                            arguments: args.to_string(),
487                        },
488                    }]),
489                    tool_call_id: None,
490                    name: None,
491                },
492                finish_reason: Some("tool_calls".into()),
493            }],
494            usage: Some(UsageInfo {
495                prompt_tokens: 8,
496                completion_tokens: 4,
497                total_tokens: 12,
498            }),
499        }
500    }
501
502    /// Minimal tool used by the loop tests — just echoes its args back.
503    struct EchoTool {
504        name: &'static str,
505        read_only: bool,
506    }
507
508    #[async_trait]
509    impl Tool for EchoTool {
510        fn name(&self) -> &str {
511            self.name
512        }
513        fn read_only_hint(&self) -> bool {
514            self.read_only
515        }
516        fn definition(&self) -> ToolDefinition {
517            ToolDefinition {
518                name: self.name.to_string(),
519                description: "echo".into(),
520                parameters: json!({"type":"object"}),
521            }
522        }
523        async fn call_json(&self, args: serde_json::Value) -> std::result::Result<String, String> {
524            Ok(format!("echoed {}", args))
525        }
526    }
527
528    fn tools(items: Vec<(&'static str, bool)>) -> Arc<Vec<Box<dyn Tool>>> {
529        Arc::new(
530            items
531                .into_iter()
532                .map(|(n, ro)| {
533                    Box::new(EchoTool {
534                        name: n,
535                        read_only: ro,
536                    }) as Box<dyn Tool>
537                })
538                .collect(),
539        )
540    }
541
542    async fn collect(
543        http: MockHttp,
544        toolset: Arc<Vec<Box<dyn Tool>>>,
545        prompt: &str,
546        opts: RunOptions,
547    ) -> Vec<SdkMessage> {
548        run(http, "test-key".into(), toolset, prompt.into(), opts)
549            .collect()
550            .await
551    }
552
553    #[tokio::test]
554    async fn text_only_emits_assistant_then_success() {
555        let http = MockHttp::new(vec![assistant_text("hello world")]);
556        let msgs = collect(http, tools(vec![]), "hi", RunOptions::default()).await;
557
558        assert!(matches!(msgs[0], SdkMessage::System { .. }));
559        assert!(matches!(&msgs[1], SdkMessage::Assistant { .. }));
560        match &msgs[2] {
561            SdkMessage::Result {
562                subtype,
563                result: Some(t),
564                num_turns,
565                ..
566            } => {
567                assert_eq!(*subtype, ResultSubtype::Success);
568                assert_eq!(t, "hello world");
569                assert_eq!(*num_turns, 0);
570            }
571            other => panic!("expected Result, got {other:?}"),
572        }
573    }
574
575    #[tokio::test]
576    async fn tool_call_then_text_completes_successfully() {
577        let http = MockHttp::new(vec![
578            assistant_tool_call("c1", "echo_ro", json!({"x": 1})),
579            assistant_text("done"),
580        ]);
581        let msgs = collect(
582            http,
583            tools(vec![("echo_ro", true)]),
584            "hi",
585            RunOptions::default().permission_mode(PermissionMode::BypassPermissions),
586        )
587        .await;
588
589        // System, Assistant(tool_use), User(tool_result), Assistant(text), Result.
590        assert_eq!(msgs.len(), 5, "msgs={msgs:?}");
591        match &msgs[1] {
592            SdkMessage::Assistant { content, .. } => {
593                assert!(matches!(content[0], ContentBlock::ToolUse { .. }));
594            }
595            _ => panic!(),
596        }
597        match &msgs[2] {
598            SdkMessage::User { content } => match &content[0] {
599                ContentBlock::ToolResult {
600                    tool_use_id,
601                    is_error,
602                    ..
603                } => {
604                    assert_eq!(tool_use_id, "c1");
605                    assert!(!is_error);
606                }
607                _ => panic!(),
608            },
609            _ => panic!(),
610        }
611        match &msgs[4] {
612            SdkMessage::Result {
613                subtype, num_turns, ..
614            } => {
615                assert_eq!(*subtype, ResultSubtype::Success);
616                assert_eq!(*num_turns, 1);
617            }
618            _ => panic!(),
619        }
620    }
621
622    #[tokio::test]
623    async fn max_turns_stops_with_error_subtype() {
624        let http = MockHttp::new(vec![
625            assistant_tool_call("c1", "echo_ro", json!({})),
626            assistant_tool_call("c2", "echo_ro", json!({})),
627        ]);
628        let msgs = collect(
629            http,
630            tools(vec![("echo_ro", true)]),
631            "loop",
632            RunOptions::default()
633                .max_turns(1)
634                .permission_mode(PermissionMode::BypassPermissions),
635        )
636        .await;
637        let last = msgs.last().unwrap();
638        match last {
639            SdkMessage::Result {
640                subtype, num_turns, ..
641            } => {
642                assert_eq!(*subtype, ResultSubtype::ErrorMaxTurns);
643                assert_eq!(*num_turns, 1);
644            }
645            _ => panic!("expected Result"),
646        }
647    }
648
649    #[tokio::test]
650    async fn plan_mode_denies_mutating_tool() {
651        // Loop sees a single tool call, plan-mode denies it, then the final
652        // assistant turn says "ok".
653        let http = MockHttp::new(vec![
654            assistant_tool_call("c1", "echo_mut", json!({})),
655            assistant_text("ok"),
656        ]);
657        let msgs = collect(
658            http,
659            tools(vec![("echo_mut", false)]),
660            "do",
661            RunOptions::default().permission_mode(PermissionMode::Plan),
662        )
663        .await;
664        // Find the User(tool_result) message and assert is_error=true.
665        let denied = msgs
666            .iter()
667            .find_map(|m| match m {
668                SdkMessage::User { content } => Some(content.clone()),
669                _ => None,
670            })
671            .expect("expected a User tool_result message");
672        match &denied[0] {
673            ContentBlock::ToolResult {
674                is_error, content, ..
675            } => {
676                assert!(*is_error);
677                assert!(content.contains("Plan mode"), "msg={content}");
678            }
679            _ => panic!(),
680        }
681    }
682
683    #[tokio::test]
684    async fn legacy_builder_prompt_round_trips_text() {
685        // Validates the back-compat `AgentBuilder` → `DeepSeekAgent::prompt`
686        // surface that `crates/research` depends on.
687        use crate::agent::AgentBuilder;
688        let http = MockHttp::new(vec![assistant_text("hello back")]);
689        let agent = AgentBuilder::new(http, "test-key", "deepseek-chat")
690            .preamble("you are a test")
691            .build();
692        let out = agent.prompt("hi".into()).await.expect("prompt ok");
693        assert_eq!(out, "hello back");
694    }
695
696    #[tokio::test]
697    async fn disallowed_tool_is_hidden_from_request() {
698        let http = MockHttp::new(vec![assistant_text("nothing to do")]);
699        let mock = http.clone();
700        let _ = collect(
701            http,
702            tools(vec![("echo_ro", true), ("echo_mut", false)]),
703            "hi",
704            RunOptions::default().disallowed_tools(["echo_mut"]),
705        )
706        .await;
707        let req = &mock.seen_requests.lock().unwrap()[0];
708        let names: Vec<String> = req
709            .tools
710            .as_ref()
711            .map(|s| s.iter().map(|t| t.function.name.clone()).collect())
712            .unwrap_or_default();
713        assert_eq!(names, vec!["echo_ro".to_string()]);
714    }
715}