use crate::{Error, Job, JobPayload, QueueConnection};
use async_trait::async_trait;
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Semaphore;
use tracing::{debug, error, info, warn};
#[async_trait]
pub trait TenantScopeProvider: Send + Sync {
async fn with_scope(
&self,
tenant_id: i64,
f: Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>,
) -> Result<(), Error>;
}
#[derive(Debug, Clone)]
pub struct WorkerConfig {
pub queues: Vec<String>,
pub max_jobs: usize,
pub sleep_duration: Duration,
pub stop_on_error: bool,
}
impl Default for WorkerConfig {
fn default() -> Self {
Self {
queues: vec!["default".to_string()],
max_jobs: 10,
sleep_duration: Duration::from_secs(1),
stop_on_error: false,
}
}
}
impl WorkerConfig {
pub fn new(queues: Vec<String>) -> Self {
Self {
queues,
..Default::default()
}
}
pub fn max_jobs(mut self, max: usize) -> Self {
self.max_jobs = max;
self
}
}
type JobHandler =
Arc<dyn Fn(String) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send>> + Send + Sync>;
pub struct Worker {
connection: QueueConnection,
config: WorkerConfig,
handlers: HashMap<String, JobHandler>,
semaphore: Arc<Semaphore>,
shutdown: Arc<tokio::sync::Notify>,
tenant_scope: Option<Arc<dyn TenantScopeProvider>>,
}
impl Worker {
pub fn new(connection: QueueConnection, config: WorkerConfig) -> Self {
let semaphore = Arc::new(Semaphore::new(config.max_jobs));
Self {
connection,
config,
handlers: HashMap::new(),
semaphore,
shutdown: Arc::new(tokio::sync::Notify::new()),
tenant_scope: None,
}
}
pub fn with_tenant_scope(mut self, provider: Arc<dyn TenantScopeProvider>) -> Self {
self.tenant_scope = Some(provider);
self
}
pub fn register<J>(&mut self)
where
J: Job + serde::de::DeserializeOwned + 'static,
{
let type_name = std::any::type_name::<J>().to_string();
let handler: JobHandler = Arc::new(move |data: String| {
Box::pin(async move {
let job: J = serde_json::from_str(&data)
.map_err(|e| Error::DeserializationFailed(e.to_string()))?;
job.handle().await
})
});
self.handlers.insert(type_name, handler);
}
pub async fn run(&self) -> Result<(), Error> {
info!(
queues = ?self.config.queues,
max_jobs = self.config.max_jobs,
"Starting queue worker"
);
let conn = self.connection.clone();
let queues = self.config.queues.clone();
let shutdown = self.shutdown.clone();
tokio::spawn(async move {
loop {
tokio::select! {
_ = shutdown.notified() => break,
_ = tokio::time::sleep(Duration::from_secs(1)) => {
for queue in &queues {
if let Err(e) = conn.migrate_delayed(queue).await {
error!(queue = queue, error = %e, "Failed to migrate delayed jobs");
}
}
}
}
}
});
loop {
tokio::select! {
_ = self.shutdown.notified() => {
info!("Worker shutting down");
info!("Waiting for in-flight jobs to complete");
let _ = self.semaphore.acquire_many(self.config.max_jobs as u32).await;
return Ok(());
}
result = self.process_next() => {
if let Err(e) = result {
error!(error = %e, "Error processing job");
if self.config.stop_on_error {
return Err(e);
}
}
}
}
}
}
async fn process_next(&self) -> Result<(), Error> {
for queue in &self.config.queues {
if let Some(payload) = self.connection.pop_nowait(queue).await? {
self.process_job(payload).await?;
return Ok(());
}
}
tokio::time::sleep(self.config.sleep_duration).await;
Ok(())
}
async fn process_job(&self, payload: JobPayload) -> Result<(), Error> {
let permit = self.semaphore.clone().acquire_owned().await.unwrap();
let connection = self.connection.clone();
let handlers = self.handlers.clone();
let job_type = payload.job_type.clone();
let job_id = payload.id;
let tenant_scope = self.tenant_scope.clone();
let tenant_id = payload.tenant_id;
tokio::spawn(async move {
let _permit = permit;
debug!(
job_id = %job_id,
job_type = &job_type,
tenant_id = ?tenant_id,
"Processing job"
);
let handler = match handlers.get(&job_type) {
Some(h) => h,
None => {
warn!(job_type = &job_type, "No handler registered for job type");
return;
}
};
let job_result = match (&tenant_scope, tenant_id) {
(Some(scope), Some(id)) => {
let job_fut = Box::pin(handler(payload.data.clone()));
scope.with_scope(id, job_fut).await
}
_ => handler(payload.data.clone()).await,
};
match job_result {
Ok(()) => {
info!(job_id = %job_id, job_type = &job_type, "Job completed successfully");
}
Err(e) => {
error!(job_id = %job_id, job_type = &job_type, error = %e, "Job failed");
if payload.has_exceeded_retries() {
warn!(job_id = %job_id, "Job exceeded max retries, moving to failed queue");
if let Err(e) = connection.fail(payload, &e).await {
error!(error = %e, "Failed to move job to failed queue");
}
} else {
let delay = Duration::from_secs(2u64.pow(payload.attempts));
if let Err(e) = connection.release(payload, delay).await {
error!(error = %e, "Failed to release job for retry");
}
}
}
}
});
Ok(())
}
pub fn shutdown(&self) {
self.shutdown.notify_waiters();
}
}
impl Clone for Worker {
fn clone(&self) -> Self {
Self {
connection: self.connection.clone(),
config: self.config.clone(),
handlers: HashMap::new(), semaphore: self.semaphore.clone(),
shutdown: self.shutdown.clone(),
tenant_scope: self.tenant_scope.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
#[test]
fn test_tenant_scope_provider_is_object_safe() {
struct NoopProvider;
#[async_trait]
impl TenantScopeProvider for NoopProvider {
async fn with_scope(
&self,
_tenant_id: i64,
f: Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>,
) -> Result<(), Error> {
f.await
}
}
let _provider: Arc<dyn TenantScopeProvider> = Arc::new(NoopProvider);
}
struct MockScopeProvider {
called_with: Arc<Mutex<Vec<i64>>>,
should_fail: bool,
}
impl MockScopeProvider {
fn new() -> Self {
Self {
called_with: Arc::new(Mutex::new(Vec::new())),
should_fail: false,
}
}
fn failing() -> Self {
Self {
called_with: Arc::new(Mutex::new(Vec::new())),
should_fail: true,
}
}
}
#[async_trait]
impl TenantScopeProvider for MockScopeProvider {
async fn with_scope(
&self,
tenant_id: i64,
f: Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>,
) -> Result<(), Error> {
self.called_with.lock().unwrap().push(tenant_id);
if self.should_fail {
return Err(Error::tenant_not_found(tenant_id));
}
f.await
}
}
async fn make_worker() -> Worker {
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
tokio::spawn(async move {
loop {
let Ok((mut stream, _)) = listener.accept().await else {
break;
};
tokio::spawn(async move {
let (reader, mut writer) = stream.split();
let mut lines = BufReader::new(reader).lines();
while let Ok(Some(_line)) = lines.next_line().await {
let _ = writer.write_all(b"+OK\r\n").await;
}
});
}
});
let config = crate::QueueConfig::new(format!("redis://127.0.0.1:{port}"));
let conn = tokio::time::timeout(
std::time::Duration::from_secs(2),
crate::QueueConnection::new(config),
)
.await
.expect("fake Redis connection timed out")
.expect("fake Redis connection failed");
Worker::new(conn, WorkerConfig::default())
}
#[tokio::test]
async fn test_with_tenant_scope_stores_provider() {
let worker = make_worker().await;
let provider = Arc::new(MockScopeProvider::new());
let worker = worker.with_tenant_scope(provider);
assert!(
worker.tenant_scope.is_some(),
"tenant_scope must be Some after with_tenant_scope()"
);
}
#[tokio::test]
async fn test_worker_without_scope_has_none_by_default() {
let worker = make_worker().await;
assert!(
worker.tenant_scope.is_none(),
"tenant_scope must be None by default"
);
}
#[tokio::test]
async fn test_clone_preserves_tenant_scope() {
let worker = make_worker().await;
let provider: Arc<dyn TenantScopeProvider> = Arc::new(MockScopeProvider::new());
let worker = worker.with_tenant_scope(provider);
let cloned = worker.clone();
assert!(
cloned.tenant_scope.is_some(),
"Clone must preserve tenant_scope"
);
}
#[tokio::test]
async fn test_clone_without_scope_preserves_none() {
let worker = make_worker().await;
let cloned = worker.clone();
assert!(
cloned.tenant_scope.is_none(),
"Clone must preserve None tenant_scope"
);
}
#[tokio::test]
async fn test_mock_scope_provider_calls_future() {
let provider = MockScopeProvider::new();
let calls = provider.called_with.clone();
let result = provider.with_scope(42, Box::pin(async { Ok(()) })).await;
assert!(result.is_ok());
assert_eq!(calls.lock().unwrap().as_slice(), &[42]);
}
#[tokio::test]
async fn test_mock_scope_provider_failure_returns_tenant_not_found() {
let provider = MockScopeProvider::failing();
let result = provider.with_scope(99, Box::pin(async { Ok(()) })).await;
assert!(matches!(
result,
Err(Error::TenantNotFound { tenant_id: 99 })
));
}
#[tokio::test]
async fn test_scope_dispatch_tenant_id_some_calls_with_scope() {
let mock = MockScopeProvider::new();
let calls = mock.called_with.clone();
let provider: Arc<dyn TenantScopeProvider> = Arc::new(mock);
let tenant_id: Option<i64> = Some(1);
let tenant_scope: Option<Arc<dyn TenantScopeProvider>> = Some(provider);
let job_ran = Arc::new(Mutex::new(false));
let job_ran_clone = job_ran.clone();
let job_fut = Box::pin(async move {
*job_ran_clone.lock().unwrap() = true;
Ok(())
});
let result = match (&tenant_scope, tenant_id) {
(Some(scope), Some(id)) => scope.with_scope(id, job_fut).await,
_ => job_fut.await,
};
assert!(result.is_ok());
assert_eq!(calls.lock().unwrap().as_slice(), &[1i64]);
assert!(*job_ran.lock().unwrap(), "job future must have been called");
}
#[tokio::test]
async fn test_scope_dispatch_tenant_id_none_skips_with_scope() {
let mock = MockScopeProvider::new();
let calls = mock.called_with.clone();
let provider: Arc<dyn TenantScopeProvider> = Arc::new(mock);
let tenant_id: Option<i64> = None;
let tenant_scope: Option<Arc<dyn TenantScopeProvider>> = Some(provider);
let job_ran = Arc::new(Mutex::new(false));
let job_ran_clone = job_ran.clone();
let job_fut = Box::pin(async move {
*job_ran_clone.lock().unwrap() = true;
Ok(())
});
let result = match (&tenant_scope, tenant_id) {
(Some(scope), Some(id)) => scope.with_scope(id, job_fut).await,
_ => job_fut.await,
};
assert!(result.is_ok());
assert!(
calls.lock().unwrap().is_empty(),
"with_scope must not be called when tenant_id is None"
);
assert!(
*job_ran.lock().unwrap(),
"job future must still run directly"
);
}
#[tokio::test]
async fn test_scope_dispatch_no_provider_runs_job_directly() {
let tenant_id: Option<i64> = Some(1);
let tenant_scope: Option<Arc<dyn TenantScopeProvider>> = None;
let job_ran = Arc::new(Mutex::new(false));
let job_ran_clone = job_ran.clone();
let job_fut = Box::pin(async move {
*job_ran_clone.lock().unwrap() = true;
Ok(())
});
let result = match (&tenant_scope, tenant_id) {
(Some(scope), Some(id)) => scope.with_scope(id, job_fut).await,
_ => job_fut.await,
};
assert!(result.is_ok());
assert!(
*job_ran.lock().unwrap(),
"job must run directly without a provider"
);
}
}