1use crate::{
2 config::{MAX_RECORD_AGE, SHARED_MEMORY_SIZE},
3 core::models::ProcessTreeInfo,
4 core::shared_map::open_or_create,
5 error::RegistryError,
6 logging::warn,
7 task_record::{TaskRecord, TaskStatus},
8};
9use chrono::{DateTime, Duration, Utc};
10use dashmap::DashMap;
11use parking_lot::Mutex;
12use shared_hashmap::SharedMemoryHashMap;
13use std::sync::Arc;
14
15#[derive(Debug, Clone)]
17pub struct RegistryEntry {
18 pub pid: u32,
19 pub key: String,
20 pub record: TaskRecord,
21}
22
23#[derive(Debug, Clone)]
25pub struct CleanupEvent {
26 pub _pid: u32,
27 pub record: TaskRecord,
28 pub reason: CleanupReason,
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum CleanupReason {
34 ProcessExited,
35 Timeout,
36 ManagerMissing,
37}
38
39pub trait TaskStorage: Send + Sync {
42 fn register(&self, pid: u32, record: &TaskRecord) -> Result<(), RegistryError>;
44
45 fn mark_completed(
47 &self,
48 pid: u32,
49 result: Option<String>,
50 exit_code: Option<i32>,
51 completed_at: DateTime<Utc>,
52 ) -> Result<(), RegistryError>;
53
54 fn entries(&self) -> Result<Vec<RegistryEntry>, RegistryError>;
56
57 fn sweep_stale_entries<F, G>(
59 &self,
60 now: DateTime<Utc>,
61 is_process_alive: F,
62 terminate_process: &G,
63 ) -> Result<Vec<CleanupEvent>, RegistryError>
64 where
65 F: Fn(u32) -> bool,
66 G: Fn(u32) -> Result<(), String>;
67
68 fn get_completed_unread_tasks(&self) -> Result<Vec<(u32, TaskRecord)>, RegistryError>;
70
71 fn has_running_tasks(&self, filter: Option<&ProcessTreeInfo>) -> Result<bool, RegistryError>;
73}
74
75#[derive(Debug, Clone)]
79pub struct InProcessStorage {
80 tasks: Arc<DashMap<u32, TaskRecord>>,
81}
82
83impl InProcessStorage {
84 pub fn new() -> Self {
85 Self {
86 tasks: Arc::new(DashMap::new()),
87 }
88 }
89}
90
91impl Default for InProcessStorage {
92 fn default() -> Self {
93 Self::new()
94 }
95}
96
97impl TaskStorage for InProcessStorage {
98 fn register(&self, pid: u32, record: &TaskRecord) -> Result<(), RegistryError> {
99 self.tasks.insert(pid, record.clone());
100 Ok(())
101 }
102
103 fn mark_completed(
104 &self,
105 pid: u32,
106 result: Option<String>,
107 exit_code: Option<i32>,
108 completed_at: DateTime<Utc>,
109 ) -> Result<(), RegistryError> {
110 if let Some(mut record) = self.tasks.get_mut(&pid) {
111 record.status = TaskStatus::CompletedButUnread;
112 record.result = result;
113 record.exit_code = exit_code;
114 record.completed_at = Some(completed_at);
115 } else {
116 return Err(RegistryError::TaskNotFound(pid));
117 }
118 Ok(())
119 }
120
121 fn entries(&self) -> Result<Vec<RegistryEntry>, RegistryError> {
122 Ok(self
123 .tasks
124 .iter()
125 .map(|entry| RegistryEntry {
126 pid: *entry.key(),
127 key: entry.key().to_string(),
128 record: entry.value().clone(),
129 })
130 .collect())
131 }
132
133 fn sweep_stale_entries<F, G>(
134 &self,
135 now: DateTime<Utc>,
136 is_process_alive: F,
137 terminate_process: &G,
138 ) -> Result<Vec<CleanupEvent>, RegistryError>
139 where
140 F: Fn(u32) -> bool,
141 G: Fn(u32) -> Result<(), String>,
142 {
143 const MAX_AGE_HOURS: i64 = 12;
144 let max_age = Duration::hours(MAX_AGE_HOURS);
145
146 let mut cleanup_events = Vec::new();
147
148 let pids_to_cleanup: Vec<(u32, CleanupReason)> = self
149 .tasks
150 .iter()
151 .filter_map(|entry| {
152 let pid = *entry.key();
153 let record = entry.value();
154
155 if !is_process_alive(pid) {
157 if record.status == TaskStatus::Running {
159 return Some((pid, CleanupReason::ProcessExited));
160 }
161 }
162
163 let age = now.signed_duration_since(record.started_at);
165 if age > max_age {
166 if record.status == TaskStatus::Running && is_process_alive(pid) {
167 let _ = terminate_process(pid);
169 return Some((pid, CleanupReason::Timeout));
170 }
171 }
172
173 None
174 })
175 .collect();
176
177 for (pid, cleanup_reason) in pids_to_cleanup {
178 if let Some(mut record) = self.tasks.get_mut(&pid) {
179 record.status = TaskStatus::CompletedButUnread;
180 record.completed_at = Some(now);
181 record.cleanup_reason = Some(
182 match cleanup_reason {
183 CleanupReason::ProcessExited => "process_exited",
184 CleanupReason::Timeout => "timeout",
185 CleanupReason::ManagerMissing => "manager_missing",
186 }
187 .to_string(),
188 );
189
190 cleanup_events.push(CleanupEvent {
191 _pid: pid,
192 record: record.clone(),
193 reason: cleanup_reason,
194 });
195 }
196 }
197
198 Ok(cleanup_events)
199 }
200
201 fn get_completed_unread_tasks(&self) -> Result<Vec<(u32, TaskRecord)>, RegistryError> {
202 let completed: Vec<(u32, TaskRecord)> = self
203 .tasks
204 .iter()
205 .filter_map(|entry| {
206 let pid = *entry.key();
207 let record = entry.value();
208
209 if record.status == TaskStatus::CompletedButUnread {
210 Some((pid, record.clone()))
212 } else {
213 None
214 }
215 })
216 .collect();
217
218 for (pid, _) in &completed {
220 self.tasks.remove(pid);
221 }
222
223 Ok(completed)
224 }
225
226 fn has_running_tasks(&self, filter: Option<&ProcessTreeInfo>) -> Result<bool, RegistryError> {
227 if let Some(tree_filter) = filter {
228 Ok(self.tasks.iter().any(|entry| {
229 let record = entry.value();
230 record.status == TaskStatus::Running
231 && record
232 .process_tree
233 .as_ref()
234 .map(|tree| tree.root_parent_pid == tree_filter.root_parent_pid)
235 .unwrap_or(false)
236 }))
237 } else {
238 Ok(self
239 .tasks
240 .iter()
241 .any(|entry| entry.value().status == TaskStatus::Running))
242 }
243 }
244}
245
246#[derive(Debug, Clone)]
249pub struct SharedMemoryStorage {
250 namespace: String,
251 map: Arc<Mutex<SharedMemoryHashMap<String, String>>>,
252}
253
254impl SharedMemoryStorage {
255 pub fn connect() -> Result<Self, RegistryError> {
258 let pid = std::process::id();
259 Self::connect_for_pid(pid)
260 }
261
262 pub fn connect_for_pid(pid: u32) -> Result<Self, RegistryError> {
265 let namespace = format!("{}_task", pid);
266 Self::connect_with_namespace(namespace)
267 }
268
269 pub fn connect_with_namespace(namespace: String) -> Result<Self, RegistryError> {
271 let map = open_or_create(&namespace, SHARED_MEMORY_SIZE)?;
272 Ok(Self {
273 namespace,
274 map: Arc::new(Mutex::new(map)),
275 })
276 }
277
278 pub fn cleanup(&self) -> Result<(), RegistryError> {
280 use shared_memory::ShmemConf;
281
282 if let Ok(mut shmem) = ShmemConf::new()
284 .os_id(&self.namespace)
285 .size(SHARED_MEMORY_SIZE)
286 .open()
287 {
288 let _ = shmem.set_owner(true);
289 }
290
291 Ok(())
292 }
293
294 fn with_map<T>(
295 &self,
296 f: impl FnOnce(&mut SharedMemoryHashMap<String, String>) -> Result<T, RegistryError>,
297 ) -> Result<T, RegistryError> {
298 let mut guard = self.map.lock();
299 f(&mut guard)
300 }
301
302 fn remove_keys(&self, keys: &[String]) -> Result<(), RegistryError> {
303 if keys.is_empty() {
304 return Ok(());
305 }
306 self.with_map(|map| {
307 for key in keys {
308 map.remove(key);
309 }
310 Ok(())
311 })
312 }
313}
314
315impl TaskStorage for SharedMemoryStorage {
316 fn register(&self, pid: u32, record: &TaskRecord) -> Result<(), RegistryError> {
317 let key = pid.to_string();
318 let value = serde_json::to_string(record)?;
319 self.with_map(|map| {
320 map.try_insert(key.clone(), value)?;
321 Ok(())
322 })
323 }
324
325 fn mark_completed(
326 &self,
327 pid: u32,
328 result: Option<String>,
329 exit_code: Option<i32>,
330 completed_at: DateTime<Utc>,
331 ) -> Result<(), RegistryError> {
332 let key = pid.to_string();
333 self.with_map(move |map| {
334 let existing = map
335 .get(&key)
336 .ok_or_else(|| RegistryError::Map(format!("no task found for pid {pid}")))?;
337 let record: TaskRecord = serde_json::from_str(&existing)?;
338 let updated_record = record.mark_completed(result, exit_code, completed_at);
339 let updated_value = serde_json::to_string(&updated_record)?;
340 let _ = map.insert(key.clone(), updated_value);
341 Ok(())
342 })
343 }
344
345 fn entries(&self) -> Result<Vec<RegistryEntry>, RegistryError> {
346 let snapshot: Vec<(String, String)> = {
347 let guard = self.map.lock();
348 guard.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
349 };
350
351 let mut entries = Vec::new();
352 let mut invalid_keys = Vec::new();
353
354 for (key, value) in snapshot {
355 match key.parse::<u32>() {
356 Ok(pid) => match serde_json::from_str::<TaskRecord>(&value) {
357 Ok(record) => entries.push(RegistryEntry {
358 pid,
359 key: key.clone(),
360 record,
361 }),
362 Err(err) => {
363 warn(format!("failed to parse task record pid={key}: {err}"));
364 invalid_keys.push(key);
365 }
366 },
367 Err(_) => {
368 warn(format!("detected invalid pid key: {key}"));
369 invalid_keys.push(key);
370 }
371 }
372 }
373
374 if !invalid_keys.is_empty() {
375 self.remove_keys(&invalid_keys)?;
376 }
377
378 Ok(entries)
379 }
380
381 fn sweep_stale_entries<F, G>(
382 &self,
383 now: DateTime<Utc>,
384 is_process_alive: F,
385 terminate_process: &G,
386 ) -> Result<Vec<CleanupEvent>, RegistryError>
387 where
388 F: Fn(u32) -> bool,
389 G: Fn(u32) -> Result<(), String>,
390 {
391 let entries = self.entries()?;
392 let mut removals = Vec::new();
393 let mut events = Vec::new();
394
395 for mut entry in entries {
396 let mut should_cleanup = false;
397 let mut cleanup_reason = CleanupReason::ProcessExited;
398
399 if !is_process_alive(entry.pid) {
401 should_cleanup = true;
402 cleanup_reason = CleanupReason::ProcessExited;
403 } else {
404 if let Some(_manager_pid) = entry.record.manager_pid.filter(|&manager_pid| {
406 manager_pid != entry.pid && !is_process_alive(manager_pid)
407 }) {
408 let _ = terminate_process(entry.pid);
409 should_cleanup = true;
410 cleanup_reason = CleanupReason::ManagerMissing;
411 }
412
413 if !should_cleanup {
415 let age = now.signed_duration_since(entry.record.started_at);
416 let max_age = Duration::from_std(MAX_RECORD_AGE).unwrap_or(Duration::zero());
417 if age > max_age {
418 let _ = terminate_process(entry.pid);
419 should_cleanup = true;
420 cleanup_reason = CleanupReason::Timeout;
421 }
422 }
423 }
424
425 if should_cleanup {
426 removals.push(entry.pid.to_string());
427
428 entry.record.cleanup_reason = Some(
430 match cleanup_reason {
431 CleanupReason::ProcessExited => "process_exited",
432 CleanupReason::Timeout => "timeout",
433 CleanupReason::ManagerMissing => "manager_missing",
434 }
435 .to_string(),
436 );
437
438 events.push(CleanupEvent {
439 _pid: entry.pid,
440 record: entry.record,
441 reason: cleanup_reason,
442 });
443 }
444 }
445
446 if !removals.is_empty() {
447 self.remove_keys(&removals)?;
448 }
449
450 Ok(events)
451 }
452
453 fn get_completed_unread_tasks(&self) -> Result<Vec<(u32, TaskRecord)>, RegistryError> {
454 let entries = self.entries()?;
455 let mut completed_pids = Vec::new();
456
457 for entry in &entries {
458 if entry.record.status == TaskStatus::CompletedButUnread {
459 completed_pids.push(entry.pid);
460 }
461 }
462
463 for pid in &completed_pids {
465 let key = pid.to_string();
466 let _ = self.with_map(|map| {
467 map.remove(&key);
468 Ok::<(), RegistryError>(())
469 });
470 }
471
472 let completed_tasks: Vec<(u32, TaskRecord)> = entries
474 .into_iter()
475 .filter(|entry| entry.record.status == TaskStatus::CompletedButUnread)
476 .map(|entry| (entry.pid, entry.record))
477 .collect();
478
479 Ok(completed_tasks)
480 }
481
482 fn has_running_tasks(&self, filter: Option<&ProcessTreeInfo>) -> Result<bool, RegistryError> {
483 let entries = self.entries()?;
484
485 if let Some(tree_filter) = filter {
486 Ok(entries.iter().any(|entry| {
487 entry.record.status == TaskStatus::Running
488 && entry
489 .record
490 .process_tree
491 .as_ref()
492 .map(|tree| tree.root_parent_pid == tree_filter.root_parent_pid)
493 .unwrap_or(false)
494 }))
495 } else {
496 Ok(entries
497 .iter()
498 .any(|entry| entry.record.status == TaskStatus::Running))
499 }
500 }
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506
507 #[test]
508 fn test_in_process_storage_register() {
509 let storage = InProcessStorage::new();
510 let record = TaskRecord::new(
511 Utc::now(),
512 "123".to_string(),
513 "/tmp/test.log".to_string(),
514 Some(100),
515 );
516
517 assert!(storage.register(123, &record).is_ok());
518 let entries = storage.entries().unwrap();
519 assert_eq!(entries.len(), 1);
520 assert_eq!(entries[0].pid, 123);
521 }
522
523 #[test]
524 fn test_in_process_storage_mark_completed() {
525 let storage = InProcessStorage::new();
526 let record = TaskRecord::new(
527 Utc::now(),
528 "456".to_string(),
529 "/tmp/test.log".to_string(),
530 Some(100),
531 );
532
533 storage.register(456, &record).unwrap();
534 storage
535 .mark_completed(456, Some("success".to_string()), Some(0), Utc::now())
536 .unwrap();
537
538 let completed = storage.get_completed_unread_tasks().unwrap();
539 assert_eq!(completed.len(), 1);
540 assert_eq!(completed[0].0, 456);
541 assert_eq!(completed[0].1.result, Some("success".to_string()));
542 }
543
544 #[test]
545 fn test_in_process_storage_sweep_stale() {
546 let storage = InProcessStorage::new();
547 let old_time = Utc::now() - Duration::hours(13);
548 let record = TaskRecord::new(
549 old_time,
550 "789".to_string(),
551 "/tmp/test.log".to_string(),
552 Some(100),
553 );
554
555 storage.register(789, &record).unwrap();
556
557 let is_alive = |_: u32| false;
558 let terminate = |_: u32| Ok(());
559
560 let events = storage
561 .sweep_stale_entries(Utc::now(), is_alive, &terminate)
562 .unwrap();
563
564 assert_eq!(events.len(), 1);
565 assert_eq!(events[0]._pid, 789);
566 }
567
568 #[cfg(test)]
569 mod concurrency_tests {
570 use super::*;
571 use std::collections::HashMap;
572 use std::sync::{Arc, Mutex as StdMutex};
573 use std::thread;
574 use std::time::Instant;
575
576 #[derive(Debug)]
578 struct LegacyInProcessStorage {
579 tasks: Arc<StdMutex<HashMap<u32, TaskRecord>>>,
580 }
581
582 impl LegacyInProcessStorage {
583 fn new() -> Self {
584 Self {
585 tasks: Arc::new(StdMutex::new(HashMap::new())),
586 }
587 }
588 }
589
590 impl TaskStorage for LegacyInProcessStorage {
591 fn register(&self, pid: u32, record: &TaskRecord) -> Result<(), RegistryError> {
592 let mut tasks = self.tasks.lock().unwrap();
593 tasks.insert(pid, record.clone());
594 Ok(())
595 }
596
597 fn entries(&self) -> Result<Vec<RegistryEntry>, RegistryError> {
598 let tasks = self.tasks.lock().unwrap();
599 Ok(tasks
600 .iter()
601 .map(|(&pid, record)| RegistryEntry {
602 pid,
603 key: pid.to_string(),
604 record: record.clone(),
605 })
606 .collect())
607 }
608
609 fn mark_completed(
611 &self,
612 _pid: u32,
613 _result: Option<String>,
614 _exit_code: Option<i32>,
615 _completed_at: DateTime<Utc>,
616 ) -> Result<(), RegistryError> {
617 Ok(())
618 }
619
620 fn sweep_stale_entries<F, G>(
621 &self,
622 _now: DateTime<Utc>,
623 _is_process_alive: F,
624 _terminate_process: &G,
625 ) -> Result<Vec<CleanupEvent>, RegistryError> {
626 Ok(Vec::new())
627 }
628
629 fn get_completed_unread_tasks(&self) -> Result<Vec<(u32, TaskRecord)>, RegistryError> {
630 Ok(Vec::new())
631 }
632
633 fn has_running_tasks(
634 &self,
635 _filter: Option<&ProcessTreeInfo>,
636 ) -> Result<bool, RegistryError> {
637 Ok(false)
638 }
639 }
640
641 #[test]
642 fn test_concurrent_performance_comparison() {
643 const NUM_THREADS: usize = 8;
644 const OPERATIONS_PER_THREAD: usize = 1000;
645 const NUM_PIDS: usize = NUM_THREADS * OPERATIONS_PER_THREAD;
646
647 let dashmap_storage = Arc::new(InProcessStorage::new());
648 let legacy_storage = Arc::new(LegacyInProcessStorage::new());
649
650 let start = Instant::now();
652 let mut handles = Vec::new();
653
654 for thread_id in 0..NUM_THREADS {
655 let storage = Arc::clone(&dashmap_storage);
656 let handle = thread::spawn(move || {
657 for i in 0..OPERATIONS_PER_THREAD {
658 let pid = (thread_id * OPERATIONS_PER_THREAD + i) as u32;
659 let record = TaskRecord::new(
660 Utc::now(),
661 format!("cmd_{}", pid),
662 format!("/tmp/log_{}.log", pid),
663 Some(pid),
664 );
665 storage.register(pid, &record).unwrap();
666 }
667 });
668 handles.push(handle);
669 }
670
671 for handle in handles {
672 handle.join().unwrap();
673 }
674 let dashmap_duration = start.elapsed();
675
676 let start = Instant::now();
678 let mut handles = Vec::new();
679
680 for thread_id in 0..NUM_THREADS {
681 let storage = Arc::clone(&legacy_storage);
682 let handle = thread::spawn(move || {
683 for i in 0..OPERATIONS_PER_THREAD {
684 let pid = (thread_id * OPERATIONS_PER_THREAD + i) as u32;
685 let record = TaskRecord::new(
686 Utc::now(),
687 format!("cmd_{}", pid),
688 format!("/tmp/log_{}.log", pid),
689 Some(pid),
690 );
691 storage.register(pid, &record).unwrap();
692 }
693 });
694 handles.push(handle);
695 }
696
697 for handle in handles {
698 handle.join().unwrap();
699 }
700 let legacy_duration = start.elapsed();
701
702 let dashmap_entries = dashmap_storage.entries().unwrap();
704 let legacy_entries = legacy_storage.entries().unwrap();
705
706 assert_eq!(dashmap_entries.len(), NUM_PIDS);
707 assert_eq!(legacy_entries.len(), NUM_PIDS);
708
709 println!("=== 并发性能测试结果 ===");
710 println!("DashMap: {:?} ({} 操作)", dashmap_duration, NUM_PIDS);
711 println!("Mutex<HashMap>: {:?} ({} 操作)", legacy_duration, NUM_PIDS);
712
713 if dashmap_duration < legacy_duration {
714 let speedup =
715 legacy_duration.as_nanos() as f64 / dashmap_duration.as_nanos() as f64;
716 println!("DashMap 速度提升: {:.2}x", speedup);
717 }
718
719 assert!(
721 dashmap_duration <= legacy_duration * 2,
722 "DashMap performance regression detected"
723 );
724 }
725 }
726}