1use std::collections::VecDeque;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use futures::stream::BoxStream;
6use futures::{Stream, StreamExt};
7use serde_json::Value;
8
9use crate::agent::agent_core::{AgentEvent, DeepseekAgent, ToolCallInfo, ToolCallResult};
10use crate::error::ApiError;
11use crate::raw::ChatCompletionChunk;
12use crate::raw::request::message::{FunctionCall, Message, Role, ToolCall, ToolType};
13
14type SummarizeFuture = Pin<Box<dyn std::future::Future<Output = DeepseekAgent> + Send>>;
15
16struct FetchResult {
19 content: Option<String>,
20 raw_tool_calls: Vec<ToolCall>,
21}
22
23struct ToolsResult {
24 results: Vec<ToolCallResult>,
25}
26
27struct PartialToolCall {
30 id: String,
31 name: String,
32 arguments: String,
33}
34
35struct StreamingData {
36 stream: BoxStream<'static, Result<ChatCompletionChunk, ApiError>>,
37 agent: DeepseekAgent,
38 content_buf: String,
39 tool_call_bufs: Vec<Option<PartialToolCall>>,
40}
41
42type FetchFuture = Pin<
45 Box<dyn std::future::Future<Output = (Result<FetchResult, ApiError>, DeepseekAgent)> + Send>,
46>;
47
48type ConnectFuture = Pin<
49 Box<
50 dyn std::future::Future<
51 Output = (
52 Result<BoxStream<'static, Result<ChatCompletionChunk, ApiError>>, ApiError>,
53 DeepseekAgent,
54 ),
55 > + Send,
56 >,
57>;
58
59type ExecFuture = Pin<Box<dyn std::future::Future<Output = (ToolsResult, DeepseekAgent)> + Send>>;
60
61pub struct AgentStream {
64 agent: Option<DeepseekAgent>,
66 state: AgentStreamState,
67}
68
69enum AgentStreamState {
70 Idle,
71 Summarizing(SummarizeFuture),
73 FetchingResponse(FetchFuture),
74 ConnectingStream(ConnectFuture),
75 StreamingChunks(Box<StreamingData>),
76 YieldingToolCalls {
79 pending: VecDeque<ToolCallInfo>,
80 raw: Vec<ToolCall>,
81 },
82 ExecutingTools(ExecFuture),
83 YieldingToolResults {
85 pending: VecDeque<ToolCallResult>,
86 },
87 Done,
88}
89
90impl AgentStream {
91 pub fn new(agent: DeepseekAgent) -> Self {
92 Self {
93 agent: Some(agent),
94 state: AgentStreamState::Idle,
95 }
96 }
97
98 pub fn into_agent(self) -> Option<DeepseekAgent> {
99 match self.state {
100 AgentStreamState::StreamingChunks(data) => Some(data.agent),
101 _ => self.agent,
102 }
103 }
104}
105
106impl Stream for AgentStream {
107 type Item = Result<AgentEvent, ApiError>;
108
109 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
110 let this = self.get_mut();
111
112 loop {
113 if matches!(this.state, AgentStreamState::StreamingChunks(_)) {
117 let mut data = match std::mem::replace(&mut this.state, AgentStreamState::Done) {
118 AgentStreamState::StreamingChunks(d) => d,
119 _ => unreachable!(),
120 };
121
122 match data.stream.poll_next_unpin(cx) {
123 Poll::Pending => {
124 this.state = AgentStreamState::StreamingChunks(data);
125 return Poll::Pending;
126 }
127
128 Poll::Ready(Some(Ok(chunk))) => {
129 let mut fragment: Option<String> = None;
130
131 if let Some(choice) = chunk.choices.into_iter().next() {
132 let delta = choice.delta;
133
134 if let Some(dtcs) = delta.tool_calls {
135 for dtc in dtcs {
136 let idx = dtc.index as usize;
137 if data.tool_call_bufs.len() <= idx {
138 data.tool_call_bufs.resize_with(idx + 1, || None);
139 }
140 let entry = &mut data.tool_call_bufs[idx];
141 if entry.is_none() {
142 *entry = Some(PartialToolCall {
143 id: dtc.id.clone().unwrap_or_default(),
144 name: dtc
145 .function
146 .as_ref()
147 .and_then(|f| f.name.clone())
148 .unwrap_or_default(),
149 arguments: String::new(),
150 });
151 }
152 if let Some(partial) = entry.as_mut() {
153 if let Some(id) = dtc.id
154 && partial.id.is_empty()
155 {
156 partial.id = id;
157 }
158 if let Some(func) = dtc.function
159 && let Some(args) = func.arguments
160 {
161 partial.arguments.push_str(&args);
162 }
163 }
164 }
165 }
166
167 if let Some(content) = delta.content
168 && !content.is_empty()
169 {
170 data.content_buf.push_str(&content);
171 fragment = Some(content);
172 }
173 }
174
175 this.state = AgentStreamState::StreamingChunks(data);
176
177 if let Some(text) = fragment {
178 return Poll::Ready(Some(Ok(AgentEvent::Token(text))));
179 }
180 continue;
181 }
182
183 Poll::Ready(Some(Err(e))) => {
184 this.agent = Some(data.agent);
185 return Poll::Ready(Some(Err(e)));
187 }
188
189 Poll::Ready(None) => {
190 let raw_tool_calls: Vec<ToolCall> = data
192 .tool_call_bufs
193 .into_iter()
194 .flatten()
195 .map(|p| ToolCall {
196 id: p.id,
197 r#type: ToolType::Function,
198 function: FunctionCall {
199 name: p.name,
200 arguments: p.arguments,
201 },
202 })
203 .collect();
204
205 let assistant_msg = Message {
206 role: Role::Assistant,
207 content: if data.content_buf.is_empty() {
208 None
209 } else {
210 Some(data.content_buf)
211 },
212 tool_calls: if raw_tool_calls.is_empty() {
213 None
214 } else {
215 Some(raw_tool_calls.clone())
216 },
217 ..Default::default()
218 };
219 data.agent.conversation.history_mut().push(assistant_msg);
220
221 if raw_tool_calls.is_empty() {
222 this.agent = Some(data.agent);
223 this.state = AgentStreamState::Done;
224 return Poll::Ready(None);
225 }
226
227 let pending = raw_tool_calls
228 .iter()
229 .map(raw_to_tool_call_info)
230 .collect::<VecDeque<_>>();
231 this.agent = Some(data.agent);
232 this.state = AgentStreamState::YieldingToolCalls {
233 pending,
234 raw: raw_tool_calls,
235 };
236 continue;
237 }
238 }
239 }
240
241 match &mut this.state {
242 AgentStreamState::Done => return Poll::Ready(None),
243
244 AgentStreamState::Idle => {
245 let agent = this.agent.take().expect("agent missing");
246 let fut = Box::pin(run_summarize(agent));
247 this.state = AgentStreamState::Summarizing(fut);
248 }
249
250 AgentStreamState::Summarizing(fut) => match fut.as_mut().poll(cx) {
251 Poll::Pending => return Poll::Pending,
252 Poll::Ready(agent) => {
253 if agent.streaming {
254 let fut = Box::pin(connect_stream(agent));
255 this.state = AgentStreamState::ConnectingStream(fut);
256 } else {
257 let fut = Box::pin(fetch_response(agent));
258 this.state = AgentStreamState::FetchingResponse(fut);
259 }
260 }
261 },
262
263 AgentStreamState::FetchingResponse(fut) => match fut.as_mut().poll(cx) {
264 Poll::Pending => return Poll::Pending,
265 Poll::Ready((Err(e), agent)) => {
266 this.agent = Some(agent);
267 this.state = AgentStreamState::Done;
268 return Poll::Ready(Some(Err(e)));
269 }
270 Poll::Ready((Ok(fetch), agent)) => {
271 this.agent = Some(agent);
272
273 if fetch.raw_tool_calls.is_empty() {
274 this.state = AgentStreamState::Done;
275 return if let Some(text) = fetch.content {
276 Poll::Ready(Some(Ok(AgentEvent::Token(text))))
277 } else {
278 Poll::Ready(None)
279 };
280 }
281
282 let pending = fetch
283 .raw_tool_calls
284 .iter()
285 .map(raw_to_tool_call_info)
286 .collect::<VecDeque<_>>();
287 this.state = AgentStreamState::YieldingToolCalls {
288 pending,
289 raw: fetch.raw_tool_calls,
290 };
291
292 if let Some(text) = fetch.content {
293 return Poll::Ready(Some(Ok(AgentEvent::Token(text))));
294 }
295 continue;
296 }
297 },
298
299 AgentStreamState::ConnectingStream(fut) => match fut.as_mut().poll(cx) {
300 Poll::Pending => return Poll::Pending,
301 Poll::Ready((Err(e), agent)) => {
302 this.agent = Some(agent);
303 this.state = AgentStreamState::Done;
304 return Poll::Ready(Some(Err(e)));
305 }
306 Poll::Ready((Ok(stream), agent)) => {
307 this.state = AgentStreamState::StreamingChunks(Box::new(StreamingData {
308 stream,
309 agent,
310 content_buf: String::new(),
311 tool_call_bufs: Vec::new(),
312 }));
313 }
314 },
315
316 AgentStreamState::YieldingToolCalls { pending, raw } => {
317 if let Some(info) = pending.pop_front() {
318 return Poll::Ready(Some(Ok(AgentEvent::ToolCall(info))));
319 }
320 let agent = this.agent.take().expect("agent missing");
322 let raw_calls = std::mem::take(raw);
323 let fut = Box::pin(execute_tools(agent, raw_calls));
324 this.state = AgentStreamState::ExecutingTools(fut);
325 }
326
327 AgentStreamState::ExecutingTools(fut) => match fut.as_mut().poll(cx) {
328 Poll::Pending => return Poll::Pending,
329 Poll::Ready((tools_result, agent)) => {
330 this.agent = Some(agent);
331 this.state = AgentStreamState::YieldingToolResults {
332 pending: tools_result.results.into_iter().collect(),
333 };
334 }
335 },
336
337 AgentStreamState::YieldingToolResults { pending } => {
338 if let Some(result) = pending.pop_front() {
339 return Poll::Ready(Some(Ok(AgentEvent::ToolResult(result))));
340 }
341 this.state = AgentStreamState::Idle;
343 }
344
345 AgentStreamState::StreamingChunks(_) => unreachable!(),
346 }
347 }
348 }
349}
350
351fn raw_to_tool_call_info(tc: &ToolCall) -> ToolCallInfo {
354 ToolCallInfo {
355 id: tc.id.clone(),
356 name: tc.function.name.clone(),
357 args: serde_json::from_str(&tc.function.arguments).unwrap_or(Value::Null),
358 }
359}
360
361async fn run_summarize(mut agent: DeepseekAgent) -> DeepseekAgent {
363 agent.conversation.maybe_summarize().await;
364 agent
365}
366
367fn build_request(agent: &DeepseekAgent) -> crate::api::ApiRequest {
368 let history = agent.conversation.history().to_vec();
369 let mut req = crate::api::ApiRequest::builder().messages(history);
370 for tool in &agent.tools {
371 for raw in tool.raw_tools() {
372 req = req.add_tool(raw);
373 }
374 }
375 if !agent.tools.is_empty() {
376 req = req.tool_choice_auto();
377 }
378 req
379}
380
381async fn fetch_response(
382 mut agent: DeepseekAgent,
383) -> (Result<FetchResult, ApiError>, DeepseekAgent) {
384 let req = build_request(&agent);
385
386 let resp = match agent.conversation.client.send(req).await {
387 Ok(r) => r,
388 Err(e) => return (Err(e), agent),
389 };
390
391 let choice = match resp.choices.into_iter().next() {
392 Some(c) => c,
393 None => {
394 return (
395 Err(ApiError::Other("empty response: no choices".into())),
396 agent,
397 );
398 }
399 };
400
401 let assistant_msg = choice.message;
402 let content = assistant_msg.content.clone();
403 let raw_tool_calls = assistant_msg.tool_calls.clone().unwrap_or_default();
404 agent.conversation.history_mut().push(assistant_msg);
405
406 (
407 Ok(FetchResult {
408 content,
409 raw_tool_calls,
410 }),
411 agent,
412 )
413}
414
415async fn connect_stream(
416 agent: DeepseekAgent,
417) -> (
418 Result<BoxStream<'static, Result<ChatCompletionChunk, ApiError>>, ApiError>,
419 DeepseekAgent,
420) {
421 let req = build_request(&agent);
422 match agent.conversation.client.clone().into_stream(req).await {
423 Ok(stream) => (Ok(stream), agent),
424 Err(e) => (Err(e), agent),
425 }
426}
427
428async fn execute_tools(
429 mut agent: DeepseekAgent,
430 raw_tool_calls: Vec<ToolCall>,
431) -> (ToolsResult, DeepseekAgent) {
432 let mut results = vec![];
433
434 for tc in raw_tool_calls {
435 let args: Value = serde_json::from_str(&tc.function.arguments).unwrap_or(Value::Null);
436
437 let result = match agent.tool_index.get(&tc.function.name) {
438 Some(&idx) => agent.tools[idx].call(&tc.function.name, args.clone()).await,
439 None => {
440 serde_json::json!({ "error": format!("unknown tool: {}", tc.function.name) })
441 }
442 };
443
444 agent.conversation.history_mut().push(Message {
445 role: Role::Tool,
446 content: Some(result.to_string()),
447 tool_call_id: Some(tc.id.clone()),
448 ..Default::default()
449 });
450
451 results.push(ToolCallResult {
452 id: tc.id,
453 name: tc.function.name,
454 args,
455 result,
456 });
457 }
458
459 (ToolsResult { results }, agent)
460}