1use std::sync::Arc;
17
18use async_stream::stream;
19use futures::future::join_all;
20use futures::stream::Stream;
21use serde_json::{json, Value};
22
23use crate::client::HttpClient;
24use crate::types::{
25 tool_result_msg, ChatContent, ChatMessage, ChatRequest, FunctionSchema, ToolSchema, UsageInfo,
26};
27
28use super::messages::{ContentBlock, ResultSubtype, SdkMessage, SystemSubtype};
29use super::options::RunOptions;
30use super::permissions::{PermissionDecision, PermissionMode};
31use super::pricing::{map_stop_reason, turn_cost_usd};
32use super::tool::Tool;
33
34pub fn run<H>(
39 http: H,
40 api_key: String,
41 tools: Arc<Vec<Box<dyn Tool>>>,
42 user_prompt: String,
43 opts: RunOptions,
44) -> impl Stream<Item = SdkMessage>
45where
46 H: HttpClient + Send + Sync + 'static,
47{
48 stream! {
49 let session_id = opts
51 .session_id
52 .clone()
53 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
54 yield SdkMessage::System {
55 subtype: SystemSubtype::Init,
56 session_id: session_id.clone(),
57 data: json!({
58 "model": opts.model,
59 "permission_mode": opts.permission_mode,
60 "max_turns": opts.max_turns,
61 "max_budget_usd": opts.max_budget_usd,
62 }),
63 };
64
65 let visible_tools: Vec<&Box<dyn Tool>> = tools
67 .iter()
68 .filter(|t| {
69 let n = t.name();
70 if opts.disallowed_tools.iter().any(|d| d == n) {
71 return false;
72 }
73 if let Some(allow) = &opts.allowed_tools {
74 return allow.iter().any(|a| a == n);
75 }
76 true
77 })
78 .collect();
79
80 let tool_schemas: Vec<ToolSchema> = visible_tools
81 .iter()
82 .map(|t| {
83 let def = t.definition();
84 ToolSchema {
85 r#type: "function".into(),
86 function: FunctionSchema {
87 name: def.name,
88 description: def.description,
89 parameters: def.parameters,
90 },
91 }
92 })
93 .collect();
94
95 let mut messages: Vec<ChatMessage> = Vec::new();
97 if !opts.system_prompt.is_empty() {
98 messages.push(ChatMessage {
99 role: "system".into(),
100 content: ChatContent::Text(opts.system_prompt.clone()),
101 reasoning_content: None,
102 tool_calls: None,
103 tool_call_id: None,
104 name: None,
105 });
106 }
107 messages.push(ChatMessage {
108 role: "user".into(),
109 content: ChatContent::Text(user_prompt),
110 reasoning_content: None,
111 tool_calls: None,
112 tool_call_id: None,
113 name: None,
114 });
115
116 let url = format!("{}/chat/completions", opts.base_url);
117 let mut num_turns: u32 = 0;
118 let mut total_prompt_tokens: u32 = 0;
119 let mut total_completion_tokens: u32 = 0;
120 let mut total_cost: Option<f64> = turn_cost_usd(&opts.model, 0, 0).map(|_| 0.0);
121 let mut last_stop_reason: Option<String> = None;
122
123 loop {
124 let request = ChatRequest {
125 model: opts.model.clone(),
126 messages: messages.clone(),
127 tools: if tool_schemas.is_empty() { None } else { Some(tool_schemas.clone()) },
128 tool_choice: if tool_schemas.is_empty() {
129 None
130 } else {
131 Some(json!("auto"))
132 },
133 temperature: Some(opts.effort.temperature()),
134 max_tokens: Some(opts.effort.max_tokens()),
135 stream: Some(false),
136 reasoning_effort: Some(match opts.effort {
137 crate::types::EffortLevel::Max => "max".into(),
138 crate::types::EffortLevel::High => "high".into(),
139 crate::types::EffortLevel::Medium => "medium".into(),
140 crate::types::EffortLevel::Low => "low".into(),
141 }),
142 thinking: Some(json!({"type": "enabled"})),
143 };
144
145 let resp = match http.post_json(&url, &api_key, &request).await {
146 Ok(r) => r,
147 Err(e) => {
148 tracing::warn!(error = %e, "agent loop transport error");
149 yield SdkMessage::Result {
150 subtype: ResultSubtype::ErrorDuringExecution,
151 result: None,
152 total_cost_usd: total_cost,
153 usage: usage_info(total_prompt_tokens, total_completion_tokens),
154 num_turns,
155 session_id,
156 stop_reason: last_stop_reason,
157 };
158 return;
159 }
160 };
161
162 if let Some(u) = &resp.usage {
164 total_prompt_tokens = total_prompt_tokens.saturating_add(u.prompt_tokens);
165 total_completion_tokens = total_completion_tokens.saturating_add(u.completion_tokens);
166 if let (Some(running), Some(turn)) = (
167 total_cost.as_mut(),
168 turn_cost_usd(&opts.model, u.prompt_tokens, u.completion_tokens),
169 ) {
170 *running += turn;
171 }
172 }
173
174 let Some(choice) = resp.choices.into_iter().next() else {
175 yield SdkMessage::Result {
176 subtype: ResultSubtype::ErrorDuringExecution,
177 result: None,
178 total_cost_usd: total_cost,
179 usage: usage_info(total_prompt_tokens, total_completion_tokens),
180 num_turns,
181 session_id,
182 stop_reason: last_stop_reason,
183 };
184 return;
185 };
186
187 let finish_reason = choice.finish_reason.as_deref().unwrap_or("stop");
188 last_stop_reason = map_stop_reason(finish_reason);
189 let assistant_msg = choice.message;
190
191 if finish_reason == "tool_calls" {
192 let tool_calls = assistant_msg.tool_calls.clone().unwrap_or_default();
193
194 let mut content_blocks: Vec<ContentBlock> = Vec::new();
196 let text = assistant_msg.content.as_str();
197 if !text.is_empty() {
198 content_blocks.push(ContentBlock::Text { text: text.to_string() });
199 }
200 let parsed_calls: Vec<(String, String, Value)> = tool_calls
201 .iter()
202 .map(|c| {
203 let args: Value =
204 serde_json::from_str(&c.function.arguments).unwrap_or(json!({}));
205 (c.id.clone(), c.function.name.clone(), args)
206 })
207 .collect();
208 for (id, name, input) in &parsed_calls {
209 content_blocks.push(ContentBlock::ToolUse {
210 id: id.clone(),
211 name: name.clone(),
212 input: input.clone(),
213 });
214 }
215 yield SdkMessage::Assistant {
216 content: content_blocks,
217 stop_reason: last_stop_reason.clone(),
218 };
219
220 messages.push(assistant_msg);
222
223 let mut decisions: Vec<(String, String, Value, PermissionDecision, bool)> =
225 Vec::with_capacity(parsed_calls.len());
226 for (id, name, args) in parsed_calls {
227 let tool_ref = visible_tools.iter().find(|t| t.name() == name);
228 let read_only = tool_ref.map(|t| t.read_only_hint()).unwrap_or(false);
229
230 let mode_decision = opts.permission_mode.evaluate(&name, read_only);
231 let final_decision = match (mode_decision, &opts.pre_tool_hook) {
232 (PermissionDecision::Allow, _) => PermissionDecision::Allow,
233 (PermissionDecision::Deny(r), _) => PermissionDecision::Deny(r),
234 (PermissionDecision::Ask, Some(hook)) => {
235 match hook.check(&name, &args).await {
236 PermissionDecision::Ask => PermissionDecision::Deny(format!(
237 "Tool `{name}` requires approval and the hook returned Ask"
238 )),
239 d => d,
240 }
241 }
242 (PermissionDecision::Ask, None) => {
243 if matches!(opts.permission_mode, PermissionMode::BypassPermissions) {
244 PermissionDecision::Allow
245 } else {
246 PermissionDecision::Deny(format!(
247 "Tool `{name}` not pre-approved and no permission hook configured"
248 ))
249 }
250 }
251 };
252
253 decisions.push((id, name, args, final_decision, read_only));
254 }
255
256 let mut tool_results: Vec<(String, Result<String, String>)> = Vec::new();
258 let mut parallel_idxs: Vec<usize> = Vec::new();
259 let mut sequential_idxs: Vec<usize> = Vec::new();
260 for (i, (_, _, _, d, ro)) in decisions.iter().enumerate() {
261 if matches!(d, PermissionDecision::Allow) {
262 if *ro {
263 parallel_idxs.push(i);
264 } else {
265 sequential_idxs.push(i);
266 }
267 }
268 }
269
270 if !parallel_idxs.is_empty() {
272 let futs = parallel_idxs.iter().map(|&i| {
273 let (id, name, args, _, _) = &decisions[i];
274 let id = id.clone();
275 let name = name.clone();
276 let args = args.clone();
277 let tools = Arc::clone(&tools);
278 async move {
279 let res = match tools.iter().find(|t| t.name() == name) {
280 Some(t) => t.call_json(args).await,
281 None => Err(format!("Unknown tool: {name}")),
282 };
283 (id, res)
284 }
285 });
286 let outs = join_all(futs).await;
287 for (id, res) in outs {
288 tool_results.push((id, res));
289 }
290 }
291
292 for i in sequential_idxs {
294 let (id, name, args, _, _) = &decisions[i];
295 let res = match tools.iter().find(|t| t.name() == *name) {
296 Some(t) => t.call_json(args.clone()).await,
297 None => Err(format!("Unknown tool: {name}")),
298 };
299 tool_results.push((id.clone(), res));
300 }
301
302 for (id, _name, _args, d, _) in &decisions {
304 if let PermissionDecision::Deny(reason) = d {
305 tool_results.push((id.clone(), Err(reason.clone())));
306 }
307 }
308
309 let id_order: Vec<String> = decisions.iter().map(|d| d.0.clone()).collect();
312 tool_results.sort_by_key(|(id, _)| {
313 id_order.iter().position(|x| x == id).unwrap_or(usize::MAX)
314 });
315
316 let mut user_blocks: Vec<ContentBlock> = Vec::with_capacity(tool_results.len());
318 for (call_id, res) in &tool_results {
319 let (content_str, is_error) = match res {
320 Ok(s) => (s.clone(), false),
321 Err(e) => (e.clone(), true),
322 };
323 messages.push(tool_result_msg(call_id, &content_str));
324 user_blocks.push(ContentBlock::ToolResult {
325 tool_use_id: call_id.clone(),
326 content: content_str,
327 is_error,
328 });
329 }
330 yield SdkMessage::User { content: user_blocks };
331
332 num_turns = num_turns.saturating_add(1);
333
334 if let Some(limit) = opts.max_turns {
335 if num_turns >= limit {
336 yield SdkMessage::Result {
337 subtype: ResultSubtype::ErrorMaxTurns,
338 result: None,
339 total_cost_usd: total_cost,
340 usage: usage_info(total_prompt_tokens, total_completion_tokens),
341 num_turns,
342 session_id,
343 stop_reason: last_stop_reason,
344 };
345 return;
346 }
347 }
348 if let (Some(budget), Some(cost)) = (opts.max_budget_usd, total_cost) {
349 if cost >= budget {
350 yield SdkMessage::Result {
351 subtype: ResultSubtype::ErrorMaxBudgetUsd,
352 result: None,
353 total_cost_usd: total_cost,
354 usage: usage_info(total_prompt_tokens, total_completion_tokens),
355 num_turns,
356 session_id,
357 stop_reason: last_stop_reason,
358 };
359 return;
360 }
361 }
362 } else {
363 let text = assistant_msg.content.as_str().to_string();
365 yield SdkMessage::Assistant {
366 content: vec![ContentBlock::Text { text: text.clone() }],
367 stop_reason: last_stop_reason.clone(),
368 };
369 yield SdkMessage::Result {
370 subtype: ResultSubtype::Success,
371 result: Some(text),
372 total_cost_usd: total_cost,
373 usage: usage_info(total_prompt_tokens, total_completion_tokens),
374 num_turns,
375 session_id,
376 stop_reason: last_stop_reason,
377 };
378 return;
379 }
380 }
381 }
382}
383
384fn usage_info(prompt: u32, completion: u32) -> Option<UsageInfo> {
385 if prompt == 0 && completion == 0 {
386 None
387 } else {
388 Some(UsageInfo {
389 prompt_tokens: prompt,
390 completion_tokens: completion,
391 total_tokens: prompt.saturating_add(completion),
392 })
393 }
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399
400 use std::sync::Mutex;
401
402 use async_trait::async_trait;
403 use futures::StreamExt;
404 use serde_json::json;
405
406 use crate::agent::permissions::PermissionMode;
407 use crate::agent::tool::ToolDefinition;
408 use crate::client::HttpClient;
409 use crate::error::Result as DResult;
410 use crate::types::{
411 ChatContent, ChatMessage, ChatRequest, ChatResponse, Choice, FunctionCall, ToolCall,
412 UsageInfo,
413 };
414
415 #[derive(Clone)]
420 struct MockHttp {
421 queue: Arc<Mutex<Vec<ChatResponse>>>,
422 seen_requests: Arc<Mutex<Vec<ChatRequest>>>,
423 }
424
425 impl MockHttp {
426 fn new(queue: Vec<ChatResponse>) -> Self {
427 Self {
428 queue: Arc::new(Mutex::new(queue)),
429 seen_requests: Arc::new(Mutex::new(Vec::new())),
430 }
431 }
432 }
433
434 #[async_trait]
435 impl HttpClient for MockHttp {
436 async fn post_json(
437 &self,
438 _url: &str,
439 _bearer: &str,
440 body: &ChatRequest,
441 ) -> DResult<ChatResponse> {
442 self.seen_requests.lock().unwrap().push(body.clone());
443 let mut q = self.queue.lock().unwrap();
444 assert!(!q.is_empty(), "MockHttp: queue exhausted");
445 Ok(q.remove(0))
446 }
447 }
448
449 fn assistant_text(text: &str) -> ChatResponse {
450 ChatResponse {
451 id: "test".into(),
452 choices: vec![Choice {
453 index: 0,
454 message: ChatMessage {
455 role: "assistant".into(),
456 content: ChatContent::Text(text.into()),
457 reasoning_content: None,
458 tool_calls: None,
459 tool_call_id: None,
460 name: None,
461 },
462 finish_reason: Some("stop".into()),
463 }],
464 usage: Some(UsageInfo {
465 prompt_tokens: 10,
466 completion_tokens: 5,
467 total_tokens: 15,
468 }),
469 }
470 }
471
472 fn assistant_tool_call(id: &str, name: &str, args: serde_json::Value) -> ChatResponse {
473 ChatResponse {
474 id: "test".into(),
475 choices: vec![Choice {
476 index: 0,
477 message: ChatMessage {
478 role: "assistant".into(),
479 content: ChatContent::Null,
480 reasoning_content: None,
481 tool_calls: Some(vec![ToolCall {
482 id: id.into(),
483 r#type: "function".into(),
484 function: FunctionCall {
485 name: name.into(),
486 arguments: args.to_string(),
487 },
488 }]),
489 tool_call_id: None,
490 name: None,
491 },
492 finish_reason: Some("tool_calls".into()),
493 }],
494 usage: Some(UsageInfo {
495 prompt_tokens: 8,
496 completion_tokens: 4,
497 total_tokens: 12,
498 }),
499 }
500 }
501
502 struct EchoTool {
504 name: &'static str,
505 read_only: bool,
506 }
507
508 #[async_trait]
509 impl Tool for EchoTool {
510 fn name(&self) -> &str {
511 self.name
512 }
513 fn read_only_hint(&self) -> bool {
514 self.read_only
515 }
516 fn definition(&self) -> ToolDefinition {
517 ToolDefinition {
518 name: self.name.to_string(),
519 description: "echo".into(),
520 parameters: json!({"type":"object"}),
521 }
522 }
523 async fn call_json(&self, args: serde_json::Value) -> std::result::Result<String, String> {
524 Ok(format!("echoed {}", args))
525 }
526 }
527
528 fn tools(items: Vec<(&'static str, bool)>) -> Arc<Vec<Box<dyn Tool>>> {
529 Arc::new(
530 items
531 .into_iter()
532 .map(|(n, ro)| {
533 Box::new(EchoTool {
534 name: n,
535 read_only: ro,
536 }) as Box<dyn Tool>
537 })
538 .collect(),
539 )
540 }
541
542 async fn collect(
543 http: MockHttp,
544 toolset: Arc<Vec<Box<dyn Tool>>>,
545 prompt: &str,
546 opts: RunOptions,
547 ) -> Vec<SdkMessage> {
548 run(http, "test-key".into(), toolset, prompt.into(), opts)
549 .collect()
550 .await
551 }
552
553 #[tokio::test]
554 async fn text_only_emits_assistant_then_success() {
555 let http = MockHttp::new(vec![assistant_text("hello world")]);
556 let msgs = collect(http, tools(vec![]), "hi", RunOptions::default()).await;
557
558 assert!(matches!(msgs[0], SdkMessage::System { .. }));
559 assert!(matches!(&msgs[1], SdkMessage::Assistant { .. }));
560 match &msgs[2] {
561 SdkMessage::Result {
562 subtype,
563 result: Some(t),
564 num_turns,
565 ..
566 } => {
567 assert_eq!(*subtype, ResultSubtype::Success);
568 assert_eq!(t, "hello world");
569 assert_eq!(*num_turns, 0);
570 }
571 other => panic!("expected Result, got {other:?}"),
572 }
573 }
574
575 #[tokio::test]
576 async fn tool_call_then_text_completes_successfully() {
577 let http = MockHttp::new(vec![
578 assistant_tool_call("c1", "echo_ro", json!({"x": 1})),
579 assistant_text("done"),
580 ]);
581 let msgs = collect(
582 http,
583 tools(vec![("echo_ro", true)]),
584 "hi",
585 RunOptions::default().permission_mode(PermissionMode::BypassPermissions),
586 )
587 .await;
588
589 assert_eq!(msgs.len(), 5, "msgs={msgs:?}");
591 match &msgs[1] {
592 SdkMessage::Assistant { content, .. } => {
593 assert!(matches!(content[0], ContentBlock::ToolUse { .. }));
594 }
595 _ => panic!(),
596 }
597 match &msgs[2] {
598 SdkMessage::User { content } => match &content[0] {
599 ContentBlock::ToolResult {
600 tool_use_id,
601 is_error,
602 ..
603 } => {
604 assert_eq!(tool_use_id, "c1");
605 assert!(!is_error);
606 }
607 _ => panic!(),
608 },
609 _ => panic!(),
610 }
611 match &msgs[4] {
612 SdkMessage::Result { subtype, num_turns, .. } => {
613 assert_eq!(*subtype, ResultSubtype::Success);
614 assert_eq!(*num_turns, 1);
615 }
616 _ => panic!(),
617 }
618 }
619
620 #[tokio::test]
621 async fn max_turns_stops_with_error_subtype() {
622 let http = MockHttp::new(vec![
623 assistant_tool_call("c1", "echo_ro", json!({})),
624 assistant_tool_call("c2", "echo_ro", json!({})),
625 ]);
626 let msgs = collect(
627 http,
628 tools(vec![("echo_ro", true)]),
629 "loop",
630 RunOptions::default()
631 .max_turns(1)
632 .permission_mode(PermissionMode::BypassPermissions),
633 )
634 .await;
635 let last = msgs.last().unwrap();
636 match last {
637 SdkMessage::Result {
638 subtype, num_turns, ..
639 } => {
640 assert_eq!(*subtype, ResultSubtype::ErrorMaxTurns);
641 assert_eq!(*num_turns, 1);
642 }
643 _ => panic!("expected Result"),
644 }
645 }
646
647 #[tokio::test]
648 async fn plan_mode_denies_mutating_tool() {
649 let http = MockHttp::new(vec![
652 assistant_tool_call("c1", "echo_mut", json!({})),
653 assistant_text("ok"),
654 ]);
655 let msgs = collect(
656 http,
657 tools(vec![("echo_mut", false)]),
658 "do",
659 RunOptions::default().permission_mode(PermissionMode::Plan),
660 )
661 .await;
662 let denied = msgs
664 .iter()
665 .find_map(|m| match m {
666 SdkMessage::User { content } => Some(content.clone()),
667 _ => None,
668 })
669 .expect("expected a User tool_result message");
670 match &denied[0] {
671 ContentBlock::ToolResult { is_error, content, .. } => {
672 assert!(*is_error);
673 assert!(content.contains("Plan mode"), "msg={content}");
674 }
675 _ => panic!(),
676 }
677 }
678
679 #[tokio::test]
680 async fn legacy_builder_prompt_round_trips_text() {
681 use crate::agent::AgentBuilder;
684 let http = MockHttp::new(vec![assistant_text("hello back")]);
685 let agent = AgentBuilder::new(http, "test-key", "deepseek-chat")
686 .preamble("you are a test")
687 .build();
688 let out = agent.prompt("hi".into()).await.expect("prompt ok");
689 assert_eq!(out, "hello back");
690 }
691
692 #[tokio::test]
693 async fn disallowed_tool_is_hidden_from_request() {
694 let http = MockHttp::new(vec![assistant_text("nothing to do")]);
695 let mock = http.clone();
696 let _ = collect(
697 http,
698 tools(vec![("echo_ro", true), ("echo_mut", false)]),
699 "hi",
700 RunOptions::default().disallowed_tools(["echo_mut"]),
701 )
702 .await;
703 let req = &mock.seen_requests.lock().unwrap()[0];
704 let names: Vec<String> = req
705 .tools
706 .as_ref()
707 .map(|s| s.iter().map(|t| t.function.name.clone()).collect())
708 .unwrap_or_default();
709 assert_eq!(names, vec!["echo_ro".to_string()]);
710 }
711}
712