1use anyhow::Result;
2use colored::Colorize;
3use futures::StreamExt;
4use tokio::sync::mpsc;
5
6use crate::{
7 agent::{
8 agent::DeepSeekAgent,
9 executor::{execute_tool_cached, execute_tools_parallel},
10 types::{AgentEvent, ApprovalResult, UndoAction},
11 },
12 api::{
13 streaming::StreamParser,
14 types::{Message, ToolCall},
15 },
16 tools::schemas::get_filtered_tools_schemas,
17};
18
19impl DeepSeekAgent {
20 pub async fn chat_stream(
21 &mut self,
22 user_input: String,
23 tx: mpsc::Sender<AgentEvent>,
24 approval_rx: &mut mpsc::Receiver<ApprovalResult>,
25 ) -> Result<()> {
26 self.manage_context();
27 self.reset_cancel();
28 self.tool_cache.clear();
30 let res = self
31 .chat_stream_inner(user_input, tx.clone(), approval_rx)
32 .await;
33
34 if self.is_cancelled() {
37 self.cleanup_aborted_messages();
38 }
39 self.save();
40
41 if self.is_cancelled() {
43 let _ = tx
44 .send(AgentEvent::Aborted {
45 token_usage: self.token_usage.clone(),
46 })
47 .await;
48 }
49
50 res
51 }
52
53 async fn chat_stream_inner(
54 &mut self,
55 user_input: String,
56 tx: mpsc::Sender<AgentEvent>,
57 approval_rx: &mut mpsc::Receiver<ApprovalResult>,
58 ) -> Result<()> {
59 tracing::info!("chat_stream_inner started, input len: {}", user_input.len());
60 if !user_input.is_empty() {
61 self.messages.push(Message {
62 role: "user".to_string(),
63 content: Some(user_input),
64 reasoning_content: None,
65 tool_calls: None,
66 tool_call_id: None,
67 });
68 }
69
70 let mut iteration = 0;
71 while iteration < self.config.max_iterations {
72 if self.is_cancelled() {
73 break;
74 }
75
76 iteration += 1;
77 tracing::info!(
78 "Starting iteration {} of {}",
79 iteration,
80 self.config.max_iterations
81 );
82 let options = crate::api::types::ChatOptions {
83 temperature: self.config.temperature,
84 top_p: self.config.top_p,
85 presence_penalty: self.config.presence_penalty,
86 frequency_penalty: self.config.frequency_penalty,
87 max_tokens: Some(self.config.max_tokens),
88 thinking_enabled: self.config.thinking_enabled,
89 reasoning_effort: self.config.reasoning_effort.clone(),
90 json_mode: self.config.json_mode,
91 };
92
93 let cancel_token = self
94 .cancel_token
95 .lock()
96 .unwrap_or_else(|e| e.into_inner())
97 .clone();
98
99 let response_res = tokio::select! {
100 res = self.client.chat_completions(
101 &self.model,
102 self.messages.clone(),
103 Some(get_filtered_tools_schemas(self.is_git_repo, self.has_github_token)),
104 options,
105 ) => res,
106 _ = cancel_token.cancelled() => {
107 break;
108 }
109 };
110
111 let response = match response_res {
112 Ok(res) => res,
113 Err(e) => {
114 tracing::error!("API Request Failed: {}", e);
115 let _ = tx
116 .send(AgentEvent::Error {
117 content: format!("API Error: {}", e),
118 })
119 .await;
120 break;
121 }
122 };
123
124 let mut full_content = String::new();
125 let mut full_reasoning = String::new();
126 let mut tool_calls: Vec<ToolCall> = Vec::new();
127
128 let mut stream = response.bytes_stream();
129 let mut parser = StreamParser::new();
130 let mut stream_error = None;
131
132 loop {
133 let item_res = tokio::select! {
134 item = stream.next() => item,
135 _ = cancel_token.cancelled() => {
136 break;
137 }
138 };
139
140 let item = match item_res {
141 Some(item) => item,
142 None => break,
143 };
144
145 if self.is_cancelled() {
146 break;
147 }
148
149 match item {
150 Ok(bytes) => {
151 let chunks = parser.parse_chunk(&bytes);
152
153 for chunk in chunks {
154 if let Some(usage) = chunk.usage {
155 self.token_usage.prompt_tokens += usage.prompt_tokens;
156 self.token_usage.completion_tokens += usage.completion_tokens;
157 }
158
159 for choice in chunk.choices {
160 if let Some(reasoning) =
161 choice.delta.reasoning_content.filter(|r| !r.is_empty())
162 {
163 full_reasoning.push_str(&reasoning);
164 if tx
165 .send(AgentEvent::Reasoning { content: reasoning })
166 .await
167 .is_err()
168 {
169 break;
170 }
171 }
172 if let Some(content) =
173 choice.delta.content.filter(|c| !c.is_empty())
174 {
175 full_content.push_str(&content);
176 if tx.send(AgentEvent::Content { content }).await.is_err() {
177 break;
178 }
179 }
180 if let Some(deltas) = choice.delta.tool_calls {
181 for delta in deltas {
182 while tool_calls.len() <= delta.index {
183 tool_calls.push(ToolCall {
184 id: String::new(),
185 r#type: "function".to_string(),
186 function: crate::api::types::FunctionCall {
187 name: String::new(),
188 arguments: String::new(),
189 },
190 });
191 }
192 let tc = &mut tool_calls[delta.index];
193 if let Some(id) = delta.id {
194 tc.id.push_str(&id);
195 }
196 if let Some(f) = delta.function {
197 if let Some(n) = f.name {
198 tc.function.name.push_str(&n);
199 }
200 if let Some(a) = f.arguments {
201 tc.function.arguments.push_str(&a);
202 }
203 }
204 }
205 }
206 }
207 }
208 }
209 Err(e) => {
210 stream_error = Some(format!("Stream Error: {}", e));
211 break;
212 }
213 }
214 }
215
216 if self.is_cancelled() {
217 break;
218 }
219
220 if let Some(err) = stream_error {
221 tracing::error!("Response Stream Error: {}", err);
222 let _ = tx.send(AgentEvent::Error { content: err }).await;
223 break;
224 }
225
226 let assistant_msg = Message {
227 role: "assistant".to_string(),
228 content: if full_content.is_empty() {
229 None
230 } else {
231 Some(full_content.clone())
232 },
233 reasoning_content: if full_reasoning.is_empty() {
234 None
235 } else {
236 Some(full_reasoning.clone())
237 },
238 tool_calls: if tool_calls.is_empty() {
239 None
240 } else {
241 Some(tool_calls.clone())
242 },
243 tool_call_id: None,
244 };
245 self.messages.push(assistant_msg);
246
247 if tool_calls.is_empty() {
248 break;
249 }
250
251 let mut approved_calls: Vec<(usize, &ToolCall)> = Vec::new();
252 let mut denied_results: Vec<(usize, String, String)> = Vec::new();
253
254 for (i, tc) in tool_calls.iter().enumerate() {
255 if self.is_cancelled() {
256 break;
257 }
258 let name = tc.function.name.as_str();
259 let args: serde_json::Map<String, serde_json::Value> =
260 serde_json::from_str(&tc.function.arguments).unwrap_or_default();
261
262 let is_traversal = crate::agent::security::is_path_traversal_arg(&args);
263 let needs_approval = ((crate::agent::security::get_approval_required_tools()
264 .contains(name)
265 || crate::agent::security::is_dangerous_tool(name, &args))
266 && !self.config.debug)
267 || is_traversal;
268
269 let (approved, always) = if needs_approval && (!self.auto_approve || is_traversal) {
270 let approval_name = if is_traversal {
271 format!("path_traversal_warning:{}", tc.function.name)
272 } else {
273 tc.function.name.clone()
274 };
275 if tx
276 .send(AgentEvent::ApprovalRequest {
277 name: approval_name,
278 args: tc.function.arguments.clone(),
279 })
280 .await
281 .is_err()
282 {
283 break;
284 }
285
286 let cancel_token = self
287 .cancel_token
288 .lock()
289 .unwrap_or_else(|e| e.into_inner())
290 .clone();
291
292 tokio::select! {
293 res = approval_rx.recv() => {
294 match res {
295 Some(ApprovalResult::Yes) => (true, false),
296 Some(ApprovalResult::Always) => {
297 if is_traversal {
298 (true, false)
299 } else {
300 (true, true)
301 }
302 }
303 _ => (false, false),
304 }
305 }
306 _ = cancel_token.cancelled() => {
307 (false, false)
308 }
309 }
310 } else {
311 (true, false)
312 };
313
314 if always {
315 self.auto_approve = true;
316 }
317
318 if approved {
319 approved_calls.push((i, tc));
320 } else {
321 denied_results.push((
322 i,
323 tc.id.clone(),
324 "Tool execution denied by user.".to_string(),
325 ));
326 }
327 }
328
329 if self.is_cancelled() {
330 break;
331 }
332
333 if !approved_calls.is_empty() {
334 for (_, tc) in &approved_calls {
335 let _ = tx
336 .send(AgentEvent::ToolStart {
337 name: tc.function.name.clone(),
338 args: tc.function.arguments.clone(),
339 })
340 .await;
341 }
342
343 let tool_inputs: Vec<(String, serde_json::Map<String, serde_json::Value>)> =
344 approved_calls
345 .iter()
346 .map(|(_, tc)| {
347 (
348 tc.function.name.clone(),
349 serde_json::from_str(&tc.function.arguments).unwrap_or_default(),
350 )
351 })
352 .collect();
353
354 let results: Vec<(usize, Result<String>, Vec<UndoAction>)> = if tool_inputs.len()
355 == 1
356 {
357 let (name, args) = tool_inputs
358 .into_iter()
359 .next()
360 .expect("single tool input must exist");
361 let mut temp_undo = Vec::new();
362
363 if name == "execute_shell_command" {
364 if let Some(cmd) = args.get("command").and_then(|v| v.as_str()) {
365 if let Some(stripped) = cmd.strip_prefix("cd ") {
366 let new_dir = stripped.trim().trim_matches('"').trim_matches('\'');
367 let target_path = self.cwd.join(new_dir);
368 if let Ok(validated) = crate::tools::base::validate_path(
369 target_path.to_str().unwrap_or("."),
370 ) {
371 if validated.exists() && validated.is_dir() {
372 self.cwd = validated.clone();
373 let _ = std::env::set_current_dir(&self.cwd);
374 }
375 }
376 }
377 }
378 }
379
380 let has_traversal = crate::agent::security::is_path_traversal_arg(&args);
381 let _guard = crate::tools::base::PathTraversalGuard::new(has_traversal);
382 let (result, _cached) = execute_tool_cached(
383 &name,
384 &args,
385 &mut temp_undo,
386 &mut self.tool_cache,
387 Some(&self.cwd),
388 )
389 .await;
390 vec![(0, result, temp_undo)]
391 } else {
392 let has_traversal = tool_inputs
393 .iter()
394 .any(|(_, args)| crate::agent::security::is_path_traversal_arg(args));
395 let _guard = crate::tools::base::PathTraversalGuard::new(has_traversal);
396 let res = execute_tools_parallel(&tool_inputs, Some(self.cwd.clone())).await;
397 res
398 };
399
400 for (tool_idx, result, undo_actions) in results {
401 self.undo_stack.extend(undo_actions);
402
403 let (_orig_idx, tc) = &approved_calls[tool_idx];
404 let result_str = match result {
405 Ok(res) => res,
406 Err(e) => format!("Error: {}", e),
407 };
408
409 let display_result = Some(if result_str.len() > 500 {
410 let trunc: String = result_str.chars().take(500).collect();
411 format!(
412 "{}\n... (truncated, {} total chars)",
413 trunc,
414 result_str.len()
415 )
416 } else {
417 result_str.clone()
418 });
419
420 let _ = tx
421 .send(AgentEvent::ToolEnd {
422 name: tc.function.name.clone(),
423 result: display_result,
424 })
425 .await;
426
427 if self.is_cancelled() {
428 break;
429 }
430
431 let mut stored_content = result_str;
432 if stored_content.len() > self.config.max_tool_output_chars {
433 let trunc: String = stored_content
434 .chars()
435 .take(self.config.max_tool_output_chars)
436 .collect();
437 stored_content = format!(
438 "{}\n\n... [Output Truncated to {} chars (total {} chars) to save \
439 tokens. Use specific tools or grep/read_local_file with line ranges \
440 if you need to read more.] ...",
441 trunc,
442 self.config.max_tool_output_chars,
443 stored_content.len()
444 );
445 }
446
447 self.messages.push(Message {
448 role: "tool".to_string(),
449 content: Some(stored_content),
450 reasoning_content: None,
451 tool_calls: None,
452 tool_call_id: Some(tc.id.clone()),
453 });
454 }
455 }
456
457 for (_, tool_id, msg) in denied_results {
458 let _ = tx
459 .send(AgentEvent::ToolEnd {
460 name: "denied".to_string(),
461 result: Some(msg.clone()),
462 })
463 .await;
464 self.messages.push(Message {
465 role: "tool".to_string(),
466 content: Some(msg),
467 reasoning_content: None,
468 tool_calls: None,
469 tool_call_id: Some(tool_id),
470 });
471 }
472
473 if self.is_cancelled() {
474 break;
475 }
476 }
477
478 if self.config.show_token_usage {
479 let total = self.token_usage.prompt_tokens + self.token_usage.completion_tokens;
480 let usage_msg = format!(
481 "\n{} [{} {} | {} {} | {} {}]\n",
482 "📊 Token Usage:".bold().blue(),
483 "Prompt:".cyan(),
484 self.token_usage.prompt_tokens.to_string().cyan(),
485 "Completion:".green(),
486 self.token_usage.completion_tokens.to_string().green(),
487 "Total:".yellow(),
488 total.to_string().yellow()
489 );
490 let _ = tx.send(AgentEvent::Content { content: usage_msg }).await;
491 }
492
493 Ok(())
494 }
495}