celers_core/
revocation.rs

1//! Task Revocation
2//!
3//! This module provides enhanced task revocation capabilities:
4//!
5//! - **Revoke by ID**: Revoke a specific task by its ID
6//! - **Revoke by Pattern**: Revoke tasks matching a name pattern (glob or regex)
7//! - **Bulk Revocation**: Revoke multiple tasks at once
8//! - **Persistent Revocation**: Revocations that survive worker restarts
9//!
10//! # Example
11//!
12//! ```rust
13//! use celers_core::revocation::{RevocationManager, RevocationMode};
14//! use uuid::Uuid;
15//!
16//! let mut manager = RevocationManager::new();
17//! let task_id = Uuid::new_v4();
18//!
19//! // Revoke a single task
20//! manager.revoke(task_id, RevocationMode::Terminate);
21//!
22//! // Revoke all tasks matching a pattern
23//! manager.revoke_by_pattern("email.*", RevocationMode::Ignore);
24//!
25//! // Check if a task is revoked
26//! assert!(manager.is_revoked(task_id));
27//! ```
28
29use crate::router::PatternMatcher;
30use serde::{Deserialize, Serialize};
31use std::collections::{HashMap, HashSet};
32use std::sync::{Arc, RwLock};
33use std::time::{Duration, SystemTime, UNIX_EPOCH};
34use uuid::Uuid;
35
36/// Revocation mode for how to handle a revoked task
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
38#[serde(rename_all = "lowercase")]
39pub enum RevocationMode {
40    /// Terminate the task if already running
41    #[default]
42    Terminate,
43    /// Ignore the task (don't execute but don't terminate if running)
44    Ignore,
45}
46
47/// A request to revoke a task
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct RevocationRequest {
50    /// Task ID to revoke
51    pub task_id: Uuid,
52    /// Revocation mode
53    pub mode: RevocationMode,
54    /// When the revocation was issued (Unix timestamp)
55    pub timestamp: f64,
56    /// Optional expiration time (Unix timestamp)
57    pub expires: Option<f64>,
58    /// Reason for revocation
59    pub reason: Option<String>,
60    /// Signal to send (for terminate mode)
61    pub signal: Option<String>,
62}
63
64impl RevocationRequest {
65    /// Create a new revocation request
66    #[must_use]
67    pub fn new(task_id: Uuid, mode: RevocationMode) -> Self {
68        Self {
69            task_id,
70            mode,
71            timestamp: current_timestamp(),
72            expires: None,
73            reason: None,
74            signal: None,
75        }
76    }
77
78    /// Set expiration time
79    #[must_use]
80    pub fn with_expiration(mut self, expires_in: Duration) -> Self {
81        self.expires = Some(current_timestamp() + expires_in.as_secs_f64());
82        self
83    }
84
85    /// Set reason for revocation
86    #[must_use]
87    pub fn with_reason(mut self, reason: impl Into<String>) -> Self {
88        self.reason = Some(reason.into());
89        self
90    }
91
92    /// Set signal to send (for terminate mode)
93    #[must_use]
94    pub fn with_signal(mut self, signal: impl Into<String>) -> Self {
95        self.signal = Some(signal.into());
96        self
97    }
98
99    /// Check if this revocation has expired
100    #[inline]
101    #[must_use]
102    pub fn is_expired(&self) -> bool {
103        if let Some(expires) = self.expires {
104            current_timestamp() > expires
105        } else {
106            false
107        }
108    }
109}
110
111/// A pattern-based revocation rule
112#[derive(Debug, Clone)]
113pub struct PatternRevocation {
114    /// Pattern matcher for task names
115    pub pattern: PatternMatcher,
116    /// Revocation mode
117    pub mode: RevocationMode,
118    /// When the revocation was issued
119    pub timestamp: f64,
120    /// Optional expiration time
121    pub expires: Option<f64>,
122    /// Reason for revocation
123    pub reason: Option<String>,
124}
125
126impl PatternRevocation {
127    /// Create a new pattern revocation
128    #[must_use]
129    pub fn new(pattern: PatternMatcher, mode: RevocationMode) -> Self {
130        Self {
131            pattern,
132            mode,
133            timestamp: current_timestamp(),
134            expires: None,
135            reason: None,
136        }
137    }
138
139    /// Set expiration time
140    #[must_use]
141    pub fn with_expiration(mut self, expires_in: Duration) -> Self {
142        self.expires = Some(current_timestamp() + expires_in.as_secs_f64());
143        self
144    }
145
146    /// Set reason for revocation
147    #[must_use]
148    pub fn with_reason(mut self, reason: impl Into<String>) -> Self {
149        self.reason = Some(reason.into());
150        self
151    }
152
153    /// Check if this revocation has expired
154    #[inline]
155    #[must_use]
156    pub fn is_expired(&self) -> bool {
157        if let Some(expires) = self.expires {
158            current_timestamp() > expires
159        } else {
160            false
161        }
162    }
163
164    /// Check if a task name matches this pattern
165    #[inline]
166    #[must_use]
167    pub fn matches(&self, task_name: &str) -> bool {
168        self.pattern.matches(task_name)
169    }
170}
171
172/// Result of a revocation check
173#[derive(Debug, Clone)]
174pub struct RevocationResult {
175    /// Whether the task is revoked
176    pub revoked: bool,
177    /// Revocation mode
178    pub mode: RevocationMode,
179    /// Reason for revocation
180    pub reason: Option<String>,
181    /// Signal to send (for terminate mode)
182    pub signal: Option<String>,
183}
184
185impl RevocationResult {
186    /// Create a result indicating task is not revoked
187    #[must_use]
188    pub fn not_revoked() -> Self {
189        Self {
190            revoked: false,
191            mode: RevocationMode::Ignore,
192            reason: None,
193            signal: None,
194        }
195    }
196
197    /// Create a result indicating task is revoked
198    #[must_use]
199    pub fn revoked(mode: RevocationMode, reason: Option<String>, signal: Option<String>) -> Self {
200        Self {
201            revoked: true,
202            mode,
203            reason,
204            signal,
205        }
206    }
207}
208
209/// Serializable revocation state for persistence
210#[derive(Debug, Clone, Serialize, Deserialize, Default)]
211pub struct RevocationState {
212    /// Revoked task IDs (`task_id` -> request)
213    pub revoked_tasks: HashMap<String, RevocationRequest>,
214    /// Pattern-based revocations (serializable form)
215    pub pattern_revocations: Vec<SerializablePatternRevocation>,
216}
217
218/// Serializable form of `PatternRevocation`
219#[derive(Debug, Clone, Serialize, Deserialize)]
220pub struct SerializablePatternRevocation {
221    /// Pattern string (glob format)
222    pub pattern: String,
223    /// Revocation mode
224    pub mode: RevocationMode,
225    /// When the revocation was issued
226    pub timestamp: f64,
227    /// Optional expiration time
228    pub expires: Option<f64>,
229    /// Reason for revocation
230    pub reason: Option<String>,
231}
232
233impl From<&PatternRevocation> for SerializablePatternRevocation {
234    fn from(rev: &PatternRevocation) -> Self {
235        // Extract pattern string (simplified - assumes glob pattern)
236        let pattern = match &rev.pattern {
237            PatternMatcher::Exact(s) => s.clone(),
238            PatternMatcher::Glob(g) => g.pattern().to_string(),
239            PatternMatcher::Regex(r) => r.pattern().to_string(),
240            PatternMatcher::All => "*".to_string(),
241        };
242        Self {
243            pattern,
244            mode: rev.mode,
245            timestamp: rev.timestamp,
246            expires: rev.expires,
247            reason: rev.reason.clone(),
248        }
249    }
250}
251
252impl SerializablePatternRevocation {
253    /// Convert to `PatternRevocation`
254    #[must_use]
255    pub fn into_pattern_revocation(self) -> PatternRevocation {
256        PatternRevocation {
257            pattern: PatternMatcher::glob(&self.pattern),
258            mode: self.mode,
259            timestamp: self.timestamp,
260            expires: self.expires,
261            reason: self.reason,
262        }
263    }
264}
265
266/// Revocation manager for tracking revoked tasks
267#[derive(Debug, Default)]
268pub struct RevocationManager {
269    /// Revoked task IDs
270    revoked_ids: HashMap<Uuid, RevocationRequest>,
271    /// Pattern-based revocations
272    pattern_revocations: Vec<PatternRevocation>,
273    /// Set of currently terminated task IDs
274    terminated: HashSet<Uuid>,
275}
276
277impl RevocationManager {
278    /// Create a new revocation manager
279    #[must_use]
280    pub fn new() -> Self {
281        Self::default()
282    }
283
284    /// Revoke a task by ID
285    pub fn revoke(&mut self, task_id: Uuid, mode: RevocationMode) {
286        let request = RevocationRequest::new(task_id, mode);
287        self.revoked_ids.insert(task_id, request);
288    }
289
290    /// Revoke a task with a full request
291    pub fn revoke_with_request(&mut self, request: RevocationRequest) {
292        self.revoked_ids.insert(request.task_id, request);
293    }
294
295    /// Revoke all tasks matching a pattern
296    pub fn revoke_by_pattern(&mut self, pattern: &str, mode: RevocationMode) {
297        let pattern_rev = PatternRevocation::new(PatternMatcher::glob(pattern), mode);
298        self.pattern_revocations.push(pattern_rev);
299    }
300
301    /// Revoke by pattern with full configuration
302    pub fn revoke_with_pattern(&mut self, revocation: PatternRevocation) {
303        self.pattern_revocations.push(revocation);
304    }
305
306    /// Bulk revoke multiple tasks
307    pub fn bulk_revoke(&mut self, task_ids: &[Uuid], mode: RevocationMode) {
308        for &task_id in task_ids {
309            self.revoke(task_id, mode);
310        }
311    }
312
313    /// Check if a task is revoked (by ID)
314    #[inline]
315    #[must_use]
316    pub fn is_revoked(&self, task_id: Uuid) -> bool {
317        if let Some(request) = self.revoked_ids.get(&task_id) {
318            !request.is_expired()
319        } else {
320            false
321        }
322    }
323
324    /// Check if a task should be revoked (by ID or pattern)
325    #[must_use]
326    pub fn check_revocation(&self, task_id: Uuid, task_name: &str) -> RevocationResult {
327        // Check by ID first
328        if let Some(request) = self.revoked_ids.get(&task_id) {
329            if !request.is_expired() {
330                return RevocationResult::revoked(
331                    request.mode,
332                    request.reason.clone(),
333                    request.signal.clone(),
334                );
335            }
336        }
337
338        // Check by pattern
339        for pattern_rev in &self.pattern_revocations {
340            if !pattern_rev.is_expired() && pattern_rev.matches(task_name) {
341                return RevocationResult::revoked(
342                    pattern_rev.mode,
343                    pattern_rev.reason.clone(),
344                    None,
345                );
346            }
347        }
348
349        RevocationResult::not_revoked()
350    }
351
352    /// Mark a task as terminated
353    pub fn mark_terminated(&mut self, task_id: Uuid) {
354        self.terminated.insert(task_id);
355    }
356
357    /// Check if a task has been terminated
358    #[inline]
359    #[must_use]
360    pub fn is_terminated(&self, task_id: Uuid) -> bool {
361        self.terminated.contains(&task_id)
362    }
363
364    /// Remove revocation for a task ID
365    pub fn unrevoke(&mut self, task_id: Uuid) {
366        self.revoked_ids.remove(&task_id);
367    }
368
369    /// Remove pattern-based revocations matching a pattern string
370    pub fn remove_pattern(&mut self, pattern: &str) {
371        self.pattern_revocations.retain(|p| {
372            if let PatternMatcher::Glob(g) = &p.pattern {
373                g.pattern() != pattern
374            } else {
375                true
376            }
377        });
378    }
379
380    /// Clean up expired revocations
381    pub fn cleanup_expired(&mut self) {
382        self.revoked_ids.retain(|_, request| !request.is_expired());
383        self.pattern_revocations.retain(|rev| !rev.is_expired());
384    }
385
386    /// Get all revoked task IDs
387    #[must_use]
388    pub fn revoked_ids(&self) -> Vec<Uuid> {
389        self.revoked_ids
390            .iter()
391            .filter(|(_, request)| !request.is_expired())
392            .map(|(id, _)| *id)
393            .collect()
394    }
395
396    /// Get count of revoked tasks
397    #[inline]
398    #[must_use]
399    pub fn revoked_count(&self) -> usize {
400        self.revoked_ids
401            .values()
402            .filter(|request| !request.is_expired())
403            .count()
404    }
405
406    /// Clear all revocations
407    pub fn clear(&mut self) {
408        self.revoked_ids.clear();
409        self.pattern_revocations.clear();
410        self.terminated.clear();
411    }
412
413    /// Export state for persistence
414    pub fn export_state(&self) -> RevocationState {
415        let revoked_tasks = self
416            .revoked_ids
417            .iter()
418            .filter(|(_, req)| !req.is_expired())
419            .map(|(id, req)| (id.to_string(), req.clone()))
420            .collect();
421
422        let pattern_revocations = self
423            .pattern_revocations
424            .iter()
425            .filter(|rev| !rev.is_expired())
426            .map(SerializablePatternRevocation::from)
427            .collect();
428
429        RevocationState {
430            revoked_tasks,
431            pattern_revocations,
432        }
433    }
434
435    /// Import state from persistence
436    pub fn import_state(&mut self, state: RevocationState) {
437        for (id_str, request) in state.revoked_tasks {
438            if !request.is_expired() {
439                if let Ok(id) = Uuid::parse_str(&id_str) {
440                    self.revoked_ids.insert(id, request);
441                }
442            }
443        }
444
445        for ser_pattern in state.pattern_revocations {
446            let pattern_rev = ser_pattern.into_pattern_revocation();
447            if !pattern_rev.is_expired() {
448                self.pattern_revocations.push(pattern_rev);
449            }
450        }
451    }
452}
453
454/// Thread-safe revocation manager for workers
455#[derive(Debug, Clone, Default)]
456pub struct WorkerRevocationManager {
457    inner: Arc<RwLock<RevocationManager>>,
458}
459
460impl WorkerRevocationManager {
461    /// Create a new worker revocation manager
462    #[must_use]
463    pub fn new() -> Self {
464        Self::default()
465    }
466
467    /// Revoke a task by ID
468    pub fn revoke(&self, task_id: Uuid, mode: RevocationMode) {
469        if let Ok(mut guard) = self.inner.write() {
470            guard.revoke(task_id, mode);
471        }
472    }
473
474    /// Revoke a task with a full request
475    pub fn revoke_with_request(&self, request: RevocationRequest) {
476        if let Ok(mut guard) = self.inner.write() {
477            guard.revoke_with_request(request);
478        }
479    }
480
481    /// Revoke by pattern
482    pub fn revoke_by_pattern(&self, pattern: &str, mode: RevocationMode) {
483        if let Ok(mut guard) = self.inner.write() {
484            guard.revoke_by_pattern(pattern, mode);
485        }
486    }
487
488    /// Bulk revoke multiple tasks
489    pub fn bulk_revoke(&self, task_ids: &[Uuid], mode: RevocationMode) {
490        if let Ok(mut guard) = self.inner.write() {
491            guard.bulk_revoke(task_ids, mode);
492        }
493    }
494
495    /// Check if a task is revoked by ID
496    #[must_use]
497    pub fn is_revoked(&self, task_id: Uuid) -> bool {
498        if let Ok(guard) = self.inner.read() {
499            guard.is_revoked(task_id)
500        } else {
501            false
502        }
503    }
504
505    /// Check revocation status (by ID and pattern)
506    #[must_use]
507    pub fn check_revocation(&self, task_id: Uuid, task_name: &str) -> RevocationResult {
508        if let Ok(guard) = self.inner.read() {
509            guard.check_revocation(task_id, task_name)
510        } else {
511            RevocationResult::not_revoked()
512        }
513    }
514
515    /// Mark a task as terminated
516    pub fn mark_terminated(&self, task_id: Uuid) {
517        if let Ok(mut guard) = self.inner.write() {
518            guard.mark_terminated(task_id);
519        }
520    }
521
522    /// Check if a task has been terminated
523    #[must_use]
524    pub fn is_terminated(&self, task_id: Uuid) -> bool {
525        if let Ok(guard) = self.inner.read() {
526            guard.is_terminated(task_id)
527        } else {
528            false
529        }
530    }
531
532    /// Remove revocation for a task ID
533    pub fn unrevoke(&self, task_id: Uuid) {
534        if let Ok(mut guard) = self.inner.write() {
535            guard.unrevoke(task_id);
536        }
537    }
538
539    /// Clean up expired revocations
540    pub fn cleanup_expired(&self) {
541        if let Ok(mut guard) = self.inner.write() {
542            guard.cleanup_expired();
543        }
544    }
545
546    /// Get all revoked task IDs
547    #[must_use]
548    pub fn revoked_ids(&self) -> Vec<Uuid> {
549        if let Ok(guard) = self.inner.read() {
550            guard.revoked_ids()
551        } else {
552            Vec::new()
553        }
554    }
555
556    /// Get count of revoked tasks
557    #[must_use]
558    pub fn revoked_count(&self) -> usize {
559        if let Ok(guard) = self.inner.read() {
560            guard.revoked_count()
561        } else {
562            0
563        }
564    }
565
566    /// Export state for persistence
567    #[must_use]
568    pub fn export_state(&self) -> RevocationState {
569        if let Ok(guard) = self.inner.read() {
570            guard.export_state()
571        } else {
572            RevocationState::default()
573        }
574    }
575
576    /// Import state from persistence
577    pub fn import_state(&self, state: RevocationState) {
578        if let Ok(mut guard) = self.inner.write() {
579            guard.import_state(state);
580        }
581    }
582
583    /// Clear all revocations
584    pub fn clear(&self) {
585        if let Ok(mut guard) = self.inner.write() {
586            guard.clear();
587        }
588    }
589}
590
591/// Get current timestamp as f64
592fn current_timestamp() -> f64 {
593    SystemTime::now()
594        .duration_since(UNIX_EPOCH)
595        .unwrap_or_default()
596        .as_secs_f64()
597}
598
599#[cfg(test)]
600mod tests {
601    use super::*;
602
603    #[test]
604    fn test_revocation_request() {
605        let task_id = Uuid::new_v4();
606        let request = RevocationRequest::new(task_id, RevocationMode::Terminate);
607
608        assert_eq!(request.task_id, task_id);
609        assert_eq!(request.mode, RevocationMode::Terminate);
610        assert!(!request.is_expired());
611    }
612
613    #[test]
614    fn test_revocation_request_with_expiration() {
615        let task_id = Uuid::new_v4();
616        let request = RevocationRequest::new(task_id, RevocationMode::Terminate)
617            .with_expiration(Duration::from_secs(0));
618
619        // Immediately expired
620        std::thread::sleep(Duration::from_millis(10));
621        assert!(request.is_expired());
622    }
623
624    #[test]
625    fn test_revocation_manager_basic() {
626        let mut manager = RevocationManager::new();
627        let task_id = Uuid::new_v4();
628
629        manager.revoke(task_id, RevocationMode::Terminate);
630        assert!(manager.is_revoked(task_id));
631
632        let other_id = Uuid::new_v4();
633        assert!(!manager.is_revoked(other_id));
634    }
635
636    #[test]
637    fn test_revocation_by_pattern() {
638        let mut manager = RevocationManager::new();
639
640        manager.revoke_by_pattern("email.*", RevocationMode::Ignore);
641
642        let result = manager.check_revocation(Uuid::new_v4(), "email.send");
643        assert!(result.revoked);
644        assert_eq!(result.mode, RevocationMode::Ignore);
645
646        let result = manager.check_revocation(Uuid::new_v4(), "sms.send");
647        assert!(!result.revoked);
648    }
649
650    #[test]
651    fn test_bulk_revoke() {
652        let mut manager = RevocationManager::new();
653        let ids: Vec<Uuid> = (0..5).map(|_| Uuid::new_v4()).collect();
654
655        manager.bulk_revoke(&ids, RevocationMode::Terminate);
656
657        for id in &ids {
658            assert!(manager.is_revoked(*id));
659        }
660    }
661
662    #[test]
663    fn test_unrevoke() {
664        let mut manager = RevocationManager::new();
665        let task_id = Uuid::new_v4();
666
667        manager.revoke(task_id, RevocationMode::Terminate);
668        assert!(manager.is_revoked(task_id));
669
670        manager.unrevoke(task_id);
671        assert!(!manager.is_revoked(task_id));
672    }
673
674    #[test]
675    fn test_cleanup_expired() {
676        let mut manager = RevocationManager::new();
677        let task_id = Uuid::new_v4();
678
679        // Add an expired revocation
680        let request = RevocationRequest::new(task_id, RevocationMode::Terminate)
681            .with_expiration(Duration::from_secs(0));
682        std::thread::sleep(Duration::from_millis(10));
683        manager.revoke_with_request(request);
684
685        // Add a non-expired revocation
686        let other_id = Uuid::new_v4();
687        manager.revoke(other_id, RevocationMode::Terminate);
688
689        manager.cleanup_expired();
690
691        assert!(!manager.is_revoked(task_id)); // Expired
692        assert!(manager.is_revoked(other_id)); // Not expired
693    }
694
695    #[test]
696    fn test_export_import_state() {
697        let mut manager = RevocationManager::new();
698        let task_id = Uuid::new_v4();
699
700        manager.revoke(task_id, RevocationMode::Terminate);
701        manager.revoke_by_pattern("email.*", RevocationMode::Ignore);
702
703        let state = manager.export_state();
704
705        let mut new_manager = RevocationManager::new();
706        new_manager.import_state(state);
707
708        assert!(new_manager.is_revoked(task_id));
709        let result = new_manager.check_revocation(Uuid::new_v4(), "email.send");
710        assert!(result.revoked);
711    }
712
713    #[test]
714    fn test_worker_revocation_manager() {
715        let manager = WorkerRevocationManager::new();
716        let task_id = Uuid::new_v4();
717
718        manager.revoke(task_id, RevocationMode::Terminate);
719        assert!(manager.is_revoked(task_id));
720
721        manager.mark_terminated(task_id);
722        assert!(manager.is_terminated(task_id));
723    }
724
725    #[test]
726    fn test_revocation_state_serialization() {
727        let mut manager = RevocationManager::new();
728        let task_id = Uuid::new_v4();
729
730        manager.revoke(task_id, RevocationMode::Terminate);
731        manager.revoke_by_pattern("tasks.*", RevocationMode::Ignore);
732
733        let state = manager.export_state();
734        let json = serde_json::to_string(&state).unwrap();
735        let parsed: RevocationState = serde_json::from_str(&json).unwrap();
736
737        assert!(!parsed.revoked_tasks.is_empty());
738        assert!(!parsed.pattern_revocations.is_empty());
739    }
740
741    #[test]
742    fn test_revocation_with_reason() {
743        let mut manager = RevocationManager::new();
744        let task_id = Uuid::new_v4();
745
746        let request = RevocationRequest::new(task_id, RevocationMode::Terminate)
747            .with_reason("Manual cancellation by user");
748        manager.revoke_with_request(request);
749
750        let result = manager.check_revocation(task_id, "any.task");
751        assert!(result.revoked);
752        assert_eq!(
753            result.reason,
754            Some("Manual cancellation by user".to_string())
755        );
756    }
757
758    #[test]
759    fn test_revoked_count() {
760        let mut manager = RevocationManager::new();
761
762        for _ in 0..5 {
763            manager.revoke(Uuid::new_v4(), RevocationMode::Terminate);
764        }
765
766        assert_eq!(manager.revoked_count(), 5);
767        assert_eq!(manager.revoked_ids().len(), 5);
768    }
769
770    #[test]
771    fn test_clear() {
772        let mut manager = RevocationManager::new();
773
774        manager.revoke(Uuid::new_v4(), RevocationMode::Terminate);
775        manager.revoke_by_pattern("*", RevocationMode::Ignore);
776        manager.mark_terminated(Uuid::new_v4());
777
778        manager.clear();
779
780        assert_eq!(manager.revoked_count(), 0);
781    }
782}