1use crate::chat_error::ChatError;
2use crate::infra::hook::{HookContext, HookEvent, HookManager};
3use crate::message_types::{PlanDecision, StreamMsg, ToolResultMsg};
4use crate::storage::{
5 ChatMessage, DisplayHint, ImageData, MessageRole, SessionOp, SessionOpKind, ToolCallItem,
6 append_session_op,
7};
8use crate::tools::Tool;
9use crate::tools::compact_tool::CompactTool;
10use crate::util::log::write_info_log;
11use crate::util::safe_lock;
12use std::collections::HashSet;
13use std::env::current_dir;
14use std::mem::take;
15use std::sync::{Arc, Mutex, mpsc};
16use std::time::{SystemTime, UNIX_EPOCH};
17
18pub(super) struct ToolCallContext<'a> {
20 pub(super) stream_msg_sender: &'a mpsc::Sender<StreamMsg>,
21 pub(super) tool_result_receiver: &'a mpsc::Receiver<ToolResultMsg>,
22 pub(super) pending_user_messages: &'a Arc<Mutex<Vec<ChatMessage>>>,
23 pub(super) hook_manager: &'a HookManager,
24 pub(super) disabled_hooks: &'a [String],
25 pub(super) supports_vision: bool,
26 pub(super) display_messages: &'a Arc<Mutex<Vec<ChatMessage>>>,
28 pub(super) context_messages: &'a Arc<Mutex<Vec<ChatMessage>>>,
30 pub(super) streaming_content: &'a Arc<Mutex<String>>,
31 pub(super) session_id: &'a str,
32}
33
34pub(super) struct ToolCallResult {
36 pub(super) compact_requested: bool,
37 pub(super) plan_with_context_clear: Option<String>,
39}
40
41pub(super) fn drain_pending_user_messages(
43 messages: &mut Vec<ChatMessage>,
44 pending_user_messages: &Arc<Mutex<Vec<ChatMessage>>>,
45) {
46 let mut pending = safe_lock(pending_user_messages, "agent::drain_pending");
47 if !pending.is_empty() {
48 for msg in pending.drain(..) {
52 if msg.role == MessageRole::User
53 && msg.content.trim_start().starts_with("<system_reminder>")
54 {
55 continue;
57 }
58 let mut msg = msg;
59 if msg.role == MessageRole::User {
60 msg.content = format!("[User appended] {}", msg.content);
61 }
62 messages.push(msg);
63 }
64 }
65}
66
67pub(super) fn push_both(
79 display: &Arc<Mutex<Vec<ChatMessage>>>,
80 context: &Arc<Mutex<Vec<ChatMessage>>>,
81 msg: ChatMessage,
82) {
83 if let Ok(mut msgs) = display.lock() {
84 msgs.push(msg.clone());
85 }
86 if let Ok(mut msgs) = context.lock() {
87 msgs.push(msg);
88 }
89}
90
91pub(super) fn clear_channels(
93 display: &Arc<Mutex<Vec<ChatMessage>>>,
94 context: &Arc<Mutex<Vec<ChatMessage>>>,
95) {
96 if let Ok(mut msgs) = display.lock() {
97 msgs.clear();
98 }
99 if let Ok(mut msgs) = context.lock() {
100 msgs.clear();
101 }
102}
103
104pub(super) fn sync_context_full(
106 display: &Arc<Mutex<Vec<ChatMessage>>>,
107 context: &Arc<Mutex<Vec<ChatMessage>>>,
108 new_messages: &[ChatMessage],
109) {
110 if let Ok(mut msgs) = context.lock() {
111 msgs.clear();
112 msgs.extend_from_slice(new_messages);
113 }
114 if let Ok(mut msgs) = display.lock() {
115 msgs.clear();
116 msgs.extend_from_slice(new_messages);
117 }
118}
119
120pub(super) fn flush_streaming_as_message(
123 streaming_content: &Arc<Mutex<String>>,
124 streaming_reasoning_content: &Arc<Mutex<String>>,
125 messages: &mut Vec<ChatMessage>,
126 display: &Arc<Mutex<Vec<ChatMessage>>>,
127 context: &Arc<Mutex<Vec<ChatMessage>>>,
128 reasoning_content: Option<String>,
129) {
130 let mut stream_buf = safe_lock(streaming_content, "agent::flush_streaming");
131 if !stream_buf.is_empty() {
132 let mut text_msg = ChatMessage::text(MessageRole::Assistant, take(&mut *stream_buf));
133 text_msg.reasoning_content = reasoning_content;
134 messages.push(text_msg.clone());
135 push_both(display, context, text_msg);
136 }
137 safe_lock(
139 streaming_reasoning_content,
140 "agent::flush_streaming_reasoning",
141 )
142 .clear();
143}
144
145fn log_tool_request(tool_items: &[ToolCallItem]) {
147 let mut log_content = String::new();
148 for item in tool_items {
149 log_content.push_str(&format!("- {}: {}\n", item.name, item.arguments));
150 }
151 write_info_log("工具调用请求", &log_content);
152}
153
154fn log_tool_results(tool_items: &[ToolCallItem], tool_results: &[ToolResultMsg]) {
156 let mut log_content = String::new();
157 for (i, result) in tool_results.iter().enumerate() {
158 let (tool_name, tool_args) = tool_items
159 .get(i)
160 .map(|t| (t.name.as_str(), t.arguments.as_str()))
161 .unwrap_or(("unknown", ""));
162 log_content.push_str(&format!(
163 "- [{}] {}({}): {}\n",
164 result.tool_call_id, tool_name, tool_args, result.result
165 ));
166 }
167 write_info_log("工具调用结果", &log_content);
168}
169
170pub(super) fn process_tool_calls(
175 tool_items: Vec<ToolCallItem>,
176 assistant_text: String,
177 messages: &mut Vec<ChatMessage>,
178 ctx: &ToolCallContext<'_>,
179 reasoning_content: Option<String>,
180) -> Result<ToolCallResult, ChatError> {
181 log_tool_request(&tool_items);
182
183 if !assistant_text.is_empty() {
184 write_info_log("Sprite 回复", &assistant_text);
185 }
186
187 let compact_requested = tool_items.iter().any(|t| t.name == CompactTool {}.name());
189
190 let tool_call_msg = ChatMessage {
193 role: MessageRole::Assistant,
194 content: assistant_text,
195 tool_calls: Some(tool_items.clone()),
196 tool_call_id: None,
197 images: None,
198 reasoning_content,
199 sender_name: None,
200 recipient_name: None,
201 display_hint: DisplayHint::Normal,
202 };
203 messages.push(tool_call_msg.clone());
204 push_both(ctx.display_messages, ctx.context_messages, tool_call_msg);
205 if let Ok(mut stream_buf) = ctx.streaming_content.lock() {
207 stream_buf.clear();
208 }
209
210 if ctx
211 .stream_msg_sender
212 .send(StreamMsg::ToolCallRequest(tool_items.clone()))
213 .is_err()
214 {
215 return Err(ChatError::Other("工具调用通道已断开".to_string()));
216 }
217
218 let mut tool_results: Vec<ToolResultMsg> = Vec::with_capacity(tool_items.len());
219 let mut plan_clear_context: Option<String> = None;
220 let mut channel_broken = false;
221 for _ in &tool_items {
222 if channel_broken {
223 break;
224 }
225 match ctx.tool_result_receiver.recv() {
226 Ok(result) => {
227 if result.plan_decision == PlanDecision::ApproveAndClearContext {
229 plan_clear_context = Some(result.result.clone());
230 }
231 tool_results.push(result);
232 }
233 Err(_) => {
234 channel_broken = true;
235 }
236 }
237 }
238
239 let received_ids: HashSet<String> = tool_results
242 .iter()
243 .map(|r| r.tool_call_id.clone())
244 .collect();
245 for item in &tool_items {
246 if !received_ids.contains(&item.id) {
247 let reason = if channel_broken {
248 "[工具执行中断: 结果通道已断开]"
249 } else {
250 "[工具执行中断: 未收到结果]"
251 };
252 tool_results.push(ToolResultMsg {
253 tool_call_id: item.id.clone(),
254 result: reason.to_string(),
255 is_error: true,
256 images: Vec::new(),
257 plan_decision: PlanDecision::None,
258 });
259 }
260 }
261
262 log_tool_results(&tool_items, &tool_results);
263
264 append_write_ops(&tool_items, &tool_results, ctx.session_id);
266
267 let mut deferred_image_msgs: Vec<ChatMessage> = Vec::with_capacity(tool_results.len());
271
272 for result in tool_results {
273 let mut result_content = result.result;
274 let result_images = result.images;
275
276 let tool_name = tool_items
278 .iter()
279 .find(|t| t.id == result.tool_call_id)
280 .map(|t| t.name.clone());
281
282 if ctx.hook_manager.has_hooks_for(HookEvent::PostToolExecution) {
284 let hook_ctx = HookContext {
285 event: HookEvent::PostToolExecution,
286 tool_name: tool_name.clone(),
287 tool_result: Some(result_content.clone()),
288 session_id: Some(ctx.session_id.to_string()),
289 cwd: current_dir()
290 .map(|p| p.display().to_string())
291 .unwrap_or_else(|_| ".".to_string()),
292 ..Default::default()
293 };
294 if let Some(hook_result) =
295 ctx.hook_manager
296 .execute(HookEvent::PostToolExecution, hook_ctx, ctx.disabled_hooks)
297 && let Some(new_result) = hook_result.tool_result
298 {
299 result_content = new_result;
300 }
301 }
302
303 let tool_msg = ChatMessage {
304 role: MessageRole::Tool,
305 content: result_content,
306 tool_calls: None,
307 tool_call_id: Some(result.tool_call_id.clone()),
308 images: None,
309 reasoning_content: None,
310 sender_name: None,
311 recipient_name: None,
312 display_hint: DisplayHint::Normal,
313 };
314 messages.push(tool_msg.clone());
315 push_both(ctx.display_messages, ctx.context_messages, tool_msg);
316
317 if !result_images.is_empty() {
319 let tool_label = tool_name.as_deref().unwrap_or("unknown");
320 let img_count = result_images.len();
321 write_info_log(
322 "ImageInjection",
323 &format!(
324 "工具 {} 返回了 {} 张图片, supports_vision={}",
325 tool_label, img_count, ctx.supports_vision
326 ),
327 );
328 if ctx.supports_vision {
329 let img_msg = ChatMessage {
330 role: MessageRole::User,
331 content: format!(
332 "[{tool_label} 返回了 {img_count} 张图片,请查看图片内容并继续帮助完成任务]"
333 ),
334 tool_calls: None,
335 tool_call_id: None,
336 images: Some(
337 result_images
338 .into_iter()
339 .map(|img| ImageData {
340 base64: img.base64,
341 media_type: img.media_type,
342 })
343 .collect(),
344 ),
345 reasoning_content: None,
346 sender_name: None,
347 recipient_name: None,
348 display_hint: DisplayHint::Normal,
349 };
350 deferred_image_msgs.push(img_msg);
351 } else {
352 write_info_log(
353 "ImageInjection",
354 &format!(
355 "supports_vision=false,丢弃 {} 返回的 {} 张图片",
356 tool_label, img_count
357 ),
358 );
359 }
360 }
361 }
362
363 if !deferred_image_msgs.is_empty() {
365 write_info_log(
366 "ImageInjection",
367 &format!(
368 "在所有 tool results 之后注入 {} 条图片消息",
369 deferred_image_msgs.len()
370 ),
371 );
372 for img_msg in deferred_image_msgs {
373 messages.push(img_msg);
375 }
376 }
377
378 drain_pending_user_messages(messages, ctx.pending_user_messages);
379
380 Ok(ToolCallResult {
381 compact_requested,
382 plan_with_context_clear: plan_clear_context,
383 })
384}
385
386fn extract_path_from_args(args: &str) -> Option<String> {
388 serde_json::from_str::<serde_json::Value>(args)
389 .ok()
390 .and_then(|v| v.get("path")?.as_str().map(String::from))
391}
392
393fn extract_command_from_args(args: &str) -> Option<String> {
395 serde_json::from_str::<serde_json::Value>(args)
396 .ok()
397 .and_then(|v| v.get("command")?.as_str().map(String::from))
398}
399
400fn append_write_ops(tool_items: &[ToolCallItem], tool_results: &[ToolResultMsg], session_id: &str) {
402 let now_ms = SystemTime::now()
403 .duration_since(UNIX_EPOCH)
404 .unwrap_or_default()
405 .as_millis() as u64;
406
407 for item in tool_items {
408 let is_error = tool_results
409 .iter()
410 .any(|r| r.tool_call_id == item.id && r.is_error);
411
412 let op_kind = match item.name.as_str() {
413 "Edit" => {
414 extract_path_from_args(&item.arguments).map(|path| SessionOpKind::Edit { path })
415 }
416 "Write" => {
417 extract_path_from_args(&item.arguments).map(|path| SessionOpKind::Write { path })
418 }
419 "Shell" => extract_command_from_args(&item.arguments)
420 .map(|cmd| SessionOpKind::Bash { command: cmd }),
421 _ => None,
422 };
423
424 if let Some(op) = op_kind {
425 let _ = append_session_op(
426 session_id,
427 &SessionOp {
428 op,
429 timestamp_ms: now_ms,
430 is_error,
431 },
432 );
433 }
434 }
435}