deepseek/agent/
builder.rs1use std::sync::Arc;
6
7use futures::StreamExt;
8
9use crate::client::HttpClient;
10use crate::error::{DeepSeekError, Result};
11
12use super::loop_runner::run;
13use super::messages::{ResultSubtype, SdkMessage};
14use super::options::RunOptions;
15use super::tool::Tool;
16
17pub struct AgentBuilder<H: HttpClient> {
19 http: H,
20 api_key: String,
21 model: String,
22 preamble: String,
23 tools: Vec<Box<dyn Tool>>,
24 base_url: String,
25 worker_id: String,
26 options: RunOptions,
27}
28
29impl<H: HttpClient> AgentBuilder<H> {
30 pub fn new(http: H, api_key: impl Into<String>, model: impl Into<String>) -> Self {
31 let model = model.into();
32 let mut options = RunOptions::new(model.clone());
33 options.base_url = "https://api.deepseek.com/v1".into();
34 Self {
35 http,
36 api_key: api_key.into(),
37 model,
38 preamble: String::new(),
39 tools: Vec::new(),
40 base_url: "https://api.deepseek.com".into(),
41 worker_id: String::new(),
42 options,
43 }
44 }
45
46 pub fn preamble(mut self, p: &str) -> Self {
47 self.preamble = p.into();
48 self
49 }
50
51 pub fn tool(mut self, t: impl Tool + 'static) -> Self {
52 self.tools.push(Box::new(t));
53 self
54 }
55
56 pub fn base_url(mut self, url: &str) -> Self {
60 let trimmed = url.trim_end_matches('/').to_string();
61 self.base_url = trimmed.clone();
62 self.options.base_url = if trimmed.ends_with("/v1") {
64 trimmed
65 } else {
66 format!("{trimmed}/v1")
67 };
68 self
69 }
70
71 pub fn worker_id(mut self, id: impl Into<String>) -> Self {
72 self.worker_id = id.into();
73 self
74 }
75
76 pub fn options(mut self, opts: RunOptions) -> Self {
80 self.options = opts;
81 self
82 }
83
84 pub fn build(self) -> DeepSeekAgent<H> {
85 let mut options = self.options;
86 options.model = self.model.clone();
87 options.system_prompt = self.preamble;
88 DeepSeekAgent {
89 http: self.http,
90 api_key: self.api_key,
91 tools: Arc::new(self.tools),
92 options,
93 worker_id: self.worker_id,
94 }
95 }
96}
97
98pub struct DeepSeekAgent<H: HttpClient> {
102 http: H,
103 api_key: String,
104 tools: Arc<Vec<Box<dyn Tool>>>,
105 options: RunOptions,
106 #[allow(dead_code)]
107 worker_id: String,
108}
109
110impl<H: HttpClient + Clone + Send + Sync + 'static> DeepSeekAgent<H> {
111 pub fn run(&self, user_prompt: String) -> impl futures::Stream<Item = SdkMessage> + use<H> {
113 run(
114 self.http.clone(),
115 self.api_key.clone(),
116 Arc::clone(&self.tools),
117 user_prompt,
118 self.options.clone_for_run(),
119 )
120 }
121
122 pub async fn prompt(&self, user_prompt: String) -> Result<String> {
125 let mut stream = Box::pin(self.run(user_prompt));
126 let mut last_text: Option<String> = None;
127 while let Some(msg) = stream.next().await {
128 match msg {
129 SdkMessage::Result {
130 subtype: ResultSubtype::Success,
131 result: Some(text),
132 ..
133 } => return Ok(text),
134 SdkMessage::Result { subtype, .. } => {
135 return Err(DeepSeekError::Other(format!(
136 "agent stopped with subtype {subtype:?}"
137 )));
138 }
139 SdkMessage::Assistant { content, .. } => {
140 if let Some(t) = content.iter().find_map(|b| match b {
141 super::messages::ContentBlock::Text { text } => Some(text.clone()),
142 _ => None,
143 }) {
144 last_text = Some(t);
145 }
146 }
147 _ => {}
148 }
149 }
150 last_text.ok_or_else(|| {
151 DeepSeekError::Other("agent stream ended without a Result message".into())
152 })
153 }
154}
155
156impl RunOptions {
157 fn clone_for_run(&self) -> RunOptions {
160 RunOptions {
161 model: self.model.clone(),
162 system_prompt: self.system_prompt.clone(),
163 allowed_tools: self.allowed_tools.clone(),
164 disallowed_tools: self.disallowed_tools.clone(),
165 max_turns: self.max_turns,
166 max_budget_usd: self.max_budget_usd,
167 effort: self.effort.clone(),
168 permission_mode: self.permission_mode,
169 pre_tool_hook: self.pre_tool_hook.clone(),
170 session_id: self.session_id.clone(),
171 base_url: self.base_url.clone(),
172 compaction: self.compaction.clone(),
173 }
174 }
175}