Skip to main content

j_cli/command/chat/
api.rs

1use super::model::{ChatMessage, ModelProvider};
2use async_openai::{
3    Client,
4    config::OpenAIConfig,
5    types::chat::{
6        ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
7        ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
8        ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
9        ChatCompletionRequestUserMessageArgs, ChatCompletionTools, CreateChatCompletionRequest,
10        CreateChatCompletionRequestArgs, FunctionCall,
11    },
12};
13use futures::StreamExt;
14
15/// 根据 ModelProvider 配置创建 async-openai Client
16pub fn create_openai_client(provider: &ModelProvider) -> Client<OpenAIConfig> {
17    let config = OpenAIConfig::new()
18        .with_api_key(&provider.api_key)
19        .with_api_base(&provider.api_base);
20    Client::with_config(config)
21}
22
23/// 将内部 ChatMessage 转换为 async-openai 的请求消息格式
24pub fn to_openai_messages(messages: &[ChatMessage]) -> Vec<ChatCompletionRequestMessage> {
25    messages
26        .iter()
27        .filter_map(|msg| match msg.role.as_str() {
28            "system" => ChatCompletionRequestSystemMessageArgs::default()
29                .content(msg.content.as_str())
30                .build()
31                .ok()
32                .map(ChatCompletionRequestMessage::System),
33            "user" => ChatCompletionRequestUserMessageArgs::default()
34                .content(msg.content.as_str())
35                .build()
36                .ok()
37                .map(ChatCompletionRequestMessage::User),
38            "assistant" => {
39                let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
40                if !msg.content.is_empty() {
41                    builder.content(msg.content.as_str());
42                }
43                if let Some(ref tool_calls) = msg.tool_calls {
44                    let tc_list: Vec<ChatCompletionMessageToolCalls> = tool_calls
45                        .iter()
46                        .map(|tc| {
47                            ChatCompletionMessageToolCalls::Function(
48                                ChatCompletionMessageToolCall {
49                                    id: tc.id.clone(),
50                                    function: FunctionCall {
51                                        name: tc.name.clone(),
52                                        arguments: tc.arguments.clone(),
53                                    },
54                                },
55                            )
56                        })
57                        .collect();
58                    builder.tool_calls(tc_list);
59                }
60                builder
61                    .build()
62                    .ok()
63                    .map(ChatCompletionRequestMessage::Assistant)
64            }
65            "tool" => {
66                let tool_call_id = msg.tool_call_id.clone().unwrap_or_default();
67                ChatCompletionRequestToolMessageArgs::default()
68                    .content(msg.content.as_str())
69                    .tool_call_id(tool_call_id)
70                    .build()
71                    .ok()
72                    .map(ChatCompletionRequestMessage::Tool)
73            }
74            _ => None,
75        })
76        .collect()
77}
78
79/// 构建带工具定义的请求
80pub fn build_request_with_tools(
81    provider: &ModelProvider,
82    messages: &[ChatMessage],
83    tools: Vec<ChatCompletionTools>,
84    system_prompt: Option<&str>,
85) -> Result<CreateChatCompletionRequest, String> {
86    let mut openai_messages = Vec::new();
87    if let Some(sys) = system_prompt {
88        let trimmed = sys.trim();
89        if !trimmed.is_empty() {
90            if let Ok(msg) = ChatCompletionRequestSystemMessageArgs::default()
91                .content(trimmed)
92                .build()
93            {
94                openai_messages.push(ChatCompletionRequestMessage::System(msg));
95            }
96        }
97    }
98    openai_messages.extend(to_openai_messages(messages));
99    let mut builder = CreateChatCompletionRequestArgs::default();
100    builder.model(&provider.model).messages(openai_messages);
101    if !tools.is_empty() {
102        builder.tools(tools);
103    }
104    builder.build().map_err(|e| format!("构建请求失败: {}", e))
105}
106
107/// 使用 async-openai 流式调用 API,通过回调逐步输出
108/// 返回完整的助手回复内容
109pub async fn call_openai_stream_async(
110    provider: &ModelProvider,
111    messages: &[ChatMessage],
112    system_prompt: Option<&str>,
113    on_chunk: &mut dyn FnMut(&str),
114) -> Result<String, String> {
115    let client = create_openai_client(provider);
116    let mut openai_messages = Vec::new();
117    if let Some(sys) = system_prompt {
118        let trimmed = sys.trim();
119        if !trimmed.is_empty() {
120            if let Ok(msg) = ChatCompletionRequestSystemMessageArgs::default()
121                .content(trimmed)
122                .build()
123            {
124                openai_messages.push(ChatCompletionRequestMessage::System(msg));
125            }
126        }
127    }
128    openai_messages.extend(to_openai_messages(messages));
129
130    let request = CreateChatCompletionRequestArgs::default()
131        .model(&provider.model)
132        .messages(openai_messages)
133        .build()
134        .map_err(|e| format!("构建请求失败: {}", e))?;
135
136    let mut stream = client
137        .chat()
138        .create_stream(request)
139        .await
140        .map_err(|e| format!("API 请求失败: {}", e))?;
141
142    let mut full_content = String::new();
143
144    while let Some(result) = stream.next().await {
145        match result {
146            Ok(response) => {
147                for choice in &response.choices {
148                    if let Some(ref content) = choice.delta.content {
149                        full_content.push_str(content);
150                        on_chunk(content);
151                    }
152                }
153            }
154            Err(e) => {
155                return Err(format!("流式响应错误: {}", e));
156            }
157        }
158    }
159
160    Ok(full_content)
161}
162
163/// 同步包装:创建 tokio runtime 执行异步流式调用
164pub fn call_openai_stream(
165    provider: &ModelProvider,
166    messages: &[ChatMessage],
167    system_prompt: Option<&str>,
168    on_chunk: &mut dyn FnMut(&str),
169) -> Result<String, String> {
170    let rt = tokio::runtime::Runtime::new().map_err(|e| format!("创建异步运行时失败: {}", e))?;
171    rt.block_on(call_openai_stream_async(
172        provider,
173        messages,
174        system_prompt,
175        on_chunk,
176    ))
177}