Skip to main content

aster/core/
background_tasks.rs

1//! 后台对话任务管理器
2//!
3//! 用于将对话转到后台运行
4
5use parking_lot::RwLock;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::fs::{self, File, OpenOptions};
9use std::io::Write;
10use std::path::PathBuf;
11use std::sync::Arc;
12use uuid::Uuid;
13
14/// 后台对话任务
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct BackgroundTask {
17    /// 任务 ID
18    pub id: String,
19    /// 任务类型
20    pub task_type: String,
21    /// 用户输入
22    pub user_input: String,
23    /// 任务状态
24    pub status: TaskStatus,
25    /// 开始时间(毫秒)
26    pub start_time: u64,
27    /// 结束时间(毫秒)
28    #[serde(default)]
29    pub end_time: Option<u64>,
30    /// 文本输出
31    pub text_output: String,
32    /// 工具调用记录
33    pub tool_calls: Vec<ToolCallRecord>,
34    /// 输出文件路径
35    pub output_file: PathBuf,
36    /// 是否已取消
37    pub cancelled: bool,
38    /// 错误信息
39    #[serde(default)]
40    pub error: Option<String>,
41}
42
43/// 任务状态
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
45#[serde(rename_all = "lowercase")]
46pub enum TaskStatus {
47    Running,
48    Completed,
49    Failed,
50}
51
52/// 工具调用记录
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct ToolCallRecord {
55    /// 工具名称
56    pub name: String,
57    /// 输入参数
58    pub input: serde_json::Value,
59    /// 执行结果
60    #[serde(default)]
61    pub result: Option<String>,
62    /// 错误信息
63    #[serde(default)]
64    pub error: Option<String>,
65    /// 时间戳
66    pub timestamp: u64,
67}
68
69/// 任务摘要
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct TaskSummary {
72    pub id: String,
73    pub task_type: String,
74    pub status: TaskStatus,
75    pub user_input: String,
76    pub duration: u64,
77    pub output_preview: String,
78}
79
80/// 任务统计
81#[derive(Debug, Clone, Default, Serialize, Deserialize)]
82pub struct TaskStats {
83    pub total: usize,
84    pub running: usize,
85    pub completed: usize,
86    pub failed: usize,
87}
88
89/// 后台任务管理器
90pub struct BackgroundTaskManager {
91    tasks: RwLock<HashMap<String, BackgroundTask>>,
92    tasks_dir: PathBuf,
93}
94
95impl BackgroundTaskManager {
96    /// 创建新的任务管理器
97    pub fn new() -> Self {
98        let tasks_dir = get_tasks_dir();
99        Self {
100            tasks: RwLock::new(HashMap::new()),
101            tasks_dir,
102        }
103    }
104
105    /// 创建新的后台任务
106    pub fn create_task(&self, user_input: &str) -> BackgroundTask {
107        let task_id = Uuid::new_v4().to_string();
108        let output_file = self.tasks_dir.join(format!("{}.log", task_id));
109        let now = current_timestamp();
110
111        let task = BackgroundTask {
112            id: task_id.clone(),
113            task_type: "conversation".to_string(),
114            user_input: user_input.to_string(),
115            status: TaskStatus::Running,
116            start_time: now,
117            end_time: None,
118            text_output: String::new(),
119            tool_calls: Vec::new(),
120            output_file: output_file.clone(),
121            cancelled: false,
122            error: None,
123        };
124
125        // 写入任务开始信息
126        if let Ok(mut file) = File::create(&output_file) {
127            let _ = writeln!(file, "=== Background Task Started ===");
128            let _ = writeln!(file, "Task ID: {}", task_id);
129            let _ = writeln!(file, "User Input: {}", user_input);
130            let _ = writeln!(file, "Start Time: {}", now);
131            let _ = writeln!(file);
132        }
133
134        self.tasks.write().insert(task_id, task.clone());
135        task
136    }
137
138    /// 追加文本输出
139    pub fn append_text(&self, task_id: &str, text: &str) {
140        let mut tasks = self.tasks.write();
141        if let Some(task) = tasks.get_mut(task_id) {
142            task.text_output.push_str(text);
143
144            // 写入文件
145            if let Ok(mut file) = OpenOptions::new().append(true).open(&task.output_file) {
146                let _ = file.write_all(text.as_bytes());
147            }
148        }
149    }
150
151    /// 添加工具调用记录
152    pub fn add_tool_call(
153        &self,
154        task_id: &str,
155        tool_name: &str,
156        input: serde_json::Value,
157        result: Option<String>,
158        error: Option<String>,
159    ) {
160        let mut tasks = self.tasks.write();
161        if let Some(task) = tasks.get_mut(task_id) {
162            let record = ToolCallRecord {
163                name: tool_name.to_string(),
164                input: input.clone(),
165                result: result.clone(),
166                error: error.clone(),
167                timestamp: current_timestamp(),
168            };
169            task.tool_calls.push(record);
170
171            // 写入文件
172            if let Ok(mut file) = OpenOptions::new().append(true).open(&task.output_file) {
173                let _ = writeln!(file, "\n--- Tool: {} ---", tool_name);
174                let _ = writeln!(
175                    file,
176                    "Input: {}",
177                    serde_json::to_string_pretty(&input).unwrap_or_default()
178                );
179                if let Some(ref r) = result {
180                    let preview = if r.len() > 1000 {
181                        r.get(..1000).unwrap_or(r)
182                    } else {
183                        r
184                    };
185                    let _ = writeln!(file, "Result: {}", preview);
186                }
187                if let Some(ref e) = error {
188                    let _ = writeln!(file, "Error: {}", e);
189                }
190                let _ = writeln!(file);
191            }
192        }
193    }
194
195    /// 完成任务
196    pub fn complete_task(&self, task_id: &str, success: bool, error: Option<String>) {
197        let mut tasks = self.tasks.write();
198        if let Some(task) = tasks.get_mut(task_id) {
199            task.status = if success {
200                TaskStatus::Completed
201            } else {
202                TaskStatus::Failed
203            };
204            task.end_time = Some(current_timestamp());
205            task.error = error.clone();
206
207            // 写入结束信息
208            if let Ok(mut file) = OpenOptions::new().append(true).open(&task.output_file) {
209                let status = if success { "Completed" } else { "Failed" };
210                let _ = writeln!(file, "\n=== Task {} ===", status);
211                let _ = writeln!(file, "End Time: {}", task.end_time.unwrap());
212                let _ = writeln!(
213                    file,
214                    "Duration: {}ms",
215                    task.end_time.unwrap() - task.start_time
216                );
217                if let Some(ref e) = error {
218                    let _ = writeln!(file, "Error: {}", e);
219                }
220            }
221        }
222    }
223
224    /// 取消任务
225    pub fn cancel_task(&self, task_id: &str) -> bool {
226        let mut tasks = self.tasks.write();
227        if let Some(task) = tasks.get_mut(task_id) {
228            task.cancelled = true;
229            drop(tasks);
230            self.complete_task(task_id, false, Some("Task cancelled by user".to_string()));
231            return true;
232        }
233        false
234    }
235
236    /// 获取任务
237    pub fn get_task(&self, task_id: &str) -> Option<BackgroundTask> {
238        self.tasks.read().get(task_id).cloned()
239    }
240
241    /// 获取所有任务
242    pub fn get_all_tasks(&self) -> Vec<BackgroundTask> {
243        self.tasks.read().values().cloned().collect()
244    }
245
246    /// 获取任务摘要列表
247    pub fn get_task_summaries(&self) -> Vec<TaskSummary> {
248        let now = current_timestamp();
249        self.tasks
250            .read()
251            .values()
252            .map(|task| {
253                let input_preview = if task.user_input.len() > 100 {
254                    format!(
255                        "{}...",
256                        task.user_input.get(..100).unwrap_or(&task.user_input)
257                    )
258                } else {
259                    task.user_input.clone()
260                };
261                let output_preview = if task.text_output.len() > 200 {
262                    format!(
263                        "{}...",
264                        task.text_output.get(..200).unwrap_or(&task.text_output)
265                    )
266                } else {
267                    task.text_output.clone()
268                };
269
270                TaskSummary {
271                    id: task.id.clone(),
272                    task_type: task.task_type.clone(),
273                    status: task.status,
274                    user_input: input_preview,
275                    duration: task.end_time.unwrap_or(now) - task.start_time,
276                    output_preview,
277                }
278            })
279            .collect()
280    }
281
282    /// 删除任务
283    pub fn delete_task(&self, task_id: &str) -> bool {
284        let mut tasks = self.tasks.write();
285        if let Some(task) = tasks.remove(task_id) {
286            // 如果任务还在运行,先取消
287            if task.status == TaskStatus::Running {
288                drop(tasks);
289                self.cancel_task(task_id);
290            }
291
292            // 删除输出文件
293            let _ = fs::remove_file(&task.output_file);
294            return true;
295        }
296        false
297    }
298
299    /// 清理已完成的任务
300    pub fn cleanup_completed(&self) -> usize {
301        let task_ids: Vec<String> = self
302            .tasks
303            .read()
304            .iter()
305            .filter(|(_, t)| t.status != TaskStatus::Running)
306            .map(|(id, _)| id.clone())
307            .collect();
308
309        let mut cleaned = 0;
310        for id in task_ids {
311            if self.delete_task(&id) {
312                cleaned += 1;
313            }
314        }
315        cleaned
316    }
317
318    /// 获取任务统计
319    pub fn get_stats(&self) -> TaskStats {
320        let tasks = self.tasks.read();
321        TaskStats {
322            total: tasks.len(),
323            running: tasks
324                .values()
325                .filter(|t| t.status == TaskStatus::Running)
326                .count(),
327            completed: tasks
328                .values()
329                .filter(|t| t.status == TaskStatus::Completed)
330                .count(),
331            failed: tasks
332                .values()
333                .filter(|t| t.status == TaskStatus::Failed)
334                .count(),
335        }
336    }
337
338    /// 检查任务是否已取消
339    pub fn is_cancelled(&self, task_id: &str) -> bool {
340        self.tasks
341            .read()
342            .get(task_id)
343            .map(|t| t.cancelled)
344            .unwrap_or(false)
345    }
346}
347
348impl Default for BackgroundTaskManager {
349    fn default() -> Self {
350        Self::new()
351    }
352}
353
354// 辅助函数
355
356/// 获取任务目录
357fn get_tasks_dir() -> PathBuf {
358    let dir = dirs::home_dir()
359        .unwrap_or_else(|| PathBuf::from("."))
360        .join(".aster")
361        .join("tasks")
362        .join("conversations");
363
364    if !dir.exists() {
365        let _ = fs::create_dir_all(&dir);
366    }
367
368    dir
369}
370
371/// 获取当前时间戳(毫秒)
372fn current_timestamp() -> u64 {
373    std::time::SystemTime::now()
374        .duration_since(std::time::UNIX_EPOCH)
375        .unwrap_or_default()
376        .as_millis() as u64
377}
378
379/// 全局任务管理器
380static GLOBAL_MANAGER: once_cell::sync::Lazy<Arc<BackgroundTaskManager>> =
381    once_cell::sync::Lazy::new(|| Arc::new(BackgroundTaskManager::new()));
382
383/// 获取全局任务管理器
384pub fn global_task_manager() -> Arc<BackgroundTaskManager> {
385    GLOBAL_MANAGER.clone()
386}