ds_api/agent/stream.rs
1//! Agent streaming state machine.
2//!
3//! This module is responsible *only* for scheduling and polling — it does not
4//! contain any business logic. All "do actual work" functions live in
5//! [`executor`][super::executor]:
6//!
7//! ```text
8//! AgentStream::poll_next
9//! │
10//! ├─ Idle → spawn run_summarize future
11//! ├─ Summarizing → poll future → ConnectingStream | FetchingResponse
12//! ├─ FetchingResponse → poll future → YieldingToolCalls | Done (yield Token)
13//! ├─ ConnectingStream → poll future → StreamingChunks
14//! ├─ StreamingChunks → poll inner stream → yield Token | YieldingToolCalls | Done
15//! ├─ YieldingToolCalls → drain queue → ExecutingTools (yield ToolCall per item)
16//! ├─ ExecutingTools → poll future → YieldingToolResults
17//! ├─ YieldingToolResults → drain queue → Idle (yield ToolResult per item)
18//! └─ Done → Poll::Ready(None)
19//! ```
20
21use std::collections::VecDeque;
22use std::pin::Pin;
23use std::task::{Context, Poll};
24
25use futures::{Stream, StreamExt};
26
27use super::executor::{
28 ConnectFuture, ExecFuture, FetchFuture, StreamingData, SummarizeFuture, apply_chunk_delta,
29 connect_stream, execute_tools, fetch_response, finalize_stream, raw_to_tool_call_info,
30 run_summarize,
31};
32use crate::agent::agent_core::{AgentEvent, DeepseekAgent, ToolCallResult};
33use crate::error::ApiError;
34
35// ── State machine ─────────────────────────────────────────────────────────────
36
37/// Drives an agent through one or more API turns, tool-execution rounds, and
38/// summarization passes, emitting [`AgentEvent`]s as a [`Stream`].
39///
40/// Obtain one by calling [`DeepseekAgent::chat`][crate::agent::DeepseekAgent::chat].
41/// Collect it with any `futures::StreamExt` combinator or `while let Some(…)`.
42///
43/// # Example
44///
45/// ```no_run
46/// use futures::StreamExt;
47/// use ds_api::{DeepseekAgent, AgentEvent};
48///
49/// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
50/// let mut stream = DeepseekAgent::new("sk-...")
51/// .with_streaming()
52/// .chat("What is 2 + 2?");
53///
54/// while let Some(event) = stream.next().await {
55/// match event? {
56/// AgentEvent::Token(text) => print!("{text}"),
57/// AgentEvent::ToolCall(info) => println!("\n[calling {}]", info.name),
58/// AgentEvent::ToolResult(res) => println!("[result: {}]", res.result),
59/// }
60/// }
61/// # Ok(())
62/// # }
63/// ```
64pub struct AgentStream {
65 /// The agent is held here whenever no future has taken ownership of it.
66 agent: Option<DeepseekAgent>,
67 state: AgentStreamState,
68}
69
70/// Every variant is self-contained: it either holds the agent directly or stores
71/// a future that will return the agent when it resolves.
72pub(crate) enum AgentStreamState {
73 /// Waiting to start (or restart after tool results are delivered).
74 Idle,
75 /// Running `maybe_summarize` before the next API turn.
76 Summarizing(SummarizeFuture),
77 /// Awaiting a non-streaming API response.
78 FetchingResponse(FetchFuture),
79 /// Awaiting the initial SSE connection.
80 ConnectingStream(ConnectFuture),
81 /// Polling an active SSE stream chunk-by-chunk.
82 StreamingChunks(Box<StreamingData>),
83 /// Yielding individual `ToolCall` preview events before execution starts.
84 /// `raw` contains the wire-format calls needed to kick off [`ExecutingTools`].
85 YieldingToolCalls {
86 pending: VecDeque<crate::agent::agent_core::ToolCallInfo>,
87 raw: Vec<crate::raw::request::message::ToolCall>,
88 },
89 /// Awaiting parallel/sequential tool execution.
90 ExecutingTools(ExecFuture),
91 /// Yielding individual `ToolResult` events after execution completes.
92 YieldingToolResults { pending: VecDeque<ToolCallResult> },
93 /// Terminal state — the stream will never produce another item.
94 Done,
95}
96
97// ── Constructor / accessor ────────────────────────────────────────────────────
98
99impl AgentStream {
100 /// Wrap an agent and start in the [`Idle`][AgentStreamState::Idle] state.
101 pub fn new(agent: DeepseekAgent) -> Self {
102 Self {
103 agent: Some(agent),
104 state: AgentStreamState::Idle,
105 }
106 }
107
108 /// Consume the stream and return the agent.
109 ///
110 /// If the stream finished normally (or was dropped mid-stream), the agent is
111 /// returned so callers can continue the conversation without constructing a
112 /// new one.
113 ///
114 /// Returns `None` only if the agent is currently owned by an in-progress
115 /// future (i.e. the stream was dropped mid-poll, which is very unusual).
116 pub fn into_agent(self) -> Option<DeepseekAgent> {
117 match self.state {
118 AgentStreamState::StreamingChunks(data) => Some(data.agent),
119 _ => self.agent,
120 }
121 }
122}
123
124// ── Stream implementation ─────────────────────────────────────────────────────
125
126impl Stream for AgentStream {
127 type Item = Result<AgentEvent, ApiError>;
128
129 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
130 let this = self.get_mut();
131
132 loop {
133 // ── StreamingChunks is handled first to avoid borrow-checker
134 // conflicts: we need to both poll the inner stream *and* replace
135 // `this.state`, which requires owning the data.
136 if matches!(this.state, AgentStreamState::StreamingChunks(_)) {
137 let mut data = match std::mem::replace(&mut this.state, AgentStreamState::Done) {
138 AgentStreamState::StreamingChunks(d) => d,
139 _ => unreachable!(),
140 };
141
142 match data.stream.poll_next_unpin(cx) {
143 Poll::Pending => {
144 this.state = AgentStreamState::StreamingChunks(data);
145 return Poll::Pending;
146 }
147
148 Poll::Ready(Some(Ok(chunk))) => {
149 let fragment = apply_chunk_delta(&mut data, chunk);
150 this.state = AgentStreamState::StreamingChunks(data);
151 if let Some(text) = fragment {
152 return Poll::Ready(Some(Ok(AgentEvent::Token(text))));
153 }
154 continue;
155 }
156
157 Poll::Ready(Some(Err(e))) => {
158 // Stream errored — salvage the agent and terminate.
159 this.agent = Some(data.agent);
160 // state stays Done (set above via mem::replace)
161 return Poll::Ready(Some(Err(e)));
162 }
163
164 Poll::Ready(None) => {
165 // SSE stream ended — assemble full tool calls from buffers.
166 let raw_tool_calls = finalize_stream(&mut data);
167
168 if raw_tool_calls.is_empty() {
169 this.agent = Some(data.agent);
170 this.state = AgentStreamState::Done;
171 return Poll::Ready(None);
172 }
173
174 let pending = raw_tool_calls
175 .iter()
176 .map(raw_to_tool_call_info)
177 .collect::<VecDeque<_>>();
178 this.agent = Some(data.agent);
179 this.state = AgentStreamState::YieldingToolCalls {
180 pending,
181 raw: raw_tool_calls,
182 };
183 continue;
184 }
185 }
186 }
187
188 // ── All other states ──────────────────────────────────────────────
189 match &mut this.state {
190 AgentStreamState::Done => return Poll::Ready(None),
191
192 AgentStreamState::Idle => {
193 let agent = this.agent.as_mut().expect("agent missing in Idle state");
194 agent.drain_interrupts();
195 let agent = this.agent.take().unwrap();
196 this.state = AgentStreamState::Summarizing(Box::pin(run_summarize(agent)));
197 }
198
199 AgentStreamState::Summarizing(fut) => match fut.as_mut().poll(cx) {
200 Poll::Pending => return Poll::Pending,
201 Poll::Ready(agent) => {
202 this.state = if agent.streaming {
203 AgentStreamState::ConnectingStream(Box::pin(connect_stream(agent)))
204 } else {
205 AgentStreamState::FetchingResponse(Box::pin(fetch_response(agent)))
206 };
207 }
208 },
209
210 AgentStreamState::FetchingResponse(fut) => match fut.as_mut().poll(cx) {
211 Poll::Pending => return Poll::Pending,
212 Poll::Ready((Err(e), agent)) => {
213 this.agent = Some(agent);
214 this.state = AgentStreamState::Done;
215 return Poll::Ready(Some(Err(e)));
216 }
217 Poll::Ready((Ok(fetch), agent)) => {
218 this.agent = Some(agent);
219
220 if fetch.raw_tool_calls.is_empty() {
221 this.state = AgentStreamState::Done;
222 return if let Some(text) = fetch.content {
223 Poll::Ready(Some(Ok(AgentEvent::Token(text))))
224 } else {
225 Poll::Ready(None)
226 };
227 }
228
229 let pending = fetch
230 .raw_tool_calls
231 .iter()
232 .map(raw_to_tool_call_info)
233 .collect::<VecDeque<_>>();
234
235 // Yield any text content before transitioning.
236 let maybe_text = fetch.content.map(AgentEvent::Token);
237 this.state = AgentStreamState::YieldingToolCalls {
238 pending,
239 raw: fetch.raw_tool_calls,
240 };
241
242 if let Some(event) = maybe_text {
243 return Poll::Ready(Some(Ok(event)));
244 }
245 continue;
246 }
247 },
248
249 AgentStreamState::ConnectingStream(fut) => match fut.as_mut().poll(cx) {
250 Poll::Pending => return Poll::Pending,
251 Poll::Ready((Err(e), agent)) => {
252 this.agent = Some(agent);
253 this.state = AgentStreamState::Done;
254 return Poll::Ready(Some(Err(e)));
255 }
256 Poll::Ready((Ok(stream), agent)) => {
257 this.state = AgentStreamState::StreamingChunks(Box::new(StreamingData {
258 stream,
259 agent,
260 content_buf: String::new(),
261 tool_call_bufs: Vec::new(),
262 }));
263 // Loop back to hit the StreamingChunks branch.
264 }
265 },
266
267 AgentStreamState::YieldingToolCalls { pending, raw } => {
268 if let Some(info) = pending.pop_front() {
269 return Poll::Ready(Some(Ok(AgentEvent::ToolCall(info))));
270 }
271 // All previews yielded — begin execution.
272 let agent = this
273 .agent
274 .take()
275 .expect("agent missing in YieldingToolCalls");
276 let raw_calls = std::mem::take(raw);
277 this.state =
278 AgentStreamState::ExecutingTools(Box::pin(execute_tools(agent, raw_calls)));
279 }
280
281 AgentStreamState::ExecutingTools(fut) => match fut.as_mut().poll(cx) {
282 Poll::Pending => return Poll::Pending,
283 Poll::Ready((tools_result, agent)) => {
284 this.agent = Some(agent);
285 this.state = AgentStreamState::YieldingToolResults {
286 pending: tools_result.results.into_iter().collect(),
287 };
288 }
289 },
290
291 AgentStreamState::YieldingToolResults { pending } => {
292 if let Some(result) = pending.pop_front() {
293 return Poll::Ready(Some(Ok(AgentEvent::ToolResult(result))));
294 }
295 // All results delivered — loop back for the next API turn.
296 this.state = AgentStreamState::Idle;
297 }
298
299 // Handled in the dedicated block above; this arm is unreachable
300 // but the compiler cannot verify that without exhaustiveness help.
301 AgentStreamState::StreamingChunks(_) => unreachable!(),
302 }
303 }
304 }
305}