mod task_codec;
use crate::backend::option::{RateLimit, RetryPolicy, TaskOptions};
use crate::base::{keys::TaskState, Broker};
use crate::error::{Error, Result};
use crate::inspector::InspectorTrait;
use crate::proto;
use chrono::{DateTime, Utc};
use http::{header::CONTENT_TYPE, HeaderName, HeaderValue};
#[cfg(feature = "json")]
pub use task_codec::JsonPayloadCodec;
#[cfg(feature = "msgpack")]
pub use task_codec::MsgPackPayloadCodec;
pub use task_codec::PayloadCodec;
pub type HeaderMap = http::HeaderMap;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::{Debug, Formatter};
use std::sync::Arc;
use std::time::Duration;
use tracing::warn;
use uuid::Uuid;
pub trait IntoHeaders {
fn into_headers(self) -> HeaderMap;
}
impl IntoHeaders for HeaderMap {
fn into_headers(self) -> HeaderMap {
self
}
}
impl IntoHeaders for &HeaderMap {
fn into_headers(self) -> HeaderMap {
self.clone()
}
}
impl IntoHeaders for HashMap<String, String> {
fn into_headers(self) -> HeaderMap {
headers_from_string_map(self)
}
}
impl IntoHeaders for &HashMap<String, String> {
fn into_headers(self) -> HeaderMap {
headers_from_string_map(self.clone())
}
}
fn headers_from_string_map(headers: HashMap<String, String>) -> HeaderMap {
let mut map = HeaderMap::new();
for (name, value) in headers {
let header_name = match HeaderName::from_bytes(name.as_bytes()) {
Ok(header_name) => header_name,
Err(e) => {
warn!(header = %name, error = %e, "Ignoring invalid header name");
continue;
}
};
let header_value = match HeaderValue::from_str(&value) {
Ok(header_value) => header_value,
Err(e) => {
warn!(header = %header_name, error = %e, "Ignoring invalid header value");
continue;
}
};
map.append(header_name, header_value);
}
map
}
pub trait ToHashMap {
fn to_hashmap(&self) -> HashMap<String, String>;
}
impl ToHashMap for HeaderMap {
fn to_hashmap(&self) -> HashMap<String, String> {
headers_to_string_map(self)
}
}
fn headers_to_string_map(headers: &HeaderMap) -> HashMap<String, String> {
let mut map = HashMap::new();
for (name, value) in headers {
if let Ok(value) = value.to_str() {
map.insert(name.as_str().to_string(), value.to_string());
}
}
map
}
#[derive(Clone)]
pub struct ResultWriter {
task_id: String,
queue: String,
broker: Arc<dyn Broker>,
}
impl std::fmt::Debug for ResultWriter {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ResultWriter")
.field("task_id", &self.task_id)
.field("queue", &self.queue)
.field("broker", &"<Broker>")
.finish()
}
}
impl ResultWriter {
pub fn new(task_id: String, queue: String, broker: Arc<dyn Broker>) -> Self {
Self {
task_id,
queue,
broker,
}
}
pub async fn write(&self, data: &[u8]) -> Result<usize> {
self
.broker
.write_result(&self.queue, &self.task_id, data)
.await?;
Ok(data.len())
}
pub fn task_id(&self) -> &str {
&self.task_id
}
}
#[derive(Clone)]
pub struct Task {
pub task_type: String,
pub payload: Vec<u8>,
headers: HeaderMap,
pub options: TaskOptions,
result_writer: Option<Arc<ResultWriter>>,
inspector: Option<Arc<dyn InspectorTrait>>,
content_type: Option<HeaderValue>,
}
impl Debug for Task {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Task")
.field("task_type", &self.task_type)
.field("payload", &self.payload)
.field("headers", &self.headers)
.field("options", &self.options)
.field("content_type", &self.content_type)
.finish()
}
}
impl Task {
pub fn new<T: AsRef<str>>(task_type: T, payload: &[u8]) -> Result<Self> {
let task_type = task_type.as_ref();
if task_type.trim().is_empty() {
return Err(Error::InvalidTaskType {
task_type: task_type.to_string(),
});
}
Ok(Self {
task_type: task_type.to_string(),
payload: payload.to_vec(),
headers: HeaderMap::new(),
options: TaskOptions::default(),
result_writer: None,
inspector: None,
content_type: None,
})
}
pub fn new_with_headers<T: AsRef<str>, H: IntoHeaders>(
task_type: T,
payload: &[u8],
headers: H,
) -> Result<Self> {
Ok(Self::new(task_type, payload)?.with_headers(headers))
}
pub fn new_with_codec<T: AsRef<str>, P: Serialize, C: PayloadCodec>(
task_type: T,
payload: &P,
codec: &C,
) -> Result<Self> {
let encoded = codec.encode(payload)?;
let mut task = Self::new(task_type, &encoded)?;
task.content_type = Some(
HeaderValue::from_str(codec.content_type())
.map_err(|e| Error::Serialization(e.to_string()))?,
);
task.apply_protected_headers();
Ok(task)
}
#[cfg(feature = "json")]
pub fn new_with_json<T: AsRef<str>, P: Serialize>(task_type: T, payload: &P) -> Result<Self> {
Self::new_with_codec(task_type, payload, &JsonPayloadCodec)
}
#[cfg(feature = "msgpack")]
pub fn new_with_msgpack<T: AsRef<str>, P: Serialize>(task_type: T, payload: &P) -> Result<Self> {
Self::new_with_codec(task_type, payload, &MsgPackPayloadCodec)
}
pub fn with_headers<H: IntoHeaders>(mut self, headers: H) -> Self {
self.headers = headers.into_headers();
self.apply_protected_headers();
self
}
pub fn with_options(mut self, options: TaskOptions) -> Self {
self.options = options;
self
}
pub fn with_queue<T: AsRef<str>>(mut self, queue: T) -> Self {
self.options.queue = queue.as_ref().to_string();
self
}
pub fn with_max_retry(mut self, max_retry: i32) -> Self {
self.options.max_retry = max_retry.max(0);
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.options.timeout = Some(timeout);
self
}
pub fn with_deadline(mut self, deadline: DateTime<Utc>) -> Self {
self.options.deadline = Some(deadline);
self
}
pub fn with_unique_ttl(mut self, ttl: Duration) -> Self {
self.options.unique_ttl = Some(ttl);
self
}
pub fn with_group<T: AsRef<str>>(mut self, group: T) -> Self {
self.options.group = Some(group.as_ref().to_string());
self
}
pub fn with_retry_policy(mut self, policy: RetryPolicy) -> Self {
self.options.retry_policy = Some(policy);
self
}
pub fn with_rate_limit(mut self, rate_limit: RateLimit) -> Self {
self.options.rate_limit = Some(rate_limit);
self
}
pub fn with_task_id<T: AsRef<str>>(mut self, id: T) -> Self {
self.options.task_id = Some(id.as_ref().to_string());
self
}
pub fn with_process_at(mut self, when: DateTime<Utc>) -> Self {
self.options.process_at = Some(when);
self
}
pub fn with_process_in(mut self, delay: Duration) -> Self {
self.options.process_in = Some(delay);
self
}
pub fn with_retention(mut self, retention: Duration) -> Self {
self.options.retention = Some(retention);
self
}
pub fn with_group_grace_period(mut self, grace: Duration) -> Self {
self.options.group_grace_period = Some(grace);
self
}
pub fn get_type(&self) -> &str {
&self.task_type
}
pub fn get_queue(&self) -> &str {
&self.options.queue
}
pub fn get_payload(&self) -> &[u8] {
&self.payload
}
pub fn get_headers(&self) -> &HeaderMap {
&self.headers
}
pub fn headers(&self) -> &HeaderMap {
&self.headers
}
pub(crate) fn resolved_headers(&self) -> HashMap<String, String> {
self.resolved_header_map().to_hashmap()
}
pub(crate) fn resolved_header_map(&self) -> HeaderMap {
let mut headers = self.headers.clone();
self.apply_protected_headers_to(&mut headers);
headers
}
pub fn result_writer(&self) -> Option<&Arc<ResultWriter>> {
self.result_writer.as_ref()
}
pub fn inspector(&self) -> Option<&Arc<dyn InspectorTrait>> {
self.inspector.as_ref()
}
pub(crate) fn with_result_writer(mut self, writer: Arc<ResultWriter>) -> Self {
self.result_writer = Some(writer);
self
}
pub(crate) fn with_inspector(mut self, inspector: Arc<dyn InspectorTrait>) -> Self {
self.inspector = Some(inspector);
self
}
#[cfg(feature = "json")]
pub fn get_payload_with_json<T: DeserializeOwned>(&self) -> Result<T> {
JsonPayloadCodec.decode(&self.payload)
}
#[cfg(feature = "msgpack")]
pub fn get_payload_with_msgpack<T: DeserializeOwned>(&self) -> Result<T> {
MsgPackPayloadCodec.decode(&self.payload)
}
pub fn get_payload_with_codec<T: DeserializeOwned, C: PayloadCodec>(
&self,
codec: &C,
) -> Result<T> {
codec.decode(&self.payload)
}
fn apply_protected_headers(&mut self) {
if let Some(content_type) = &self.content_type {
self.headers.insert(CONTENT_TYPE, content_type.clone());
}
}
fn apply_protected_headers_to(&self, headers: &mut HeaderMap) {
if let Some(content_type) = &self.content_type {
headers.insert(CONTENT_TYPE, content_type.clone());
}
}
}
impl PartialEq for Task {
fn eq(&self, other: &Self) -> bool {
self.task_type == other.task_type
&& self.payload == other.payload
&& self.headers == other.headers
&& self.options == other.options
&& self.content_type == other.content_type
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct TaskInfo {
pub id: String,
pub queue: String,
pub task_type: String,
pub payload: Vec<u8>,
pub headers: HeaderMap,
pub state: TaskState,
pub max_retry: i32,
pub retried: i32,
pub last_err: Option<String>,
pub last_failed_at: Option<DateTime<Utc>>,
pub timeout: Option<Duration>,
pub deadline: Option<DateTime<Utc>>,
pub group: Option<String>,
pub next_process_at: Option<DateTime<Utc>>,
pub is_orphaned: bool,
pub retention: Option<Duration>,
pub completed_at: Option<DateTime<Utc>>,
pub result: Option<Vec<u8>>,
}
impl TaskInfo {
pub fn from_proto(
msg: &proto::TaskMessage,
state: TaskState,
next_process_at: Option<DateTime<Utc>>,
result: Option<Vec<u8>>,
) -> Self {
Self {
id: msg.id.clone(),
queue: msg.queue.clone(),
task_type: msg.r#type.clone(),
payload: msg.payload.clone(),
headers: msg.headers.clone().into_headers(),
state,
max_retry: msg.retry,
retried: msg.retried,
last_err: if msg.error_msg.is_empty() {
None
} else {
Some(msg.error_msg.clone())
},
last_failed_at: if msg.last_failed_at == 0 {
None
} else {
Some(DateTime::from_timestamp(msg.last_failed_at, 0).unwrap_or_default())
},
timeout: if msg.timeout == 0 {
None
} else {
Some(Duration::from_secs(msg.timeout as u64))
},
deadline: if msg.deadline == 0 {
None
} else {
Some(DateTime::from_timestamp(msg.deadline, 0).unwrap_or_default())
},
group: if msg.group_key.is_empty() {
None
} else {
Some(msg.group_key.clone())
},
next_process_at, is_orphaned: false, retention: if msg.retention == 0 {
None
} else {
Some(Duration::from_secs(msg.retention as u64))
},
completed_at: if msg.completed_at == 0 {
None
} else {
Some(DateTime::from_timestamp(msg.completed_at, 0).unwrap_or_default())
},
result, }
}
pub fn to_proto(&self) -> proto::TaskMessage {
proto::TaskMessage {
id: self.id.clone(),
r#type: self.task_type.clone(),
payload: self.payload.clone(),
queue: self.queue.clone(),
retry: self.max_retry,
retried: self.retried,
error_msg: self.last_err.clone().unwrap_or_default(),
last_failed_at: self.last_failed_at.map(|dt| dt.timestamp()).unwrap_or(0),
timeout: self.timeout.map(|d| d.as_secs() as i64).unwrap_or(0),
deadline: self.deadline.map(|dt| dt.timestamp()).unwrap_or(0),
unique_key: String::new(), group_key: self.group.clone().unwrap_or_default(),
retention: self.retention.map(|d| d.as_secs() as i64).unwrap_or(0),
completed_at: self.completed_at.map(|dt| dt.timestamp()).unwrap_or(0),
headers: self.headers.to_hashmap(),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct QueueStats {
pub name: String,
pub active: i64,
pub pending: i64,
pub scheduled: i64,
pub retry: i64,
pub archived: i64,
pub completed: i64,
pub aggregating: i64,
pub daily_stats: Vec<DailyStats>,
}
impl QueueStats {
#[allow(clippy::too_many_arguments)]
pub fn new<N: Into<String>>(
name: N,
active: i64,
pending: i64,
scheduled: i64,
retry: i64,
archived: i64,
completed: i64,
aggregating: i64,
daily_stats: Vec<DailyStats>,
) -> Self {
let mut n = name.into();
if let Some(idx) = n.find("}:") {
n = n[..idx].to_string();
}
if let Some(start) = n.find('{') {
if let Some(rel_end) = n[start + 1..].find('}') {
let end = start + 1 + rel_end;
n = n[start + 1..end].to_string();
}
}
if n.ends_with('}') {
n = n.trim_end_matches('}').to_string();
}
QueueStats {
name: n,
active,
pending,
scheduled,
retry,
archived,
completed,
aggregating,
daily_stats,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct DailyStats {
pub queue: String,
pub processed: i64,
pub failed: i64,
pub date: DateTime<Utc>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct QueueInfo {
pub queue: String,
pub memory_usage: i64,
pub latency: Duration,
pub size: i32,
pub groups: i32,
pub pending: i32,
pub active: i32,
pub scheduled: i32,
pub retry: i32,
pub archived: i32,
pub completed: i32,
pub aggregating: i32,
pub processed: i32,
pub failed: i32,
pub processed_total: i32,
pub failed_total: i32,
pub paused: bool,
pub timestamp: DateTime<Utc>,
}
pub fn generate_unique_key(queue: &str, task_type: &str, payload: &[u8]) -> String {
crate::base::keys::unique_key(queue, task_type, payload)
}
pub fn generate_task_id() -> String {
Uuid::new_v4().to_string()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::option::RetryPolicy;
use crate::base::constants::DEFAULT_QUEUE_NAME;
#[test]
fn test_task_creation() {
let task = Task::new("test_task", b"test payload").unwrap();
assert_eq!(task.task_type, "test_task");
assert_eq!(task.payload, b"test payload");
assert_eq!(task.options.queue, DEFAULT_QUEUE_NAME);
}
#[test]
fn test_task_with_options() {
let task = Task::new("test_task", b"test payload")
.unwrap()
.with_queue("custom_queue")
.with_max_retry(10);
assert_eq!(task.options.queue, "custom_queue");
assert_eq!(task.options.max_retry, 10);
}
#[test]
fn test_with_headers_replaces_headers() {
let mut headers = HeaderMap::new();
headers.insert("x-trace-id", HeaderValue::from_static("trace-1"));
let task = Task::new("test_task", b"test payload")
.unwrap()
.with_headers(headers.clone());
assert_eq!(task.headers(), &headers);
}
#[cfg(feature = "json")]
#[test]
fn test_task_json_payload() {
#[derive(Serialize, Deserialize, Debug, PartialEq)]
struct TestPayload {
message: String,
count: i32,
}
let payload = TestPayload {
message: "test".to_string(),
count: 42,
};
let task = Task::new_with_json("test_task", &payload).unwrap();
let decoded: TestPayload = task.get_payload_with_json().unwrap();
assert_eq!(decoded, payload);
}
#[derive(Debug, Clone, Copy)]
struct PrefixJsonCodec;
impl PayloadCodec for PrefixJsonCodec {
fn content_type(&self) -> &'static str {
"application/x-prefix-json"
}
fn encode<T: Serialize>(&self, value: &T) -> Result<Vec<u8>> {
let mut encoded = b"PREFIX:".to_vec();
encoded.extend(serde_json::to_vec(value).map_err(|e| Error::Serialization(e.to_string()))?);
Ok(encoded)
}
fn decode<T: DeserializeOwned>(&self, bytes: &[u8]) -> Result<T> {
let body = bytes
.strip_prefix(b"PREFIX:")
.ok_or_else(|| Error::Deserialization("missing PREFIX header".to_string()))?;
serde_json::from_slice(body).map_err(|e| Error::Deserialization(e.to_string()))
}
}
#[test]
fn test_task_custom_codec_round_trip() {
#[derive(Serialize, Deserialize, Debug, PartialEq)]
struct TestPayload {
name: String,
count: i32,
}
let payload = TestPayload {
name: "custom".to_string(),
count: 7,
};
let task = Task::new_with_codec("custom_task", &payload, &PrefixJsonCodec).unwrap();
assert_eq!(
task.headers().get(CONTENT_TYPE),
Some(&HeaderValue::from_static("application/x-prefix-json"))
);
let decoded: TestPayload = task.get_payload_with_codec(&PrefixJsonCodec).unwrap();
assert_eq!(decoded, payload);
}
#[cfg(feature = "json")]
#[test]
fn test_json_content_type_cannot_be_overridden_by_with_headers() {
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/plain"));
headers.insert("x-custom", HeaderValue::from_static("v1"));
let task = Task::new_with_json("test_task", &serde_json::json!({"ok": true}))
.unwrap()
.with_headers(headers);
assert_eq!(
task.headers().get(CONTENT_TYPE),
Some(&HeaderValue::from_static("application/json"))
);
assert_eq!(
task.headers().get("x-custom"),
Some(&HeaderValue::from_static("v1"))
);
}
#[cfg(feature = "json")]
#[test]
fn test_resolved_headers_enforces_json_content_type() {
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("text/plain"));
let task = Task::new_with_json("test_task", &serde_json::json!({"ok": true}))
.unwrap()
.with_headers(headers);
let resolved = task.resolved_header_map();
assert_eq!(
resolved.get(CONTENT_TYPE),
Some(&HeaderValue::from_static("application/json"))
);
}
#[test]
fn test_task_state_conversion() {
assert_eq!("active".parse::<TaskState>(), Ok(TaskState::Active));
assert_eq!("pending".parse::<TaskState>(), Ok(TaskState::Pending));
assert!("invalid".parse::<TaskState>().is_err());
assert_eq!(TaskState::Active.as_str(), "active");
assert_eq!(TaskState::Pending.as_str(), "pending");
}
#[test]
fn test_unique_key_generation() {
let key1 = generate_unique_key("queue1", "task_type", b"payload");
let key2 = generate_unique_key("queue1", "task_type", b"payload");
let key3 = generate_unique_key("queue2", "task_type", b"payload");
assert_eq!(key1, key2);
assert_ne!(key1, key3);
}
#[test]
fn test_task_id_generation() {
let id1 = generate_task_id();
let id2 = generate_task_id();
assert_ne!(id1, id2);
assert!(Uuid::parse_str(&id1).is_ok());
assert!(Uuid::parse_str(&id2).is_ok());
}
#[test]
fn test_retry_policy_fixed() {
let policy = RetryPolicy::Fixed(Duration::from_secs(30));
assert_eq!(policy.calculate_delay(0), Duration::from_secs(30));
assert_eq!(policy.calculate_delay(5), Duration::from_secs(30));
}
#[test]
fn test_retry_policy_exponential() {
let policy = RetryPolicy::Exponential {
base_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(300),
multiplier: 2.0,
jitter: false,
};
assert_eq!(policy.calculate_delay(0), Duration::from_secs(1));
assert_eq!(policy.calculate_delay(1), Duration::from_secs(2));
assert_eq!(policy.calculate_delay(2), Duration::from_secs(4));
let delay = policy.calculate_delay(10);
assert_eq!(delay, Duration::from_secs(300));
}
#[test]
fn test_retry_policy_linear() {
let policy = RetryPolicy::Linear {
base_delay: Duration::from_secs(10),
max_delay: Duration::from_secs(100),
step: Duration::from_secs(5),
};
assert_eq!(policy.calculate_delay(0), Duration::from_secs(10));
assert_eq!(policy.calculate_delay(1), Duration::from_secs(15));
assert_eq!(policy.calculate_delay(2), Duration::from_secs(20));
let delay = policy.calculate_delay(100);
assert_eq!(delay, Duration::from_secs(100));
}
#[test]
fn test_rate_limit_key_generation() {
let rate_limit = RateLimit::per_task_type(Duration::from_secs(60), 10);
let key = rate_limit.generate_key("email:send", "high_priority");
assert_eq!(key, "asynq:ratelimit:task:email:send");
let rate_limit = RateLimit::per_queue(Duration::from_secs(60), 10);
let key = rate_limit.generate_key("email:send", "high_priority");
assert_eq!(key, "asynq:ratelimit:queue:high_priority");
let rate_limit = RateLimit::custom("custom_key", Duration::from_secs(60), 10);
let key = rate_limit.generate_key("email:send", "high_priority");
assert_eq!(key, "asynq:ratelimit:custom:custom_key");
}
#[test]
fn test_task_with_retry_policy() {
let retry_policy = RetryPolicy::default_exponential();
let task = Task::new("test:task", b"payload")
.unwrap()
.with_retry_policy(retry_policy.clone());
assert_eq!(task.options.retry_policy, Some(retry_policy));
}
#[test]
fn test_task_with_rate_limit() {
let rate_limit = RateLimit::per_task_type(Duration::from_secs(60), 100);
let task = Task::new("test:task", b"payload")
.unwrap()
.with_rate_limit(rate_limit.clone());
assert_eq!(task.options.rate_limit, Some(rate_limit));
}
#[test]
fn test_task_result_writer_none_on_new_task() {
let task = Task::new("test:task", b"payload").unwrap();
assert!(task.result_writer().is_none());
}
#[tokio::test]
async fn test_result_writer_functionality() {
use crate::backend::{RedisBroker, RedisConnectionType};
let redis_url =
std::env::var("REDIS_URL").unwrap_or_else(|_| "redis://localhost:6379".to_string());
let redis_config = match RedisConnectionType::single(redis_url) {
Ok(config) => config,
Err(_) => {
println!("Skipping test: Redis not available");
return;
}
};
let broker = match RedisBroker::new(redis_config).await {
Ok(broker) => Arc::new(broker),
Err(_) => {
println!("Skipping test: Could not connect to Redis");
return;
}
};
let task_id = generate_task_id();
let queue = "test_queue";
let result_writer = ResultWriter::new(task_id.clone(), queue.to_string(), broker.clone());
assert_eq!(result_writer.task_id(), task_id);
let result_data = b"test result data";
let bytes_written = result_writer.write(result_data).await.unwrap();
assert_eq!(bytes_written, result_data.len());
}
#[test]
fn test_task_with_result_writer() {
let task = Task::new("test:task", b"payload").unwrap();
assert!(task.result_writer().is_none());
}
}