Skip to main content

kojin_core/
message.rs

1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4
5use crate::state::TaskState;
6use crate::task_id::TaskId;
7
8/// A task message that flows through the broker.
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct TaskMessage {
11    /// Unique task identifier.
12    pub id: TaskId,
13    /// Registered task name (e.g., "send_email").
14    pub task_name: String,
15    /// Target queue name.
16    pub queue: String,
17    /// Serialized task payload.
18    pub payload: serde_json::Value,
19    /// Current task state.
20    pub state: TaskState,
21    /// Current retry count.
22    pub retries: u32,
23    /// Maximum allowed retries.
24    pub max_retries: u32,
25    /// When the message was created.
26    pub created_at: DateTime<Utc>,
27    /// When the message was last updated.
28    pub updated_at: DateTime<Utc>,
29    /// Optional ETA — earliest time the task should execute.
30    pub eta: Option<DateTime<Utc>>,
31    /// Arbitrary headers for middleware / tracing propagation.
32    pub headers: HashMap<String, String>,
33
34    // -- Workflow metadata (Phase 2) --
35    /// Parent task ID for workflow tracking.
36    #[serde(default, skip_serializing_if = "Option::is_none")]
37    pub parent_id: Option<TaskId>,
38    /// Correlation ID for tracing an entire workflow.
39    #[serde(default, skip_serializing_if = "Option::is_none")]
40    pub correlation_id: Option<String>,
41    /// Group ID this task belongs to.
42    #[serde(default, skip_serializing_if = "Option::is_none")]
43    pub group_id: Option<String>,
44    /// Total number of tasks in the group.
45    #[serde(default, skip_serializing_if = "Option::is_none")]
46    pub group_total: Option<u32>,
47    /// Chord callback to enqueue when all group members complete.
48    #[serde(default, skip_serializing_if = "Option::is_none")]
49    pub chord_callback: Option<Box<TaskMessage>>,
50}
51
52impl TaskMessage {
53    /// Create a new task message with defaults.
54    pub fn new(
55        task_name: impl Into<String>,
56        queue: impl Into<String>,
57        payload: serde_json::Value,
58    ) -> Self {
59        let now = Utc::now();
60        Self {
61            id: TaskId::new(),
62            task_name: task_name.into(),
63            queue: queue.into(),
64            payload,
65            state: TaskState::Pending,
66            retries: 0,
67            max_retries: 3,
68            created_at: now,
69            updated_at: now,
70            eta: None,
71            headers: HashMap::new(),
72            parent_id: None,
73            correlation_id: None,
74            group_id: None,
75            group_total: None,
76            chord_callback: None,
77        }
78    }
79
80    /// Set max retries.
81    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
82        self.max_retries = max_retries;
83        self
84    }
85
86    /// Set ETA.
87    pub fn with_eta(mut self, eta: DateTime<Utc>) -> Self {
88        self.eta = Some(eta);
89        self
90    }
91
92    /// Add a header.
93    pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
94        self.headers.insert(key.into(), value.into());
95        self
96    }
97
98    /// Set parent task ID for workflow tracking.
99    pub fn with_parent_id(mut self, parent_id: TaskId) -> Self {
100        self.parent_id = Some(parent_id);
101        self
102    }
103
104    /// Set correlation ID for tracing an entire workflow.
105    pub fn with_correlation_id(mut self, correlation_id: impl Into<String>) -> Self {
106        self.correlation_id = Some(correlation_id.into());
107        self
108    }
109
110    /// Set group metadata.
111    pub fn with_group(mut self, group_id: impl Into<String>, group_total: u32) -> Self {
112        self.group_id = Some(group_id.into());
113        self.group_total = Some(group_total);
114        self
115    }
116
117    /// Set chord callback.
118    pub fn with_chord_callback(mut self, callback: TaskMessage) -> Self {
119        self.chord_callback = Some(Box::new(callback));
120        self
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127
128    #[test]
129    fn task_message_serde_roundtrip() {
130        let msg = TaskMessage::new(
131            "send_email",
132            "default",
133            serde_json::json!({"to": "a@b.com"}),
134        )
135        .with_max_retries(5)
136        .with_header("trace_id", "abc123");
137
138        let json = serde_json::to_string(&msg).unwrap();
139        let deserialized: TaskMessage = serde_json::from_str(&json).unwrap();
140
141        assert_eq!(msg.id, deserialized.id);
142        assert_eq!(msg.task_name, deserialized.task_name);
143        assert_eq!(msg.queue, deserialized.queue);
144        assert_eq!(msg.max_retries, deserialized.max_retries);
145        assert_eq!(msg.headers.get("trace_id"), Some(&"abc123".to_string()));
146    }
147
148    #[test]
149    fn task_message_defaults() {
150        let msg = TaskMessage::new("test", "default", serde_json::Value::Null);
151        assert_eq!(msg.state, TaskState::Pending);
152        assert_eq!(msg.retries, 0);
153        assert_eq!(msg.max_retries, 3);
154        assert!(msg.eta.is_none());
155        assert!(msg.headers.is_empty());
156        assert!(msg.parent_id.is_none());
157        assert!(msg.correlation_id.is_none());
158        assert!(msg.group_id.is_none());
159        assert!(msg.group_total.is_none());
160        assert!(msg.chord_callback.is_none());
161    }
162
163    #[test]
164    fn backward_compat_deserialization() {
165        // Simulate a v0.1.0 message without workflow fields
166        let old_json = serde_json::json!({
167            "id": "01234567-89ab-cdef-0123-456789abcdef",
168            "task_name": "send_email",
169            "queue": "default",
170            "payload": {"to": "a@b.com"},
171            "state": "pending",
172            "retries": 0,
173            "max_retries": 3,
174            "created_at": "2025-01-01T00:00:00Z",
175            "updated_at": "2025-01-01T00:00:00Z",
176            "eta": null,
177            "headers": {}
178        });
179        let msg: TaskMessage = serde_json::from_value(old_json).unwrap();
180        assert_eq!(msg.task_name, "send_email");
181        assert!(msg.parent_id.is_none());
182        assert!(msg.correlation_id.is_none());
183        assert!(msg.group_id.is_none());
184        assert!(msg.group_total.is_none());
185        assert!(msg.chord_callback.is_none());
186    }
187
188    #[test]
189    fn workflow_metadata_roundtrip() {
190        let callback = TaskMessage::new("callback", "default", serde_json::json!({}));
191        let msg = TaskMessage::new("task", "default", serde_json::json!({}))
192            .with_parent_id(TaskId::new())
193            .with_correlation_id("corr-123")
194            .with_group("group-1", 5)
195            .with_chord_callback(callback);
196
197        let json = serde_json::to_string(&msg).unwrap();
198        let deserialized: TaskMessage = serde_json::from_str(&json).unwrap();
199
200        assert_eq!(msg.parent_id, deserialized.parent_id);
201        assert_eq!(msg.correlation_id, deserialized.correlation_id);
202        assert_eq!(msg.group_id, deserialized.group_id);
203        assert_eq!(msg.group_total, deserialized.group_total);
204        assert!(deserialized.chord_callback.is_some());
205        assert_eq!(deserialized.chord_callback.unwrap().task_name, "callback");
206    }
207}