Skip to main content

celers_protocol/
types.rs

1//! Protocol types, structs, and enums for Celery messages.
2
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fmt;
7use uuid::Uuid;
8
9/// Common content type constants
10pub(crate) const CONTENT_TYPE_JSON: &str = "application/json";
11#[cfg(feature = "msgpack")]
12pub(crate) const CONTENT_TYPE_MSGPACK: &str = "application/x-msgpack";
13#[cfg(feature = "binary")]
14pub(crate) const CONTENT_TYPE_BINARY: &str = "application/octet-stream";
15
16/// Common encoding constants
17pub(crate) const ENCODING_UTF8: &str = "utf-8";
18pub(crate) const ENCODING_BINARY: &str = "binary";
19
20/// Default language
21pub(crate) const DEFAULT_LANG: &str = "rust";
22
23/// Validation errors for Celery protocol messages
24///
25/// # Examples
26///
27/// ```
28/// use celers_protocol::{Message, ValidationError};
29/// use uuid::Uuid;
30///
31/// // Create a message with an empty task name
32/// let msg = Message::new("".to_string(), Uuid::new_v4(), vec![1, 2, 3]);
33///
34/// // Validation will fail with a structured error
35/// match msg.validate() {
36///     Ok(_) => panic!("Should have failed"),
37///     Err(ValidationError::EmptyTaskName) => {
38///         println!("Task name cannot be empty");
39///     }
40///     Err(e) => panic!("Unexpected error: {}", e),
41/// }
42/// ```
43#[derive(Debug, Clone, PartialEq, Eq, Hash)]
44pub enum ValidationError {
45    /// Task name is empty
46    EmptyTaskName,
47    /// Retry count exceeds maximum
48    RetryLimitExceeded { retries: u32, max: u32 },
49    /// ETA is after expiration time
50    EtaAfterExpiration,
51    /// Invalid delivery mode (must be 1 or 2)
52    InvalidDeliveryMode { mode: u8 },
53    /// Invalid priority (must be 0-9)
54    InvalidPriority { priority: u8 },
55    /// Content type is empty
56    EmptyContentType,
57    /// Message body is empty
58    EmptyBody,
59    /// Message body exceeds size limit
60    BodyTooLarge { size: usize, max: usize },
61}
62
63impl fmt::Display for ValidationError {
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        match self {
66            ValidationError::EmptyTaskName => write!(f, "Task name cannot be empty"),
67            ValidationError::RetryLimitExceeded { retries, max } => {
68                write!(f, "Retries ({}) cannot exceed {}", retries, max)
69            }
70            ValidationError::EtaAfterExpiration => {
71                write!(f, "ETA cannot be after expiration time")
72            }
73            ValidationError::InvalidDeliveryMode { mode } => {
74                write!(
75                    f,
76                    "Invalid delivery mode ({}): must be 1 (non-persistent) or 2 (persistent)",
77                    mode
78                )
79            }
80            ValidationError::InvalidPriority { priority } => {
81                write!(
82                    f,
83                    "Invalid priority ({}): must be between 0 and 9",
84                    priority
85                )
86            }
87            ValidationError::EmptyContentType => write!(f, "Content type cannot be empty"),
88            ValidationError::EmptyBody => write!(f, "Message body cannot be empty"),
89            ValidationError::BodyTooLarge { size, max } => {
90                write!(
91                    f,
92                    "Message body too large: {} bytes (max {} bytes)",
93                    size, max
94                )
95            }
96        }
97    }
98}
99
100impl std::error::Error for ValidationError {}
101
102/// Protocol version
103#[derive(
104    Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize,
105)]
106pub enum ProtocolVersion {
107    /// Protocol version 2 (Celery 4.x+)
108    #[default]
109    V2,
110    /// Protocol version 5 (Celery 5.x+)
111    V5,
112}
113
114impl std::fmt::Display for ProtocolVersion {
115    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        match self {
117            ProtocolVersion::V2 => write!(f, "v2"),
118            ProtocolVersion::V5 => write!(f, "v5"),
119        }
120    }
121}
122
123impl std::str::FromStr for ProtocolVersion {
124    type Err = String;
125
126    fn from_str(s: &str) -> Result<Self, Self::Err> {
127        match s.to_lowercase().as_str() {
128            "v2" | "2" => Ok(ProtocolVersion::V2),
129            "v5" | "5" => Ok(ProtocolVersion::V5),
130            _ => Err(format!("Invalid protocol version: {}", s)),
131        }
132    }
133}
134
135impl ProtocolVersion {
136    /// Check if this is protocol version 2
137    #[inline]
138    pub const fn is_v2(self) -> bool {
139        matches!(self, ProtocolVersion::V2)
140    }
141
142    /// Check if this is protocol version 5
143    #[inline]
144    pub const fn is_v5(self) -> bool {
145        matches!(self, ProtocolVersion::V5)
146    }
147
148    /// Get the version number as u8
149    #[inline]
150    pub const fn as_u8(self) -> u8 {
151        match self {
152            ProtocolVersion::V2 => 2,
153            ProtocolVersion::V5 => 5,
154        }
155    }
156
157    /// Get the version number as a static string
158    #[inline]
159    pub const fn as_number_str(self) -> &'static str {
160        match self {
161            ProtocolVersion::V2 => "2",
162            ProtocolVersion::V5 => "5",
163        }
164    }
165}
166
167/// Content type for serialization
168#[derive(Debug, Clone, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
169pub enum ContentType {
170    /// JSON serialization
171    #[default]
172    Json,
173    /// MessagePack serialization
174    #[cfg(feature = "msgpack")]
175    MessagePack,
176    /// Binary serialization
177    #[cfg(feature = "binary")]
178    Binary,
179    /// Custom content type
180    Custom(String),
181}
182
183impl ContentType {
184    #[inline]
185    pub fn as_str(&self) -> &str {
186        match self {
187            ContentType::Json => CONTENT_TYPE_JSON,
188            #[cfg(feature = "msgpack")]
189            ContentType::MessagePack => CONTENT_TYPE_MSGPACK,
190            #[cfg(feature = "binary")]
191            ContentType::Binary => CONTENT_TYPE_BINARY,
192            ContentType::Custom(s) => s,
193        }
194    }
195}
196
197impl std::fmt::Display for ContentType {
198    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199        write!(f, "{}", self.as_str())
200    }
201}
202
203impl std::str::FromStr for ContentType {
204    type Err = String;
205
206    fn from_str(s: &str) -> Result<Self, Self::Err> {
207        match s {
208            CONTENT_TYPE_JSON => Ok(ContentType::Json),
209            #[cfg(feature = "msgpack")]
210            CONTENT_TYPE_MSGPACK => Ok(ContentType::MessagePack),
211            #[cfg(feature = "binary")]
212            CONTENT_TYPE_BINARY => Ok(ContentType::Binary),
213            other => Ok(ContentType::Custom(other.to_string())),
214        }
215    }
216}
217
218impl From<&str> for ContentType {
219    fn from(s: &str) -> Self {
220        match s {
221            CONTENT_TYPE_JSON => ContentType::Json,
222            #[cfg(feature = "msgpack")]
223            CONTENT_TYPE_MSGPACK => ContentType::MessagePack,
224            #[cfg(feature = "binary")]
225            CONTENT_TYPE_BINARY => ContentType::Binary,
226            other => ContentType::Custom(other.to_string()),
227        }
228    }
229}
230
231impl AsRef<str> for ContentType {
232    fn as_ref(&self) -> &str {
233        self.as_str()
234    }
235}
236
237/// Content encoding
238#[derive(Debug, Clone, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
239pub enum ContentEncoding {
240    /// UTF-8 encoding
241    #[default]
242    Utf8,
243    /// Binary encoding
244    Binary,
245    /// Custom encoding
246    Custom(String),
247}
248
249impl ContentEncoding {
250    #[inline]
251    pub fn as_str(&self) -> &str {
252        match self {
253            ContentEncoding::Utf8 => ENCODING_UTF8,
254            ContentEncoding::Binary => ENCODING_BINARY,
255            ContentEncoding::Custom(s) => s,
256        }
257    }
258}
259
260impl std::fmt::Display for ContentEncoding {
261    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262        write!(f, "{}", self.as_str())
263    }
264}
265
266impl std::str::FromStr for ContentEncoding {
267    type Err = String;
268
269    fn from_str(s: &str) -> Result<Self, Self::Err> {
270        match s {
271            ENCODING_UTF8 => Ok(ContentEncoding::Utf8),
272            ENCODING_BINARY => Ok(ContentEncoding::Binary),
273            other => Ok(ContentEncoding::Custom(other.to_string())),
274        }
275    }
276}
277
278impl From<&str> for ContentEncoding {
279    fn from(s: &str) -> Self {
280        match s {
281            ENCODING_UTF8 => ContentEncoding::Utf8,
282            ENCODING_BINARY => ContentEncoding::Binary,
283            other => ContentEncoding::Custom(other.to_string()),
284        }
285    }
286}
287
288impl AsRef<str> for ContentEncoding {
289    fn as_ref(&self) -> &str {
290        self.as_str()
291    }
292}
293
294/// Message headers (Celery protocol)
295#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
296pub struct MessageHeaders {
297    /// Task name (e.g., "tasks.add")
298    pub task: String,
299
300    /// Task ID (UUID)
301    pub id: Uuid,
302
303    /// Programming language ("rust", "py")
304    #[serde(default = "default_lang")]
305    pub lang: String,
306
307    /// Root task ID (for workflow tracking)
308    #[serde(skip_serializing_if = "Option::is_none")]
309    pub root_id: Option<Uuid>,
310
311    /// Parent task ID (for nested tasks)
312    #[serde(skip_serializing_if = "Option::is_none")]
313    pub parent_id: Option<Uuid>,
314
315    /// Group ID (for grouped tasks)
316    #[serde(skip_serializing_if = "Option::is_none")]
317    pub group: Option<Uuid>,
318
319    /// Maximum retries
320    #[serde(skip_serializing_if = "Option::is_none")]
321    pub retries: Option<u32>,
322
323    /// ETA (Estimated Time of Arrival) for delayed tasks
324    #[serde(skip_serializing_if = "Option::is_none")]
325    pub eta: Option<DateTime<Utc>>,
326
327    /// Task expiration timestamp
328    #[serde(skip_serializing_if = "Option::is_none")]
329    pub expires: Option<DateTime<Utc>>,
330
331    /// Additional custom headers
332    #[serde(flatten)]
333    pub extra: HashMap<String, serde_json::Value>,
334}
335
336fn default_lang() -> String {
337    DEFAULT_LANG.to_string()
338}
339
340impl MessageHeaders {
341    pub fn new(task: String, id: Uuid) -> Self {
342        Self {
343            task,
344            id,
345            lang: default_lang(),
346            root_id: None,
347            parent_id: None,
348            group: None,
349            retries: None,
350            eta: None,
351            expires: None,
352            extra: HashMap::new(),
353        }
354    }
355
356    /// Set the language field (builder pattern)
357    #[must_use]
358    pub fn with_lang(mut self, lang: String) -> Self {
359        self.lang = lang;
360        self
361    }
362
363    /// Set the root ID field (builder pattern)
364    #[must_use]
365    pub fn with_root_id(mut self, root_id: Uuid) -> Self {
366        self.root_id = Some(root_id);
367        self
368    }
369
370    /// Set the parent ID field (builder pattern)
371    #[must_use]
372    pub fn with_parent_id(mut self, parent_id: Uuid) -> Self {
373        self.parent_id = Some(parent_id);
374        self
375    }
376
377    /// Set the group field (builder pattern)
378    #[must_use]
379    pub fn with_group(mut self, group: Uuid) -> Self {
380        self.group = Some(group);
381        self
382    }
383
384    /// Set the retries field (builder pattern)
385    #[must_use]
386    pub fn with_retries(mut self, retries: u32) -> Self {
387        self.retries = Some(retries);
388        self
389    }
390
391    /// Set the ETA field (builder pattern)
392    #[must_use]
393    pub fn with_eta(mut self, eta: DateTime<Utc>) -> Self {
394        self.eta = Some(eta);
395        self
396    }
397
398    /// Set the expires field (builder pattern)
399    #[must_use]
400    pub fn with_expires(mut self, expires: DateTime<Utc>) -> Self {
401        self.expires = Some(expires);
402        self
403    }
404
405    /// Validate message headers
406    pub fn validate(&self) -> Result<(), ValidationError> {
407        if self.task.is_empty() {
408            return Err(ValidationError::EmptyTaskName);
409        }
410
411        if let Some(retries) = self.retries {
412            if retries > 1000 {
413                return Err(ValidationError::RetryLimitExceeded { retries, max: 1000 });
414            }
415        }
416
417        // Validate ETA and expiration relationship
418        if let (Some(eta), Some(expires)) = (self.eta, self.expires) {
419            if eta > expires {
420                return Err(ValidationError::EtaAfterExpiration);
421            }
422        }
423
424        Ok(())
425    }
426}
427
428/// Message properties (AMQP-like)
429#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
430pub struct MessageProperties {
431    /// Correlation ID for RPC-style calls
432    #[serde(skip_serializing_if = "Option::is_none")]
433    pub correlation_id: Option<String>,
434
435    /// Reply-to queue for results
436    #[serde(skip_serializing_if = "Option::is_none")]
437    pub reply_to: Option<String>,
438
439    /// Delivery mode (1 = non-persistent, 2 = persistent)
440    #[serde(default = "default_delivery_mode")]
441    pub delivery_mode: u8,
442
443    /// Priority (0-9, higher = more priority)
444    #[serde(skip_serializing_if = "Option::is_none")]
445    pub priority: Option<u8>,
446}
447
448const fn default_delivery_mode() -> u8 {
449    2 // Persistent by default
450}
451
452impl Default for MessageProperties {
453    fn default() -> Self {
454        Self {
455            correlation_id: None,
456            reply_to: None,
457            delivery_mode: default_delivery_mode(),
458            priority: None,
459        }
460    }
461}
462
463impl MessageProperties {
464    /// Create new MessageProperties with default values
465    pub fn new() -> Self {
466        Self::default()
467    }
468
469    /// Set correlation ID (builder pattern)
470    #[must_use]
471    pub fn with_correlation_id(mut self, correlation_id: String) -> Self {
472        self.correlation_id = Some(correlation_id);
473        self
474    }
475
476    /// Set reply-to queue (builder pattern)
477    #[must_use]
478    pub fn with_reply_to(mut self, reply_to: String) -> Self {
479        self.reply_to = Some(reply_to);
480        self
481    }
482
483    /// Set delivery mode (builder pattern)
484    #[must_use]
485    pub fn with_delivery_mode(mut self, delivery_mode: u8) -> Self {
486        self.delivery_mode = delivery_mode;
487        self
488    }
489
490    /// Set priority (builder pattern)
491    #[must_use]
492    pub fn with_priority(mut self, priority: u8) -> Self {
493        self.priority = Some(priority);
494        self
495    }
496
497    /// Validate message properties
498    pub fn validate(&self) -> Result<(), ValidationError> {
499        if self.delivery_mode != 1 && self.delivery_mode != 2 {
500            return Err(ValidationError::InvalidDeliveryMode {
501                mode: self.delivery_mode,
502            });
503        }
504
505        if let Some(priority) = self.priority {
506            if priority > 9 {
507                return Err(ValidationError::InvalidPriority { priority });
508            }
509        }
510
511        Ok(())
512    }
513}
514
515/// Complete Celery message
516#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
517pub struct Message {
518    /// Message headers
519    pub headers: MessageHeaders,
520
521    /// Message properties
522    pub properties: MessageProperties,
523
524    /// Serialized body (task arguments)
525    #[serde(with = "serde_bytes_opt")]
526    pub body: Vec<u8>,
527
528    /// Content type
529    #[serde(rename = "content-type")]
530    pub content_type: String,
531
532    /// Content encoding
533    #[serde(rename = "content-encoding")]
534    pub content_encoding: String,
535}
536
537// Custom serde module for optional byte arrays
538mod serde_bytes_opt {
539    use base64::Engine;
540    use serde::de::Error;
541    use serde::{Deserialize, Deserializer, Serializer};
542
543    pub fn serialize<S>(bytes: &Vec<u8>, serializer: S) -> Result<S::Ok, S::Error>
544    where
545        S: Serializer,
546    {
547        // Serialize as base64 string for JSON compatibility
548        let encoded = base64::engine::general_purpose::STANDARD.encode(bytes);
549        serializer.serialize_str(&encoded)
550    }
551
552    pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
553    where
554        D: Deserializer<'de>,
555    {
556        let s = String::deserialize(deserializer)?;
557        base64::engine::general_purpose::STANDARD
558            .decode(&s)
559            .map_err(Error::custom)
560    }
561}
562
563/// Task arguments (args, kwargs)
564#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
565pub struct TaskArgs {
566    /// Positional arguments
567    #[serde(default)]
568    pub args: Vec<serde_json::Value>,
569
570    /// Keyword arguments
571    #[serde(default)]
572    pub kwargs: HashMap<String, serde_json::Value>,
573}
574
575impl TaskArgs {
576    /// Create a new empty TaskArgs
577    pub fn new() -> Self {
578        Self::default()
579    }
580
581    /// Set all positional arguments at once (builder pattern)
582    #[must_use]
583    pub fn with_args(mut self, args: Vec<serde_json::Value>) -> Self {
584        self.args = args;
585        self
586    }
587
588    /// Set all keyword arguments at once (builder pattern)
589    #[must_use]
590    pub fn with_kwargs(mut self, kwargs: HashMap<String, serde_json::Value>) -> Self {
591        self.kwargs = kwargs;
592        self
593    }
594
595    /// Add a single positional argument
596    pub fn add_arg(&mut self, arg: serde_json::Value) {
597        self.args.push(arg);
598    }
599
600    /// Add a single keyword argument
601    pub fn add_kwarg(&mut self, key: String, value: serde_json::Value) {
602        self.kwargs.insert(key, value);
603    }
604
605    /// Check if both args and kwargs are empty
606    #[inline(always)]
607    pub fn is_empty(&self) -> bool {
608        self.args.is_empty() && self.kwargs.is_empty()
609    }
610
611    /// Get the total number of arguments (positional + keyword)
612    #[inline(always)]
613    pub fn len(&self) -> usize {
614        self.args.len() + self.kwargs.len()
615    }
616
617    /// Check if there are any positional arguments
618    #[inline(always)]
619    pub fn has_args(&self) -> bool {
620        !self.args.is_empty()
621    }
622
623    /// Check if there are any keyword arguments
624    #[inline(always)]
625    pub fn has_kwargs(&self) -> bool {
626        !self.kwargs.is_empty()
627    }
628
629    /// Clear all arguments
630    pub fn clear(&mut self) {
631        self.args.clear();
632        self.kwargs.clear();
633    }
634
635    /// Get a positional argument by index
636    #[inline]
637    pub fn get_arg(&self, index: usize) -> Option<&serde_json::Value> {
638        self.args.get(index)
639    }
640
641    /// Get a keyword argument by key
642    #[inline]
643    pub fn get_kwarg(&self, key: &str) -> Option<&serde_json::Value> {
644        self.kwargs.get(key)
645    }
646
647    /// Create TaskArgs from a JSON string
648    pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
649        serde_json::from_str(json)
650    }
651
652    /// Convert TaskArgs to a JSON string
653    pub fn to_json(&self) -> Result<String, serde_json::Error> {
654        serde_json::to_string(self)
655    }
656
657    /// Convert TaskArgs to pretty-printed JSON
658    pub fn to_json_pretty(&self) -> Result<String, serde_json::Error> {
659        serde_json::to_string_pretty(self)
660    }
661}
662
663// Index trait for accessing positional arguments by index
664impl std::ops::Index<usize> for TaskArgs {
665    type Output = serde_json::Value;
666
667    #[inline]
668    fn index(&self, index: usize) -> &Self::Output {
669        &self.args[index]
670    }
671}
672
673// IndexMut trait for mutating positional arguments by index
674impl std::ops::IndexMut<usize> for TaskArgs {
675    #[inline]
676    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
677        &mut self.args[index]
678    }
679}
680
681// Index trait for accessing keyword arguments by string key
682impl std::ops::Index<&str> for TaskArgs {
683    type Output = serde_json::Value;
684
685    #[inline]
686    fn index(&self, key: &str) -> &Self::Output {
687        &self.kwargs[key]
688    }
689}
690
691// IntoIterator for TaskArgs - iterates over positional args
692impl IntoIterator for TaskArgs {
693    type Item = serde_json::Value;
694    type IntoIter = std::vec::IntoIter<serde_json::Value>;
695
696    fn into_iter(self) -> Self::IntoIter {
697        self.args.into_iter()
698    }
699}
700
701// IntoIterator for &TaskArgs - iterates over positional args by reference
702impl<'a> IntoIterator for &'a TaskArgs {
703    type Item = &'a serde_json::Value;
704    type IntoIter = std::slice::Iter<'a, serde_json::Value>;
705
706    fn into_iter(self) -> Self::IntoIter {
707        self.args.iter()
708    }
709}
710
711// Extend trait for TaskArgs - extend with more positional arguments
712impl Extend<serde_json::Value> for TaskArgs {
713    fn extend<T: IntoIterator<Item = serde_json::Value>>(&mut self, iter: T) {
714        self.args.extend(iter);
715    }
716}
717
718// Extend trait for TaskArgs with key-value pairs for kwargs
719impl Extend<(String, serde_json::Value)> for TaskArgs {
720    fn extend<T: IntoIterator<Item = (String, serde_json::Value)>>(&mut self, iter: T) {
721        self.kwargs.extend(iter);
722    }
723}
724
725// FromIterator for TaskArgs - build from iterator of positional args
726impl FromIterator<serde_json::Value> for TaskArgs {
727    fn from_iter<T: IntoIterator<Item = serde_json::Value>>(iter: T) -> Self {
728        Self {
729            args: iter.into_iter().collect(),
730            kwargs: HashMap::new(),
731        }
732    }
733}