Skip to main content

celers_backend_redis/
types.rs

1//! Core types for the Redis result backend
2//!
3//! This module contains the fundamental types used throughout the backend:
4//! - [`BackendError`] and [`Result`] for error handling
5//! - [`TaskResult`] for task state tracking
6//! - [`ProgressInfo`] for long-running task progress
7//! - [`TaskMeta`] for task metadata storage
8//! - [`TaskTtlConfig`] for per-task TTL configuration
9//! - [`ChordState`] for barrier synchronization
10
11use chrono::{DateTime, Utc};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::time::Duration;
15use thiserror::Error;
16use uuid::Uuid;
17
18/// Result backend errors
19#[derive(Debug, Error)]
20pub enum BackendError {
21    #[error("Redis error: {0}")]
22    Redis(#[from] redis::RedisError),
23
24    #[error("Serialization error: {0}")]
25    Serialization(String),
26
27    #[error("Result not found: {0}")]
28    NotFound(Uuid),
29
30    #[error("Connection error: {0}")]
31    Connection(String),
32}
33
34impl BackendError {
35    /// Check if the error is Redis-related
36    pub fn is_redis(&self) -> bool {
37        matches!(self, BackendError::Redis(_))
38    }
39
40    /// Check if the error is serialization-related
41    pub fn is_serialization(&self) -> bool {
42        matches!(self, BackendError::Serialization(_))
43    }
44
45    /// Check if the error is not-found
46    pub fn is_not_found(&self) -> bool {
47        matches!(self, BackendError::NotFound(_))
48    }
49
50    /// Check if the error is connection-related
51    pub fn is_connection(&self) -> bool {
52        matches!(self, BackendError::Connection(_))
53    }
54
55    /// Check if this is a retryable error
56    ///
57    /// Returns true for Redis and connection errors, which are typically transient.
58    /// Returns false for serialization and not-found errors.
59    pub fn is_retryable(&self) -> bool {
60        matches!(self, BackendError::Redis(_) | BackendError::Connection(_))
61    }
62
63    /// Get the error category as a string
64    pub fn category(&self) -> &'static str {
65        match self {
66            BackendError::Redis(_) => "redis",
67            BackendError::Serialization(_) => "serialization",
68            BackendError::NotFound(_) => "not_found",
69            BackendError::Connection(_) => "connection",
70        }
71    }
72}
73
74pub type Result<T> = std::result::Result<T, BackendError>;
75
76/// Task result state
77#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
78pub enum TaskResult {
79    /// Task is pending execution
80    Pending,
81
82    /// Task is currently running
83    Started,
84
85    /// Task completed successfully
86    Success(serde_json::Value),
87
88    /// Task failed with error
89    Failure(String),
90
91    /// Task was revoked/cancelled
92    Revoked,
93
94    /// Task retry scheduled
95    Retry(u32),
96}
97
98impl TaskResult {
99    /// Check if the task is pending
100    pub fn is_pending(&self) -> bool {
101        matches!(self, TaskResult::Pending)
102    }
103
104    /// Check if the task is started
105    pub fn is_started(&self) -> bool {
106        matches!(self, TaskResult::Started)
107    }
108
109    /// Check if the task succeeded
110    pub fn is_success(&self) -> bool {
111        matches!(self, TaskResult::Success(_))
112    }
113
114    /// Check if the task failed
115    pub fn is_failure(&self) -> bool {
116        matches!(self, TaskResult::Failure(_))
117    }
118
119    /// Check if the task was revoked
120    pub fn is_revoked(&self) -> bool {
121        matches!(self, TaskResult::Revoked)
122    }
123
124    /// Check if the task is being retried
125    pub fn is_retry(&self) -> bool {
126        matches!(self, TaskResult::Retry(_))
127    }
128
129    /// Check if the task is in a terminal state (success, failure, or revoked)
130    pub fn is_terminal(&self) -> bool {
131        matches!(
132            self,
133            TaskResult::Success(_) | TaskResult::Failure(_) | TaskResult::Revoked
134        )
135    }
136
137    /// Check if the task is in an active (non-terminal) state
138    pub fn is_active(&self) -> bool {
139        !self.is_terminal()
140    }
141
142    /// Check if two TaskResult values are of the same variant type
143    ///
144    /// This compares only the variant, ignoring inner values.
145    /// For example, `Success(1)` and `Success(2)` are considered the same.
146    pub fn same_variant(&self, other: &TaskResult) -> bool {
147        matches!(
148            (self, other),
149            (TaskResult::Pending, TaskResult::Pending)
150                | (TaskResult::Started, TaskResult::Started)
151                | (TaskResult::Success(_), TaskResult::Success(_))
152                | (TaskResult::Failure(_), TaskResult::Failure(_))
153                | (TaskResult::Revoked, TaskResult::Revoked)
154                | (TaskResult::Retry(_), TaskResult::Retry(_))
155        )
156    }
157
158    /// Get the success result value if the task succeeded
159    pub fn success_value(&self) -> Option<&serde_json::Value> {
160        match self {
161            TaskResult::Success(value) => Some(value),
162            _ => None,
163        }
164    }
165
166    /// Get the failure error message if the task failed
167    pub fn failure_message(&self) -> Option<&str> {
168        match self {
169            TaskResult::Failure(msg) => Some(msg),
170            _ => None,
171        }
172    }
173
174    /// Get the retry count if the task is being retried
175    pub fn retry_count(&self) -> Option<u32> {
176        match self {
177            TaskResult::Retry(count) => Some(*count),
178            _ => None,
179        }
180    }
181}
182
183impl std::fmt::Display for TaskResult {
184    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185        match self {
186            TaskResult::Pending => write!(f, "PENDING"),
187            TaskResult::Started => write!(f, "STARTED"),
188            TaskResult::Success(_) => write!(f, "SUCCESS"),
189            TaskResult::Failure(err) => write!(f, "FAILURE: {}", err),
190            TaskResult::Revoked => write!(f, "REVOKED"),
191            TaskResult::Retry(count) => write!(f, "RETRY({})", count),
192        }
193    }
194}
195
196/// Progress information for long-running tasks
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct ProgressInfo {
199    /// Current progress value (e.g., items processed)
200    pub current: u64,
201
202    /// Total progress value (e.g., total items)
203    pub total: u64,
204
205    /// Optional progress message
206    pub message: Option<String>,
207
208    /// Progress percentage (0-100)
209    pub percent: f64,
210
211    /// Timestamp of last progress update
212    pub updated_at: DateTime<Utc>,
213}
214
215impl ProgressInfo {
216    /// Create new progress info
217    pub fn new(current: u64, total: u64) -> Self {
218        let percent = if total > 0 {
219            (current as f64 / total as f64 * 100.0).min(100.0)
220        } else {
221            0.0
222        };
223
224        Self {
225            current,
226            total,
227            message: None,
228            percent,
229            updated_at: Utc::now(),
230        }
231    }
232
233    /// Create progress with message
234    pub fn with_message(mut self, message: String) -> Self {
235        self.message = Some(message);
236        self
237    }
238
239    /// Check if the task is complete (100% progress)
240    pub fn is_complete(&self) -> bool {
241        self.percent >= 100.0
242    }
243
244    /// Check if there is a progress message
245    pub fn has_message(&self) -> bool {
246        self.message.is_some()
247    }
248
249    /// Get remaining items to process
250    pub fn remaining(&self) -> u64 {
251        self.total.saturating_sub(self.current)
252    }
253
254    /// Get progress as a fraction (0.0 to 1.0)
255    pub fn fraction(&self) -> f64 {
256        self.percent / 100.0
257    }
258}
259
260impl std::fmt::Display for ProgressInfo {
261    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262        write!(f, "{}/{} ({:.1}%)", self.current, self.total, self.percent)?;
263        if let Some(ref msg) = self.message {
264            write!(f, " - {}", msg)?;
265        }
266        Ok(())
267    }
268}
269
270/// Task metadata stored in result backend
271#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct TaskMeta {
273    /// Task ID
274    pub task_id: Uuid,
275
276    /// Task name
277    pub task_name: String,
278
279    /// Task status/result
280    pub result: TaskResult,
281
282    /// Timestamp when task was created
283    pub created_at: DateTime<Utc>,
284
285    /// Timestamp when task was started
286    pub started_at: Option<DateTime<Utc>>,
287
288    /// Timestamp when task completed
289    pub completed_at: Option<DateTime<Utc>>,
290
291    /// Worker that executed the task
292    pub worker: Option<String>,
293
294    /// Task progress (for long-running tasks)
295    #[serde(skip_serializing_if = "Option::is_none")]
296    pub progress: Option<ProgressInfo>,
297
298    /// Version number for result versioning
299    #[serde(default)]
300    pub version: u32,
301
302    /// Tags for categorizing and filtering tasks
303    #[serde(default, skip_serializing_if = "Vec::is_empty")]
304    pub tags: Vec<String>,
305
306    /// Custom metadata for flexible key-value storage
307    #[serde(default, skip_serializing_if = "std::collections::HashMap::is_empty")]
308    pub metadata: std::collections::HashMap<String, serde_json::Value>,
309
310    /// Worker hostname that executed the task
311    #[serde(default, skip_serializing_if = "Option::is_none")]
312    pub worker_hostname: Option<String>,
313
314    /// Task runtime in milliseconds
315    #[serde(default, skip_serializing_if = "Option::is_none")]
316    pub runtime_ms: Option<u64>,
317
318    /// Peak memory usage in bytes during execution
319    #[serde(default, skip_serializing_if = "Option::is_none")]
320    pub memory_bytes: Option<u64>,
321
322    /// Number of retries before completion
323    #[serde(default, skip_serializing_if = "Option::is_none")]
324    pub retries: Option<u32>,
325
326    /// Queue the task was consumed from
327    #[serde(default, skip_serializing_if = "Option::is_none")]
328    pub queue: Option<String>,
329}
330
331impl TaskMeta {
332    pub fn new(task_id: Uuid, task_name: String) -> Self {
333        Self {
334            task_id,
335            task_name,
336            result: TaskResult::Pending,
337            created_at: Utc::now(),
338            started_at: None,
339            completed_at: None,
340            worker: None,
341            progress: None,
342            version: 0,
343            tags: Vec::new(),
344            metadata: std::collections::HashMap::new(),
345            worker_hostname: None,
346            runtime_ms: None,
347            memory_bytes: None,
348            retries: None,
349            queue: None,
350        }
351    }
352
353    /// Check if the task has started
354    pub fn has_started(&self) -> bool {
355        self.started_at.is_some()
356    }
357
358    /// Check if the task has completed
359    pub fn has_completed(&self) -> bool {
360        self.completed_at.is_some()
361    }
362
363    /// Check if the task has progress information
364    pub fn has_progress(&self) -> bool {
365        self.progress.is_some()
366    }
367
368    /// Get the task duration if completed
369    pub fn duration(&self) -> Option<chrono::Duration> {
370        match (self.started_at, self.completed_at) {
371            (Some(start), Some(end)) => Some(end - start),
372            _ => None,
373        }
374    }
375
376    /// Get the task age (time since creation)
377    pub fn age(&self) -> chrono::Duration {
378        Utc::now() - self.created_at
379    }
380
381    /// Get the execution time (time since start)
382    pub fn execution_time(&self) -> Option<chrono::Duration> {
383        self.started_at.map(|start| Utc::now() - start)
384    }
385
386    /// Check if the task is in a terminal state
387    pub fn is_terminal(&self) -> bool {
388        self.result.is_terminal()
389    }
390
391    /// Check if the task is in an active state
392    pub fn is_active(&self) -> bool {
393        self.result.is_active()
394    }
395
396    /// Add a tag to this task
397    pub fn add_tag(&mut self, tag: impl Into<String>) {
398        let tag = tag.into();
399        if !self.tags.contains(&tag) {
400            self.tags.push(tag);
401        }
402    }
403
404    /// Remove a tag from this task
405    pub fn remove_tag(&mut self, tag: &str) {
406        self.tags.retain(|t| t != tag);
407    }
408
409    /// Check if this task has a specific tag
410    pub fn has_tag(&self, tag: &str) -> bool {
411        self.tags.iter().any(|t| t == tag)
412    }
413
414    /// Check if this task has any of the specified tags
415    pub fn has_any_tag(&self, tags: &[String]) -> bool {
416        tags.iter().any(|tag| self.has_tag(tag))
417    }
418
419    /// Check if this task has all of the specified tags
420    pub fn has_all_tags(&self, tags: &[String]) -> bool {
421        tags.iter().all(|tag| self.has_tag(tag))
422    }
423
424    /// Set a custom metadata field
425    pub fn set_metadata(&mut self, key: impl Into<String>, value: serde_json::Value) {
426        self.metadata.insert(key.into(), value);
427    }
428
429    /// Get a custom metadata field
430    pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
431        self.metadata.get(key)
432    }
433
434    /// Remove a custom metadata field
435    pub fn remove_metadata(&mut self, key: &str) -> Option<serde_json::Value> {
436        self.metadata.remove(key)
437    }
438
439    /// Check if a custom metadata field exists
440    pub fn has_metadata(&self, key: &str) -> bool {
441        self.metadata.contains_key(key)
442    }
443}
444
445impl std::fmt::Display for TaskMeta {
446    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
447        write!(
448            f,
449            "Task[{}] name={} result={}",
450            &self.task_id.to_string()[..8],
451            self.task_name,
452            self.result
453        )?;
454
455        if let Some(worker) = &self.worker {
456            write!(f, " worker={}", worker)?;
457        }
458
459        if let Some(progress) = &self.progress {
460            write!(f, " progress={}", progress)?;
461        }
462
463        Ok(())
464    }
465}
466
467/// Per-task-type TTL configuration for result expiration
468///
469/// Allows setting different TTLs for different task types, with a fallback
470/// to a default TTL when no per-task override is configured.
471///
472/// # Example
473/// ```
474/// use celers_backend_redis::TaskTtlConfig;
475/// use std::time::Duration;
476///
477/// let mut config = TaskTtlConfig::with_default(Duration::from_secs(3600));
478/// config.set_task_ttl("long_running_task", Duration::from_secs(86400));
479/// config.set_task_ttl("ephemeral_task", Duration::from_secs(60));
480///
481/// assert_eq!(config.get_ttl("long_running_task"), Some(Duration::from_secs(86400)));
482/// assert_eq!(config.get_ttl("unknown_task"), Some(Duration::from_secs(3600)));
483/// ```
484#[derive(Debug, Clone)]
485pub struct TaskTtlConfig {
486    /// Default TTL for all task results
487    default_ttl: Option<Duration>,
488    /// Per-task-type TTL overrides (task_name -> TTL)
489    task_ttls: HashMap<String, Duration>,
490}
491
492impl Default for TaskTtlConfig {
493    fn default() -> Self {
494        Self::new()
495    }
496}
497
498impl TaskTtlConfig {
499    /// Create a new empty TTL configuration (no default, no per-task overrides)
500    pub fn new() -> Self {
501        Self {
502            default_ttl: None,
503            task_ttls: HashMap::new(),
504        }
505    }
506
507    /// Create a TTL configuration with a default TTL for all tasks
508    pub fn with_default(ttl: Duration) -> Self {
509        Self {
510            default_ttl: Some(ttl),
511            task_ttls: HashMap::new(),
512        }
513    }
514
515    /// Set a per-task-type TTL override
516    pub fn set_task_ttl(&mut self, task_name: &str, ttl: Duration) {
517        self.task_ttls.insert(task_name.to_string(), ttl);
518    }
519
520    /// Get the TTL for a specific task type
521    ///
522    /// Returns the per-task TTL if one is set, otherwise falls back to
523    /// the default TTL. Returns `None` if neither is configured.
524    pub fn get_ttl(&self, task_name: &str) -> Option<Duration> {
525        self.task_ttls.get(task_name).copied().or(self.default_ttl)
526    }
527
528    /// Check if this configuration has any TTLs configured
529    pub fn is_empty(&self) -> bool {
530        self.default_ttl.is_none() && self.task_ttls.is_empty()
531    }
532
533    /// Get the default TTL
534    pub fn default_ttl(&self) -> Option<Duration> {
535        self.default_ttl
536    }
537
538    /// Set the default TTL
539    pub fn set_default_ttl(&mut self, ttl: Duration) {
540        self.default_ttl = Some(ttl);
541    }
542
543    /// Remove the per-task TTL override for a specific task type
544    pub fn remove_task_ttl(&mut self, task_name: &str) -> Option<Duration> {
545        self.task_ttls.remove(task_name)
546    }
547
548    /// Get the number of per-task TTL overrides
549    pub fn task_ttl_count(&self) -> usize {
550        self.task_ttls.len()
551    }
552}
553
554/// Chord state (for barrier synchronization)
555#[derive(Debug, Clone, Serialize, Deserialize)]
556pub struct ChordState {
557    /// Chord ID (group ID)
558    pub chord_id: Uuid,
559
560    /// Total number of tasks in chord
561    pub total: usize,
562
563    /// Number of completed tasks
564    pub completed: usize,
565
566    /// Callback task to execute when chord completes
567    pub callback: Option<String>,
568
569    /// Task IDs in the chord
570    pub task_ids: Vec<Uuid>,
571
572    /// Chord creation timestamp
573    pub created_at: DateTime<Utc>,
574
575    /// Chord timeout (if any)
576    #[serde(skip_serializing_if = "Option::is_none")]
577    pub timeout: Option<Duration>,
578
579    /// Whether the chord has been cancelled
580    #[serde(default)]
581    pub cancelled: bool,
582
583    /// Cancellation reason
584    #[serde(skip_serializing_if = "Option::is_none")]
585    pub cancellation_reason: Option<String>,
586
587    /// Number of retry attempts
588    #[serde(default)]
589    pub retry_count: u32,
590
591    /// Maximum retry attempts
592    #[serde(skip_serializing_if = "Option::is_none")]
593    pub max_retries: Option<u32>,
594}
595
596impl ChordState {
597    /// Create a new chord state
598    pub fn new(chord_id: Uuid, total: usize, task_ids: Vec<Uuid>) -> Self {
599        Self {
600            chord_id,
601            total,
602            completed: 0,
603            callback: None,
604            task_ids,
605            created_at: Utc::now(),
606            timeout: None,
607            cancelled: false,
608            cancellation_reason: None,
609            retry_count: 0,
610            max_retries: None,
611        }
612    }
613
614    /// Set the chord timeout
615    pub fn with_timeout(mut self, timeout: Duration) -> Self {
616        self.timeout = Some(timeout);
617        self
618    }
619
620    /// Set the callback task
621    pub fn with_callback(mut self, callback: String) -> Self {
622        self.callback = Some(callback);
623        self
624    }
625
626    /// Check if the chord is complete (all tasks finished)
627    pub fn is_complete(&self) -> bool {
628        self.completed >= self.total && !self.cancelled
629    }
630
631    /// Check if the chord is cancelled
632    pub fn is_cancelled(&self) -> bool {
633        self.cancelled
634    }
635
636    /// Cancel the chord
637    pub fn cancel(&mut self, reason: Option<String>) {
638        self.cancelled = true;
639        self.cancellation_reason = reason;
640    }
641
642    /// Check if the chord is in a terminal state (complete, cancelled, or timed out)
643    pub fn is_terminal(&self) -> bool {
644        self.is_complete() || self.is_cancelled() || self.is_timed_out()
645    }
646
647    /// Check if the chord has timed out
648    pub fn is_timed_out(&self) -> bool {
649        if let Some(timeout) = self.timeout {
650            let age = Utc::now() - self.created_at;
651            age.num_milliseconds() > timeout.as_millis() as i64
652        } else {
653            false
654        }
655    }
656
657    /// Get the remaining time before timeout
658    pub fn remaining_timeout(&self) -> Option<Duration> {
659        self.timeout.and_then(|timeout| {
660            let age = Utc::now() - self.created_at;
661            let age_ms = age.num_milliseconds().max(0) as u64;
662            let timeout_ms = timeout.as_millis() as u64;
663
664            if age_ms < timeout_ms {
665                Some(Duration::from_millis(timeout_ms - age_ms))
666            } else {
667                None
668            }
669        })
670    }
671
672    /// Get the number of remaining tasks
673    pub fn remaining(&self) -> usize {
674        self.total.saturating_sub(self.completed)
675    }
676
677    /// Get the completion percentage (0.0 to 100.0)
678    pub fn percent_complete(&self) -> f64 {
679        if self.total > 0 {
680            (self.completed as f64 / self.total as f64 * 100.0).min(100.0)
681        } else {
682            0.0
683        }
684    }
685
686    /// Check if the chord has a callback
687    pub fn has_callback(&self) -> bool {
688        self.callback.is_some()
689    }
690
691    /// Check if the chord has a timeout
692    pub fn has_timeout(&self) -> bool {
693        self.timeout.is_some()
694    }
695
696    /// Get the number of tasks in the chord
697    pub fn task_count(&self) -> usize {
698        self.task_ids.len()
699    }
700
701    /// Get the chord age (time since creation)
702    pub fn age(&self) -> chrono::Duration {
703        Utc::now() - self.created_at
704    }
705
706    /// Set maximum retry attempts
707    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
708        self.max_retries = Some(max_retries);
709        self
710    }
711
712    /// Check if the chord can be retried
713    pub fn can_retry(&self) -> bool {
714        if let Some(max_retries) = self.max_retries {
715            self.retry_count < max_retries
716        } else {
717            false
718        }
719    }
720
721    /// Increment the retry count and reset the chord for retry
722    ///
723    /// Returns true if retry is allowed, false if max retries exceeded
724    pub fn retry(&mut self) -> bool {
725        if !self.can_retry() {
726            return false;
727        }
728        self.retry_count += 1;
729        self.completed = 0;
730        self.cancelled = false;
731        self.cancellation_reason = None;
732        self.created_at = Utc::now();
733        true
734    }
735
736    /// Get remaining retry attempts
737    pub fn remaining_retries(&self) -> Option<u32> {
738        self.max_retries
739            .map(|max| max.saturating_sub(self.retry_count))
740    }
741
742    /// Check if this is a retry attempt
743    pub fn is_retry(&self) -> bool {
744        self.retry_count > 0
745    }
746}
747
748impl std::fmt::Display for ChordState {
749    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
750        write!(
751            f,
752            "Chord[{}] {}/{} tasks ({:.1}%)",
753            &self.chord_id.to_string()[..8],
754            self.completed,
755            self.total,
756            self.percent_complete()
757        )?;
758
759        if let Some(ref callback) = self.callback {
760            write!(f, " callback={}", callback)?;
761        }
762
763        if self.is_cancelled() {
764            write!(f, " [CANCELLED")?;
765            if let Some(ref reason) = self.cancellation_reason {
766                write!(f, ": {}", reason)?;
767            }
768            write!(f, "]")?;
769        } else if let Some(timeout) = self.timeout {
770            if self.is_timed_out() {
771                write!(f, " [TIMED OUT]")?;
772            } else if let Some(remaining) = self.remaining_timeout() {
773                write!(f, " timeout={:?} remaining={:?}", timeout, remaining)?;
774            }
775        }
776
777        Ok(())
778    }
779}