celers_protocol/
embed.rs

1//! Celery embedded body format
2//!
3//! This module provides support for the Celery Protocol v2 embedded body format.
4//! In Protocol v2, the message body is a tuple of `[args, kwargs, embed]` where:
5//!
6//! - `args` - Positional arguments (list)
7//! - `kwargs` - Keyword arguments (dict)
8//! - `embed` - Embedded metadata (callbacks, errbacks, chain, chord, etc.)
9//!
10//! # Example
11//!
12//! ```
13//! use celers_protocol::embed::{EmbeddedBody, EmbedOptions};
14//! use serde_json::json;
15//!
16//! let body = EmbeddedBody::new()
17//!     .with_args(vec![json!(1), json!(2)])
18//!     .with_kwarg("debug", json!(true));
19//!
20//! let encoded = body.encode().unwrap();
21//! let decoded = EmbeddedBody::decode(&encoded).unwrap();
22//! assert_eq!(decoded.args, vec![json!(1), json!(2)]);
23//! ```
24
25use serde::{Deserialize, Serialize};
26use serde_json::Value;
27use std::collections::HashMap;
28use uuid::Uuid;
29
30/// Callback signature for link/errback
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct CallbackSignature {
33    /// Task name
34    pub task: String,
35
36    /// Task ID (optional, will be generated if not provided)
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub task_id: Option<Uuid>,
39
40    /// Positional arguments
41    #[serde(default)]
42    pub args: Vec<Value>,
43
44    /// Keyword arguments
45    #[serde(default)]
46    pub kwargs: HashMap<String, Value>,
47
48    /// Task options
49    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
50    pub options: HashMap<String, Value>,
51
52    /// Immutable flag (don't append parent result)
53    #[serde(default)]
54    pub immutable: bool,
55
56    /// Subtask type (for internal use)
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub subtask_type: Option<String>,
59}
60
61impl CallbackSignature {
62    /// Create a new callback signature
63    pub fn new(task: impl Into<String>) -> Self {
64        Self {
65            task: task.into(),
66            task_id: None,
67            args: Vec::new(),
68            kwargs: HashMap::new(),
69            options: HashMap::new(),
70            immutable: false,
71            subtask_type: None,
72        }
73    }
74
75    /// Set task ID
76    #[must_use]
77    pub fn with_task_id(mut self, task_id: Uuid) -> Self {
78        self.task_id = Some(task_id);
79        self
80    }
81
82    /// Set positional arguments
83    #[must_use]
84    pub fn with_args(mut self, args: Vec<Value>) -> Self {
85        self.args = args;
86        self
87    }
88
89    /// Add a keyword argument
90    #[must_use]
91    pub fn with_kwarg(mut self, key: impl Into<String>, value: Value) -> Self {
92        self.kwargs.insert(key.into(), value);
93        self
94    }
95
96    /// Set as immutable
97    #[must_use]
98    pub fn immutable(mut self) -> Self {
99        self.immutable = true;
100        self
101    }
102
103    /// Add an option
104    #[must_use]
105    pub fn with_option(mut self, key: impl Into<String>, value: Value) -> Self {
106        self.options.insert(key.into(), value);
107        self
108    }
109}
110
111/// Embed options in the message body
112#[derive(Debug, Clone, Default, Serialize, Deserialize)]
113pub struct EmbedOptions {
114    /// Callbacks to execute on success (link)
115    #[serde(default, skip_serializing_if = "Vec::is_empty")]
116    pub callbacks: Vec<CallbackSignature>,
117
118    /// Callbacks to execute on error (errback)
119    #[serde(default, skip_serializing_if = "Vec::is_empty")]
120    pub errbacks: Vec<CallbackSignature>,
121
122    /// Chain of tasks to execute after this one
123    #[serde(default, skip_serializing_if = "Vec::is_empty")]
124    pub chain: Vec<CallbackSignature>,
125
126    /// Chord callback (executed after group completes)
127    #[serde(skip_serializing_if = "Option::is_none")]
128    pub chord: Option<CallbackSignature>,
129
130    /// Group ID
131    #[serde(skip_serializing_if = "Option::is_none")]
132    pub group: Option<Uuid>,
133
134    /// Parent task ID
135    #[serde(skip_serializing_if = "Option::is_none")]
136    pub parent_id: Option<Uuid>,
137
138    /// Root task ID
139    #[serde(skip_serializing_if = "Option::is_none")]
140    pub root_id: Option<Uuid>,
141
142    /// Additional custom embed fields
143    #[serde(flatten)]
144    pub extra: HashMap<String, Value>,
145}
146
147impl EmbedOptions {
148    /// Create new empty embed options
149    pub fn new() -> Self {
150        Self::default()
151    }
152
153    /// Add a success callback (link)
154    #[must_use]
155    pub fn with_callback(mut self, callback: CallbackSignature) -> Self {
156        self.callbacks.push(callback);
157        self
158    }
159
160    /// Add an error callback (errback)
161    #[must_use]
162    pub fn with_errback(mut self, errback: CallbackSignature) -> Self {
163        self.errbacks.push(errback);
164        self
165    }
166
167    /// Add a chain task
168    #[must_use]
169    pub fn with_chain_task(mut self, task: CallbackSignature) -> Self {
170        self.chain.push(task);
171        self
172    }
173
174    /// Set the chord callback
175    #[must_use]
176    pub fn with_chord(mut self, chord: CallbackSignature) -> Self {
177        self.chord = Some(chord);
178        self
179    }
180
181    /// Set the group ID
182    #[must_use]
183    pub fn with_group(mut self, group: Uuid) -> Self {
184        self.group = Some(group);
185        self
186    }
187
188    /// Set the parent task ID
189    #[must_use]
190    pub fn with_parent(mut self, parent_id: Uuid) -> Self {
191        self.parent_id = Some(parent_id);
192        self
193    }
194
195    /// Set the root task ID
196    #[must_use]
197    pub fn with_root(mut self, root_id: Uuid) -> Self {
198        self.root_id = Some(root_id);
199        self
200    }
201
202    /// Check if there are any callbacks
203    pub fn has_callbacks(&self) -> bool {
204        !self.callbacks.is_empty()
205    }
206
207    /// Check if there are any errbacks
208    pub fn has_errbacks(&self) -> bool {
209        !self.errbacks.is_empty()
210    }
211
212    /// Check if there is a chain
213    pub fn has_chain(&self) -> bool {
214        !self.chain.is_empty()
215    }
216
217    /// Check if there is a chord
218    pub fn has_chord(&self) -> bool {
219        self.chord.is_some()
220    }
221
222    /// Check if this has any workflow elements
223    pub fn has_workflow(&self) -> bool {
224        self.has_callbacks() || self.has_errbacks() || self.has_chain() || self.has_chord()
225    }
226}
227
228/// Complete embedded body format [args, kwargs, embed]
229#[derive(Debug, Clone, Default)]
230pub struct EmbeddedBody {
231    /// Positional arguments
232    pub args: Vec<Value>,
233
234    /// Keyword arguments
235    pub kwargs: HashMap<String, Value>,
236
237    /// Embed options
238    pub embed: EmbedOptions,
239}
240
241impl EmbeddedBody {
242    /// Create a new embedded body
243    pub fn new() -> Self {
244        Self::default()
245    }
246
247    /// Set positional arguments
248    #[must_use]
249    pub fn with_args(mut self, args: Vec<Value>) -> Self {
250        self.args = args;
251        self
252    }
253
254    /// Add a positional argument
255    #[must_use]
256    pub fn with_arg(mut self, arg: Value) -> Self {
257        self.args.push(arg);
258        self
259    }
260
261    /// Set keyword arguments
262    #[must_use]
263    pub fn with_kwargs(mut self, kwargs: HashMap<String, Value>) -> Self {
264        self.kwargs = kwargs;
265        self
266    }
267
268    /// Add a keyword argument
269    #[must_use]
270    pub fn with_kwarg(mut self, key: impl Into<String>, value: Value) -> Self {
271        self.kwargs.insert(key.into(), value);
272        self
273    }
274
275    /// Set embed options
276    #[must_use]
277    pub fn with_embed(mut self, embed: EmbedOptions) -> Self {
278        self.embed = embed;
279        self
280    }
281
282    /// Add a success callback
283    #[must_use]
284    pub fn with_callback(mut self, callback: CallbackSignature) -> Self {
285        self.embed.callbacks.push(callback);
286        self
287    }
288
289    /// Add an error callback
290    #[must_use]
291    pub fn with_errback(mut self, errback: CallbackSignature) -> Self {
292        self.embed.errbacks.push(errback);
293        self
294    }
295
296    /// Encode to JSON bytes (Celery wire format)
297    pub fn encode(&self) -> Result<Vec<u8>, serde_json::Error> {
298        // Convert embed to Value (empty object if no workflow)
299        let embed_value = if self.embed.has_workflow()
300            || self.embed.group.is_some()
301            || self.embed.parent_id.is_some()
302            || self.embed.root_id.is_some()
303        {
304            serde_json::to_value(&self.embed)?
305        } else {
306            Value::Object(serde_json::Map::new())
307        };
308
309        let tuple = (&self.args, &self.kwargs, embed_value);
310
311        serde_json::to_vec(&tuple)
312    }
313
314    /// Decode from JSON bytes
315    pub fn decode(bytes: &[u8]) -> Result<Self, serde_json::Error> {
316        let tuple: (Vec<Value>, HashMap<String, Value>, Value) = serde_json::from_slice(bytes)?;
317
318        let embed: EmbedOptions = if tuple.2.is_object()
319            && !tuple
320                .2
321                .as_object()
322                .expect("value should be an object")
323                .is_empty()
324        {
325            serde_json::from_value(tuple.2)?
326        } else {
327            EmbedOptions::default()
328        };
329
330        Ok(Self {
331            args: tuple.0,
332            kwargs: tuple.1,
333            embed,
334        })
335    }
336
337    /// Encode to JSON string
338    pub fn to_json_string(&self) -> Result<String, serde_json::Error> {
339        let bytes = self.encode()?;
340        Ok(String::from_utf8_lossy(&bytes).to_string())
341    }
342
343    /// Decode from JSON string
344    pub fn from_json_string(s: &str) -> Result<Self, serde_json::Error> {
345        Self::decode(s.as_bytes())
346    }
347}
348
349/// Serialize arguments for display/logging
350pub fn format_args(args: &[Value], kwargs: &HashMap<String, Value>) -> String {
351    let args_str: Vec<String> = args.iter().map(|v| v.to_string()).collect();
352    let kwargs_str: Vec<String> = kwargs.iter().map(|(k, v)| format!("{}={}", k, v)).collect();
353
354    let mut parts = args_str;
355    parts.extend(kwargs_str);
356    parts.join(", ")
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362    use serde_json::json;
363
364    #[test]
365    fn test_callback_signature_creation() {
366        let callback = CallbackSignature::new("tasks.process")
367            .with_args(vec![json!(1), json!(2)])
368            .with_kwarg("debug", json!(true))
369            .immutable();
370
371        assert_eq!(callback.task, "tasks.process");
372        assert_eq!(callback.args, vec![json!(1), json!(2)]);
373        assert_eq!(callback.kwargs.get("debug"), Some(&json!(true)));
374        assert!(callback.immutable);
375    }
376
377    #[test]
378    fn test_callback_signature_with_task_id() {
379        let task_id = Uuid::new_v4();
380        let callback = CallbackSignature::new("tasks.callback").with_task_id(task_id);
381
382        assert_eq!(callback.task_id, Some(task_id));
383    }
384
385    #[test]
386    fn test_embed_options_callbacks() {
387        let callback = CallbackSignature::new("tasks.success_handler");
388        let errback = CallbackSignature::new("tasks.error_handler");
389
390        let embed = EmbedOptions::new()
391            .with_callback(callback)
392            .with_errback(errback);
393
394        assert!(embed.has_callbacks());
395        assert!(embed.has_errbacks());
396        assert!(embed.has_workflow());
397    }
398
399    #[test]
400    fn test_embed_options_chain() {
401        let task1 = CallbackSignature::new("tasks.step1");
402        let task2 = CallbackSignature::new("tasks.step2");
403
404        let embed = EmbedOptions::new()
405            .with_chain_task(task1)
406            .with_chain_task(task2);
407
408        assert!(embed.has_chain());
409        assert_eq!(embed.chain.len(), 2);
410    }
411
412    #[test]
413    fn test_embed_options_chord() {
414        let chord_callback = CallbackSignature::new("tasks.chord_callback");
415        let group_id = Uuid::new_v4();
416
417        let embed = EmbedOptions::new()
418            .with_chord(chord_callback)
419            .with_group(group_id);
420
421        assert!(embed.has_chord());
422        assert_eq!(embed.group, Some(group_id));
423    }
424
425    #[test]
426    fn test_embedded_body_basic() {
427        let body = EmbeddedBody::new()
428            .with_args(vec![json!(1), json!(2)])
429            .with_kwarg("key", json!("value"));
430
431        assert_eq!(body.args, vec![json!(1), json!(2)]);
432        assert_eq!(body.kwargs.get("key"), Some(&json!("value")));
433    }
434
435    #[test]
436    fn test_embedded_body_encode_decode() {
437        let body = EmbeddedBody::new()
438            .with_args(vec![json!(10), json!(20)])
439            .with_kwarg("multiplier", json!(2));
440
441        let encoded = body.encode().unwrap();
442        let decoded = EmbeddedBody::decode(&encoded).unwrap();
443
444        assert_eq!(decoded.args, body.args);
445        assert_eq!(decoded.kwargs, body.kwargs);
446    }
447
448    #[test]
449    fn test_embedded_body_with_callbacks() {
450        let callback = CallbackSignature::new("tasks.on_success");
451        let body = EmbeddedBody::new()
452            .with_args(vec![json!("test")])
453            .with_callback(callback);
454
455        let encoded = body.encode().unwrap();
456        let decoded = EmbeddedBody::decode(&encoded).unwrap();
457
458        assert!(decoded.embed.has_callbacks());
459        assert_eq!(decoded.embed.callbacks[0].task, "tasks.on_success");
460    }
461
462    #[test]
463    fn test_embedded_body_wire_format() {
464        let body = EmbeddedBody::new()
465            .with_args(vec![json!(1), json!(2)])
466            .with_kwarg("x", json!(3));
467
468        let json_str = body.to_json_string().unwrap();
469
470        // Should be [args, kwargs, embed] format
471        let parsed: Value = serde_json::from_str(&json_str).unwrap();
472        assert!(parsed.is_array());
473
474        let arr = parsed.as_array().unwrap();
475        assert_eq!(arr.len(), 3);
476        assert!(arr[0].is_array()); // args
477        assert!(arr[1].is_object()); // kwargs
478        assert!(arr[2].is_object()); // embed
479    }
480
481    #[test]
482    fn test_embedded_body_from_json_string() {
483        let json_str = r#"[[1, 2], {"key": "value"}, {}]"#;
484        let body = EmbeddedBody::from_json_string(json_str).unwrap();
485
486        assert_eq!(body.args, vec![json!(1), json!(2)]);
487        assert_eq!(body.kwargs.get("key"), Some(&json!("value")));
488    }
489
490    #[test]
491    fn test_embedded_body_python_compatibility() {
492        // This is the exact format Python Celery uses
493        let python_body = r#"[[4, 5], {"debug": true}, {"callbacks": [{"task": "tasks.callback", "args": [], "kwargs": {}, "options": {}, "immutable": false}]}]"#;
494
495        let body = EmbeddedBody::from_json_string(python_body).unwrap();
496
497        assert_eq!(body.args, vec![json!(4), json!(5)]);
498        assert_eq!(body.kwargs.get("debug"), Some(&json!(true)));
499        assert!(body.embed.has_callbacks());
500        assert_eq!(body.embed.callbacks[0].task, "tasks.callback");
501    }
502
503    #[test]
504    fn test_format_args() {
505        let args = vec![json!(1), json!("hello")];
506        let mut kwargs = HashMap::new();
507        kwargs.insert("x".to_string(), json!(10));
508        kwargs.insert("y".to_string(), json!(20));
509
510        let formatted = format_args(&args, &kwargs);
511
512        assert!(formatted.contains("1"));
513        assert!(formatted.contains("\"hello\""));
514        assert!(formatted.contains("x=10") || formatted.contains("y=20"));
515    }
516
517    #[test]
518    fn test_embed_options_workflow_ids() {
519        let parent_id = Uuid::new_v4();
520        let root_id = Uuid::new_v4();
521
522        let embed = EmbedOptions::new()
523            .with_parent(parent_id)
524            .with_root(root_id);
525
526        assert_eq!(embed.parent_id, Some(parent_id));
527        assert_eq!(embed.root_id, Some(root_id));
528    }
529
530    #[test]
531    fn test_callback_signature_serialization() {
532        let callback = CallbackSignature::new("tasks.test")
533            .with_args(vec![json!(1)])
534            .with_kwarg("key", json!("val"))
535            .with_option("queue", json!("high-priority"));
536
537        let json = serde_json::to_string(&callback).unwrap();
538        let decoded: CallbackSignature = serde_json::from_str(&json).unwrap();
539
540        assert_eq!(decoded.task, "tasks.test");
541        assert_eq!(decoded.args, vec![json!(1)]);
542        assert_eq!(decoded.options.get("queue"), Some(&json!("high-priority")));
543    }
544}