1use crate::router::PatternMatcher;
30use serde::{Deserialize, Serialize};
31use std::collections::{HashMap, HashSet};
32use std::sync::{Arc, RwLock};
33use std::time::{Duration, SystemTime, UNIX_EPOCH};
34use uuid::Uuid;
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
38#[serde(rename_all = "lowercase")]
39pub enum RevocationMode {
40 #[default]
42 Terminate,
43 Ignore,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct RevocationRequest {
50 pub task_id: Uuid,
52 pub mode: RevocationMode,
54 pub timestamp: f64,
56 pub expires: Option<f64>,
58 pub reason: Option<String>,
60 pub signal: Option<String>,
62}
63
64impl RevocationRequest {
65 #[must_use]
67 pub fn new(task_id: Uuid, mode: RevocationMode) -> Self {
68 Self {
69 task_id,
70 mode,
71 timestamp: current_timestamp(),
72 expires: None,
73 reason: None,
74 signal: None,
75 }
76 }
77
78 #[must_use]
80 pub fn with_expiration(mut self, expires_in: Duration) -> Self {
81 self.expires = Some(current_timestamp() + expires_in.as_secs_f64());
82 self
83 }
84
85 #[must_use]
87 pub fn with_reason(mut self, reason: impl Into<String>) -> Self {
88 self.reason = Some(reason.into());
89 self
90 }
91
92 #[must_use]
94 pub fn with_signal(mut self, signal: impl Into<String>) -> Self {
95 self.signal = Some(signal.into());
96 self
97 }
98
99 #[inline]
101 #[must_use]
102 pub fn is_expired(&self) -> bool {
103 if let Some(expires) = self.expires {
104 current_timestamp() > expires
105 } else {
106 false
107 }
108 }
109}
110
111#[derive(Debug, Clone)]
113pub struct PatternRevocation {
114 pub pattern: PatternMatcher,
116 pub mode: RevocationMode,
118 pub timestamp: f64,
120 pub expires: Option<f64>,
122 pub reason: Option<String>,
124}
125
126impl PatternRevocation {
127 #[must_use]
129 pub fn new(pattern: PatternMatcher, mode: RevocationMode) -> Self {
130 Self {
131 pattern,
132 mode,
133 timestamp: current_timestamp(),
134 expires: None,
135 reason: None,
136 }
137 }
138
139 #[must_use]
141 pub fn with_expiration(mut self, expires_in: Duration) -> Self {
142 self.expires = Some(current_timestamp() + expires_in.as_secs_f64());
143 self
144 }
145
146 #[must_use]
148 pub fn with_reason(mut self, reason: impl Into<String>) -> Self {
149 self.reason = Some(reason.into());
150 self
151 }
152
153 #[inline]
155 #[must_use]
156 pub fn is_expired(&self) -> bool {
157 if let Some(expires) = self.expires {
158 current_timestamp() > expires
159 } else {
160 false
161 }
162 }
163
164 #[inline]
166 #[must_use]
167 pub fn matches(&self, task_name: &str) -> bool {
168 self.pattern.matches(task_name)
169 }
170}
171
172#[derive(Debug, Clone)]
174pub struct RevocationResult {
175 pub revoked: bool,
177 pub mode: RevocationMode,
179 pub reason: Option<String>,
181 pub signal: Option<String>,
183}
184
185impl RevocationResult {
186 #[must_use]
188 pub fn not_revoked() -> Self {
189 Self {
190 revoked: false,
191 mode: RevocationMode::Ignore,
192 reason: None,
193 signal: None,
194 }
195 }
196
197 #[must_use]
199 pub fn revoked(mode: RevocationMode, reason: Option<String>, signal: Option<String>) -> Self {
200 Self {
201 revoked: true,
202 mode,
203 reason,
204 signal,
205 }
206 }
207}
208
209#[derive(Debug, Clone, Serialize, Deserialize, Default)]
211pub struct RevocationState {
212 pub revoked_tasks: HashMap<String, RevocationRequest>,
214 pub pattern_revocations: Vec<SerializablePatternRevocation>,
216}
217
218#[derive(Debug, Clone, Serialize, Deserialize)]
220pub struct SerializablePatternRevocation {
221 pub pattern: String,
223 pub mode: RevocationMode,
225 pub timestamp: f64,
227 pub expires: Option<f64>,
229 pub reason: Option<String>,
231}
232
233impl From<&PatternRevocation> for SerializablePatternRevocation {
234 fn from(rev: &PatternRevocation) -> Self {
235 let pattern = match &rev.pattern {
237 PatternMatcher::Exact(s) => s.clone(),
238 PatternMatcher::Glob(g) => g.pattern().to_string(),
239 PatternMatcher::Regex(r) => r.pattern().to_string(),
240 PatternMatcher::All => "*".to_string(),
241 };
242 Self {
243 pattern,
244 mode: rev.mode,
245 timestamp: rev.timestamp,
246 expires: rev.expires,
247 reason: rev.reason.clone(),
248 }
249 }
250}
251
252impl SerializablePatternRevocation {
253 #[must_use]
255 pub fn into_pattern_revocation(self) -> PatternRevocation {
256 PatternRevocation {
257 pattern: PatternMatcher::glob(&self.pattern),
258 mode: self.mode,
259 timestamp: self.timestamp,
260 expires: self.expires,
261 reason: self.reason,
262 }
263 }
264}
265
266#[derive(Debug, Default)]
268pub struct RevocationManager {
269 revoked_ids: HashMap<Uuid, RevocationRequest>,
271 pattern_revocations: Vec<PatternRevocation>,
273 terminated: HashSet<Uuid>,
275}
276
277impl RevocationManager {
278 #[must_use]
280 pub fn new() -> Self {
281 Self::default()
282 }
283
284 pub fn revoke(&mut self, task_id: Uuid, mode: RevocationMode) {
286 let request = RevocationRequest::new(task_id, mode);
287 self.revoked_ids.insert(task_id, request);
288 }
289
290 pub fn revoke_with_request(&mut self, request: RevocationRequest) {
292 self.revoked_ids.insert(request.task_id, request);
293 }
294
295 pub fn revoke_by_pattern(&mut self, pattern: &str, mode: RevocationMode) {
297 let pattern_rev = PatternRevocation::new(PatternMatcher::glob(pattern), mode);
298 self.pattern_revocations.push(pattern_rev);
299 }
300
301 pub fn revoke_with_pattern(&mut self, revocation: PatternRevocation) {
303 self.pattern_revocations.push(revocation);
304 }
305
306 pub fn bulk_revoke(&mut self, task_ids: &[Uuid], mode: RevocationMode) {
308 for &task_id in task_ids {
309 self.revoke(task_id, mode);
310 }
311 }
312
313 #[inline]
315 #[must_use]
316 pub fn is_revoked(&self, task_id: Uuid) -> bool {
317 if let Some(request) = self.revoked_ids.get(&task_id) {
318 !request.is_expired()
319 } else {
320 false
321 }
322 }
323
324 #[must_use]
326 pub fn check_revocation(&self, task_id: Uuid, task_name: &str) -> RevocationResult {
327 if let Some(request) = self.revoked_ids.get(&task_id) {
329 if !request.is_expired() {
330 return RevocationResult::revoked(
331 request.mode,
332 request.reason.clone(),
333 request.signal.clone(),
334 );
335 }
336 }
337
338 for pattern_rev in &self.pattern_revocations {
340 if !pattern_rev.is_expired() && pattern_rev.matches(task_name) {
341 return RevocationResult::revoked(
342 pattern_rev.mode,
343 pattern_rev.reason.clone(),
344 None,
345 );
346 }
347 }
348
349 RevocationResult::not_revoked()
350 }
351
352 pub fn mark_terminated(&mut self, task_id: Uuid) {
354 self.terminated.insert(task_id);
355 }
356
357 #[inline]
359 #[must_use]
360 pub fn is_terminated(&self, task_id: Uuid) -> bool {
361 self.terminated.contains(&task_id)
362 }
363
364 pub fn unrevoke(&mut self, task_id: Uuid) {
366 self.revoked_ids.remove(&task_id);
367 }
368
369 pub fn remove_pattern(&mut self, pattern: &str) {
371 self.pattern_revocations.retain(|p| {
372 if let PatternMatcher::Glob(g) = &p.pattern {
373 g.pattern() != pattern
374 } else {
375 true
376 }
377 });
378 }
379
380 pub fn cleanup_expired(&mut self) {
382 self.revoked_ids.retain(|_, request| !request.is_expired());
383 self.pattern_revocations.retain(|rev| !rev.is_expired());
384 }
385
386 #[must_use]
388 pub fn revoked_ids(&self) -> Vec<Uuid> {
389 self.revoked_ids
390 .iter()
391 .filter(|(_, request)| !request.is_expired())
392 .map(|(id, _)| *id)
393 .collect()
394 }
395
396 #[inline]
398 #[must_use]
399 pub fn revoked_count(&self) -> usize {
400 self.revoked_ids
401 .values()
402 .filter(|request| !request.is_expired())
403 .count()
404 }
405
406 pub fn clear(&mut self) {
408 self.revoked_ids.clear();
409 self.pattern_revocations.clear();
410 self.terminated.clear();
411 }
412
413 pub fn export_state(&self) -> RevocationState {
415 let revoked_tasks = self
416 .revoked_ids
417 .iter()
418 .filter(|(_, req)| !req.is_expired())
419 .map(|(id, req)| (id.to_string(), req.clone()))
420 .collect();
421
422 let pattern_revocations = self
423 .pattern_revocations
424 .iter()
425 .filter(|rev| !rev.is_expired())
426 .map(SerializablePatternRevocation::from)
427 .collect();
428
429 RevocationState {
430 revoked_tasks,
431 pattern_revocations,
432 }
433 }
434
435 pub fn import_state(&mut self, state: RevocationState) {
437 for (id_str, request) in state.revoked_tasks {
438 if !request.is_expired() {
439 if let Ok(id) = Uuid::parse_str(&id_str) {
440 self.revoked_ids.insert(id, request);
441 }
442 }
443 }
444
445 for ser_pattern in state.pattern_revocations {
446 let pattern_rev = ser_pattern.into_pattern_revocation();
447 if !pattern_rev.is_expired() {
448 self.pattern_revocations.push(pattern_rev);
449 }
450 }
451 }
452}
453
454#[derive(Debug, Clone, Default)]
456pub struct WorkerRevocationManager {
457 inner: Arc<RwLock<RevocationManager>>,
458}
459
460impl WorkerRevocationManager {
461 #[must_use]
463 pub fn new() -> Self {
464 Self::default()
465 }
466
467 pub fn revoke(&self, task_id: Uuid, mode: RevocationMode) {
469 if let Ok(mut guard) = self.inner.write() {
470 guard.revoke(task_id, mode);
471 }
472 }
473
474 pub fn revoke_with_request(&self, request: RevocationRequest) {
476 if let Ok(mut guard) = self.inner.write() {
477 guard.revoke_with_request(request);
478 }
479 }
480
481 pub fn revoke_by_pattern(&self, pattern: &str, mode: RevocationMode) {
483 if let Ok(mut guard) = self.inner.write() {
484 guard.revoke_by_pattern(pattern, mode);
485 }
486 }
487
488 pub fn bulk_revoke(&self, task_ids: &[Uuid], mode: RevocationMode) {
490 if let Ok(mut guard) = self.inner.write() {
491 guard.bulk_revoke(task_ids, mode);
492 }
493 }
494
495 #[must_use]
497 pub fn is_revoked(&self, task_id: Uuid) -> bool {
498 if let Ok(guard) = self.inner.read() {
499 guard.is_revoked(task_id)
500 } else {
501 false
502 }
503 }
504
505 #[must_use]
507 pub fn check_revocation(&self, task_id: Uuid, task_name: &str) -> RevocationResult {
508 if let Ok(guard) = self.inner.read() {
509 guard.check_revocation(task_id, task_name)
510 } else {
511 RevocationResult::not_revoked()
512 }
513 }
514
515 pub fn mark_terminated(&self, task_id: Uuid) {
517 if let Ok(mut guard) = self.inner.write() {
518 guard.mark_terminated(task_id);
519 }
520 }
521
522 #[must_use]
524 pub fn is_terminated(&self, task_id: Uuid) -> bool {
525 if let Ok(guard) = self.inner.read() {
526 guard.is_terminated(task_id)
527 } else {
528 false
529 }
530 }
531
532 pub fn unrevoke(&self, task_id: Uuid) {
534 if let Ok(mut guard) = self.inner.write() {
535 guard.unrevoke(task_id);
536 }
537 }
538
539 pub fn cleanup_expired(&self) {
541 if let Ok(mut guard) = self.inner.write() {
542 guard.cleanup_expired();
543 }
544 }
545
546 #[must_use]
548 pub fn revoked_ids(&self) -> Vec<Uuid> {
549 if let Ok(guard) = self.inner.read() {
550 guard.revoked_ids()
551 } else {
552 Vec::new()
553 }
554 }
555
556 #[must_use]
558 pub fn revoked_count(&self) -> usize {
559 if let Ok(guard) = self.inner.read() {
560 guard.revoked_count()
561 } else {
562 0
563 }
564 }
565
566 #[must_use]
568 pub fn export_state(&self) -> RevocationState {
569 if let Ok(guard) = self.inner.read() {
570 guard.export_state()
571 } else {
572 RevocationState::default()
573 }
574 }
575
576 pub fn import_state(&self, state: RevocationState) {
578 if let Ok(mut guard) = self.inner.write() {
579 guard.import_state(state);
580 }
581 }
582
583 pub fn clear(&self) {
585 if let Ok(mut guard) = self.inner.write() {
586 guard.clear();
587 }
588 }
589}
590
591fn current_timestamp() -> f64 {
593 SystemTime::now()
594 .duration_since(UNIX_EPOCH)
595 .unwrap_or_default()
596 .as_secs_f64()
597}
598
599#[cfg(test)]
600mod tests {
601 use super::*;
602
603 #[test]
604 fn test_revocation_request() {
605 let task_id = Uuid::new_v4();
606 let request = RevocationRequest::new(task_id, RevocationMode::Terminate);
607
608 assert_eq!(request.task_id, task_id);
609 assert_eq!(request.mode, RevocationMode::Terminate);
610 assert!(!request.is_expired());
611 }
612
613 #[test]
614 fn test_revocation_request_with_expiration() {
615 let task_id = Uuid::new_v4();
616 let request = RevocationRequest::new(task_id, RevocationMode::Terminate)
617 .with_expiration(Duration::from_secs(0));
618
619 std::thread::sleep(Duration::from_millis(10));
621 assert!(request.is_expired());
622 }
623
624 #[test]
625 fn test_revocation_manager_basic() {
626 let mut manager = RevocationManager::new();
627 let task_id = Uuid::new_v4();
628
629 manager.revoke(task_id, RevocationMode::Terminate);
630 assert!(manager.is_revoked(task_id));
631
632 let other_id = Uuid::new_v4();
633 assert!(!manager.is_revoked(other_id));
634 }
635
636 #[test]
637 fn test_revocation_by_pattern() {
638 let mut manager = RevocationManager::new();
639
640 manager.revoke_by_pattern("email.*", RevocationMode::Ignore);
641
642 let result = manager.check_revocation(Uuid::new_v4(), "email.send");
643 assert!(result.revoked);
644 assert_eq!(result.mode, RevocationMode::Ignore);
645
646 let result = manager.check_revocation(Uuid::new_v4(), "sms.send");
647 assert!(!result.revoked);
648 }
649
650 #[test]
651 fn test_bulk_revoke() {
652 let mut manager = RevocationManager::new();
653 let ids: Vec<Uuid> = (0..5).map(|_| Uuid::new_v4()).collect();
654
655 manager.bulk_revoke(&ids, RevocationMode::Terminate);
656
657 for id in &ids {
658 assert!(manager.is_revoked(*id));
659 }
660 }
661
662 #[test]
663 fn test_unrevoke() {
664 let mut manager = RevocationManager::new();
665 let task_id = Uuid::new_v4();
666
667 manager.revoke(task_id, RevocationMode::Terminate);
668 assert!(manager.is_revoked(task_id));
669
670 manager.unrevoke(task_id);
671 assert!(!manager.is_revoked(task_id));
672 }
673
674 #[test]
675 fn test_cleanup_expired() {
676 let mut manager = RevocationManager::new();
677 let task_id = Uuid::new_v4();
678
679 let request = RevocationRequest::new(task_id, RevocationMode::Terminate)
681 .with_expiration(Duration::from_secs(0));
682 std::thread::sleep(Duration::from_millis(10));
683 manager.revoke_with_request(request);
684
685 let other_id = Uuid::new_v4();
687 manager.revoke(other_id, RevocationMode::Terminate);
688
689 manager.cleanup_expired();
690
691 assert!(!manager.is_revoked(task_id)); assert!(manager.is_revoked(other_id)); }
694
695 #[test]
696 fn test_export_import_state() {
697 let mut manager = RevocationManager::new();
698 let task_id = Uuid::new_v4();
699
700 manager.revoke(task_id, RevocationMode::Terminate);
701 manager.revoke_by_pattern("email.*", RevocationMode::Ignore);
702
703 let state = manager.export_state();
704
705 let mut new_manager = RevocationManager::new();
706 new_manager.import_state(state);
707
708 assert!(new_manager.is_revoked(task_id));
709 let result = new_manager.check_revocation(Uuid::new_v4(), "email.send");
710 assert!(result.revoked);
711 }
712
713 #[test]
714 fn test_worker_revocation_manager() {
715 let manager = WorkerRevocationManager::new();
716 let task_id = Uuid::new_v4();
717
718 manager.revoke(task_id, RevocationMode::Terminate);
719 assert!(manager.is_revoked(task_id));
720
721 manager.mark_terminated(task_id);
722 assert!(manager.is_terminated(task_id));
723 }
724
725 #[test]
726 fn test_revocation_state_serialization() {
727 let mut manager = RevocationManager::new();
728 let task_id = Uuid::new_v4();
729
730 manager.revoke(task_id, RevocationMode::Terminate);
731 manager.revoke_by_pattern("tasks.*", RevocationMode::Ignore);
732
733 let state = manager.export_state();
734 let json = serde_json::to_string(&state).unwrap();
735 let parsed: RevocationState = serde_json::from_str(&json).unwrap();
736
737 assert!(!parsed.revoked_tasks.is_empty());
738 assert!(!parsed.pattern_revocations.is_empty());
739 }
740
741 #[test]
742 fn test_revocation_with_reason() {
743 let mut manager = RevocationManager::new();
744 let task_id = Uuid::new_v4();
745
746 let request = RevocationRequest::new(task_id, RevocationMode::Terminate)
747 .with_reason("Manual cancellation by user");
748 manager.revoke_with_request(request);
749
750 let result = manager.check_revocation(task_id, "any.task");
751 assert!(result.revoked);
752 assert_eq!(
753 result.reason,
754 Some("Manual cancellation by user".to_string())
755 );
756 }
757
758 #[test]
759 fn test_revoked_count() {
760 let mut manager = RevocationManager::new();
761
762 for _ in 0..5 {
763 manager.revoke(Uuid::new_v4(), RevocationMode::Terminate);
764 }
765
766 assert_eq!(manager.revoked_count(), 5);
767 assert_eq!(manager.revoked_ids().len(), 5);
768 }
769
770 #[test]
771 fn test_clear() {
772 let mut manager = RevocationManager::new();
773
774 manager.revoke(Uuid::new_v4(), RevocationMode::Terminate);
775 manager.revoke_by_pattern("*", RevocationMode::Ignore);
776 manager.mark_terminated(Uuid::new_v4());
777
778 manager.clear();
779
780 assert_eq!(manager.revoked_count(), 0);
781 }
782}