use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::state::TaskState;
use crate::task_id::TaskId;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskMessage {
pub id: TaskId,
pub task_name: String,
pub queue: String,
pub payload: serde_json::Value,
pub state: TaskState,
pub retries: u32,
pub max_retries: u32,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub eta: Option<DateTime<Utc>>,
pub headers: HashMap<String, String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub parent_id: Option<TaskId>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub correlation_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub group_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub group_total: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub chord_callback: Option<Box<TaskMessage>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub priority: Option<u8>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dedup_key: Option<String>,
}
impl TaskMessage {
pub fn new(
task_name: impl Into<String>,
queue: impl Into<String>,
payload: serde_json::Value,
) -> Self {
let now = Utc::now();
Self {
id: TaskId::new(),
task_name: task_name.into(),
queue: queue.into(),
payload,
state: TaskState::Pending,
retries: 0,
max_retries: 3,
created_at: now,
updated_at: now,
eta: None,
headers: HashMap::new(),
parent_id: None,
correlation_id: None,
group_id: None,
group_total: None,
chord_callback: None,
priority: None,
dedup_key: None,
}
}
pub fn with_max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = max_retries;
self
}
pub fn with_eta(mut self, eta: DateTime<Utc>) -> Self {
self.eta = Some(eta);
self
}
pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.insert(key.into(), value.into());
self
}
pub fn with_parent_id(mut self, parent_id: TaskId) -> Self {
self.parent_id = Some(parent_id);
self
}
pub fn with_correlation_id(mut self, correlation_id: impl Into<String>) -> Self {
self.correlation_id = Some(correlation_id.into());
self
}
pub fn with_group(mut self, group_id: impl Into<String>, group_total: u32) -> Self {
self.group_id = Some(group_id.into());
self.group_total = Some(group_total);
self
}
pub fn with_chord_callback(mut self, callback: TaskMessage) -> Self {
self.chord_callback = Some(Box::new(callback));
self
}
pub fn with_priority(mut self, priority: u8) -> Self {
if priority > 9 {
tracing::warn!(
requested = priority,
clamped = 9,
"priority clamped to max value 9"
);
}
self.priority = Some(priority.min(9));
self
}
pub fn with_dedup_key(mut self, key: impl Into<String>) -> Self {
self.dedup_key = Some(key.into());
self
}
pub fn with_content_dedup(mut self) -> Self {
let input = format!("{}:{}", self.task_name, self.payload);
self.dedup_key = Some(format!("content:{:x}", fnv1a_64(input.as_bytes())));
self
}
}
fn fnv1a_64(data: &[u8]) -> u64 {
let mut hash: u64 = 0xcbf29ce484222325;
for &byte in data {
hash ^= byte as u64;
hash = hash.wrapping_mul(0x100000001b3);
}
hash
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn task_message_serde_roundtrip() {
let msg = TaskMessage::new(
"send_email",
"default",
serde_json::json!({"to": "a@b.com"}),
)
.with_max_retries(5)
.with_header("trace_id", "abc123");
let json = serde_json::to_string(&msg).unwrap();
let deserialized: TaskMessage = serde_json::from_str(&json).unwrap();
assert_eq!(msg.id, deserialized.id);
assert_eq!(msg.task_name, deserialized.task_name);
assert_eq!(msg.queue, deserialized.queue);
assert_eq!(msg.max_retries, deserialized.max_retries);
assert_eq!(msg.headers.get("trace_id"), Some(&"abc123".to_string()));
}
#[test]
fn task_message_defaults() {
let msg = TaskMessage::new("test", "default", serde_json::Value::Null);
assert_eq!(msg.state, TaskState::Pending);
assert_eq!(msg.retries, 0);
assert_eq!(msg.max_retries, 3);
assert!(msg.eta.is_none());
assert!(msg.headers.is_empty());
assert!(msg.parent_id.is_none());
assert!(msg.correlation_id.is_none());
assert!(msg.group_id.is_none());
assert!(msg.group_total.is_none());
assert!(msg.chord_callback.is_none());
assert!(msg.priority.is_none());
assert!(msg.dedup_key.is_none());
}
#[test]
fn backward_compat_deserialization() {
let old_json = serde_json::json!({
"id": "01234567-89ab-cdef-0123-456789abcdef",
"task_name": "send_email",
"queue": "default",
"payload": {"to": "a@b.com"},
"state": "pending",
"retries": 0,
"max_retries": 3,
"created_at": "2025-01-01T00:00:00Z",
"updated_at": "2025-01-01T00:00:00Z",
"eta": null,
"headers": {}
});
let msg: TaskMessage = serde_json::from_value(old_json).unwrap();
assert_eq!(msg.task_name, "send_email");
assert!(msg.parent_id.is_none());
assert!(msg.correlation_id.is_none());
assert!(msg.group_id.is_none());
assert!(msg.group_total.is_none());
assert!(msg.chord_callback.is_none());
assert!(msg.priority.is_none());
assert!(msg.dedup_key.is_none());
}
#[test]
fn priority_and_dedup_roundtrip() {
let msg = TaskMessage::new("task", "default", serde_json::json!({"x": 1}))
.with_priority(5)
.with_dedup_key("my-key");
let json = serde_json::to_string(&msg).unwrap();
let deserialized: TaskMessage = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.priority, Some(5));
assert_eq!(deserialized.dedup_key.as_deref(), Some("my-key"));
}
#[test]
fn priority_clamped_to_9() {
let msg = TaskMessage::new("task", "default", serde_json::Value::Null).with_priority(20);
assert_eq!(msg.priority, Some(9));
}
#[test]
fn content_dedup_deterministic() {
let msg1 = TaskMessage::new("task", "q", serde_json::json!({"a": 1})).with_content_dedup();
let msg2 = TaskMessage::new("task", "q", serde_json::json!({"a": 1})).with_content_dedup();
assert_eq!(msg1.dedup_key, msg2.dedup_key);
let msg3 =
TaskMessage::new("other_task", "q", serde_json::json!({"a": 1})).with_content_dedup();
assert_ne!(msg1.dedup_key, msg3.dedup_key);
}
#[test]
fn workflow_metadata_roundtrip() {
let callback = TaskMessage::new("callback", "default", serde_json::json!({}));
let msg = TaskMessage::new("task", "default", serde_json::json!({}))
.with_parent_id(TaskId::new())
.with_correlation_id("corr-123")
.with_group("group-1", 5)
.with_chord_callback(callback);
let json = serde_json::to_string(&msg).unwrap();
let deserialized: TaskMessage = serde_json::from_str(&json).unwrap();
assert_eq!(msg.parent_id, deserialized.parent_id);
assert_eq!(msg.correlation_id, deserialized.correlation_id);
assert_eq!(msg.group_id, deserialized.group_id);
assert_eq!(msg.group_total, deserialized.group_total);
assert!(deserialized.chord_callback.is_some());
assert_eq!(deserialized.chord_callback.unwrap().task_name, "callback");
}
}