1use std::pin::Pin;
2use std::task::{Context, Poll};
3
4use futures::Stream;
5use serde_json::Value;
6
7use crate::agent::agent_core::{AgentResponse, DeepseekAgent, ToolCallEvent};
8use crate::conversation::Conversation;
9use crate::raw::request::message::{Message, Role, ToolCall};
10
11struct FetchResult {
13 content: Option<String>,
14 raw_tool_calls: Vec<ToolCall>,
15}
16
17struct ToolsResult {
19 events: Vec<ToolCallEvent>,
20}
21
22pub struct AgentStream {
24 agent: Option<DeepseekAgent>,
25 state: AgentStreamState,
26}
27
28enum AgentStreamState {
29 Idle,
30 FetchingResponse(
32 Pin<Box<dyn std::future::Future<Output = (Option<FetchResult>, DeepseekAgent)> + Send>>,
33 ),
34 ExecutingTools(Pin<Box<dyn std::future::Future<Output = (ToolsResult, DeepseekAgent)> + Send>>),
36 Done,
37}
38
39impl AgentStream {
40 pub fn new(agent: DeepseekAgent) -> Self {
41 Self {
42 agent: Some(agent),
43 state: AgentStreamState::Idle,
44 }
45 }
46}
47
48impl Stream for AgentStream {
49 type Item = AgentResponse;
50
51 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
52 let this = self.get_mut();
53
54 loop {
55 match &mut this.state {
56 AgentStreamState::Done => return Poll::Ready(None),
57
58 AgentStreamState::Idle => {
59 let agent = this.agent.take().expect("agent missing");
60 let fut = Box::pin(fetch_response(agent));
61 this.state = AgentStreamState::FetchingResponse(fut);
62 }
63
64 AgentStreamState::FetchingResponse(fut) => {
65 match fut.as_mut().poll(cx) {
66 Poll::Pending => return Poll::Pending,
67 Poll::Ready((None, agent)) => {
68 this.agent = Some(agent);
69 this.state = AgentStreamState::Done;
70 return Poll::Ready(None);
71 }
72 Poll::Ready((Some(fetch), agent)) => {
73 if fetch.raw_tool_calls.is_empty() {
74 this.agent = Some(agent);
76 this.state = AgentStreamState::Done;
77 return Poll::Ready(Some(AgentResponse {
78 content: fetch.content,
79 tool_calls: vec![],
80 }));
81 } else {
82 let content = fetch.content.clone();
86
87 let raw_calls_owned = fetch.raw_tool_calls;
89
90 let preview_events: Vec<ToolCallEvent> = raw_calls_owned
92 .iter()
93 .map(|tc| ToolCallEvent {
94 id: tc.id.clone(),
95 name: tc.function.name.clone(),
96 args: serde_json::from_str(&tc.function.arguments)
97 .unwrap_or(serde_json::Value::Null),
98 result: serde_json::Value::Null,
99 })
100 .collect();
101
102 let exec_calls = raw_calls_owned.clone();
104
105 let fut = Box::pin(execute_tools(agent, exec_calls));
106 this.state = AgentStreamState::ExecutingTools(fut);
107 return Poll::Ready(Some(AgentResponse {
108 content,
109 tool_calls: preview_events,
110 }));
111 }
112 }
113 }
114 }
115
116 AgentStreamState::ExecutingTools(fut) => {
117 match fut.as_mut().poll(cx) {
118 Poll::Pending => return Poll::Pending,
119 Poll::Ready((results, agent)) => {
120 this.agent = Some(agent);
121 this.state = AgentStreamState::Idle;
123 return Poll::Ready(Some(AgentResponse {
124 content: None,
125 tool_calls: results.events,
126 }));
127 }
128 }
129 }
130 }
131 }
132 }
133}
134
135async fn fetch_response(mut agent: DeepseekAgent) -> (Option<FetchResult>, DeepseekAgent) {
137 let history = agent.conversation.history().clone();
139 let mut req = crate::api::ApiRequest::builder().messages(history);
140
141 for tool in &agent.tools {
143 for raw in tool.raw_tools() {
144 req = req.add_tool(raw);
145 }
146 }
147
148 if !agent.tools.is_empty() {
149 req = req.tool_choice_auto();
150 }
151
152 let resp = match agent.client.send(req).await {
154 Ok(r) => r,
155 Err(_) => return (None, agent),
156 };
157
158 let choice = match resp.choices.into_iter().next() {
159 Some(c) => c,
160 None => return (None, agent),
161 };
162
163 let assistant_msg = choice.message;
164 let content = assistant_msg.content.clone();
165 let raw_tool_calls = assistant_msg.tool_calls.clone().unwrap_or_default();
166
167 agent.conversation.history_mut().push(assistant_msg);
169
170 (
171 Some(FetchResult {
172 content,
173 raw_tool_calls,
174 }),
175 agent,
176 )
177}
178
179async fn execute_tools(
181 mut agent: DeepseekAgent,
182 raw_tool_calls: Vec<ToolCall>,
183) -> (ToolsResult, DeepseekAgent) {
184 let mut events = vec![];
185
186 for tc in raw_tool_calls {
187 let args: Value = serde_json::from_str(&tc.function.arguments).unwrap_or(Value::Null);
188
189 let result = match agent.tool_index.get(&tc.function.name) {
190 Some(&idx) => agent.tools[idx].call(&tc.function.name, args.clone()).await,
191 None => serde_json::json!({ "error": format!("unknown tool: {}", tc.function.name) }),
192 };
193
194 agent.conversation.history_mut().push(Message {
196 role: Role::Tool,
197 content: Some(result.to_string()),
198 tool_call_id: Some(tc.id.clone()),
199 ..Default::default()
200 });
201
202 events.push(ToolCallEvent {
203 id: tc.id,
204 name: tc.function.name,
205 args,
206 result,
207 });
208 }
209
210 (ToolsResult { events }, agent)
211}