Skip to main content

ds_api/agent/
stream.rs

1use std::collections::VecDeque;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use futures::stream::BoxStream;
6use futures::{Stream, StreamExt};
7use serde_json::Value;
8
9use crate::agent::agent_core::{AgentEvent, DeepseekAgent, ToolCallInfo, ToolCallResult};
10use crate::error::ApiError;
11use crate::raw::ChatCompletionChunk;
12use crate::raw::request::message::{FunctionCall, Message, Role, ToolCall, ToolType};
13
14type SummarizeFuture = Pin<Box<dyn std::future::Future<Output = DeepseekAgent> + Send>>;
15
16// ── Internal result types ────────────────────────────────────────────────────
17
18struct FetchResult {
19    content: Option<String>,
20    raw_tool_calls: Vec<ToolCall>,
21}
22
23struct ToolsResult {
24    results: Vec<ToolCallResult>,
25}
26
27// ── Streaming accumulator ────────────────────────────────────────────────────
28
29struct PartialToolCall {
30    id: String,
31    name: String,
32    arguments: String,
33}
34
35struct StreamingData {
36    stream: BoxStream<'static, Result<ChatCompletionChunk, ApiError>>,
37    agent: DeepseekAgent,
38    content_buf: String,
39    tool_call_bufs: Vec<Option<PartialToolCall>>,
40}
41
42// ── Type aliases for future outputs ─────────────────────────────────────────
43
44type FetchFuture = Pin<
45    Box<dyn std::future::Future<Output = (Result<FetchResult, ApiError>, DeepseekAgent)> + Send>,
46>;
47
48type ConnectFuture = Pin<
49    Box<
50        dyn std::future::Future<
51                Output = (
52                    Result<BoxStream<'static, Result<ChatCompletionChunk, ApiError>>, ApiError>,
53                    DeepseekAgent,
54                ),
55            > + Send,
56    >,
57>;
58
59type ExecFuture = Pin<Box<dyn std::future::Future<Output = (ToolsResult, DeepseekAgent)> + Send>>;
60
61// ── State machine ────────────────────────────────────────────────────────────
62
63pub struct AgentStream {
64    /// The agent is held here when no future has ownership of it.
65    agent: Option<DeepseekAgent>,
66    state: AgentStreamState,
67}
68
69enum AgentStreamState {
70    Idle,
71    /// Running `maybe_summarize` before the next API turn.
72    Summarizing(SummarizeFuture),
73    FetchingResponse(FetchFuture),
74    ConnectingStream(ConnectFuture),
75    StreamingChunks(Box<StreamingData>),
76    /// Yielding individual `ToolCall` events before kicking off execution.
77    /// Carries the queued previews and the raw calls needed to execute them.
78    YieldingToolCalls {
79        pending: VecDeque<ToolCallInfo>,
80        raw: Vec<ToolCall>,
81    },
82    ExecutingTools(ExecFuture),
83    /// Yielding individual `ToolResult` events after execution completes.
84    YieldingToolResults {
85        pending: VecDeque<ToolCallResult>,
86    },
87    Done,
88}
89
90impl AgentStream {
91    pub fn new(agent: DeepseekAgent) -> Self {
92        Self {
93            agent: Some(agent),
94            state: AgentStreamState::Idle,
95        }
96    }
97
98    pub fn into_agent(self) -> Option<DeepseekAgent> {
99        match self.state {
100            AgentStreamState::StreamingChunks(data) => Some(data.agent),
101            _ => self.agent,
102        }
103    }
104}
105
106impl Stream for AgentStream {
107    type Item = Result<AgentEvent, ApiError>;
108
109    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
110        let this = self.get_mut();
111
112        loop {
113            // ── StreamingChunks is extracted first to avoid borrow-checker
114            //    conflicts when we need to both poll the inner stream and
115            //    replace `this.state`.
116            if matches!(this.state, AgentStreamState::StreamingChunks(_)) {
117                let mut data = match std::mem::replace(&mut this.state, AgentStreamState::Done) {
118                    AgentStreamState::StreamingChunks(d) => d,
119                    _ => unreachable!(),
120                };
121
122                match data.stream.poll_next_unpin(cx) {
123                    Poll::Pending => {
124                        this.state = AgentStreamState::StreamingChunks(data);
125                        return Poll::Pending;
126                    }
127
128                    Poll::Ready(Some(Ok(chunk))) => {
129                        let mut fragment: Option<String> = None;
130
131                        if let Some(choice) = chunk.choices.into_iter().next() {
132                            let delta = choice.delta;
133
134                            if let Some(dtcs) = delta.tool_calls {
135                                for dtc in dtcs {
136                                    let idx = dtc.index as usize;
137                                    if data.tool_call_bufs.len() <= idx {
138                                        data.tool_call_bufs.resize_with(idx + 1, || None);
139                                    }
140                                    let entry = &mut data.tool_call_bufs[idx];
141                                    if entry.is_none() {
142                                        *entry = Some(PartialToolCall {
143                                            id: dtc.id.clone().unwrap_or_default(),
144                                            name: dtc
145                                                .function
146                                                .as_ref()
147                                                .and_then(|f| f.name.clone())
148                                                .unwrap_or_default(),
149                                            arguments: String::new(),
150                                        });
151                                    }
152                                    if let Some(partial) = entry.as_mut() {
153                                        if let Some(id) = dtc.id
154                                            && partial.id.is_empty()
155                                        {
156                                            partial.id = id;
157                                        }
158                                        if let Some(func) = dtc.function
159                                            && let Some(args) = func.arguments
160                                        {
161                                            partial.arguments.push_str(&args);
162                                        }
163                                    }
164                                }
165                            }
166
167                            if let Some(content) = delta.content
168                                && !content.is_empty()
169                            {
170                                data.content_buf.push_str(&content);
171                                fragment = Some(content);
172                            }
173                        }
174
175                        this.state = AgentStreamState::StreamingChunks(data);
176
177                        if let Some(text) = fragment {
178                            return Poll::Ready(Some(Ok(AgentEvent::Token(text))));
179                        }
180                        continue;
181                    }
182
183                    Poll::Ready(Some(Err(e))) => {
184                        this.agent = Some(data.agent);
185                        // state stays Done
186                        return Poll::Ready(Some(Err(e)));
187                    }
188
189                    Poll::Ready(None) => {
190                        // SSE stream finished — assemble raw tool calls from buffers.
191                        let raw_tool_calls: Vec<ToolCall> = data
192                            .tool_call_bufs
193                            .into_iter()
194                            .flatten()
195                            .map(|p| ToolCall {
196                                id: p.id,
197                                r#type: ToolType::Function,
198                                function: FunctionCall {
199                                    name: p.name,
200                                    arguments: p.arguments,
201                                },
202                            })
203                            .collect();
204
205                        let assistant_msg = Message {
206                            role: Role::Assistant,
207                            content: if data.content_buf.is_empty() {
208                                None
209                            } else {
210                                Some(data.content_buf)
211                            },
212                            tool_calls: if raw_tool_calls.is_empty() {
213                                None
214                            } else {
215                                Some(raw_tool_calls.clone())
216                            },
217                            ..Default::default()
218                        };
219                        data.agent.conversation.history_mut().push(assistant_msg);
220
221                        if raw_tool_calls.is_empty() {
222                            this.agent = Some(data.agent);
223                            this.state = AgentStreamState::Done;
224                            return Poll::Ready(None);
225                        }
226
227                        let pending = raw_tool_calls
228                            .iter()
229                            .map(raw_to_tool_call_info)
230                            .collect::<VecDeque<_>>();
231                        this.agent = Some(data.agent);
232                        this.state = AgentStreamState::YieldingToolCalls {
233                            pending,
234                            raw: raw_tool_calls,
235                        };
236                        continue;
237                    }
238                }
239            }
240
241            match &mut this.state {
242                AgentStreamState::Done => return Poll::Ready(None),
243
244                AgentStreamState::Idle => {
245                    let agent = this.agent.take().expect("agent missing");
246                    let fut = Box::pin(run_summarize(agent));
247                    this.state = AgentStreamState::Summarizing(fut);
248                }
249
250                AgentStreamState::Summarizing(fut) => match fut.as_mut().poll(cx) {
251                    Poll::Pending => return Poll::Pending,
252                    Poll::Ready(agent) => {
253                        if agent.streaming {
254                            let fut = Box::pin(connect_stream(agent));
255                            this.state = AgentStreamState::ConnectingStream(fut);
256                        } else {
257                            let fut = Box::pin(fetch_response(agent));
258                            this.state = AgentStreamState::FetchingResponse(fut);
259                        }
260                    }
261                },
262
263                AgentStreamState::FetchingResponse(fut) => match fut.as_mut().poll(cx) {
264                    Poll::Pending => return Poll::Pending,
265                    Poll::Ready((Err(e), agent)) => {
266                        this.agent = Some(agent);
267                        this.state = AgentStreamState::Done;
268                        return Poll::Ready(Some(Err(e)));
269                    }
270                    Poll::Ready((Ok(fetch), agent)) => {
271                        this.agent = Some(agent);
272
273                        if fetch.raw_tool_calls.is_empty() {
274                            this.state = AgentStreamState::Done;
275                            return if let Some(text) = fetch.content {
276                                Poll::Ready(Some(Ok(AgentEvent::Token(text))))
277                            } else {
278                                Poll::Ready(None)
279                            };
280                        }
281
282                        let pending = fetch
283                            .raw_tool_calls
284                            .iter()
285                            .map(raw_to_tool_call_info)
286                            .collect::<VecDeque<_>>();
287                        this.state = AgentStreamState::YieldingToolCalls {
288                            pending,
289                            raw: fetch.raw_tool_calls,
290                        };
291
292                        if let Some(text) = fetch.content {
293                            return Poll::Ready(Some(Ok(AgentEvent::Token(text))));
294                        }
295                        continue;
296                    }
297                },
298
299                AgentStreamState::ConnectingStream(fut) => match fut.as_mut().poll(cx) {
300                    Poll::Pending => return Poll::Pending,
301                    Poll::Ready((Err(e), agent)) => {
302                        this.agent = Some(agent);
303                        this.state = AgentStreamState::Done;
304                        return Poll::Ready(Some(Err(e)));
305                    }
306                    Poll::Ready((Ok(stream), agent)) => {
307                        this.state = AgentStreamState::StreamingChunks(Box::new(StreamingData {
308                            stream,
309                            agent,
310                            content_buf: String::new(),
311                            tool_call_bufs: Vec::new(),
312                        }));
313                    }
314                },
315
316                AgentStreamState::YieldingToolCalls { pending, raw } => {
317                    if let Some(info) = pending.pop_front() {
318                        return Poll::Ready(Some(Ok(AgentEvent::ToolCall(info))));
319                    }
320                    // All previews yielded — execute the tools.
321                    let agent = this.agent.take().expect("agent missing");
322                    let raw_calls = std::mem::take(raw);
323                    let fut = Box::pin(execute_tools(agent, raw_calls));
324                    this.state = AgentStreamState::ExecutingTools(fut);
325                }
326
327                AgentStreamState::ExecutingTools(fut) => match fut.as_mut().poll(cx) {
328                    Poll::Pending => return Poll::Pending,
329                    Poll::Ready((tools_result, agent)) => {
330                        this.agent = Some(agent);
331                        this.state = AgentStreamState::YieldingToolResults {
332                            pending: tools_result.results.into_iter().collect(),
333                        };
334                    }
335                },
336
337                AgentStreamState::YieldingToolResults { pending } => {
338                    if let Some(result) = pending.pop_front() {
339                        return Poll::Ready(Some(Ok(AgentEvent::ToolResult(result))));
340                    }
341                    // All results yielded — loop back for another API turn.
342                    this.state = AgentStreamState::Idle;
343                }
344
345                AgentStreamState::StreamingChunks(_) => unreachable!(),
346            }
347        }
348    }
349}
350
351// ── Helpers ───────────────────────────────────────────────────────────────────
352
353fn raw_to_tool_call_info(tc: &ToolCall) -> ToolCallInfo {
354    ToolCallInfo {
355        id: tc.id.clone(),
356        name: tc.function.name.clone(),
357        args: serde_json::from_str(&tc.function.arguments).unwrap_or(Value::Null),
358    }
359}
360
361/// Run `maybe_summarize` on the agent's conversation and return the agent.
362async fn run_summarize(mut agent: DeepseekAgent) -> DeepseekAgent {
363    agent.conversation.maybe_summarize().await;
364    agent
365}
366
367fn build_request(agent: &DeepseekAgent) -> crate::api::ApiRequest {
368    let history = agent.conversation.history().to_vec();
369    let mut req = crate::api::ApiRequest::builder().messages(history);
370    for tool in &agent.tools {
371        for raw in tool.raw_tools() {
372            req = req.add_tool(raw);
373        }
374    }
375    if !agent.tools.is_empty() {
376        req = req.tool_choice_auto();
377    }
378    req
379}
380
381async fn fetch_response(
382    mut agent: DeepseekAgent,
383) -> (Result<FetchResult, ApiError>, DeepseekAgent) {
384    let req = build_request(&agent);
385
386    let resp = match agent.conversation.client.send(req).await {
387        Ok(r) => r,
388        Err(e) => return (Err(e), agent),
389    };
390
391    let choice = match resp.choices.into_iter().next() {
392        Some(c) => c,
393        None => {
394            return (
395                Err(ApiError::Other("empty response: no choices".into())),
396                agent,
397            );
398        }
399    };
400
401    let assistant_msg = choice.message;
402    let content = assistant_msg.content.clone();
403    let raw_tool_calls = assistant_msg.tool_calls.clone().unwrap_or_default();
404    agent.conversation.history_mut().push(assistant_msg);
405
406    (
407        Ok(FetchResult {
408            content,
409            raw_tool_calls,
410        }),
411        agent,
412    )
413}
414
415async fn connect_stream(
416    agent: DeepseekAgent,
417) -> (
418    Result<BoxStream<'static, Result<ChatCompletionChunk, ApiError>>, ApiError>,
419    DeepseekAgent,
420) {
421    let req = build_request(&agent);
422    match agent.conversation.client.clone().into_stream(req).await {
423        Ok(stream) => (Ok(stream), agent),
424        Err(e) => (Err(e), agent),
425    }
426}
427
428async fn execute_tools(
429    mut agent: DeepseekAgent,
430    raw_tool_calls: Vec<ToolCall>,
431) -> (ToolsResult, DeepseekAgent) {
432    let mut results = vec![];
433
434    for tc in raw_tool_calls {
435        let args: Value = serde_json::from_str(&tc.function.arguments).unwrap_or(Value::Null);
436
437        let result = match agent.tool_index.get(&tc.function.name) {
438            Some(&idx) => agent.tools[idx].call(&tc.function.name, args.clone()).await,
439            None => {
440                serde_json::json!({ "error": format!("unknown tool: {}", tc.function.name) })
441            }
442        };
443
444        agent.conversation.history_mut().push(Message {
445            role: Role::Tool,
446            content: Some(result.to_string()),
447            tool_call_id: Some(tc.id.clone()),
448            ..Default::default()
449        });
450
451        results.push(ToolCallResult {
452            id: tc.id,
453            name: tc.function.name,
454            args,
455            result,
456        });
457    }
458
459    (ToolsResult { results }, agent)
460}