Skip to main content

ds_api/agent/
stream.rs

1use std::pin::Pin;
2use std::task::{Context, Poll};
3
4use futures::stream::BoxStream;
5use futures::{Stream, StreamExt};
6use serde_json::Value;
7
8use crate::agent::agent_core::{AgentResponse, DeepseekAgent, ToolCallEvent};
9use crate::conversation::Conversation;
10use crate::error::ApiError;
11use crate::raw::request::message::{FunctionCall, Message, Role, ToolCall, ToolType};
12use crate::raw::ChatCompletionChunk;
13
14// ── Internal result types ────────────────────────────────────────────────────
15
16struct FetchResult {
17    content: Option<String>,
18    raw_tool_calls: Vec<ToolCall>,
19}
20
21struct ToolsResult {
22    events: Vec<ToolCallEvent>,
23}
24
25// ── Streaming accumulator ────────────────────────────────────────────────────
26
27struct PartialToolCall {
28    id: String,
29    name: String,
30    arguments: String,
31}
32
33struct StreamingData {
34    stream: BoxStream<'static, Result<ChatCompletionChunk, ApiError>>,
35    agent: DeepseekAgent,
36    content_buf: String,
37    tool_call_bufs: Vec<Option<PartialToolCall>>,
38}
39
40// ── Type aliases for future outputs ─────────────────────────────────────────
41
42type FetchFuture =
43    Pin<Box<dyn std::future::Future<Output = (Result<FetchResult, ApiError>, DeepseekAgent)> + Send>>;
44
45type ConnectFuture = Pin<
46    Box<
47        dyn std::future::Future<
48                Output = (
49                    Result<BoxStream<'static, Result<ChatCompletionChunk, ApiError>>, ApiError>,
50                    DeepseekAgent,
51                ),
52            > + Send,
53    >,
54>;
55
56type ExecFuture =
57    Pin<Box<dyn std::future::Future<Output = (ToolsResult, DeepseekAgent)> + Send>>;
58
59// ── State machine ────────────────────────────────────────────────────────────
60
61pub struct AgentStream {
62    agent: Option<DeepseekAgent>,
63    state: AgentStreamState,
64}
65
66enum AgentStreamState {
67    Idle,
68    FetchingResponse(FetchFuture),
69    ConnectingStream(ConnectFuture),
70    StreamingChunks(Box<StreamingData>),
71    ExecutingTools(ExecFuture),
72    Done,
73}
74
75impl AgentStream {
76    pub fn new(agent: DeepseekAgent) -> Self {
77        Self {
78            agent: Some(agent),
79            state: AgentStreamState::Idle,
80        }
81    }
82
83    pub fn into_agent(self) -> Option<DeepseekAgent> {
84        match self.state {
85            AgentStreamState::StreamingChunks(data) => Some(data.agent),
86            _ => self.agent,
87        }
88    }
89}
90
91impl Stream for AgentStream {
92    type Item = Result<AgentResponse, ApiError>;
93
94    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
95        let this = self.get_mut();
96
97        loop {
98            // StreamingChunks is handled first to avoid borrow-checker conflicts when
99            // we need to both poll the inner stream and replace `this.state`.
100            if matches!(this.state, AgentStreamState::StreamingChunks(_)) {
101                let mut data =
102                    match std::mem::replace(&mut this.state, AgentStreamState::Done) {
103                        AgentStreamState::StreamingChunks(d) => d,
104                        _ => unreachable!(),
105                    };
106
107                match data.stream.poll_next_unpin(cx) {
108                    Poll::Pending => {
109                        this.state = AgentStreamState::StreamingChunks(data);
110                        return Poll::Pending;
111                    }
112
113                    Poll::Ready(Some(Ok(chunk))) => {
114                        let mut fragment: Option<String> = None;
115
116                        if let Some(choice) = chunk.choices.into_iter().next() {
117                            let delta = choice.delta;
118
119                            if let Some(dtcs) = delta.tool_calls {
120                                for dtc in dtcs {
121                                    let idx = dtc.index as usize;
122                                    if data.tool_call_bufs.len() <= idx {
123                                        data.tool_call_bufs.resize_with(idx + 1, || None);
124                                    }
125                                    let entry = &mut data.tool_call_bufs[idx];
126                                    if entry.is_none() {
127                                        *entry = Some(PartialToolCall {
128                                            id: dtc.id.clone().unwrap_or_default(),
129                                            name: dtc
130                                                .function
131                                                .as_ref()
132                                                .and_then(|f| f.name.clone())
133                                                .unwrap_or_default(),
134                                            arguments: String::new(),
135                                        });
136                                    }
137                                    if let Some(partial) = entry.as_mut() {
138                                        if let Some(id) = dtc.id {
139                                            if partial.id.is_empty() {
140                                                partial.id = id;
141                                            }
142                                        }
143                                        if let Some(func) = dtc.function {
144                                            if let Some(args) = func.arguments {
145                                                partial.arguments.push_str(&args);
146                                            }
147                                        }
148                                    }
149                                }
150                            }
151
152                            if let Some(content) = delta.content {
153                                if !content.is_empty() {
154                                    data.content_buf.push_str(&content);
155                                    fragment = Some(content);
156                                }
157                            }
158                        }
159
160                        this.state = AgentStreamState::StreamingChunks(data);
161
162                        if let Some(content) = fragment {
163                            return Poll::Ready(Some(Ok(AgentResponse {
164                                content: Some(content),
165                                tool_calls: vec![],
166                            })));
167                        }
168                        continue;
169                    }
170
171                    Poll::Ready(Some(Err(e))) => {
172                        // Propagate the stream error; state stays Done.
173                        this.agent = Some(data.agent);
174                        return Poll::Ready(Some(Err(e)));
175                    }
176
177                    Poll::Ready(None) => {
178                        let raw_tool_calls: Vec<ToolCall> = data
179                            .tool_call_bufs
180                            .into_iter()
181                            .flatten()
182                            .map(|p| ToolCall {
183                                id: p.id,
184                                r#type: ToolType::Function,
185                                function: FunctionCall {
186                                    name: p.name,
187                                    arguments: p.arguments,
188                                },
189                            })
190                            .collect();
191
192                        let assistant_msg = Message {
193                            role: Role::Assistant,
194                            content: if data.content_buf.is_empty() {
195                                None
196                            } else {
197                                Some(data.content_buf)
198                            },
199                            tool_calls: if raw_tool_calls.is_empty() {
200                                None
201                            } else {
202                                Some(raw_tool_calls.clone())
203                            },
204                            ..Default::default()
205                        };
206                        data.agent.conversation.history_mut().push(assistant_msg);
207
208                        if raw_tool_calls.is_empty() {
209                            this.agent = Some(data.agent);
210                            return Poll::Ready(None);
211                        }
212
213                        let preview_events = build_preview(&raw_tool_calls);
214                        let fut = Box::pin(execute_tools(data.agent, raw_tool_calls));
215                        this.state = AgentStreamState::ExecutingTools(fut);
216                        return Poll::Ready(Some(Ok(AgentResponse {
217                            content: None,
218                            tool_calls: preview_events,
219                        })));
220                    }
221                }
222            }
223
224            match &mut this.state {
225                AgentStreamState::Done => return Poll::Ready(None),
226
227                AgentStreamState::Idle => {
228                    let agent = this.agent.take().expect("agent missing");
229                    if agent.streaming {
230                        let fut = Box::pin(connect_stream(agent));
231                        this.state = AgentStreamState::ConnectingStream(fut);
232                    } else {
233                        let fut = Box::pin(fetch_response(agent));
234                        this.state = AgentStreamState::FetchingResponse(fut);
235                    }
236                }
237
238                AgentStreamState::FetchingResponse(fut) => {
239                    match fut.as_mut().poll(cx) {
240                        Poll::Pending => return Poll::Pending,
241                        Poll::Ready((Err(e), agent)) => {
242                            this.agent = Some(agent);
243                            this.state = AgentStreamState::Done;
244                            return Poll::Ready(Some(Err(e)));
245                        }
246                        Poll::Ready((Ok(fetch), agent)) => {
247                            if fetch.raw_tool_calls.is_empty() {
248                                this.agent = Some(agent);
249                                this.state = AgentStreamState::Done;
250                                return Poll::Ready(Some(Ok(AgentResponse {
251                                    content: fetch.content,
252                                    tool_calls: vec![],
253                                })));
254                            }
255
256                            let content = fetch.content.clone();
257                            let raw_calls = fetch.raw_tool_calls;
258                            let preview_events = build_preview(&raw_calls);
259                            let fut = Box::pin(execute_tools(agent, raw_calls));
260                            this.state = AgentStreamState::ExecutingTools(fut);
261                            return Poll::Ready(Some(Ok(AgentResponse {
262                                content,
263                                tool_calls: preview_events,
264                            })));
265                        }
266                    }
267                }
268
269                AgentStreamState::ConnectingStream(fut) => {
270                    match fut.as_mut().poll(cx) {
271                        Poll::Pending => return Poll::Pending,
272                        Poll::Ready((Err(e), agent)) => {
273                            this.agent = Some(agent);
274                            this.state = AgentStreamState::Done;
275                            return Poll::Ready(Some(Err(e)));
276                        }
277                        Poll::Ready((Ok(stream), agent)) => {
278                            this.state =
279                                AgentStreamState::StreamingChunks(Box::new(StreamingData {
280                                    stream,
281                                    agent,
282                                    content_buf: String::new(),
283                                    tool_call_bufs: Vec::new(),
284                                }));
285                        }
286                    }
287                }
288
289                AgentStreamState::ExecutingTools(fut) => {
290                    match fut.as_mut().poll(cx) {
291                        Poll::Pending => return Poll::Pending,
292                        Poll::Ready((results, agent)) => {
293                            this.agent = Some(agent);
294                            this.state = AgentStreamState::Idle;
295                            return Poll::Ready(Some(Ok(AgentResponse {
296                                content: None,
297                                tool_calls: results.events,
298                            })));
299                        }
300                    }
301                }
302
303                AgentStreamState::StreamingChunks(_) => unreachable!(),
304            }
305        }
306    }
307}
308
309// ── Helpers ──────────────────────────────────────────────────────────────────
310
311fn build_preview(raw_calls: &[ToolCall]) -> Vec<ToolCallEvent> {
312    raw_calls
313        .iter()
314        .map(|tc| ToolCallEvent {
315            id: tc.id.clone(),
316            name: tc.function.name.clone(),
317            args: serde_json::from_str(&tc.function.arguments).unwrap_or(Value::Null),
318            result: Value::Null,
319        })
320        .collect()
321}
322
323fn build_request(agent: &DeepseekAgent) -> crate::api::ApiRequest {
324    let history = agent.conversation.history().clone();
325    let mut req = crate::api::ApiRequest::builder().messages(history);
326    for tool in &agent.tools {
327        for raw in tool.raw_tools() {
328            req = req.add_tool(raw);
329        }
330    }
331    if !agent.tools.is_empty() {
332        req = req.tool_choice_auto();
333    }
334    req
335}
336
337async fn fetch_response(
338    mut agent: DeepseekAgent,
339) -> (Result<FetchResult, ApiError>, DeepseekAgent) {
340    let req = build_request(&agent);
341
342    let resp = match agent.client.send(req).await {
343        Ok(r) => r,
344        Err(e) => return (Err(e), agent),
345    };
346
347    let choice = match resp.choices.into_iter().next() {
348        Some(c) => c,
349        None => return (Err(ApiError::Other("empty response: no choices".into())), agent),
350    };
351
352    let assistant_msg = choice.message;
353    let content = assistant_msg.content.clone();
354    let raw_tool_calls = assistant_msg.tool_calls.clone().unwrap_or_default();
355    agent.conversation.history_mut().push(assistant_msg);
356
357    (Ok(FetchResult { content, raw_tool_calls }), agent)
358}
359
360async fn connect_stream(
361    agent: DeepseekAgent,
362) -> (
363    Result<BoxStream<'static, Result<ChatCompletionChunk, ApiError>>, ApiError>,
364    DeepseekAgent,
365) {
366    let req = build_request(&agent);
367    match agent.client.clone().into_stream(req).await {
368        Ok(stream) => (Ok(stream), agent),
369        Err(e) => (Err(e), agent),
370    }
371}
372
373async fn execute_tools(
374    mut agent: DeepseekAgent,
375    raw_tool_calls: Vec<ToolCall>,
376) -> (ToolsResult, DeepseekAgent) {
377    let mut events = vec![];
378
379    for tc in raw_tool_calls {
380        let args: Value = serde_json::from_str(&tc.function.arguments).unwrap_or(Value::Null);
381
382        let result = match agent.tool_index.get(&tc.function.name) {
383            Some(&idx) => agent.tools[idx].call(&tc.function.name, args.clone()).await,
384            None => {
385                serde_json::json!({ "error": format!("unknown tool: {}", tc.function.name) })
386            }
387        };
388
389        agent.conversation.history_mut().push(Message {
390            role: Role::Tool,
391            content: Some(result.to_string()),
392            tool_call_id: Some(tc.id.clone()),
393            ..Default::default()
394        });
395
396        events.push(ToolCallEvent {
397            id: tc.id,
398            name: tc.function.name,
399            args,
400            result,
401        });
402    }
403
404    (ToolsResult { events }, agent)
405}