Skip to main content

deepseek_rust_cli/agent/
processor.rs

1use anyhow::Result;
2use colored::Colorize;
3use futures::StreamExt;
4use tokio::sync::mpsc;
5
6use crate::{
7    agent::{
8        agent::DeepSeekAgent,
9        executor::{execute_tool_cached, execute_tools_parallel},
10        types::{AgentEvent, ApprovalResult, UndoAction},
11    },
12    api::{
13        streaming::StreamParser,
14        types::{Message, ToolCall},
15    },
16    tools::schemas::get_filtered_tools_schemas,
17};
18
19impl DeepSeekAgent {
20    pub async fn chat_stream(
21        &mut self,
22        user_input: String,
23        tx: mpsc::Sender<AgentEvent>,
24        approval_rx: &mut mpsc::Receiver<ApprovalResult>,
25    ) -> Result<()> {
26        self.manage_context();
27        self.reset_cancel();
28        // Clear tool cache each request
29        self.tool_cache.clear();
30        let res = self
31            .chat_stream_inner(user_input, tx.clone(), approval_rx)
32            .await;
33
34        // If cancelled, clean up orphaned tool messages BEFORE saving,
35        // otherwise malformed history breaks subsequent API calls.
36        if self.is_cancelled() {
37            self.cleanup_aborted_messages();
38        }
39        self.save();
40
41        // If cancelled, send aborted event
42        if self.is_cancelled() {
43            let _ = tx
44                .send(AgentEvent::Aborted {
45                    token_usage: self.token_usage.clone(),
46                })
47                .await;
48        }
49
50        res
51    }
52
53    async fn chat_stream_inner(
54        &mut self,
55        user_input: String,
56        tx: mpsc::Sender<AgentEvent>,
57        approval_rx: &mut mpsc::Receiver<ApprovalResult>,
58    ) -> Result<()> {
59        tracing::info!("chat_stream_inner started, input len: {}", user_input.len());
60        if !user_input.is_empty() {
61            self.messages.push(Message {
62                role: "user".to_string(),
63                content: Some(user_input),
64                reasoning_content: None,
65                tool_calls: None,
66                tool_call_id: None,
67            });
68        }
69
70        let mut iteration = 0;
71        while iteration < self.config.max_iterations {
72            if self.is_cancelled() {
73                break;
74            }
75
76            iteration += 1;
77            tracing::info!(
78                "Starting iteration {} of {}",
79                iteration,
80                self.config.max_iterations
81            );
82            let options = crate::api::types::ChatOptions {
83                temperature: self.config.temperature,
84                top_p: self.config.top_p,
85                presence_penalty: self.config.presence_penalty,
86                frequency_penalty: self.config.frequency_penalty,
87                max_tokens: Some(self.config.max_tokens),
88                thinking_enabled: self.config.thinking_enabled,
89                reasoning_effort: self.config.reasoning_effort.clone(),
90                json_mode: self.config.json_mode,
91            };
92
93            let cancel_token = self
94                .cancel_token
95                .lock()
96                .unwrap_or_else(|e| e.into_inner())
97                .clone();
98
99            let response_res = tokio::select! {
100                res = self.client.chat_completions(
101                    &self.model,
102                    self.messages.clone(),
103                    Some(get_filtered_tools_schemas(self.is_git_repo, self.has_github_token)),
104                    options,
105                ) => res,
106                _ = cancel_token.cancelled() => {
107                    break;
108                }
109            };
110
111            let response = match response_res {
112                Ok(res) => res,
113                Err(e) => {
114                    tracing::error!("API Request Failed: {}", e);
115                    let _ = tx
116                        .send(AgentEvent::Error {
117                            content: format!("API Error: {}", e),
118                        })
119                        .await;
120                    break;
121                }
122            };
123
124            let mut full_content = String::new();
125            let mut full_reasoning = String::new();
126            let mut tool_calls: Vec<ToolCall> = Vec::new();
127
128            let mut stream = response.bytes_stream();
129            let mut parser = StreamParser::new();
130            let mut stream_error = None;
131
132            loop {
133                let item_res = tokio::select! {
134                    item = stream.next() => item,
135                    _ = cancel_token.cancelled() => {
136                        break;
137                    }
138                };
139
140                let item = match item_res {
141                    Some(item) => item,
142                    None => break,
143                };
144
145                if self.is_cancelled() {
146                    break;
147                }
148
149                match item {
150                    Ok(bytes) => {
151                        let chunks = parser.parse_chunk(&bytes);
152
153                        for chunk in chunks {
154                            if let Some(usage) = chunk.usage {
155                                self.token_usage.prompt_tokens += usage.prompt_tokens;
156                                self.token_usage.completion_tokens += usage.completion_tokens;
157                            }
158
159                            for choice in chunk.choices {
160                                if let Some(reasoning) =
161                                    choice.delta.reasoning_content.filter(|r| !r.is_empty())
162                                {
163                                    full_reasoning.push_str(&reasoning);
164                                    if tx
165                                        .send(AgentEvent::Reasoning { content: reasoning })
166                                        .await
167                                        .is_err()
168                                    {
169                                        break;
170                                    }
171                                }
172                                if let Some(content) =
173                                    choice.delta.content.filter(|c| !c.is_empty())
174                                {
175                                    full_content.push_str(&content);
176                                    if tx.send(AgentEvent::Content { content }).await.is_err() {
177                                        break;
178                                    }
179                                }
180                                if let Some(deltas) = choice.delta.tool_calls {
181                                    for delta in deltas {
182                                        while tool_calls.len() <= delta.index {
183                                            tool_calls.push(ToolCall {
184                                                id: String::new(),
185                                                r#type: "function".to_string(),
186                                                function: crate::api::types::FunctionCall {
187                                                    name: String::new(),
188                                                    arguments: String::new(),
189                                                },
190                                            });
191                                        }
192                                        let tc = &mut tool_calls[delta.index];
193                                        if let Some(id) = delta.id {
194                                            tc.id.push_str(&id);
195                                        }
196                                        if let Some(f) = delta.function {
197                                            if let Some(n) = f.name {
198                                                tc.function.name.push_str(&n);
199                                            }
200                                            if let Some(a) = f.arguments {
201                                                tc.function.arguments.push_str(&a);
202                                            }
203                                        }
204                                    }
205                                }
206                            }
207                        }
208                    }
209                    Err(e) => {
210                        stream_error = Some(format!("Stream Error: {}", e));
211                        break;
212                    }
213                }
214            }
215
216            if self.is_cancelled() {
217                break;
218            }
219
220            if let Some(err) = stream_error {
221                tracing::error!("Response Stream Error: {}", err);
222                let _ = tx.send(AgentEvent::Error { content: err }).await;
223                break;
224            }
225
226            let assistant_msg = Message {
227                role: "assistant".to_string(),
228                content: if full_content.is_empty() {
229                    None
230                } else {
231                    Some(full_content.clone())
232                },
233                reasoning_content: if full_reasoning.is_empty() {
234                    None
235                } else {
236                    Some(full_reasoning.clone())
237                },
238                tool_calls: if tool_calls.is_empty() {
239                    None
240                } else {
241                    Some(tool_calls.clone())
242                },
243                tool_call_id: None,
244            };
245            self.messages.push(assistant_msg);
246
247            if tool_calls.is_empty() {
248                break;
249            }
250
251            let mut approved_calls: Vec<(usize, &ToolCall)> = Vec::new();
252            let mut denied_results: Vec<(usize, String, String)> = Vec::new();
253
254            for (i, tc) in tool_calls.iter().enumerate() {
255                if self.is_cancelled() {
256                    break;
257                }
258                let name = tc.function.name.as_str();
259                let args: serde_json::Map<String, serde_json::Value> =
260                    serde_json::from_str(&tc.function.arguments).unwrap_or_default();
261
262                let is_traversal = crate::agent::security::is_path_traversal_arg(&args);
263                let needs_approval = ((crate::agent::security::get_approval_required_tools()
264                    .contains(name)
265                    || crate::agent::security::is_dangerous_tool(name, &args))
266                    && !self.config.debug)
267                    || is_traversal;
268
269                let (approved, always) = if needs_approval && (!self.auto_approve || is_traversal) {
270                    let approval_name = if is_traversal {
271                        format!("path_traversal_warning:{}", tc.function.name)
272                    } else {
273                        tc.function.name.clone()
274                    };
275                    if tx
276                        .send(AgentEvent::ApprovalRequest {
277                            name: approval_name,
278                            args: tc.function.arguments.clone(),
279                        })
280                        .await
281                        .is_err()
282                    {
283                        break;
284                    }
285
286                    let cancel_token = self
287                        .cancel_token
288                        .lock()
289                        .unwrap_or_else(|e| e.into_inner())
290                        .clone();
291
292                    tokio::select! {
293                        res = approval_rx.recv() => {
294                            match res {
295                                Some(ApprovalResult::Yes) => (true, false),
296                                Some(ApprovalResult::Always) => {
297                                    if is_traversal {
298                                        (true, false)
299                                    } else {
300                                        (true, true)
301                                    }
302                                }
303                                _ => (false, false),
304                            }
305                        }
306                        _ = cancel_token.cancelled() => {
307                            (false, false)
308                        }
309                    }
310                } else {
311                    (true, false)
312                };
313
314                if always {
315                    self.auto_approve = true;
316                }
317
318                if approved {
319                    approved_calls.push((i, tc));
320                } else {
321                    denied_results.push((
322                        i,
323                        tc.id.clone(),
324                        "Tool execution denied by user.".to_string(),
325                    ));
326                }
327            }
328
329            if self.is_cancelled() {
330                break;
331            }
332
333            if !approved_calls.is_empty() {
334                for (_, tc) in &approved_calls {
335                    let _ = tx
336                        .send(AgentEvent::ToolStart {
337                            name: tc.function.name.clone(),
338                            args: tc.function.arguments.clone(),
339                        })
340                        .await;
341                }
342
343                let tool_inputs: Vec<(String, serde_json::Map<String, serde_json::Value>)> =
344                    approved_calls
345                        .iter()
346                        .map(|(_, tc)| {
347                            (
348                                tc.function.name.clone(),
349                                serde_json::from_str(&tc.function.arguments).unwrap_or_default(),
350                            )
351                        })
352                        .collect();
353
354                let results: Vec<(usize, Result<String>, Vec<UndoAction>)> = if tool_inputs.len()
355                    == 1
356                {
357                    let (name, args) = tool_inputs
358                        .into_iter()
359                        .next()
360                        .expect("single tool input must exist");
361                    let mut temp_undo = Vec::new();
362
363                    if name == "execute_shell_command" {
364                        if let Some(cmd) = args.get("command").and_then(|v| v.as_str()) {
365                            if let Some(stripped) = cmd.strip_prefix("cd ") {
366                                let new_dir = stripped.trim().trim_matches('"').trim_matches('\'');
367                                let target_path = self.cwd.join(new_dir);
368                                if let Ok(validated) = crate::tools::base::validate_path(
369                                    target_path.to_str().unwrap_or("."),
370                                ) {
371                                    if validated.exists() && validated.is_dir() {
372                                        self.cwd = validated.clone();
373                                        let _ = std::env::set_current_dir(&self.cwd);
374                                    }
375                                }
376                            }
377                        }
378                    }
379
380                    let has_traversal = crate::agent::security::is_path_traversal_arg(&args);
381                    let _guard = crate::tools::base::PathTraversalGuard::new(has_traversal);
382                    let (result, _cached) = execute_tool_cached(
383                        &name,
384                        &args,
385                        &mut temp_undo,
386                        &mut self.tool_cache,
387                        Some(&self.cwd),
388                    )
389                    .await;
390                    vec![(0, result, temp_undo)]
391                } else {
392                    let has_traversal = tool_inputs
393                        .iter()
394                        .any(|(_, args)| crate::agent::security::is_path_traversal_arg(args));
395                    let _guard = crate::tools::base::PathTraversalGuard::new(has_traversal);
396                    let res = execute_tools_parallel(&tool_inputs, Some(self.cwd.clone())).await;
397                    res
398                };
399
400                for (tool_idx, result, undo_actions) in results {
401                    self.undo_stack.extend(undo_actions);
402
403                    let (_orig_idx, tc) = &approved_calls[tool_idx];
404                    let result_str = match result {
405                        Ok(res) => res,
406                        Err(e) => format!("Error: {}", e),
407                    };
408
409                    let display_result = Some(if result_str.len() > 500 {
410                        let trunc: String = result_str.chars().take(500).collect();
411                        format!(
412                            "{}\n... (truncated, {} total chars)",
413                            trunc,
414                            result_str.len()
415                        )
416                    } else {
417                        result_str.clone()
418                    });
419
420                    let _ = tx
421                        .send(AgentEvent::ToolEnd {
422                            name: tc.function.name.clone(),
423                            result: display_result,
424                        })
425                        .await;
426
427                    if self.is_cancelled() {
428                        break;
429                    }
430
431                    let mut stored_content = result_str;
432                    if stored_content.len() > self.config.max_tool_output_chars {
433                        let trunc: String = stored_content
434                            .chars()
435                            .take(self.config.max_tool_output_chars)
436                            .collect();
437                        stored_content = format!(
438                            "{}\n\n... [Output Truncated to {} chars (total {} chars) to save \
439                             tokens. Use specific tools or grep/read_local_file with line ranges \
440                             if you need to read more.] ...",
441                            trunc,
442                            self.config.max_tool_output_chars,
443                            stored_content.len()
444                        );
445                    }
446
447                    self.messages.push(Message {
448                        role: "tool".to_string(),
449                        content: Some(stored_content),
450                        reasoning_content: None,
451                        tool_calls: None,
452                        tool_call_id: Some(tc.id.clone()),
453                    });
454                }
455            }
456
457            for (_, tool_id, msg) in denied_results {
458                let _ = tx
459                    .send(AgentEvent::ToolEnd {
460                        name: "denied".to_string(),
461                        result: Some(msg.clone()),
462                    })
463                    .await;
464                self.messages.push(Message {
465                    role: "tool".to_string(),
466                    content: Some(msg),
467                    reasoning_content: None,
468                    tool_calls: None,
469                    tool_call_id: Some(tool_id),
470                });
471            }
472
473            if self.is_cancelled() {
474                break;
475            }
476        }
477
478        if self.config.show_token_usage {
479            let total = self.token_usage.prompt_tokens + self.token_usage.completion_tokens;
480            let usage_msg = format!(
481                "\n{} [{} {} | {} {} | {} {}]\n",
482                "📊 Token Usage:".bold().blue(),
483                "Prompt:".cyan(),
484                self.token_usage.prompt_tokens.to_string().cyan(),
485                "Completion:".green(),
486                self.token_usage.completion_tokens.to_string().green(),
487                "Total:".yellow(),
488                total.to_string().yellow()
489            );
490            let _ = tx.send(AgentEvent::Content { content: usage_msg }).await;
491        }
492
493        Ok(())
494    }
495}