1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4
5use crate::state::TaskState;
6use crate::task_id::TaskId;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct TaskMessage {
11 pub id: TaskId,
13 pub task_name: String,
15 pub queue: String,
17 pub payload: serde_json::Value,
19 pub state: TaskState,
21 pub retries: u32,
23 pub max_retries: u32,
25 pub created_at: DateTime<Utc>,
27 pub updated_at: DateTime<Utc>,
29 pub eta: Option<DateTime<Utc>>,
31 pub headers: HashMap<String, String>,
33
34 #[serde(default, skip_serializing_if = "Option::is_none")]
37 pub parent_id: Option<TaskId>,
38 #[serde(default, skip_serializing_if = "Option::is_none")]
40 pub correlation_id: Option<String>,
41 #[serde(default, skip_serializing_if = "Option::is_none")]
43 pub group_id: Option<String>,
44 #[serde(default, skip_serializing_if = "Option::is_none")]
46 pub group_total: Option<u32>,
47 #[serde(default, skip_serializing_if = "Option::is_none")]
49 pub chord_callback: Option<Box<TaskMessage>>,
50}
51
52impl TaskMessage {
53 pub fn new(
55 task_name: impl Into<String>,
56 queue: impl Into<String>,
57 payload: serde_json::Value,
58 ) -> Self {
59 let now = Utc::now();
60 Self {
61 id: TaskId::new(),
62 task_name: task_name.into(),
63 queue: queue.into(),
64 payload,
65 state: TaskState::Pending,
66 retries: 0,
67 max_retries: 3,
68 created_at: now,
69 updated_at: now,
70 eta: None,
71 headers: HashMap::new(),
72 parent_id: None,
73 correlation_id: None,
74 group_id: None,
75 group_total: None,
76 chord_callback: None,
77 }
78 }
79
80 pub fn with_max_retries(mut self, max_retries: u32) -> Self {
82 self.max_retries = max_retries;
83 self
84 }
85
86 pub fn with_eta(mut self, eta: DateTime<Utc>) -> Self {
88 self.eta = Some(eta);
89 self
90 }
91
92 pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
94 self.headers.insert(key.into(), value.into());
95 self
96 }
97
98 pub fn with_parent_id(mut self, parent_id: TaskId) -> Self {
100 self.parent_id = Some(parent_id);
101 self
102 }
103
104 pub fn with_correlation_id(mut self, correlation_id: impl Into<String>) -> Self {
106 self.correlation_id = Some(correlation_id.into());
107 self
108 }
109
110 pub fn with_group(mut self, group_id: impl Into<String>, group_total: u32) -> Self {
112 self.group_id = Some(group_id.into());
113 self.group_total = Some(group_total);
114 self
115 }
116
117 pub fn with_chord_callback(mut self, callback: TaskMessage) -> Self {
119 self.chord_callback = Some(Box::new(callback));
120 self
121 }
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127
128 #[test]
129 fn task_message_serde_roundtrip() {
130 let msg = TaskMessage::new(
131 "send_email",
132 "default",
133 serde_json::json!({"to": "a@b.com"}),
134 )
135 .with_max_retries(5)
136 .with_header("trace_id", "abc123");
137
138 let json = serde_json::to_string(&msg).unwrap();
139 let deserialized: TaskMessage = serde_json::from_str(&json).unwrap();
140
141 assert_eq!(msg.id, deserialized.id);
142 assert_eq!(msg.task_name, deserialized.task_name);
143 assert_eq!(msg.queue, deserialized.queue);
144 assert_eq!(msg.max_retries, deserialized.max_retries);
145 assert_eq!(msg.headers.get("trace_id"), Some(&"abc123".to_string()));
146 }
147
148 #[test]
149 fn task_message_defaults() {
150 let msg = TaskMessage::new("test", "default", serde_json::Value::Null);
151 assert_eq!(msg.state, TaskState::Pending);
152 assert_eq!(msg.retries, 0);
153 assert_eq!(msg.max_retries, 3);
154 assert!(msg.eta.is_none());
155 assert!(msg.headers.is_empty());
156 assert!(msg.parent_id.is_none());
157 assert!(msg.correlation_id.is_none());
158 assert!(msg.group_id.is_none());
159 assert!(msg.group_total.is_none());
160 assert!(msg.chord_callback.is_none());
161 }
162
163 #[test]
164 fn backward_compat_deserialization() {
165 let old_json = serde_json::json!({
167 "id": "01234567-89ab-cdef-0123-456789abcdef",
168 "task_name": "send_email",
169 "queue": "default",
170 "payload": {"to": "a@b.com"},
171 "state": "pending",
172 "retries": 0,
173 "max_retries": 3,
174 "created_at": "2025-01-01T00:00:00Z",
175 "updated_at": "2025-01-01T00:00:00Z",
176 "eta": null,
177 "headers": {}
178 });
179 let msg: TaskMessage = serde_json::from_value(old_json).unwrap();
180 assert_eq!(msg.task_name, "send_email");
181 assert!(msg.parent_id.is_none());
182 assert!(msg.correlation_id.is_none());
183 assert!(msg.group_id.is_none());
184 assert!(msg.group_total.is_none());
185 assert!(msg.chord_callback.is_none());
186 }
187
188 #[test]
189 fn workflow_metadata_roundtrip() {
190 let callback = TaskMessage::new("callback", "default", serde_json::json!({}));
191 let msg = TaskMessage::new("task", "default", serde_json::json!({}))
192 .with_parent_id(TaskId::new())
193 .with_correlation_id("corr-123")
194 .with_group("group-1", 5)
195 .with_chord_callback(callback);
196
197 let json = serde_json::to_string(&msg).unwrap();
198 let deserialized: TaskMessage = serde_json::from_str(&json).unwrap();
199
200 assert_eq!(msg.parent_id, deserialized.parent_id);
201 assert_eq!(msg.correlation_id, deserialized.correlation_id);
202 assert_eq!(msg.group_id, deserialized.group_id);
203 assert_eq!(msg.group_total, deserialized.group_total);
204 assert!(deserialized.chord_callback.is_some());
205 assert_eq!(deserialized.chord_callback.unwrap().task_name, "callback");
206 }
207}