use crate::{
TaskBackend, TaskStatus,
locking::TaskLock,
registry::TaskRegistry,
result::{ResultBackend, TaskResultMetadata},
webhook::{HttpWebhookSender, WebhookConfig, WebhookEvent, WebhookSender},
};
use chrono::Utc;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{Semaphore, broadcast};
#[derive(Debug, Clone)]
pub struct WorkerConfig {
pub name: String,
pub concurrency: usize,
pub poll_interval: Duration,
pub webhook_configs: Vec<WebhookConfig>,
}
impl WorkerConfig {
pub fn new(name: String) -> Self {
Self {
name,
concurrency: 4,
poll_interval: Duration::from_secs(1),
webhook_configs: Vec::new(),
}
}
pub fn with_concurrency(mut self, concurrency: usize) -> Self {
self.concurrency = concurrency;
self
}
pub fn with_poll_interval(mut self, interval: Duration) -> Self {
self.poll_interval = interval;
self
}
pub fn with_webhook(mut self, webhook_config: WebhookConfig) -> Self {
self.webhook_configs.push(webhook_config);
self
}
pub fn with_webhooks(mut self, webhook_configs: Vec<WebhookConfig>) -> Self {
self.webhook_configs = webhook_configs;
self
}
}
impl Default for WorkerConfig {
fn default() -> Self {
Self::new("worker".to_string())
}
}
pub struct Worker {
config: WorkerConfig,
shutdown_tx: broadcast::Sender<()>,
registry: Option<Arc<TaskRegistry>>,
task_lock: Option<Arc<dyn TaskLock>>,
result_backend: Option<Arc<dyn ResultBackend>>,
webhook_senders: Vec<Arc<dyn WebhookSender>>,
concurrency_semaphore: Arc<Semaphore>,
}
impl Worker {
pub fn new(config: WorkerConfig) -> Self {
let (shutdown_tx, _) = broadcast::channel(1);
let concurrency_semaphore = Arc::new(Semaphore::new(config.concurrency));
let webhook_senders: Vec<Arc<dyn WebhookSender>> = config
.webhook_configs
.iter()
.map(|webhook_config| {
Arc::new(HttpWebhookSender::new(webhook_config.clone())) as Arc<dyn WebhookSender>
})
.collect();
Self {
config,
shutdown_tx,
registry: None,
task_lock: None,
result_backend: None,
webhook_senders,
concurrency_semaphore,
}
}
pub fn with_registry(mut self, registry: Arc<TaskRegistry>) -> Self {
self.registry = Some(registry);
self
}
pub fn with_lock(mut self, task_lock: Arc<dyn TaskLock>) -> Self {
self.task_lock = Some(task_lock);
self
}
pub fn with_result_backend(mut self, result_backend: Arc<dyn ResultBackend>) -> Self {
self.result_backend = Some(result_backend);
self
}
pub async fn run(
&self,
backend: Arc<dyn TaskBackend>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
use tokio::time::interval;
let mut shutdown_rx = self.shutdown_tx.subscribe();
let mut poll_interval = interval(self.config.poll_interval);
tracing::info!(
worker = %self.config.name,
concurrency = self.config.concurrency,
"Worker started"
);
loop {
tokio::select! {
_ = shutdown_rx.recv() => {
tracing::info!(worker = %self.config.name, "Shutdown signal received");
break;
}
_ = poll_interval.tick() => {
self.try_process_task(backend.clone()).await;
}
}
}
tracing::info!(worker = %self.config.name, "Worker stopped");
Ok(())
}
async fn try_process_task(&self, backend: Arc<dyn TaskBackend>) {
let permit = match self.concurrency_semaphore.clone().acquire_owned().await {
Ok(permit) => permit,
Err(_) => {
tracing::error!(
worker = %self.config.name,
"Concurrency semaphore closed unexpectedly"
);
return;
}
};
match backend.dequeue().await {
Ok(Some(task_id)) => {
tracing::info!(worker = %self.config.name, task_id = %task_id, "Processing task");
match self.execute_task(task_id, backend.clone()).await {
Ok(_) => {
tracing::info!(
worker = %self.config.name,
task_id = %task_id,
"Task completed successfully"
);
if let Err(e) = backend.update_status(task_id, TaskStatus::Success).await {
tracing::error!(
worker = %self.config.name,
task_id = %task_id,
error = %e,
"Failed to update task status"
);
}
}
Err(e) => {
tracing::error!(
worker = %self.config.name,
task_id = %task_id,
error = %e,
"Task failed"
);
if let Err(e) = backend.update_status(task_id, TaskStatus::Failure).await {
tracing::error!(
worker = %self.config.name,
task_id = %task_id,
error = %e,
"Failed to update task status"
);
}
}
}
drop(permit);
}
Ok(None) => {
}
Err(e) => {
tracing::error!(worker = %self.config.name, error = %e, "Failed to dequeue task");
}
}
}
async fn execute_task(
&self,
task_id: crate::TaskId,
backend: Arc<dyn TaskBackend>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
tracing::debug!(worker = %self.config.name, task_id = %task_id, "Executing task");
let started_at = Utc::now();
let mut lock_token = None;
if let Some(ref lock) = self.task_lock {
match lock.acquire(task_id, Duration::from_secs(300)).await? {
Some(token) => lock_token = Some(token),
None => {
tracing::info!(
worker = %self.config.name,
task_id = %task_id,
"Task already locked by another worker"
);
return Ok(());
}
}
}
let serialized_task = backend.get_task_data(task_id).await?;
let task_name = serialized_task
.as_ref()
.map(|t| t.name().to_string())
.unwrap_or_else(|| "unknown_task".to_string());
let result: Result<(), Box<dyn std::error::Error + Send + Sync>> =
if let Some(ref registry) = self.registry {
match serialized_task {
Some(serialized_task) => {
tracing::debug!(
worker = %self.config.name,
task_name = %task_name,
"Executing task with registry"
);
match registry
.create(serialized_task.name(), serialized_task.data())
.await
{
Ok(task_executor) => {
match task_executor.execute().await {
Ok(_) => {
tracing::info!(
worker = %self.config.name,
task_name = %task_name,
"Task completed successfully"
);
Ok(())
}
Err(e) => {
tracing::error!(
worker = %self.config.name,
task_name = %task_name,
error = %e,
"Task failed"
);
Err(Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
}
}
}
Err(e) => {
tracing::error!(
worker = %self.config.name,
task_name = %task_name,
error = %e,
"Failed to deserialize task"
);
Err(Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
}
}
}
None => {
tracing::warn!(
worker = %self.config.name,
task_id = %task_id,
"Task not found in backend"
);
Err(format!("Task {} not found", task_id).into())
}
}
} else {
tracing::debug!(
worker = %self.config.name,
"Task execution without registry (basic mode)"
);
Ok(())
};
let completed_at = Utc::now();
let duration_ms = (completed_at - started_at).num_milliseconds().max(0) as u64;
let (task_status, webhook_status) = match &result {
Ok(_) => (TaskStatus::Success, crate::webhook::TaskStatus::Success),
Err(_) => (TaskStatus::Failure, crate::webhook::TaskStatus::Failed),
};
let store_error = if let Some(ref result_backend) = self.result_backend {
let metadata = match result {
Ok(_) => TaskResultMetadata::new(
task_id,
task_status,
Some("Task completed successfully".to_string()),
),
Err(ref e) => {
TaskResultMetadata::with_error(task_id, format!("Task failed: {}", e))
}
};
result_backend.store_result(metadata).await.err()
} else {
None
};
if !self.webhook_senders.is_empty() {
let webhook_event = WebhookEvent {
task_id,
task_name,
status: webhook_status,
result: match webhook_status {
crate::webhook::TaskStatus::Success => {
Some("Task completed successfully".to_string())
}
crate::webhook::TaskStatus::Failed => None,
crate::webhook::TaskStatus::Cancelled => None,
},
error: match webhook_status {
crate::webhook::TaskStatus::Failed => match &result {
Err(e) => Some(e.to_string()),
_ => Some("Unknown error".to_string()),
},
_ => None,
},
started_at,
completed_at,
duration_ms,
};
for sender in &self.webhook_senders {
let sender_clone = Arc::clone(sender);
let event_clone = webhook_event.clone();
tokio::spawn(async move {
if let Err(e) = sender_clone.send(&event_clone).await {
tracing::error!(error = %e, "Failed to send webhook notification");
}
});
}
}
if let Some(ref lock) = self.task_lock
&& let Some(ref token) = lock_token
{
match lock.release(task_id, token).await {
Ok(false) => {
tracing::warn!(
worker = %self.config.name,
task_id = %task_id,
"Lock release returned false: token mismatch or lock already expired"
);
}
Err(e) => {
tracing::error!(
worker = %self.config.name,
task_id = %task_id,
error = %e,
"Failed to release task lock"
);
}
Ok(true) => {}
}
}
if let Some(e) = store_error {
return Err(Box::new(e));
}
result
}
pub async fn stop(&self) {
let _ = self.shutdown_tx.send(());
}
}
impl Default for Worker {
fn default() -> Self {
let config = WorkerConfig::default();
let concurrency_semaphore = Arc::new(Semaphore::new(config.concurrency));
Self {
config,
shutdown_tx: broadcast::channel(1).0,
registry: None,
task_lock: None,
result_backend: None,
webhook_senders: Vec::new(),
concurrency_semaphore,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{DummyBackend, Task, TaskId, TaskPriority};
use rstest::rstest;
use std::time::Duration;
use tokio::time::sleep;
#[allow(dead_code)]
struct TestTask {
id: TaskId,
name: String,
}
impl Task for TestTask {
fn id(&self) -> TaskId {
self.id
}
fn name(&self) -> &str {
&self.name
}
fn priority(&self) -> TaskPriority {
TaskPriority::new(5)
}
}
#[rstest]
#[tokio::test]
async fn test_worker_creation() {
let config = WorkerConfig::new("test-worker".to_string());
let worker = Worker::new(config);
assert_eq!(worker.config.name, "test-worker");
}
#[rstest]
#[tokio::test]
async fn test_worker_config_builder() {
let config = WorkerConfig::new("test".to_string())
.with_concurrency(8)
.with_poll_interval(Duration::from_millis(100));
assert_eq!(config.concurrency, 8);
assert_eq!(config.poll_interval, Duration::from_millis(100));
}
#[rstest]
#[tokio::test]
async fn test_worker_start_and_stop() {
let worker = Worker::new(WorkerConfig::default());
let backend = Arc::new(DummyBackend::new());
let worker_clone = Worker {
config: worker.config.clone(),
shutdown_tx: worker.shutdown_tx.clone(),
registry: None,
task_lock: None,
result_backend: None,
webhook_senders: Vec::new(),
concurrency_semaphore: worker.concurrency_semaphore.clone(),
};
let handle = tokio::spawn(async move { worker.run(backend).await });
sleep(Duration::from_millis(100)).await;
worker_clone.stop().await;
let result = tokio::time::timeout(Duration::from_secs(2), handle).await;
assert!(result.is_ok());
}
#[rstest]
#[tokio::test]
async fn test_worker_with_registry() {
use crate::registry::TaskRegistry;
let registry = Arc::new(TaskRegistry::new());
let worker = Worker::new(WorkerConfig::default()).with_registry(registry);
assert!(worker.registry.is_some());
}
#[rstest]
#[tokio::test]
async fn test_worker_with_lock() {
use crate::locking::MemoryTaskLock;
let lock = Arc::new(MemoryTaskLock::new());
let worker = Worker::new(WorkerConfig::default()).with_lock(lock);
assert!(worker.task_lock.is_some());
}
#[rstest]
#[tokio::test]
async fn test_try_process_task_returns_early_when_semaphore_closed() {
let config = WorkerConfig::new("test-worker".to_string());
let semaphore = Arc::new(Semaphore::new(1));
semaphore.close(); let worker = Worker {
config,
shutdown_tx: broadcast::channel(1).0,
registry: None,
task_lock: None,
result_backend: None,
webhook_senders: Vec::new(),
concurrency_semaphore: semaphore,
};
let backend = Arc::new(DummyBackend::new());
worker.try_process_task(backend).await;
}
#[rstest]
#[tokio::test]
async fn test_worker_with_result_backend() {
use crate::result::MemoryResultBackend;
let backend = Arc::new(MemoryResultBackend::new());
let worker = Worker::new(WorkerConfig::default()).with_result_backend(backend);
assert!(worker.result_backend.is_some());
}
}