use crate::{
storage::{Keys, RedisClient, dependencies},
Error, Result, Task,
task::TaskStatus,
progress::{ProgressContext, ProgressConfig},
};
use crate::processor::{Mux, HandlerContext};
use crate::server::config::ServerState;
use crate::task::progress_ext::set_progress_context;
use chrono::Utc;
use fred::prelude::{RedisKey, RedisValue};
use rmp_serde;
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
pub struct Worker {
pub id: String,
state: Arc<ServerState>,
shutdown: Arc<AtomicBool>,
mux: Arc<Mutex<Mux>>,
queue_index: Arc<AtomicUsize>,
queues: Vec<Arc<String>>,
}
impl Worker {
pub fn new(
id: String,
state: Arc<ServerState>,
shutdown: Arc<AtomicBool>,
mux: Arc<Mutex<Mux>>,
) -> Self {
let queues: Vec<Arc<String>> = state.config.queues
.iter()
.map(|s| Arc::new(s.clone()))
.collect();
Self {
id,
state,
shutdown,
mux,
queue_index: Arc::new(AtomicUsize::new(0)),
queues,
}
}
pub async fn run(self) -> Result<()> {
tracing::info!("Worker {} starting", self.id);
self.register().await?;
let heartbeat = self.start_heartbeat();
let result = self.task_loop().await;
if let Err(e) = self.unregister().await {
tracing::error!("Failed to unregister worker: {}", e);
}
heartbeat.abort();
tracing::info!("Worker {} stopped", self.id);
result
}
async fn register(&self) -> Result<()> {
let metadata = WorkerMetadata {
id: self.id.clone(),
server_name: self.state.config.server_name.clone(),
queues: self.state.config.queues.clone(),
started_at: Utc::now().timestamp(),
last_heartbeat: Utc::now().timestamp(),
processed_total: 0,
status: "idle".to_string(),
};
let data = rmp_serde::to_vec(&metadata)
.map_err(|e| Error::Serialization(e.to_string()))?;
let worker_key: RedisKey = Keys::meta_worker(&self.id).into();
self.state.redis.set(worker_key, RedisValue::Bytes(data.into())).await?;
let workers_key: RedisKey = Keys::meta_workers().into();
self.state.redis.sadd(workers_key, self.id.as_str().into()).await?;
let queues_key: RedisKey = Keys::meta_queues().into();
for queue in &self.state.config.queues {
self.state.redis.sadd(queues_key.clone(), queue.as_str().into()).await?;
}
self.update_heartbeat().await?;
tracing::debug!("Worker {} registered", self.id);
Ok(())
}
async fn unregister(&self) -> Result<()> {
let workers_key: RedisKey = Keys::meta_workers().into();
self.state.redis.srem(workers_key, self.id.as_str().into()).await?;
let worker_key: RedisKey = Keys::meta_worker(&self.id).into();
self.state.redis.del(vec![worker_key]).await?;
let heartbeat_key: RedisKey = Keys::meta_heartbeat(&self.id).into();
self.state.redis.del(vec![heartbeat_key]).await?;
tracing::debug!("Worker {} unregistered", self.id);
Ok(())
}
fn start_heartbeat(&self) -> JoinHandle<()> {
let id = self.id.clone();
let redis = self.state.redis.clone();
let interval = Duration::from_secs(self.state.config.heartbeat_interval);
let worker_timeout = self.state.config.worker_timeout;
let ttl_multiplier = self.state.config.heartbeat_ttl_multiplier;
let shutdown = self.shutdown.clone();
tokio::spawn(async move {
let mut ticker = tokio::time::interval(interval);
while !shutdown.load(Ordering::Relaxed) {
ticker.tick().await;
if let Err(e) = Self::update_heartbeat_for(&id, &redis, worker_timeout, ttl_multiplier).await {
tracing::error!("Heartbeat update failed: {}", e);
}
}
})
}
async fn update_heartbeat(&self) -> Result<()> {
Self::update_heartbeat_for(
&self.id,
&self.state.redis,
self.state.config.worker_timeout,
self.state.config.heartbeat_ttl_multiplier,
).await
}
async fn update_heartbeat_for(worker_id: &str, redis: &RedisClient, worker_timeout: u64, ttl_multiplier: f64) -> Result<()> {
let heartbeat_key: RedisKey = Keys::meta_heartbeat(worker_id).into();
let now = Utc::now().timestamp();
redis.set(heartbeat_key.clone(), now.to_string().into()).await?;
let ttl = (worker_timeout as f64 * ttl_multiplier) as u64;
redis.expire(heartbeat_key, ttl).await?;
tracing::trace!("Heartbeat updated for worker {}, TTL: {}s (multiplier: {})", worker_id, ttl, ttl_multiplier);
Ok(())
}
async fn task_loop(&self) -> Result<()> {
while !self.shutdown.load(Ordering::Relaxed) {
let queue = self.next_queue();
match self.dequeue_task_any(&queue).await {
Ok(Some(task)) => {
let result = self.process_task(task).await;
match result {
Ok(_) => {
tracing::debug!("Task processed successfully");
}
Err(e) => {
tracing::error!("Task processing failed: {}", e);
}
}
}
Ok(None) => {
tokio::time::sleep(Duration::from_millis(self.state.config.poll_interval)).await;
}
Err(Error::QueuePaused(_)) => {
tokio::time::sleep(Duration::from_secs(5)).await;
}
Err(Error::Shutdown) => {
break;
}
Err(e) => {
let error_msg = e.to_string();
if error_msg.contains("Timeout") || error_msg.contains("timed out") {
tracing::debug!("Queue empty, waiting for tasks...");
} else {
tracing::warn!("Dequeue error: {}", e);
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
}
Ok(())
}
fn next_queue(&self) -> Arc<String> {
let index = self.queue_index.fetch_add(1, Ordering::Relaxed) % self.queues.len();
Arc::clone(&self.queues[index])
}
async fn dequeue_task(&self, queue: &str) -> Result<Option<Task>> {
let pause_key: RedisKey = Keys::pause(queue).into();
if self.state.redis.exists(pause_key).await? {
return Err(Error::QueuePaused(queue.to_string()));
}
let queue_key: RedisKey = Keys::queue(queue).into();
let timeout = self.state.config.dequeue_timeout;
match self.state.redis.blpop(queue_key, timeout).await? {
Some((_, task_id)) => {
let active_key: RedisKey = Keys::active(queue).into();
self.state.redis.lpush(active_key, task_id.as_str().into()).await?;
self.load_task(&task_id).await.map(Some)
}
None => Ok(None),
}
}
async fn dequeue_task_priority(&self, queue: &str) -> Result<Option<Task>> {
let pqueue_key: RedisKey = Keys::priority_queue(queue).into();
let active_key: RedisKey = Keys::active(queue).into();
let pause_key: RedisKey = Keys::pause(queue).into();
let dummy_key: RedisKey = "rediq:task:*".to_string().into();
let task_ttl = crate::config::get_task_ttl() as usize;
match self.state.redis.pdequeue_lua(
pqueue_key,
active_key,
pause_key,
dummy_key,
task_ttl,
).await? {
task_id if !task_id.is_empty() => {
self.load_task(&task_id).await.map(Some)
}
_ => Ok(None), }
}
async fn dequeue_task_any(&self, queue: &str) -> Result<Option<Task>> {
tracing::debug!("Attempting to dequeue from priority queue {}", queue);
match self.dequeue_task_priority(queue).await {
Ok(Some(task)) => {
tracing::debug!("Successfully dequeued from priority queue {}", queue);
return Ok(Some(task));
}
Ok(None) => {
tracing::debug!("Priority queue {} is empty, trying regular queue", queue);
}
Err(Error::QueuePaused(_)) => {
return Err(Error::QueuePaused(queue.to_string()));
}
Err(e) => {
tracing::warn!("Priority queue dequeue failed: {}, trying regular queue", e);
}
}
self.dequeue_task(queue).await
}
async fn load_task(&self, task_id: &str) -> Result<Task> {
let task_key: RedisKey = Keys::task(task_id).into();
let data = self.state.redis.hget(task_key.clone(), "data".into()).await?
.ok_or_else(|| Error::TaskNotFound(task_id.to_string()))?;
let bytes = data.as_bytes()
.ok_or_else(|| Error::Serialization("Task data is not bytes".into()))?;
let mut task: Task = rmp_serde::from_slice(bytes)
.map_err(|e| Error::Serialization(e.to_string()))?;
task.status = TaskStatus::Active;
task.processed_at = Some(Utc::now().timestamp());
let current_timestamp = Utc::now().timestamp();
self.state.redis.hset(
task_key.clone(),
vec![
("status".into(), current_timestamp.to_string().into()),
],
).await?;
Ok(task)
}
async fn process_task(&self, mut task: Task) -> Result<()> {
tracing::debug!("Processing task: {}", task.description());
let progress_config = ProgressConfig::default();
let progress_ctx = ProgressContext::new(
task.id.clone(),
self.state.redis.clone(),
progress_config,
);
set_progress_context(Some(progress_ctx.clone()));
let _ = progress_ctx.report(0).await;
let cancelled = Arc::new(AtomicBool::new(false));
let handler_ctx = HandlerContext::new(
task.id.clone(),
self.state.redis.clone(),
Some(progress_ctx.clone()),
cancelled.clone(),
);
if !self.state.middleware.is_empty() {
self.state.middleware.before(&task).await?;
}
let mux = self.mux.lock().await;
let handler = mux.process_with_context(&task, &handler_ctx);
let result = tokio::time::timeout(task.options.timeout, handler).await;
let process_result = match result {
Ok(r) => r,
Err(_) => {
task.last_error = Some("Task timed out".to_string());
Err(Error::Timeout(format!("Task {} timed out after {:?}", task.id, task.options.timeout)))
}
};
drop(mux);
if !self.state.middleware.is_empty() {
let _ = self.state.middleware.after(&task, &process_result).await;
}
set_progress_context(None);
match &process_result {
Ok(_) => {
let _ = progress_ctx.report(100).await;
self.ack_task(&task, TaskStatus::Processed, None).await?;
}
Err(e) => {
task.last_error = Some(e.to_string());
task.retry_cnt += 1;
if task.can_retry() {
self.ack_task(&task, TaskStatus::Retry, Some(e)).await?;
self.schedule_retry(&task).await?;
} else {
self.ack_task(&task, TaskStatus::Dead, Some(e)).await?;
if let Err(dep_err) = self.fail_dependent_tasks(&task.id, &task.queue).await {
tracing::error!("Failed to propagate failure to dependent tasks: {}", dep_err);
}
}
}
}
process_result
}
async fn ack_task(&self, task: &Task, status: TaskStatus, error: Option<&Error>) -> Result<()> {
let active_key: RedisKey = Keys::active(&task.queue).into();
let task_key: RedisKey = Keys::task(&task.id).into();
self.state.redis.lrem(active_key, task.id.as_str().into(), 1).await?;
let mut task_data = task.clone();
task_data.status = status;
task_data.last_error = error.map(|e| e.to_string());
let data = rmp_serde::to_vec(&task_data)
.map_err(|e| Error::Serialization(e.to_string()))?;
self.state.redis.hset(
task_key.clone(),
vec![
("data".into(), RedisValue::Bytes(data.into())),
("queue".into(), task.queue.as_str().into()),
],
).await?;
let task_ttl = crate::config::get_task_ttl();
self.state.redis.expire(task_key, task_ttl).await?;
if status == TaskStatus::Processed {
self.increment_processed().await?;
let stats_key: RedisKey = Keys::stats(&task.queue).into();
let field_key: RedisKey = "processed".into();
let _ = self.state.redis.hincrby(stats_key, field_key, 1).await;
}
if status == TaskStatus::Processed {
if let Err(e) = self.check_dependent_tasks(&task.id).await {
tracing::error!("Failed to check dependent tasks: {}", e);
}
}
Ok(())
}
async fn increment_processed(&self) -> Result<()> {
let worker_key: RedisKey = Keys::meta_worker(&self.id).into();
let data = self.state.redis.get(worker_key.clone()).await?
.ok_or_else(|| Error::Validation("Worker metadata not found".into()))?;
let bytes = data.as_bytes()
.ok_or_else(|| Error::Serialization("Worker data is not bytes".into()))?;
let mut metadata: WorkerMetadata = rmp_serde::from_slice(bytes)
.map_err(|e| Error::Serialization(e.to_string()))?;
metadata.processed_total += 1;
metadata.last_heartbeat = Utc::now().timestamp();
let new_data = rmp_serde::to_vec(&metadata)
.map_err(|e| Error::Serialization(e.to_string()))?;
self.state.redis.set(worker_key, RedisValue::Bytes(new_data.into())).await?;
Ok(())
}
async fn check_dependent_tasks(&self, completed_task_id: &str) -> Result<()> {
dependencies::check_dependents(&self.state.redis, completed_task_id).await?;
Ok(())
}
async fn fail_dependent_tasks(&self, failed_task_id: &str, queue: &str) -> Result<()> {
let count = dependencies::fail_dependents(&self.state.redis, failed_task_id, queue).await?;
if count > 0 {
tracing::warn!(
"Task {} failed, {} dependent tasks moved to dead queue",
failed_task_id,
count
);
}
Ok(())
}
async fn schedule_retry(&self, task: &Task) -> Result<()> {
let delay = task.retry_delay()
.ok_or_else(|| Error::Validation("No retry delay available".into()))?;
let execute_at = Utc::now().timestamp() + delay.as_secs() as i64;
let retry_key: RedisKey = Keys::retry(&task.queue).into();
self.state.redis.zadd(
retry_key,
task.id.as_str().into(),
execute_at,
).await?;
tracing::debug!("Task {} scheduled for retry in {:?}", task.id, delay);
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkerMetadata {
pub id: String,
pub server_name: String,
pub queues: Vec<String>,
pub started_at: i64,
pub last_heartbeat: i64,
pub processed_total: u64,
pub status: String,
}
#[cfg(test)]
mod tests {
#[test]
fn test_next_queue_round_robin() {
}
}