1use std::pin::Pin;
2use std::task::{Context, Poll};
3
4use futures::stream::BoxStream;
5use futures::{Stream, StreamExt};
6use serde_json::Value;
7
8use crate::agent::agent_core::{AgentResponse, DeepseekAgent, ToolCallEvent};
9use crate::conversation::Conversation;
10use crate::error::ApiError;
11use crate::raw::request::message::{FunctionCall, Message, Role, ToolCall, ToolType};
12use crate::raw::ChatCompletionChunk;
13
14struct FetchResult {
17 content: Option<String>,
18 raw_tool_calls: Vec<ToolCall>,
19}
20
21struct ToolsResult {
22 events: Vec<ToolCallEvent>,
23}
24
25struct PartialToolCall {
28 id: String,
29 name: String,
30 arguments: String,
31}
32
33struct StreamingData {
34 stream: BoxStream<'static, Result<ChatCompletionChunk, ApiError>>,
35 agent: DeepseekAgent,
36 content_buf: String,
37 tool_call_bufs: Vec<Option<PartialToolCall>>,
38}
39
40type FetchFuture =
43 Pin<Box<dyn std::future::Future<Output = (Result<FetchResult, ApiError>, DeepseekAgent)> + Send>>;
44
45type ConnectFuture = Pin<
46 Box<
47 dyn std::future::Future<
48 Output = (
49 Result<BoxStream<'static, Result<ChatCompletionChunk, ApiError>>, ApiError>,
50 DeepseekAgent,
51 ),
52 > + Send,
53 >,
54>;
55
56type ExecFuture =
57 Pin<Box<dyn std::future::Future<Output = (ToolsResult, DeepseekAgent)> + Send>>;
58
59pub struct AgentStream {
62 agent: Option<DeepseekAgent>,
63 state: AgentStreamState,
64}
65
66enum AgentStreamState {
67 Idle,
68 FetchingResponse(FetchFuture),
69 ConnectingStream(ConnectFuture),
70 StreamingChunks(Box<StreamingData>),
71 ExecutingTools(ExecFuture),
72 Done,
73}
74
75impl AgentStream {
76 pub fn new(agent: DeepseekAgent) -> Self {
77 Self {
78 agent: Some(agent),
79 state: AgentStreamState::Idle,
80 }
81 }
82
83 pub fn into_agent(self) -> Option<DeepseekAgent> {
84 match self.state {
85 AgentStreamState::StreamingChunks(data) => Some(data.agent),
86 _ => self.agent,
87 }
88 }
89}
90
91impl Stream for AgentStream {
92 type Item = Result<AgentResponse, ApiError>;
93
94 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
95 let this = self.get_mut();
96
97 loop {
98 if matches!(this.state, AgentStreamState::StreamingChunks(_)) {
101 let mut data =
102 match std::mem::replace(&mut this.state, AgentStreamState::Done) {
103 AgentStreamState::StreamingChunks(d) => d,
104 _ => unreachable!(),
105 };
106
107 match data.stream.poll_next_unpin(cx) {
108 Poll::Pending => {
109 this.state = AgentStreamState::StreamingChunks(data);
110 return Poll::Pending;
111 }
112
113 Poll::Ready(Some(Ok(chunk))) => {
114 let mut fragment: Option<String> = None;
115
116 if let Some(choice) = chunk.choices.into_iter().next() {
117 let delta = choice.delta;
118
119 if let Some(dtcs) = delta.tool_calls {
120 for dtc in dtcs {
121 let idx = dtc.index as usize;
122 if data.tool_call_bufs.len() <= idx {
123 data.tool_call_bufs.resize_with(idx + 1, || None);
124 }
125 let entry = &mut data.tool_call_bufs[idx];
126 if entry.is_none() {
127 *entry = Some(PartialToolCall {
128 id: dtc.id.clone().unwrap_or_default(),
129 name: dtc
130 .function
131 .as_ref()
132 .and_then(|f| f.name.clone())
133 .unwrap_or_default(),
134 arguments: String::new(),
135 });
136 }
137 if let Some(partial) = entry.as_mut() {
138 if let Some(id) = dtc.id {
139 if partial.id.is_empty() {
140 partial.id = id;
141 }
142 }
143 if let Some(func) = dtc.function {
144 if let Some(args) = func.arguments {
145 partial.arguments.push_str(&args);
146 }
147 }
148 }
149 }
150 }
151
152 if let Some(content) = delta.content {
153 if !content.is_empty() {
154 data.content_buf.push_str(&content);
155 fragment = Some(content);
156 }
157 }
158 }
159
160 this.state = AgentStreamState::StreamingChunks(data);
161
162 if let Some(content) = fragment {
163 return Poll::Ready(Some(Ok(AgentResponse {
164 content: Some(content),
165 tool_calls: vec![],
166 })));
167 }
168 continue;
169 }
170
171 Poll::Ready(Some(Err(e))) => {
172 this.agent = Some(data.agent);
174 return Poll::Ready(Some(Err(e)));
175 }
176
177 Poll::Ready(None) => {
178 let raw_tool_calls: Vec<ToolCall> = data
179 .tool_call_bufs
180 .into_iter()
181 .flatten()
182 .map(|p| ToolCall {
183 id: p.id,
184 r#type: ToolType::Function,
185 function: FunctionCall {
186 name: p.name,
187 arguments: p.arguments,
188 },
189 })
190 .collect();
191
192 let assistant_msg = Message {
193 role: Role::Assistant,
194 content: if data.content_buf.is_empty() {
195 None
196 } else {
197 Some(data.content_buf)
198 },
199 tool_calls: if raw_tool_calls.is_empty() {
200 None
201 } else {
202 Some(raw_tool_calls.clone())
203 },
204 ..Default::default()
205 };
206 data.agent.conversation.history_mut().push(assistant_msg);
207
208 if raw_tool_calls.is_empty() {
209 this.agent = Some(data.agent);
210 return Poll::Ready(None);
211 }
212
213 let preview_events = build_preview(&raw_tool_calls);
214 let fut = Box::pin(execute_tools(data.agent, raw_tool_calls));
215 this.state = AgentStreamState::ExecutingTools(fut);
216 return Poll::Ready(Some(Ok(AgentResponse {
217 content: None,
218 tool_calls: preview_events,
219 })));
220 }
221 }
222 }
223
224 match &mut this.state {
225 AgentStreamState::Done => return Poll::Ready(None),
226
227 AgentStreamState::Idle => {
228 let agent = this.agent.take().expect("agent missing");
229 if agent.streaming {
230 let fut = Box::pin(connect_stream(agent));
231 this.state = AgentStreamState::ConnectingStream(fut);
232 } else {
233 let fut = Box::pin(fetch_response(agent));
234 this.state = AgentStreamState::FetchingResponse(fut);
235 }
236 }
237
238 AgentStreamState::FetchingResponse(fut) => {
239 match fut.as_mut().poll(cx) {
240 Poll::Pending => return Poll::Pending,
241 Poll::Ready((Err(e), agent)) => {
242 this.agent = Some(agent);
243 this.state = AgentStreamState::Done;
244 return Poll::Ready(Some(Err(e)));
245 }
246 Poll::Ready((Ok(fetch), agent)) => {
247 if fetch.raw_tool_calls.is_empty() {
248 this.agent = Some(agent);
249 this.state = AgentStreamState::Done;
250 return Poll::Ready(Some(Ok(AgentResponse {
251 content: fetch.content,
252 tool_calls: vec![],
253 })));
254 }
255
256 let content = fetch.content.clone();
257 let raw_calls = fetch.raw_tool_calls;
258 let preview_events = build_preview(&raw_calls);
259 let fut = Box::pin(execute_tools(agent, raw_calls));
260 this.state = AgentStreamState::ExecutingTools(fut);
261 return Poll::Ready(Some(Ok(AgentResponse {
262 content,
263 tool_calls: preview_events,
264 })));
265 }
266 }
267 }
268
269 AgentStreamState::ConnectingStream(fut) => {
270 match fut.as_mut().poll(cx) {
271 Poll::Pending => return Poll::Pending,
272 Poll::Ready((Err(e), agent)) => {
273 this.agent = Some(agent);
274 this.state = AgentStreamState::Done;
275 return Poll::Ready(Some(Err(e)));
276 }
277 Poll::Ready((Ok(stream), agent)) => {
278 this.state =
279 AgentStreamState::StreamingChunks(Box::new(StreamingData {
280 stream,
281 agent,
282 content_buf: String::new(),
283 tool_call_bufs: Vec::new(),
284 }));
285 }
286 }
287 }
288
289 AgentStreamState::ExecutingTools(fut) => {
290 match fut.as_mut().poll(cx) {
291 Poll::Pending => return Poll::Pending,
292 Poll::Ready((results, agent)) => {
293 this.agent = Some(agent);
294 this.state = AgentStreamState::Idle;
295 return Poll::Ready(Some(Ok(AgentResponse {
296 content: None,
297 tool_calls: results.events,
298 })));
299 }
300 }
301 }
302
303 AgentStreamState::StreamingChunks(_) => unreachable!(),
304 }
305 }
306 }
307}
308
309fn build_preview(raw_calls: &[ToolCall]) -> Vec<ToolCallEvent> {
312 raw_calls
313 .iter()
314 .map(|tc| ToolCallEvent {
315 id: tc.id.clone(),
316 name: tc.function.name.clone(),
317 args: serde_json::from_str(&tc.function.arguments).unwrap_or(Value::Null),
318 result: Value::Null,
319 })
320 .collect()
321}
322
323fn build_request(agent: &DeepseekAgent) -> crate::api::ApiRequest {
324 let history = agent.conversation.history().clone();
325 let mut req = crate::api::ApiRequest::builder().messages(history);
326 for tool in &agent.tools {
327 for raw in tool.raw_tools() {
328 req = req.add_tool(raw);
329 }
330 }
331 if !agent.tools.is_empty() {
332 req = req.tool_choice_auto();
333 }
334 req
335}
336
337async fn fetch_response(
338 mut agent: DeepseekAgent,
339) -> (Result<FetchResult, ApiError>, DeepseekAgent) {
340 let req = build_request(&agent);
341
342 let resp = match agent.client.send(req).await {
343 Ok(r) => r,
344 Err(e) => return (Err(e), agent),
345 };
346
347 let choice = match resp.choices.into_iter().next() {
348 Some(c) => c,
349 None => return (Err(ApiError::Other("empty response: no choices".into())), agent),
350 };
351
352 let assistant_msg = choice.message;
353 let content = assistant_msg.content.clone();
354 let raw_tool_calls = assistant_msg.tool_calls.clone().unwrap_or_default();
355 agent.conversation.history_mut().push(assistant_msg);
356
357 (Ok(FetchResult { content, raw_tool_calls }), agent)
358}
359
360async fn connect_stream(
361 agent: DeepseekAgent,
362) -> (
363 Result<BoxStream<'static, Result<ChatCompletionChunk, ApiError>>, ApiError>,
364 DeepseekAgent,
365) {
366 let req = build_request(&agent);
367 match agent.client.clone().into_stream(req).await {
368 Ok(stream) => (Ok(stream), agent),
369 Err(e) => (Err(e), agent),
370 }
371}
372
373async fn execute_tools(
374 mut agent: DeepseekAgent,
375 raw_tool_calls: Vec<ToolCall>,
376) -> (ToolsResult, DeepseekAgent) {
377 let mut events = vec![];
378
379 for tc in raw_tool_calls {
380 let args: Value = serde_json::from_str(&tc.function.arguments).unwrap_or(Value::Null);
381
382 let result = match agent.tool_index.get(&tc.function.name) {
383 Some(&idx) => agent.tools[idx].call(&tc.function.name, args.clone()).await,
384 None => {
385 serde_json::json!({ "error": format!("unknown tool: {}", tc.function.name) })
386 }
387 };
388
389 agent.conversation.history_mut().push(Message {
390 role: Role::Tool,
391 content: Some(result.to_string()),
392 tool_call_id: Some(tc.id.clone()),
393 ..Default::default()
394 });
395
396 events.push(ToolCallEvent {
397 id: tc.id,
398 name: tc.function.name,
399 args,
400 result,
401 });
402 }
403
404 (ToolsResult { events }, agent)
405}