agentic_warden/
storage.rs

1use crate::{
2    config::{MAX_RECORD_AGE, SHARED_MEMORY_SIZE},
3    core::models::ProcessTreeInfo,
4    core::shared_map::open_or_create,
5    error::RegistryError,
6    logging::warn,
7    task_record::{TaskRecord, TaskStatus},
8};
9use chrono::{DateTime, Duration, Utc};
10use dashmap::DashMap;
11use parking_lot::Mutex;
12use shared_hashmap::SharedMemoryHashMap;
13use std::sync::Arc;
14
15/// 任务注册表条目
16#[derive(Debug, Clone)]
17pub struct RegistryEntry {
18    pub pid: u32,
19    pub key: String,
20    pub record: TaskRecord,
21}
22
23/// 清理事件
24#[derive(Debug, Clone)]
25pub struct CleanupEvent {
26    pub _pid: u32,
27    pub record: TaskRecord,
28    pub reason: CleanupReason,
29}
30
31/// 清理原因
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum CleanupReason {
34    ProcessExited,
35    Timeout,
36    ManagerMissing,
37}
38
39/// 任务存储的统一接口
40/// 提供跨进程(SharedMemory)和进程内(InProcess)两种实现
41pub trait TaskStorage: Send + Sync {
42    /// 注册新任务
43    fn register(&self, pid: u32, record: &TaskRecord) -> Result<(), RegistryError>;
44
45    /// 标记任务完成
46    fn mark_completed(
47        &self,
48        pid: u32,
49        result: Option<String>,
50        exit_code: Option<i32>,
51        completed_at: DateTime<Utc>,
52    ) -> Result<(), RegistryError>;
53
54    /// 获取所有任务条目
55    fn entries(&self) -> Result<Vec<RegistryEntry>, RegistryError>;
56
57    /// 清理过期任务
58    fn sweep_stale_entries<F, G>(
59        &self,
60        now: DateTime<Utc>,
61        is_process_alive: F,
62        terminate_process: &G,
63    ) -> Result<Vec<CleanupEvent>, RegistryError>
64    where
65        F: Fn(u32) -> bool,
66        G: Fn(u32) -> Result<(), String>;
67
68    /// 获取已完成但未读的任务
69    fn get_completed_unread_tasks(&self) -> Result<Vec<(u32, TaskRecord)>, RegistryError>;
70
71    /// 检查是否有运行中的任务
72    fn has_running_tasks(&self, filter: Option<&ProcessTreeInfo>) -> Result<bool, RegistryError>;
73}
74
75/// 进程内任务存储(线程安全)
76/// 用于MCP启动的任务,不跨进程共享
77/// 使用DashMap提供高性能的并发访问
78#[derive(Debug, Clone)]
79pub struct InProcessStorage {
80    tasks: Arc<DashMap<u32, TaskRecord>>,
81}
82
83impl InProcessStorage {
84    pub fn new() -> Self {
85        Self {
86            tasks: Arc::new(DashMap::new()),
87        }
88    }
89}
90
91impl Default for InProcessStorage {
92    fn default() -> Self {
93        Self::new()
94    }
95}
96
97impl TaskStorage for InProcessStorage {
98    fn register(&self, pid: u32, record: &TaskRecord) -> Result<(), RegistryError> {
99        self.tasks.insert(pid, record.clone());
100        Ok(())
101    }
102
103    fn mark_completed(
104        &self,
105        pid: u32,
106        result: Option<String>,
107        exit_code: Option<i32>,
108        completed_at: DateTime<Utc>,
109    ) -> Result<(), RegistryError> {
110        if let Some(mut record) = self.tasks.get_mut(&pid) {
111            record.status = TaskStatus::CompletedButUnread;
112            record.result = result;
113            record.exit_code = exit_code;
114            record.completed_at = Some(completed_at);
115        } else {
116            return Err(RegistryError::TaskNotFound(pid));
117        }
118        Ok(())
119    }
120
121    fn entries(&self) -> Result<Vec<RegistryEntry>, RegistryError> {
122        Ok(self
123            .tasks
124            .iter()
125            .map(|entry| RegistryEntry {
126                pid: *entry.key(),
127                key: entry.key().to_string(),
128                record: entry.value().clone(),
129            })
130            .collect())
131    }
132
133    fn sweep_stale_entries<F, G>(
134        &self,
135        now: DateTime<Utc>,
136        is_process_alive: F,
137        terminate_process: &G,
138    ) -> Result<Vec<CleanupEvent>, RegistryError>
139    where
140        F: Fn(u32) -> bool,
141        G: Fn(u32) -> Result<(), String>,
142    {
143        const MAX_AGE_HOURS: i64 = 12;
144        let max_age = Duration::hours(MAX_AGE_HOURS);
145
146        let mut cleanup_events = Vec::new();
147
148        let pids_to_cleanup: Vec<(u32, CleanupReason)> = self
149            .tasks
150            .iter()
151            .filter_map(|entry| {
152                let pid = *entry.key();
153                let record = entry.value();
154
155                // 如果进程已不存在
156                if !is_process_alive(pid) {
157                    // 如果任务未标记完成,补标记
158                    if record.status == TaskStatus::Running {
159                        return Some((pid, CleanupReason::ProcessExited));
160                    }
161                }
162
163                // 如果记录太旧(超过12小时)
164                let age = now.signed_duration_since(record.started_at);
165                if age > max_age {
166                    if record.status == TaskStatus::Running && is_process_alive(pid) {
167                        // 尝试终止
168                        let _ = terminate_process(pid);
169                        return Some((pid, CleanupReason::Timeout));
170                    }
171                }
172
173                None
174            })
175            .collect();
176
177        for (pid, cleanup_reason) in pids_to_cleanup {
178            if let Some(mut record) = self.tasks.get_mut(&pid) {
179                record.status = TaskStatus::CompletedButUnread;
180                record.completed_at = Some(now);
181                record.cleanup_reason = Some(
182                    match cleanup_reason {
183                        CleanupReason::ProcessExited => "process_exited",
184                        CleanupReason::Timeout => "timeout",
185                        CleanupReason::ManagerMissing => "manager_missing",
186                    }
187                    .to_string(),
188                );
189
190                cleanup_events.push(CleanupEvent {
191                    _pid: pid,
192                    record: record.clone(),
193                    reason: cleanup_reason,
194                });
195            }
196        }
197
198        Ok(cleanup_events)
199    }
200
201    fn get_completed_unread_tasks(&self) -> Result<Vec<(u32, TaskRecord)>, RegistryError> {
202        let completed: Vec<(u32, TaskRecord)> = self
203            .tasks
204            .iter()
205            .filter_map(|entry| {
206                let pid = *entry.key();
207                let record = entry.value();
208
209                if record.status == TaskStatus::CompletedButUnread {
210                    // 标记为已读(从映射中移除)
211                    Some((pid, record.clone()))
212                } else {
213                    None
214                }
215            })
216            .collect();
217
218        // 移除已读的任务
219        for (pid, _) in &completed {
220            self.tasks.remove(pid);
221        }
222
223        Ok(completed)
224    }
225
226    fn has_running_tasks(&self, filter: Option<&ProcessTreeInfo>) -> Result<bool, RegistryError> {
227        if let Some(tree_filter) = filter {
228            Ok(self.tasks.iter().any(|entry| {
229                let record = entry.value();
230                record.status == TaskStatus::Running
231                    && record
232                        .process_tree
233                        .as_ref()
234                        .map(|tree| tree.root_parent_pid == tree_filter.root_parent_pid)
235                        .unwrap_or(false)
236            }))
237        } else {
238            Ok(self
239                .tasks
240                .iter()
241                .any(|entry| entry.value().status == TaskStatus::Running))
242        }
243    }
244}
245
246/// 跨进程任务存储(SharedMemory)
247/// 用于CLI启动的任务,支持跨进程共享
248#[derive(Debug, Clone)]
249pub struct SharedMemoryStorage {
250    namespace: String,
251    map: Arc<Mutex<SharedMemoryHashMap<String, String>>>,
252}
253
254impl SharedMemoryStorage {
255    /// 连接到当前进程的共享内存
256    /// 使用当前进程PID作为命名空间: {PID}_task
257    pub fn connect() -> Result<Self, RegistryError> {
258        let pid = std::process::id();
259        Self::connect_for_pid(pid)
260    }
261
262    /// 连接到指定PID的共享内存
263    /// 使用格式: {pid}_task
264    pub fn connect_for_pid(pid: u32) -> Result<Self, RegistryError> {
265        let namespace = format!("{}_task", pid);
266        Self::connect_with_namespace(namespace)
267    }
268
269    /// 使用指定的命名空间连接
270    pub fn connect_with_namespace(namespace: String) -> Result<Self, RegistryError> {
271        let map = open_or_create(&namespace, SHARED_MEMORY_SIZE)?;
272        Ok(Self {
273            namespace,
274            map: Arc::new(Mutex::new(map)),
275        })
276    }
277
278    /// 删除共享内存(用于进程结束时清理)
279    pub fn cleanup(&self) -> Result<(), RegistryError> {
280        use shared_memory::ShmemConf;
281
282        // 尝试删除共享内存
283        if let Ok(mut shmem) = ShmemConf::new()
284            .os_id(&self.namespace)
285            .size(SHARED_MEMORY_SIZE)
286            .open()
287        {
288            let _ = shmem.set_owner(true);
289        }
290
291        Ok(())
292    }
293
294    fn with_map<T>(
295        &self,
296        f: impl FnOnce(&mut SharedMemoryHashMap<String, String>) -> Result<T, RegistryError>,
297    ) -> Result<T, RegistryError> {
298        let mut guard = self.map.lock();
299        f(&mut guard)
300    }
301
302    fn remove_keys(&self, keys: &[String]) -> Result<(), RegistryError> {
303        if keys.is_empty() {
304            return Ok(());
305        }
306        self.with_map(|map| {
307            for key in keys {
308                map.remove(key);
309            }
310            Ok(())
311        })
312    }
313}
314
315impl TaskStorage for SharedMemoryStorage {
316    fn register(&self, pid: u32, record: &TaskRecord) -> Result<(), RegistryError> {
317        let key = pid.to_string();
318        let value = serde_json::to_string(record)?;
319        self.with_map(|map| {
320            map.try_insert(key.clone(), value)?;
321            Ok(())
322        })
323    }
324
325    fn mark_completed(
326        &self,
327        pid: u32,
328        result: Option<String>,
329        exit_code: Option<i32>,
330        completed_at: DateTime<Utc>,
331    ) -> Result<(), RegistryError> {
332        let key = pid.to_string();
333        self.with_map(move |map| {
334            let existing = map
335                .get(&key)
336                .ok_or_else(|| RegistryError::Map(format!("no task found for pid {pid}")))?;
337            let record: TaskRecord = serde_json::from_str(&existing)?;
338            let updated_record = record.mark_completed(result, exit_code, completed_at);
339            let updated_value = serde_json::to_string(&updated_record)?;
340            let _ = map.insert(key.clone(), updated_value);
341            Ok(())
342        })
343    }
344
345    fn entries(&self) -> Result<Vec<RegistryEntry>, RegistryError> {
346        let snapshot: Vec<(String, String)> = {
347            let guard = self.map.lock();
348            guard.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
349        };
350
351        let mut entries = Vec::new();
352        let mut invalid_keys = Vec::new();
353
354        for (key, value) in snapshot {
355            match key.parse::<u32>() {
356                Ok(pid) => match serde_json::from_str::<TaskRecord>(&value) {
357                    Ok(record) => entries.push(RegistryEntry {
358                        pid,
359                        key: key.clone(),
360                        record,
361                    }),
362                    Err(err) => {
363                        warn(format!("failed to parse task record pid={key}: {err}"));
364                        invalid_keys.push(key);
365                    }
366                },
367                Err(_) => {
368                    warn(format!("detected invalid pid key: {key}"));
369                    invalid_keys.push(key);
370                }
371            }
372        }
373
374        if !invalid_keys.is_empty() {
375            self.remove_keys(&invalid_keys)?;
376        }
377
378        Ok(entries)
379    }
380
381    fn sweep_stale_entries<F, G>(
382        &self,
383        now: DateTime<Utc>,
384        is_process_alive: F,
385        terminate_process: &G,
386    ) -> Result<Vec<CleanupEvent>, RegistryError>
387    where
388        F: Fn(u32) -> bool,
389        G: Fn(u32) -> Result<(), String>,
390    {
391        let entries = self.entries()?;
392        let mut removals = Vec::new();
393        let mut events = Vec::new();
394
395        for mut entry in entries {
396            let mut should_cleanup = false;
397            let mut cleanup_reason = CleanupReason::ProcessExited;
398
399            // 检查进程是否存活
400            if !is_process_alive(entry.pid) {
401                should_cleanup = true;
402                cleanup_reason = CleanupReason::ProcessExited;
403            } else {
404                // 检查manager进程
405                if let Some(_manager_pid) = entry.record.manager_pid.filter(|&manager_pid| {
406                    manager_pid != entry.pid && !is_process_alive(manager_pid)
407                }) {
408                    let _ = terminate_process(entry.pid);
409                    should_cleanup = true;
410                    cleanup_reason = CleanupReason::ManagerMissing;
411                }
412
413                // 检查是否超时
414                if !should_cleanup {
415                    let age = now.signed_duration_since(entry.record.started_at);
416                    let max_age = Duration::from_std(MAX_RECORD_AGE).unwrap_or(Duration::zero());
417                    if age > max_age {
418                        let _ = terminate_process(entry.pid);
419                        should_cleanup = true;
420                        cleanup_reason = CleanupReason::Timeout;
421                    }
422                }
423            }
424
425            if should_cleanup {
426                removals.push(entry.pid.to_string());
427
428                // Update record with cleanup reason
429                entry.record.cleanup_reason = Some(
430                    match cleanup_reason {
431                        CleanupReason::ProcessExited => "process_exited",
432                        CleanupReason::Timeout => "timeout",
433                        CleanupReason::ManagerMissing => "manager_missing",
434                    }
435                    .to_string(),
436                );
437
438                events.push(CleanupEvent {
439                    _pid: entry.pid,
440                    record: entry.record,
441                    reason: cleanup_reason,
442                });
443            }
444        }
445
446        if !removals.is_empty() {
447            self.remove_keys(&removals)?;
448        }
449
450        Ok(events)
451    }
452
453    fn get_completed_unread_tasks(&self) -> Result<Vec<(u32, TaskRecord)>, RegistryError> {
454        let entries = self.entries()?;
455        let mut completed_pids = Vec::new();
456
457        for entry in &entries {
458            if entry.record.status == TaskStatus::CompletedButUnread {
459                completed_pids.push(entry.pid);
460            }
461        }
462
463        // 从共享内存中删除已完成的任务
464        for pid in &completed_pids {
465            let key = pid.to_string();
466            let _ = self.with_map(|map| {
467                map.remove(&key);
468                Ok::<(), RegistryError>(())
469            });
470        }
471
472        // 返回已完成的任务
473        let completed_tasks: Vec<(u32, TaskRecord)> = entries
474            .into_iter()
475            .filter(|entry| entry.record.status == TaskStatus::CompletedButUnread)
476            .map(|entry| (entry.pid, entry.record))
477            .collect();
478
479        Ok(completed_tasks)
480    }
481
482    fn has_running_tasks(&self, filter: Option<&ProcessTreeInfo>) -> Result<bool, RegistryError> {
483        let entries = self.entries()?;
484
485        if let Some(tree_filter) = filter {
486            Ok(entries.iter().any(|entry| {
487                entry.record.status == TaskStatus::Running
488                    && entry
489                        .record
490                        .process_tree
491                        .as_ref()
492                        .map(|tree| tree.root_parent_pid == tree_filter.root_parent_pid)
493                        .unwrap_or(false)
494            }))
495        } else {
496            Ok(entries
497                .iter()
498                .any(|entry| entry.record.status == TaskStatus::Running))
499        }
500    }
501}
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506
507    #[test]
508    fn test_in_process_storage_register() {
509        let storage = InProcessStorage::new();
510        let record = TaskRecord::new(
511            Utc::now(),
512            "123".to_string(),
513            "/tmp/test.log".to_string(),
514            Some(100),
515        );
516
517        assert!(storage.register(123, &record).is_ok());
518        let entries = storage.entries().unwrap();
519        assert_eq!(entries.len(), 1);
520        assert_eq!(entries[0].pid, 123);
521    }
522
523    #[test]
524    fn test_in_process_storage_mark_completed() {
525        let storage = InProcessStorage::new();
526        let record = TaskRecord::new(
527            Utc::now(),
528            "456".to_string(),
529            "/tmp/test.log".to_string(),
530            Some(100),
531        );
532
533        storage.register(456, &record).unwrap();
534        storage
535            .mark_completed(456, Some("success".to_string()), Some(0), Utc::now())
536            .unwrap();
537
538        let completed = storage.get_completed_unread_tasks().unwrap();
539        assert_eq!(completed.len(), 1);
540        assert_eq!(completed[0].0, 456);
541        assert_eq!(completed[0].1.result, Some("success".to_string()));
542    }
543
544    #[test]
545    fn test_in_process_storage_sweep_stale() {
546        let storage = InProcessStorage::new();
547        let old_time = Utc::now() - Duration::hours(13);
548        let record = TaskRecord::new(
549            old_time,
550            "789".to_string(),
551            "/tmp/test.log".to_string(),
552            Some(100),
553        );
554
555        storage.register(789, &record).unwrap();
556
557        let is_alive = |_: u32| false;
558        let terminate = |_: u32| Ok(());
559
560        let events = storage
561            .sweep_stale_entries(Utc::now(), is_alive, &terminate)
562            .unwrap();
563
564        assert_eq!(events.len(), 1);
565        assert_eq!(events[0]._pid, 789);
566    }
567
568    #[cfg(test)]
569    mod concurrency_tests {
570        use super::*;
571        use std::collections::HashMap;
572        use std::sync::{Arc, Mutex as StdMutex};
573        use std::thread;
574        use std::time::Instant;
575
576        // 为了比较性能,创建一个使用Mutex<HashMap>的旧版InProcessStorage
577        #[derive(Debug)]
578        struct LegacyInProcessStorage {
579            tasks: Arc<StdMutex<HashMap<u32, TaskRecord>>>,
580        }
581
582        impl LegacyInProcessStorage {
583            fn new() -> Self {
584                Self {
585                    tasks: Arc::new(StdMutex::new(HashMap::new())),
586                }
587            }
588        }
589
590        impl TaskStorage for LegacyInProcessStorage {
591            fn register(&self, pid: u32, record: &TaskRecord) -> Result<(), RegistryError> {
592                let mut tasks = self.tasks.lock().unwrap();
593                tasks.insert(pid, record.clone());
594                Ok(())
595            }
596
597            fn entries(&self) -> Result<Vec<RegistryEntry>, RegistryError> {
598                let tasks = self.tasks.lock().unwrap();
599                Ok(tasks
600                    .iter()
601                    .map(|(&pid, record)| RegistryEntry {
602                        pid,
603                        key: pid.to_string(),
604                        record: record.clone(),
605                    })
606                    .collect())
607            }
608
609            // 简化的其他方法实现用于性能测试
610            fn mark_completed(
611                &self,
612                _pid: u32,
613                _result: Option<String>,
614                _exit_code: Option<i32>,
615                _completed_at: DateTime<Utc>,
616            ) -> Result<(), RegistryError> {
617                Ok(())
618            }
619
620            fn sweep_stale_entries<F, G>(
621                &self,
622                _now: DateTime<Utc>,
623                _is_process_alive: F,
624                _terminate_process: &G,
625            ) -> Result<Vec<CleanupEvent>, RegistryError> {
626                Ok(Vec::new())
627            }
628
629            fn get_completed_unread_tasks(&self) -> Result<Vec<(u32, TaskRecord)>, RegistryError> {
630                Ok(Vec::new())
631            }
632
633            fn has_running_tasks(
634                &self,
635                _filter: Option<&ProcessTreeInfo>,
636            ) -> Result<bool, RegistryError> {
637                Ok(false)
638            }
639        }
640
641        #[test]
642        fn test_concurrent_performance_comparison() {
643            const NUM_THREADS: usize = 8;
644            const OPERATIONS_PER_THREAD: usize = 1000;
645            const NUM_PIDS: usize = NUM_THREADS * OPERATIONS_PER_THREAD;
646
647            let dashmap_storage = Arc::new(InProcessStorage::new());
648            let legacy_storage = Arc::new(LegacyInProcessStorage::new());
649
650            // 测试DashMap性能
651            let start = Instant::now();
652            let mut handles = Vec::new();
653
654            for thread_id in 0..NUM_THREADS {
655                let storage = Arc::clone(&dashmap_storage);
656                let handle = thread::spawn(move || {
657                    for i in 0..OPERATIONS_PER_THREAD {
658                        let pid = (thread_id * OPERATIONS_PER_THREAD + i) as u32;
659                        let record = TaskRecord::new(
660                            Utc::now(),
661                            format!("cmd_{}", pid),
662                            format!("/tmp/log_{}.log", pid),
663                            Some(pid),
664                        );
665                        storage.register(pid, &record).unwrap();
666                    }
667                });
668                handles.push(handle);
669            }
670
671            for handle in handles {
672                handle.join().unwrap();
673            }
674            let dashmap_duration = start.elapsed();
675
676            // 测试Mutex<HashMap>性能
677            let start = Instant::now();
678            let mut handles = Vec::new();
679
680            for thread_id in 0..NUM_THREADS {
681                let storage = Arc::clone(&legacy_storage);
682                let handle = thread::spawn(move || {
683                    for i in 0..OPERATIONS_PER_THREAD {
684                        let pid = (thread_id * OPERATIONS_PER_THREAD + i) as u32;
685                        let record = TaskRecord::new(
686                            Utc::now(),
687                            format!("cmd_{}", pid),
688                            format!("/tmp/log_{}.log", pid),
689                            Some(pid),
690                        );
691                        storage.register(pid, &record).unwrap();
692                    }
693                });
694                handles.push(handle);
695            }
696
697            for handle in handles {
698                handle.join().unwrap();
699            }
700            let legacy_duration = start.elapsed();
701
702            // 验证结果
703            let dashmap_entries = dashmap_storage.entries().unwrap();
704            let legacy_entries = legacy_storage.entries().unwrap();
705
706            assert_eq!(dashmap_entries.len(), NUM_PIDS);
707            assert_eq!(legacy_entries.len(), NUM_PIDS);
708
709            println!("=== 并发性能测试结果 ===");
710            println!("DashMap:   {:?} ({} 操作)", dashmap_duration, NUM_PIDS);
711            println!("Mutex<HashMap>: {:?} ({} 操作)", legacy_duration, NUM_PIDS);
712
713            if dashmap_duration < legacy_duration {
714                let speedup =
715                    legacy_duration.as_nanos() as f64 / dashmap_duration.as_nanos() as f64;
716                println!("DashMap 速度提升: {:.2}x", speedup);
717            }
718
719            // DashMap应该更快或至少不相差太大
720            assert!(
721                dashmap_duration <= legacy_duration * 2,
722                "DashMap performance regression detected"
723            );
724        }
725    }
726}