Skip to main content

j_cli/command/chat/
model.rs

1use super::theme::ThemeName;
2use crate::config::YamlConfig;
3use crate::error;
4use serde::{Deserialize, Serialize};
5use std::fs;
6use std::path::PathBuf;
7
8// ========== 数据结构 ==========
9
10/// 单个模型提供方配置
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ModelProvider {
13    /// 显示名称(如 "GPT-4o", "DeepSeek-V3")
14    pub name: String,
15    /// API Base URL(如 "https://api.openai.com/v1")
16    pub api_base: String,
17    /// API Key
18    pub api_key: String,
19    /// 模型名称(如 "gpt-4o", "deepseek-chat")
20    pub model: String,
21}
22
23/// Agent 配置
24#[derive(Debug, Clone, Serialize, Deserialize, Default)]
25pub struct AgentConfig {
26    /// 模型提供方列表
27    #[serde(default)]
28    pub providers: Vec<ModelProvider>,
29    /// 当前选中的 provider 索引
30    #[serde(default)]
31    pub active_index: usize,
32    /// 系统提示词(可选)
33    #[serde(default)]
34    pub system_prompt: Option<String>,
35    /// 是否使用流式输出(默认 true,设为 false 则等回复完整后再显示)
36    #[serde(default = "default_stream_mode")]
37    pub stream_mode: bool,
38    /// 发送给 API 的历史消息数量限制(默认 20 条,避免 token 消耗过大)
39    #[serde(default = "default_max_history_messages")]
40    pub max_history_messages: usize,
41    /// 主题名称(dark / light / midnight)
42    #[serde(default)]
43    pub theme: ThemeName,
44    /// 是否启用工具调用(默认关闭)
45    #[serde(default)]
46    pub tools_enabled: bool,
47    /// 工具调用最大轮数(默认 10,防止无限循环)
48    #[serde(default = "default_max_tool_rounds")]
49    pub max_tool_rounds: usize,
50    /// 回复风格(可选)
51    #[serde(default)]
52    pub style: Option<String>,
53}
54
55fn default_max_history_messages() -> usize {
56    20
57}
58
59/// 默认流式输出
60fn default_stream_mode() -> bool {
61    true
62}
63
64/// 默认工具调用最大轮数
65fn default_max_tool_rounds() -> usize {
66    10
67}
68
69/// 单次工具调用请求(序列化到历史记录)
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct ToolCallItem {
72    pub id: String,
73    pub name: String,
74    pub arguments: String,
75}
76
77/// 对话消息
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct ChatMessage {
80    pub role: String, // "user" | "assistant" | "system" | "tool"
81    /// 消息内容(tool_call 类消息可为空)
82    #[serde(default)]
83    pub content: String,
84    /// LLM 发起的工具调用列表(仅 assistant 角色且有 tool_calls 时非 None)
85    #[serde(skip_serializing_if = "Option::is_none")]
86    pub tool_calls: Option<Vec<ToolCallItem>>,
87    /// 工具执行结果对应的 tool_call_id(仅 tool 角色时非 None)
88    #[serde(skip_serializing_if = "Option::is_none")]
89    pub tool_call_id: Option<String>,
90}
91
92impl ChatMessage {
93    /// 创建普通文本消息
94    pub fn text(role: impl Into<String>, content: impl Into<String>) -> Self {
95        Self {
96            role: role.into(),
97            content: content.into(),
98            tool_calls: None,
99            tool_call_id: None,
100        }
101    }
102}
103
104/// 对话会话
105#[derive(Debug, Clone, Serialize, Deserialize, Default)]
106pub struct ChatSession {
107    pub messages: Vec<ChatMessage>,
108}
109
110// ========== 文件路径 ==========
111
112/// 获取 agent 数据目录: ~/.jdata/agent/data/
113pub fn agent_data_dir() -> PathBuf {
114    let dir = YamlConfig::data_dir().join("agent").join("data");
115    let _ = fs::create_dir_all(&dir);
116    dir
117}
118
119/// 获取 agent 配置文件路径
120pub fn agent_config_path() -> PathBuf {
121    agent_data_dir().join("agent_config.json")
122}
123
124/// 获取对话历史文件路径
125pub fn chat_history_path() -> PathBuf {
126    agent_data_dir().join("chat_history.json")
127}
128
129/// 获取系统提示词文件路径
130pub fn system_prompt_path() -> PathBuf {
131    agent_data_dir().join("system_prompt.md")
132}
133
134/// 获取回复风格文件路径
135pub fn style_path() -> PathBuf {
136    agent_data_dir().join("style.md")
137}
138
139// ========== 配置读写 ==========
140
141/// 加载 Agent 配置
142pub fn load_agent_config() -> AgentConfig {
143    let path = agent_config_path();
144    if !path.exists() {
145        return AgentConfig::default();
146    }
147    match fs::read_to_string(&path) {
148        Ok(content) => serde_json::from_str(&content).unwrap_or_else(|e| {
149            error!("❌ 解析 agent_config.json 失败: {}", e);
150            AgentConfig::default()
151        }),
152        Err(e) => {
153            error!("❌ 读取 agent_config.json 失败: {}", e);
154            AgentConfig::default()
155        }
156    }
157}
158
159/// 保存 Agent 配置
160pub fn save_agent_config(config: &AgentConfig) -> bool {
161    let path = agent_config_path();
162    if let Some(parent) = path.parent() {
163        let _ = fs::create_dir_all(parent);
164    }
165    // system_prompt 和 style 统一存放在独立文件,不再写入 agent_config.json
166    let mut config_to_save = config.clone();
167    config_to_save.system_prompt = None;
168    config_to_save.style = None;
169    match serde_json::to_string_pretty(&config_to_save) {
170        Ok(json) => match fs::write(&path, json) {
171            Ok(_) => true,
172            Err(e) => {
173                error!("❌ 保存 agent_config.json 失败: {}", e);
174                false
175            }
176        },
177        Err(e) => {
178            error!("❌ 序列化 agent 配置失败: {}", e);
179            false
180        }
181    }
182}
183
184/// 加载对话历史
185pub fn load_chat_session() -> ChatSession {
186    let path = chat_history_path();
187    if !path.exists() {
188        return ChatSession::default();
189    }
190    match fs::read_to_string(&path) {
191        Ok(content) => serde_json::from_str(&content).unwrap_or_else(|_| ChatSession::default()),
192        Err(_) => ChatSession::default(),
193    }
194}
195
196/// 保存对话历史
197pub fn save_chat_session(session: &ChatSession) -> bool {
198    let path = chat_history_path();
199    if let Some(parent) = path.parent() {
200        let _ = fs::create_dir_all(parent);
201    }
202    match serde_json::to_string_pretty(session) {
203        Ok(json) => fs::write(&path, json).is_ok(),
204        Err(_) => false,
205    }
206}
207
208/// 加载系统提示词(来自独立文件)
209pub fn load_system_prompt() -> Option<String> {
210    let path = system_prompt_path();
211    if !path.exists() {
212        return None;
213    }
214    match fs::read_to_string(path) {
215        Ok(content) => {
216            let trimmed = content.trim();
217            if trimmed.is_empty() {
218                None
219            } else {
220                Some(trimmed.to_string())
221            }
222        }
223        Err(e) => {
224            error!("❌ 读取 system_prompt.md 失败: {}", e);
225            None
226        }
227    }
228}
229
230/// 保存系统提示词到独立文件(空字符串会删除文件)
231pub fn save_system_prompt(prompt: &str) -> bool {
232    let path = system_prompt_path();
233    if let Some(parent) = path.parent() {
234        let _ = fs::create_dir_all(parent);
235    }
236
237    let trimmed = prompt.trim();
238    if trimmed.is_empty() {
239        return match fs::remove_file(&path) {
240            Ok(_) => true,
241            Err(e) if e.kind() == std::io::ErrorKind::NotFound => true,
242            Err(e) => {
243                error!("❌ 删除 system_prompt.md 失败: {}", e);
244                false
245            }
246        };
247    }
248
249    match fs::write(path, trimmed) {
250        Ok(_) => true,
251        Err(e) => {
252            error!("❌ 保存 system_prompt.md 失败: {}", e);
253            false
254        }
255    }
256}
257
258/// 加载回复风格(来自独立文件)
259pub fn load_style() -> Option<String> {
260    let path = style_path();
261    if !path.exists() {
262        return None;
263    }
264    match fs::read_to_string(path) {
265        Ok(content) => {
266            let trimmed = content.trim();
267            if trimmed.is_empty() {
268                None
269            } else {
270                Some(trimmed.to_string())
271            }
272        }
273        Err(e) => {
274            error!("❌ 读取 style.md 失败: {}", e);
275            None
276        }
277    }
278}
279
280/// 保存回复风格到独立文件(空字符串会删除文件)
281pub fn save_style(style: &str) -> bool {
282    let path = style_path();
283    if let Some(parent) = path.parent() {
284        let _ = fs::create_dir_all(parent);
285    }
286
287    let trimmed = style.trim();
288    if trimmed.is_empty() {
289        return match fs::remove_file(&path) {
290            Ok(_) => true,
291            Err(e) if e.kind() == std::io::ErrorKind::NotFound => true,
292            Err(e) => {
293                error!("❌ 删除 style.md 失败: {}", e);
294                false
295            }
296        };
297    }
298
299    match fs::write(path, trimmed) {
300        Ok(_) => true,
301        Err(e) => {
302            error!("❌ 保存 style.md 失败: {}", e);
303            false
304        }
305    }
306}