use crate::base::Broker;
use crate::error::Error;
use crate::inspector::InspectorTrait;
use crate::server::Handler;
use crate::task::Task;
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::sync::{mpsc, Semaphore};
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
pub struct ProcessorParams {
pub broker: Arc<dyn Broker>,
pub inspector: Arc<dyn InspectorTrait>,
pub queues: HashMap<String, i32>,
pub concurrency: usize,
pub strict_priority: bool,
pub task_check_interval: Duration,
pub shutdown_timeout: Duration,
pub worker_event_sender: Option<crate::components::heartbeat::WorkerEventSender>,
}
#[derive(Clone)]
pub struct CancellationMap {
tasks: Arc<Mutex<HashMap<String, CancellationToken>>>,
}
impl CancellationMap {
pub fn new() -> Self {
Self {
tasks: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn add(&self, task_id: String, token: CancellationToken) {
if let Ok(mut tasks) = self.tasks.lock() {
tasks.insert(task_id, token);
};
}
pub fn remove(&self, task_id: &str) {
if let Ok(mut tasks) = self.tasks.lock() {
tasks.remove(task_id);
}
}
pub fn cancel(&self, task_id: &str) -> bool {
tracing::info!("canceling task {}", task_id);
if let Ok(tasks) = self.tasks.lock() {
if let Some(token) = tasks.get(task_id) {
token.cancel();
true
} else {
false
}
} else {
false
}
}
pub fn len(&self) -> usize {
if let Ok(tasks) = self.tasks.lock() {
tasks.len()
} else {
0
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl Default for CancellationMap {
fn default() -> Self {
Self::new()
}
}
pub struct Processor {
broker: Arc<dyn Broker>,
inspector: Arc<dyn InspectorTrait>,
queue_config: HashMap<String, i32>,
ordered_queues: Option<Vec<String>>,
task_check_interval: Duration,
shutdown_timeout: Duration,
sema: Arc<Semaphore>,
running: Arc<AtomicBool>,
quit_tx: Option<mpsc::Sender<()>>,
quit_rx: Option<mpsc::Receiver<()>>,
abort_tx: Option<mpsc::Sender<()>>,
handle: Option<JoinHandle<()>>,
active_workers: Arc<AtomicUsize>,
cancellations: CancellationMap,
worker_event_sender: Option<crate::components::heartbeat::WorkerEventSender>,
}
impl Processor {
pub fn new(params: ProcessorParams) -> Self {
let queues = normalize_queues(params.queues);
let ordered_queues = if params.strict_priority {
Some(sort_by_priority(&queues))
} else {
None
};
let (quit_tx, quit_rx) = mpsc::channel(1);
let (abort_tx, _abort_rx) = mpsc::channel(1);
Self {
broker: params.broker,
inspector: params.inspector,
queue_config: queues,
ordered_queues,
task_check_interval: params.task_check_interval,
shutdown_timeout: params.shutdown_timeout,
sema: Arc::new(Semaphore::new(params.concurrency)),
running: Arc::new(AtomicBool::new(false)),
quit_tx: Some(quit_tx),
quit_rx: Some(quit_rx),
abort_tx: Some(abort_tx),
handle: None,
active_workers: Arc::new(AtomicUsize::new(0)),
cancellations: CancellationMap::new(),
worker_event_sender: params.worker_event_sender,
}
}
pub fn cancellations(&self) -> CancellationMap {
self.cancellations.clone()
}
pub fn start<H>(&mut self, handler: Arc<H>)
where
H: Handler + 'static,
{
self.running.store(true, Ordering::SeqCst);
let broker = Arc::clone(&self.broker);
let inspector = Arc::clone(&self.inspector);
let running = Arc::clone(&self.running);
let sema = Arc::clone(&self.sema);
let queue_config = self.queue_config.clone();
let ordered_queues = self.ordered_queues.clone();
let task_check_interval = self.task_check_interval;
let active_workers = Arc::clone(&self.active_workers);
let cancelations = self.cancellations.clone();
let worker_event_sender = self.worker_event_sender.clone();
if let Some(mut quit_rx) = self.quit_rx.take() {
let handle = tokio::spawn(async move {
loop {
if quit_rx.try_recv().is_ok() {
tracing::debug!("Processor received quit signal");
break;
}
if !running.load(Ordering::SeqCst) {
break;
}
let permit = match sema.clone().try_acquire_owned() {
Ok(permit) => permit,
Err(_) => {
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
}
};
let queues = get_queues(&queue_config, ordered_queues.as_ref());
match broker.dequeue(&queues).await {
Ok(Some(task_msg)) => {
active_workers.fetch_add(1, Ordering::Relaxed);
let handler = Arc::clone(&handler);
let broker = Arc::clone(&broker);
let inspector = Arc::clone(&inspector);
let active_workers = Arc::clone(&active_workers);
let cancelations = cancelations.clone();
let worker_event_sender = worker_event_sender.clone();
tokio::spawn(async move {
let _permit = permit;
let task_id = task_msg.id.clone();
if let Some(ref sender) = worker_event_sender {
let worker_info = crate::components::heartbeat::WorkerInfoEntry {
msg: task_msg.clone(),
started: std::time::SystemTime::now(),
deadline: if task_msg.deadline > 0 {
std::time::UNIX_EPOCH
+ std::time::Duration::from_secs(task_msg.deadline as u64)
} else {
std::time::SystemTime::now() + std::time::Duration::from_secs(3600)
},
};
if let Err(e) = sender.send_started(worker_info).await {
tracing::warn!("Failed to send worker starting event: {}", e);
}
}
let mut task = match Task::new_with_headers(
&task_msg.r#type,
&task_msg.payload,
task_msg.headers.clone(),
) {
Ok(task) => task,
Err(e) => {
tracing::error!("Failed to create task: {}", e);
if let Some(ref sender) = worker_event_sender {
if let Err(e) = sender.send_finished(task_id.clone()).await {
tracing::warn!("Failed to send worker finished event: {}", e);
}
}
active_workers.fetch_sub(1, Ordering::Relaxed);
return;
}
};
let result_writer = Arc::new(crate::task::ResultWriter::new(
task_msg.id.clone(),
task_msg.queue.clone(),
broker.clone(),
));
task = task.with_result_writer(result_writer);
task = task.with_inspector(inspector);
let cancel_token = CancellationToken::new();
cancelations.add(task_id.clone(), cancel_token.clone());
let timeout_duration = calculate_task_timeout(&task_msg);
let result = if let Some(timeout) = timeout_duration {
tokio::select! {
result = handler.process_task(task.clone()) => result,
_ = tokio::time::sleep(timeout) => {
tracing::warn!("Task {} timed out after {:?}", task_id, timeout);
Err(Error::other("Task execution timeout"))
}
_ = cancel_token.cancelled() => {
tracing::info!("Task {} was cancelled", task_id);
Ok(())
}
}
} else {
tokio::select! {
result = handler.process_task(task.clone()) => result,
_ = cancel_token.cancelled() => {
tracing::info!("Task {} was cancelled", task_id);
Err(Error::other("Task cancelled"))
}
}
};
cancelations.remove(&task_id);
match result {
Ok(()) => {
if task_msg.retention == 0 {
if let Err(e) = broker.done(&task_msg).await {
tracing::error!("Failed to mark task as done: {}", e);
}
} else if let Err(e) = broker.mark_as_complete(&task_msg).await {
tracing::error!("Failed to mark task as complete: {}", e);
}
}
Err(e) => {
if should_retry_task(&task_msg, &e) {
let retry_delay =
calculate_retry_delay(task_msg.retried, task.options.retry_policy.as_ref());
let retry_at = chrono::Utc::now()
+ chrono::Duration::seconds(retry_delay.as_secs() as i64);
if let Err(e) = broker.requeue(&task_msg, retry_at, &e.to_string()).await {
tracing::error!("Failed to requeue task: {}", e);
}
} else {
if let Err(e) = broker.archive(&task_msg, &e.to_string()).await {
tracing::error!("Failed to archive task: {}", e);
}
}
}
}
if let Some(ref sender) = worker_event_sender {
if let Err(e) = sender.send_finished(task_id.clone()).await {
tracing::warn!("Failed to send worker finished event: {}", e);
}
}
active_workers.fetch_sub(1, Ordering::Relaxed);
});
}
Ok(None) => {
drop(permit); tokio::time::sleep(task_check_interval).await;
}
Err(e) => {
tracing::error!("Dequeue error: {}", e);
drop(permit); tokio::time::sleep(Duration::from_secs(1)).await;
}
}
}
tracing::debug!("Processor loop exited");
});
self.handle = Some(handle);
} else {
self.handle = None;
};
}
pub fn stop(&mut self) {
self.running.store(false, Ordering::SeqCst);
if let Some(tx) = self.quit_tx.take() {
let _ = tx.try_send(());
}
}
pub async fn shutdown(&mut self) {
self.stop();
let abort_tx = self.abort_tx.clone();
let shutdown_timeout = self.shutdown_timeout;
tokio::spawn(async move {
tokio::time::sleep(shutdown_timeout).await;
if let Some(tx) = abort_tx {
let _ = tx.send(()).await;
}
});
if let Some(handle) = self.handle.take() {
let _ = handle.await;
}
tracing::info!("Waiting for all workers to finish...");
let sema = Arc::clone(&self.sema);
let concurrency = sema.available_permits();
for _ in 0..concurrency {
let _ = sema.acquire().await;
}
tracing::info!("All workers have finished");
}
}
fn normalize_queues(queues: HashMap<String, i32>) -> HashMap<String, i32> {
queues
.into_iter()
.map(|(name, priority)| (name, priority.max(1)))
.collect()
}
fn sort_by_priority(queues: &HashMap<String, i32>) -> Vec<String> {
let mut queue_vec: Vec<_> = queues.iter().collect();
queue_vec.sort_by(|a, b| b.1.cmp(a.1)); queue_vec
.into_iter()
.map(|(name, _)| name.clone())
.collect()
}
fn get_queues(
queue_config: &HashMap<String, i32>,
ordered_queues: Option<&Vec<String>>,
) -> Vec<String> {
if queue_config.len() == 1 {
return queue_config.keys().cloned().collect();
}
if let Some(ordered) = ordered_queues {
return ordered.clone();
}
let mut names = Vec::new();
for (name, &priority) in queue_config {
for _ in 0..priority {
names.push(name.clone());
}
}
use rand::seq::SliceRandom;
let mut rng = rand::rng();
names.shuffle(&mut rng);
let mut seen = std::collections::HashSet::new();
let mut result = Vec::new();
for name in names {
if seen.insert(name.clone()) {
result.push(name);
}
if result.len() == queue_config.len() {
break;
}
}
result
}
fn calculate_task_timeout(task_msg: &crate::proto::TaskMessage) -> Option<Duration> {
use crate::base::constants::DEFAULT_TIMEOUT;
if task_msg.timeout > 0 {
return Some(Duration::from_secs(task_msg.timeout as u64));
}
if task_msg.deadline > 0 {
let now = chrono::Utc::now().timestamp();
let remaining = task_msg.deadline - now;
if remaining > 0 {
return Some(Duration::from_secs(remaining as u64));
}
}
Some(DEFAULT_TIMEOUT)
}
fn calculate_retry_delay(
retried: i32,
retry_policy: Option<&crate::backend::option::RetryPolicy>,
) -> Duration {
match retry_policy {
Some(policy) => policy.calculate_delay(retried),
None => {
let base_delay = (retried as f64).powf(4.0) as u64 + 15;
let jitter = rand::random::<u64>() % (30 * (retried as u64 + 1));
Duration::from_secs(base_delay + jitter)
}
}
}
fn should_retry_task(task_msg: &crate::proto::TaskMessage, err: &Error) -> bool {
task_msg.retried < task_msg.retry && !matches!(err, Error::SkipRetry(_))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::SkipRetryError;
#[test]
fn test_normalize_queues() {
let mut queues = HashMap::new();
queues.insert("default".to_string(), 3);
queues.insert("low".to_string(), 0);
queues.insert("high".to_string(), -5);
let normalized = normalize_queues(queues);
assert_eq!(normalized.get("default"), Some(&3));
assert_eq!(normalized.get("low"), Some(&1));
assert_eq!(normalized.get("high"), Some(&1));
}
#[test]
fn test_sort_by_priority() {
let mut queues = HashMap::new();
queues.insert("default".to_string(), 3);
queues.insert("low".to_string(), 1);
queues.insert("high".to_string(), 6);
let sorted = sort_by_priority(&queues);
assert_eq!(sorted[0], "high");
assert_eq!(sorted[1], "default");
assert_eq!(sorted[2], "low");
}
#[test]
fn test_get_queues_single() {
let mut queues = HashMap::new();
queues.insert("default".to_string(), 3);
let result = get_queues(&queues, None);
assert_eq!(result.len(), 1);
assert_eq!(result[0], "default");
}
#[test]
fn test_get_queues_strict_priority() {
let mut queues = HashMap::new();
queues.insert("default".to_string(), 3);
queues.insert("low".to_string(), 1);
queues.insert("high".to_string(), 6);
let ordered = vec!["high".to_string(), "default".to_string(), "low".to_string()];
let result = get_queues(&queues, Some(&ordered));
assert_eq!(result, ordered);
}
#[test]
fn test_should_retry_task_normal_error() {
let task_msg = crate::proto::TaskMessage {
retried: 0,
retry: 5,
..Default::default()
};
assert!(should_retry_task(&task_msg, &Error::other("failed")));
}
#[test]
fn test_should_retry_task_skip_retry_error() {
let task_msg = crate::proto::TaskMessage {
retried: 0,
retry: 5,
..Default::default()
};
let skip_retry = Error::from(SkipRetryError::new(std::io::Error::other(
"invalid payload",
)));
assert!(!should_retry_task(&task_msg, &skip_retry));
}
#[test]
fn test_should_retry_task_reached_limit() {
let task_msg = crate::proto::TaskMessage {
retried: 3,
retry: 3,
..Default::default()
};
assert!(!should_retry_task(&task_msg, &Error::other("failed")));
}
#[test]
fn test_cancelations_new() {
let cancelations = CancellationMap::new();
assert!(cancelations.is_empty());
assert_eq!(cancelations.len(), 0);
}
#[test]
fn test_cancelations_add_remove() {
let cancelations = CancellationMap::new();
let token = CancellationToken::new();
cancelations.add("task1".to_string(), token.clone());
assert_eq!(cancelations.len(), 1);
assert!(!cancelations.is_empty());
cancelations.remove("task1");
assert_eq!(cancelations.len(), 0);
assert!(cancelations.is_empty());
}
#[tokio::test]
async fn test_cancelations_cancel() {
let cancellations = CancellationMap::new();
let token = CancellationToken::new();
let task_id = "task1".to_string();
cancellations.add(task_id.clone(), token.clone());
assert!(!token.is_cancelled());
let cancelled = cancellations.cancel(&task_id);
assert!(cancelled);
assert!(token.is_cancelled());
}
#[tokio::test]
async fn test_cancelations_cancel_nonexistent() {
let cancellations = CancellationMap::new();
let cancelled = cancellations.cancel("nonexistent");
assert!(!cancelled);
}
#[tokio::test]
async fn test_cancelations_multiple_tasks() {
let cancellations = CancellationMap::new();
let token1 = CancellationToken::new();
let token2 = CancellationToken::new();
let token3 = CancellationToken::new();
cancellations.add("task1".to_string(), token1.clone());
cancellations.add("task2".to_string(), token2.clone());
cancellations.add("task3".to_string(), token3.clone());
assert_eq!(cancellations.len(), 3);
cancellations.cancel("task2");
assert!(!token1.is_cancelled());
assert!(token2.is_cancelled());
assert!(!token3.is_cancelled());
cancellations.remove("task1");
assert_eq!(cancellations.len(), 2);
cancellations.cancel("task3");
assert!(token3.is_cancelled());
}
}