Skip to main content

j_agent/agent/
api.rs

1use crate::chat_error::ChatError;
2use crate::llm::{
3    ChatRequest, Content, ContentPart, FunctionCall, ImageUrl, LlmClient, Message, Role,
4    TokenUsage, ToolCall, ToolDefinition,
5};
6use crate::storage::{ChatMessage, MessageRole, ModelProvider, ToolCallItem};
7use crate::util::log::{write_error_log, write_info_log};
8use futures::StreamExt;
9use std::collections::HashSet;
10
11/// 根据 ModelProvider 配置创建 LlmClient
12pub fn create_llm_client(provider: &ModelProvider) -> LlmClient {
13    LlmClient::new(&provider.api_base, &provider.api_key)
14}
15
16/// 将内部 ChatMessage 转换为 llm::Message 格式
17pub fn to_llm_messages(messages: &[ChatMessage]) -> Vec<Message> {
18    messages
19        .iter()
20        .filter_map(|msg| match msg.role {
21            MessageRole::System => Some(Message {
22                role: Role::System,
23                content: Some(Content::Text(msg.content.clone())),
24                name: None,
25                tool_calls: None,
26                tool_call_id: None,
27                reasoning_content: None,
28            }),
29            MessageRole::User => {
30                if let Some(ref images) = msg.images
31                    && !images.is_empty()
32                {
33                    // 多模态消息:Text + ImageUrl(s)
34                    write_info_log(
35                        "to_llm_messages",
36                        &format!(
37                            "构建多模态 user 消息: text_len={}, images_count={}",
38                            msg.content.len(),
39                            images.len()
40                        ),
41                    );
42                    let mut parts = vec![ContentPart::Text {
43                        text: msg.content.clone(),
44                    }];
45                    for img in images {
46                        let data_url = format!("data:{};base64,{}", img.media_type, img.base64);
47                        parts.push(ContentPart::ImageUrl {
48                            image_url: ImageUrl {
49                                url: data_url,
50                                detail: None,
51                            },
52                        });
53                    }
54                    return Some(Message {
55                        role: Role::User,
56                        content: Some(Content::Parts(parts)),
57                        name: None,
58                        tool_calls: None,
59                        tool_call_id: None,
60                        reasoning_content: None,
61                    });
62                }
63                // 纯文本消息
64                Some(Message {
65                    role: Role::User,
66                    content: Some(Content::Text(msg.content.clone())),
67                    name: None,
68                    tool_calls: None,
69                    tool_call_id: None,
70                    reasoning_content: None,
71                })
72            }
73            MessageRole::Assistant => {
74                let content = if msg.content.is_empty() {
75                    None
76                } else {
77                    Some(Content::Text(msg.content.clone()))
78                };
79                let tool_calls = msg.tool_calls.as_ref().map(|tcs| {
80                    tcs.iter()
81                        .map(|tc| ToolCall {
82                            id: tc.id.clone(),
83                            call_type: "function".to_string(),
84                            function: FunctionCall {
85                                name: tc.name.clone(),
86                                arguments: tc.arguments.clone(),
87                            },
88                        })
89                        .collect()
90                });
91                Some(Message {
92                    role: Role::Assistant,
93                    content,
94                    name: None,
95                    tool_calls,
96                    tool_call_id: None,
97                    reasoning_content: msg.reasoning_content.clone(),
98                })
99            }
100            MessageRole::Tool => {
101                let tool_call_id = msg.tool_call_id.clone().unwrap_or_default();
102                if tool_call_id.is_empty() {
103                    write_error_log(
104                        "to_llm_messages",
105                        "跳过 tool_call_id 为空的 tool 消息(旧历史或异常消息),避免 API 报错",
106                    );
107                    return None;
108                }
109                Some(Message {
110                    role: Role::Tool,
111                    content: Some(Content::Text(msg.content.clone())),
112                    name: None,
113                    tool_calls: None,
114                    tool_call_id: Some(tool_call_id),
115                    reasoning_content: None,
116                })
117            }
118        })
119        .collect()
120}
121
122/// 预处理消息数组,保证 assistant tool_calls ↔ tool result 双向配对完整,
123/// 避免 API 报 "tool_call_id not found" 或 "missing tool result" 错误。
124pub fn sanitize_messages(messages: &[ChatMessage]) -> Vec<ChatMessage> {
125    let tool_result_ids: HashSet<String> = messages
126        .iter()
127        .filter(|m| m.role == MessageRole::Tool)
128        .filter_map(|m| m.tool_call_id.clone())
129        .filter(|id| !id.is_empty())
130        .collect();
131
132    let assistant_tool_call_ids: HashSet<String> = messages
133        .iter()
134        .filter(|m| m.role == MessageRole::Assistant)
135        .flat_map(|m| {
136            m.tool_calls
137                .iter()
138                .flatten()
139                .filter(|tc| !tc.id.is_empty())
140                .map(|tc| tc.id.clone())
141        })
142        .collect();
143
144    let mut removed_count = 0usize;
145    let result: Vec<ChatMessage> = messages
146        .iter()
147        .filter_map(|msg| {
148            if msg.role == MessageRole::Tool {
149                let id = msg.tool_call_id.as_deref().unwrap_or("");
150                if id.is_empty() || !assistant_tool_call_ids.contains(id) {
151                    write_error_log(
152                        "sanitize_messages",
153                        &format!(
154                            "移除孤立 tool result tool_call_id={:?}(在 assistant tool_calls 中无对应项)",
155                            msg.tool_call_id
156                        ),
157                    );
158                    removed_count += 1;
159                    return None;
160                }
161            }
162            if msg.role == MessageRole::Assistant
163                && let Some(ref tool_calls) = msg.tool_calls
164            {
165                let valid_tool_calls: Vec<_> = tool_calls
166                    .iter()
167                    .filter(|tool_call| {
168                        !tool_call.id.is_empty() && tool_result_ids.contains(&tool_call.id)
169                    })
170                    .cloned()
171                    .collect();
172                if valid_tool_calls.len() != tool_calls.len() {
173                    let dropped = tool_calls.len() - valid_tool_calls.len();
174                    write_error_log(
175                        "sanitize_messages",
176                        &format!(
177                            "assistant tool_calls 中 {} 个条目无对应 tool result,已移除",
178                            dropped
179                        ),
180                    );
181                    removed_count += dropped;
182                    let mut sanitized_msg = msg.clone();
183                    sanitized_msg.tool_calls = if valid_tool_calls.is_empty() {
184                        None
185                    } else {
186                        Some(valid_tool_calls)
187                    };
188                    return Some(sanitized_msg);
189                }
190            }
191            Some(msg.clone())
192        })
193        .collect();
194
195    if removed_count > 0 {
196        write_info_log(
197            "sanitize_messages",
198            &format!("共清理 {} 个孤立/无效 tool_call 相关条目", removed_count),
199        );
200    }
201    result
202}
203
204/// 后置验证:确保转换后的消息中 tool_call_id 双向一致。
205fn sanitize_llm_messages(messages: &mut Vec<Message>) {
206    let assistant_tool_call_ids: HashSet<String> = messages
207        .iter()
208        .filter(|m| m.role == Role::Assistant)
209        .flat_map(|m| m.tool_calls.iter().flatten().map(|tc| tc.id.clone()))
210        .filter(|id| !id.is_empty())
211        .collect();
212
213    let tool_result_ids: HashSet<String> = messages
214        .iter()
215        .filter(|m| m.role == Role::Tool)
216        .filter_map(|m| m.tool_call_id.clone())
217        .filter(|id| !id.is_empty())
218        .collect();
219
220    let original_len = messages.len();
221
222    messages.retain(|m| {
223        if m.role == Role::Tool {
224            let id = m.tool_call_id.as_deref().unwrap_or("");
225            if !assistant_tool_call_ids.contains(id) {
226                write_error_log(
227                    "sanitize_llm_messages",
228                    &format!(
229                        "移除孤立 tool result (tool_call_id={}):在 assistant tool_calls 中无对应项",
230                        id
231                    ),
232                );
233                return false;
234            }
235        }
236        true
237    });
238
239    for msg in messages.iter_mut() {
240        if msg.role == Role::Assistant
241            && let Some(ref mut tool_calls) = msg.tool_calls
242        {
243            let before = tool_calls.len();
244            tool_calls.retain(|tc| tc.id.is_empty() || tool_result_ids.contains(&tc.id));
245            if tool_calls.len() != before {
246                write_error_log(
247                    "sanitize_llm_messages",
248                    &format!(
249                        "assistant tool_calls 中 {} 个条目无对应 tool result,已移除",
250                        before - tool_calls.len()
251                    ),
252                );
253            }
254            if tool_calls.is_empty() {
255                msg.tool_calls = None;
256            }
257        }
258    }
259
260    let removed_count = original_len - messages.len();
261    if removed_count > 0 {
262        write_info_log(
263            "sanitize_llm_messages",
264            &format!("后置验证:共移除 {} 条孤立消息", removed_count),
265        );
266    }
267}
268
269/// 构建带工具定义的请求
270pub fn build_request_with_tools(
271    provider: &ModelProvider,
272    messages: &[ChatMessage],
273    tools: Vec<ToolDefinition>,
274    system_prompt: Option<&str>,
275) -> Result<ChatRequest, ChatError> {
276    let sanitized_messages = sanitize_messages(messages);
277    let mut llm_messages = Vec::with_capacity(sanitized_messages.len() + 1);
278    if let Some(system_prompt_text) = system_prompt {
279        let trimmed = system_prompt_text.trim();
280        if !trimmed.is_empty() {
281            llm_messages.push(Message {
282                role: Role::System,
283                content: Some(Content::Text(trimmed.to_string())),
284                name: None,
285                tool_calls: None,
286                tool_call_id: None,
287                reasoning_content: None,
288            });
289        }
290    }
291    llm_messages.extend(to_llm_messages(&sanitized_messages));
292
293    // debug: 检查是否有 reasoning_content 被传递
294    for (i, msg) in llm_messages.iter().enumerate() {
295        if msg.reasoning_content.is_some() {
296            write_info_log(
297                "build_request_with_tools",
298                &format!(
299                    "消息[{}] role={:?} 携带 reasoning_content (len={})",
300                    i,
301                    msg.role,
302                    msg.reasoning_content.as_ref().map(|s| s.len()).unwrap_or(0)
303                ),
304            );
305        }
306    }
307
308    sanitize_llm_messages(&mut llm_messages);
309
310    Ok(ChatRequest {
311        model: provider.model.clone(),
312        messages: llm_messages,
313        tools: if tools.is_empty() { None } else { Some(tools) },
314        stream: None,
315        max_tokens: None,
316        extra: serde_json::Map::new(),
317    })
318}
319
320/// 流式调用 API,通过回调逐步输出,返回完整的助手回复内容
321pub async fn call_llm_stream_async(
322    provider: &ModelProvider,
323    messages: &[ChatMessage],
324    system_prompt: Option<&str>,
325    on_chunk: &mut dyn FnMut(&str),
326) -> Result<String, ChatError> {
327    let client = create_llm_client(provider);
328    let mut llm_messages = Vec::with_capacity(messages.len() + 1);
329
330    if let Some(system_prompt_text) = system_prompt {
331        let trimmed = system_prompt_text.trim();
332        if !trimmed.is_empty() {
333            llm_messages.push(Message {
334                role: Role::System,
335                content: Some(Content::Text(trimmed.to_string())),
336                name: None,
337                tool_calls: None,
338                tool_call_id: None,
339                reasoning_content: None,
340            });
341        }
342    }
343    llm_messages.extend(to_llm_messages(messages));
344
345    let request = ChatRequest {
346        model: provider.model.clone(),
347        messages: llm_messages,
348        tools: None,
349        stream: Some(true),
350        max_tokens: None,
351        extra: serde_json::Map::new(),
352    };
353
354    let request_body =
355        serde_json::to_string(&request).unwrap_or_else(|e| format!("序列化request失败: {}", e));
356
357    let mut stream = client.chat_completion_stream(&request).await.map_err(|e| {
358        let err_msg = ChatError::from(e);
359        write_info_log(
360            "call_llm_stream_async API请求 ERROR",
361            &format!("{}\nrequest body:\n{}", err_msg, request_body),
362        );
363        err_msg
364    })?;
365
366    let mut full_content = String::new();
367
368    while let Some(result) = stream.next().await {
369        match result {
370            Ok(response) => {
371                for choice in &response.choices {
372                    if let Some(ref content) = choice.delta.content {
373                        full_content.push_str(content);
374                        on_chunk(content);
375                    }
376                }
377            }
378            Err(e) => {
379                let err = ChatError::from(e);
380                write_info_log(
381                    "call_llm_stream_async 流式响应 ERROR",
382                    &format!(
383                        "{}\n已接收内容长度: {}\nrequest body:\n{}",
384                        err,
385                        full_content.len(),
386                        request_body
387                    ),
388                );
389                return Err(err);
390            }
391        }
392    }
393
394    Ok(full_content)
395}
396
397/// fallback 非流式调用结果
398#[derive(Debug)]
399pub struct FallbackResult {
400    pub content: Option<String>,
401    pub tool_calls: Option<Vec<ToolCallItem>>,
402    pub finish_reason: Option<String>,
403    pub reasoning_content: Option<String>,
404    pub usage: Option<TokenUsage>,
405}
406
407impl FallbackResult {
408    pub fn has_tool_calls(&self) -> bool {
409        self.tool_calls.is_some()
410    }
411}
412
413/// 非流式请求(已内置宽松反序列化:finish_reason 为 String)
414pub async fn call_llm_non_stream(
415    provider: &ModelProvider,
416    request: &ChatRequest,
417) -> Result<FallbackResult, ChatError> {
418    let client = create_llm_client(provider);
419    let request_body =
420        serde_json::to_string(request).unwrap_or_else(|e| format!("序列化request失败: {}", e));
421
422    let response = client.chat_completion(request).await.map_err(|e| {
423        let err = ChatError::from(e);
424        write_error_log(
425            "call_llm_non_stream",
426            &format!("{}\nrequest body:\n{}", err, request_body),
427        );
428        err
429    })?;
430
431    let choice = match response.choices.first() {
432        Some(c) => c,
433        None => {
434            return Ok(FallbackResult {
435                content: None,
436                tool_calls: None,
437                finish_reason: None,
438                reasoning_content: None,
439                usage: response.usage,
440            });
441        }
442    };
443
444    let tool_items = choice.message.tool_calls.as_ref().map(|tool_calls| {
445        tool_calls
446            .iter()
447            .map(|tool_call| {
448                let id = if tool_call.id.is_empty() {
449                    use rand::Rng;
450                    let rand_id = format!("call_{:016x}", rand::thread_rng().r#gen::<u64>());
451                    write_info_log(
452                        "call_llm_non_stream",
453                        &format!(
454                            "tool_call id 为空,已生成随机 id: {} (tool: {})",
455                            rand_id, tool_call.function.name
456                        ),
457                    );
458                    rand_id
459                } else {
460                    tool_call.id.clone()
461                };
462                ToolCallItem {
463                    id,
464                    name: tool_call.function.name.clone(),
465                    arguments: tool_call.function.arguments.clone(),
466                }
467            })
468            .collect()
469    });
470
471    if let Some(ref reason) = choice.finish_reason
472        && !matches!(
473            reason.as_str(),
474            "stop" | "length" | "tool_calls" | "content_filter" | "function_call"
475        )
476    {
477        write_info_log(
478            "call_llm_non_stream",
479            &format!("非标准 finish_reason: {}", reason),
480        );
481    }
482
483    Ok(FallbackResult {
484        content: choice.message.content.clone(),
485        tool_calls: tool_items,
486        finish_reason: choice.finish_reason.clone(),
487        reasoning_content: choice.message.reasoning_content.clone(),
488        usage: response.usage,
489    })
490}
491
492/// 同步包装:创建 tokio runtime 执行异步流式调用
493pub fn call_llm_stream(
494    provider: &ModelProvider,
495    messages: &[ChatMessage],
496    system_prompt: Option<&str>,
497    on_chunk: &mut dyn FnMut(&str),
498) -> Result<String, ChatError> {
499    let rt = tokio::runtime::Runtime::new().map_err(|e| {
500        let err = ChatError::RuntimeFailed(e.to_string());
501        write_info_log("call_llm_stream 创建runtime ERROR", &format!("{}", err));
502        err
503    })?;
504    rt.block_on(call_llm_stream_async(
505        provider,
506        messages,
507        system_prompt,
508        on_chunk,
509    ))
510}
511
512/// 清理 API 响应 body 用于错误消息:剥离 HTML 标签,截断超长内容
513#[allow(dead_code)]
514fn sanitize_api_body(body: &str) -> String {
515    let max_len = crate::constants::API_ERROR_BODY_MAX_LEN;
516    let truncated = &body[..body.len().min(max_len)];
517    let mut result = String::with_capacity(truncated.len());
518    let mut in_tag = false;
519    for ch in truncated.chars() {
520        match ch {
521            '<' => in_tag = true,
522            '>' => in_tag = false,
523            _ if !in_tag => result.push(ch),
524            _ => {}
525        }
526    }
527    result.split_whitespace().collect::<Vec<_>>().join(" ")
528}