use crate::backend::RedisBroker;
use crate::backend::RedisConnectionType;
use crate::base::Broker;
use crate::components::heartbeat::{Heartbeat, HeartbeatMeta, WorkerEventSender};
use crate::components::processor::{Processor, ProcessorParams};
use crate::components::subscriber::SubscriberConfig;
use crate::components::ComponentLifecycle;
pub use crate::config::ServerConfig;
use crate::error::{Error, Result};
use crate::inspector::InspectorTrait;
use crate::inspector::RedisInspector;
use crate::task::Task;
use async_trait::async_trait;
use std::sync::Arc;
use std::time::Duration;
use tokio::signal;
use tokio::task::JoinHandle;
use uuid::Uuid;
#[async_trait]
pub trait Handler: Send + Sync {
async fn process_task(&self, task: Task) -> Result<()>;
}
pub struct HandlerFunc<F> {
func: F,
}
impl<F> HandlerFunc<F>
where
F: Fn(Task) -> Result<()> + Send + Sync,
{
pub fn new(func: F) -> Self {
Self { func }
}
}
#[async_trait]
impl<F> Handler for HandlerFunc<F>
where
F: Fn(Task) -> Result<()> + Send + Sync,
{
async fn process_task(&self, task: Task) -> Result<()> {
(self.func)(task)
}
}
pub struct AsyncHandlerFunc<F, Fut> {
func: F,
_phantom: std::marker::PhantomData<Fut>,
}
impl<F, Fut> AsyncHandlerFunc<F, Fut>
where
F: Fn(Task) -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<()>> + Send + Sync,
{
pub fn new(func: F) -> Self {
Self {
func,
_phantom: std::marker::PhantomData,
}
}
}
#[async_trait]
impl<F, Fut> Handler for AsyncHandlerFunc<F, Fut>
where
F: Fn(Task) -> Fut + Send + Sync,
Fut: std::future::Future<Output = Result<()>> + Send + Sync,
{
async fn process_task(&self, task: Task) -> Result<()> {
(self.func)(task).await
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
enum ServerState {
New,
Running,
Stopped,
Closed,
}
pub struct Server {
broker: Arc<dyn Broker>,
inspector: Arc<dyn InspectorTrait>,
config: ServerConfig,
state: ServerState,
host: String,
pid: i32,
server_uuid: String,
worker_event_sender: Option<WorkerEventSender>,
processor: Option<Processor>,
components: Vec<(Arc<dyn ComponentLifecycle + Send + Sync>, JoinHandle<()>)>,
group_aggregator: Option<Arc<dyn crate::components::aggregator::GroupAggregator>>,
}
impl Server {
pub fn full_server_id(&self) -> String {
format!("{}:{}:{}", self.host, self.pid, self.server_uuid)
}
pub async fn new(
redis_connection_config: RedisConnectionType,
config: ServerConfig,
) -> Result<Self> {
config.validate()?;
let redis_broker = RedisBroker::new(redis_connection_config).await?;
let broker = Arc::new(redis_broker);
Self::with_redis_broker(broker, config).await
}
pub async fn with_broker_and_inspector(
broker: Arc<dyn Broker>,
inspector: Arc<dyn InspectorTrait>,
config: ServerConfig,
) -> Result<Self> {
config.validate()?;
let host = hostname::get()
.unwrap_or_default()
.to_string_lossy()
.to_string();
let pid = std::process::id() as i32;
let server_uuid = Uuid::new_v4().to_string();
Ok(Self {
broker,
inspector,
config,
state: ServerState::New,
host,
pid,
server_uuid,
worker_event_sender: None,
processor: None,
components: Vec::new(),
group_aggregator: None,
})
}
pub async fn with_redis_broker(broker: Arc<RedisBroker>, config: ServerConfig) -> Result<Self> {
let inspector = Arc::new(RedisInspector::from_broker(broker.clone()));
Self::with_broker_and_inspector(broker, inspector, config).await
}
pub fn set_group_aggregator<A>(&mut self, aggregator: A)
where
A: crate::components::aggregator::GroupAggregator + 'static,
{
self.group_aggregator = Some(Arc::new(aggregator));
}
pub async fn start<H>(&mut self, handler: H) -> Result<()>
where
H: Handler + 'static,
{
if self.state != ServerState::New {
return Err(Error::ServerRunning);
}
self.state = ServerState::Running;
self.register_server().await?;
self.init_heartbeat();
self.init_janitor();
let event_rx = self.init_subscriber();
self.init_recoverer();
self.init_forwarder();
self.init_healthcheck();
self.init_aggregator();
let worker_event_sender = self.worker_event_sender.take();
let processor_params = ProcessorParams {
broker: Arc::clone(&self.broker),
inspector: Arc::clone(&self.inspector),
queues: self.config.get_queues_with_prefix(),
concurrency: self.config.concurrency,
strict_priority: self.config.strict_priority,
task_check_interval: self.config.task_check_interval,
shutdown_timeout: self.config.shutdown_timeout,
worker_event_sender,
};
let mut processor = Processor::new(processor_params);
let handler = Arc::new(handler);
processor.start(handler);
let cancellations = processor.cancellations();
if let Some(mut rx) = event_rx {
tokio::spawn(async move {
use crate::components::subscriber::SubscriptionEvent;
while let Some(event) = rx.recv().await {
if let SubscriptionEvent::TaskCancelled { task_id } = event {
tracing::info!("Received cancellation request for task: {}", task_id);
if cancellations.cancel(&task_id) {
tracing::info!("Successfully cancelled task: {}", task_id);
} else {
tracing::debug!("Task {} not found or already completed", task_id);
}
}
}
});
}
self.processor = Some(processor);
self.wait_for_signal().await;
Ok(())
}
pub async fn run<H>(&mut self, handler: H) -> Result<()>
where
H: Handler + 'static,
{
let result = self.start(handler).await;
self.shutdown().await?;
result
}
pub async fn stop(&mut self) -> Result<()> {
if self.state == ServerState::Running {
self.state = ServerState::Stopped;
}
Ok(())
}
pub async fn shutdown(&mut self) -> Result<()> {
if self.state == ServerState::Closed {
return Ok(());
}
self.state = ServerState::Closed;
for (component, handle) in self.components.drain(..) {
component.shutdown();
let _ = tokio::time::timeout(Duration::from_secs(5), handle).await;
}
if let Some(processor) = self.processor.as_mut() {
processor.shutdown().await;
}
if let Err(e) = self
.broker
.clear_server_state(
&self.host,
self.pid,
&self.server_uuid,
self.config.acl_tenant.as_deref(),
)
.await
{
tracing::warn!(
"Failed to clear server state ({}:{}:{}): {}",
self.host,
self.pid,
self.server_uuid,
e
);
} else {
tracing::debug!("Server state cleared: {}", self.full_server_id());
}
self.broker.close().await?;
Ok(())
}
pub async fn ping(&self) -> Result<()> {
self.broker.ping().await
}
async fn register_server(&self) -> Result<()> {
let server_info = crate::proto::ServerInfo {
host: self.host.clone(),
pid: self.pid,
server_id: self.server_uuid.clone(),
concurrency: self.config.concurrency as i32,
queues: self.config.get_queues_with_prefix(),
strict_priority: self.config.strict_priority,
status: "active".to_string(),
start_time: Some(prost_types::Timestamp::from(std::time::SystemTime::now())),
active_worker_count: 0,
};
self
.broker
.write_server_state(
&server_info,
vec![],
Duration::from_secs(3600),
self.config.acl_tenant.as_deref(),
)
.await
}
fn init_heartbeat(&mut self) {
let meta = HeartbeatMeta {
host: self.host.clone(),
pid: self.pid,
server_uuid: self.server_uuid.clone(),
concurrency: self.config.concurrency,
queues: self.config.get_queues_with_prefix(),
strict_priority: self.config.strict_priority,
started: std::time::SystemTime::now(),
acl_tenant: self.config.acl_tenant.clone(),
};
let (heartbeat, worker_event_sender) = Heartbeat::new(
Arc::clone(&self.broker),
self.config.heartbeat_interval,
meta,
);
self.worker_event_sender = Some(worker_event_sender);
let hb = Arc::new(heartbeat);
let hb_handle = hb.clone().start();
self
.components
.push((hb as Arc<dyn ComponentLifecycle + Send + Sync>, hb_handle));
}
fn init_janitor(&mut self) {
let janitor_config = crate::components::janitor::JanitorConfig {
interval: self.config.janitor_interval,
batch_size: self.config.janitor_batch_size,
queues: self
.config
.get_queues_with_prefix()
.keys()
.cloned()
.collect(),
};
let janitor = Arc::new(crate::components::janitor::Janitor::new(
Arc::clone(&self.broker),
janitor_config,
));
let janitor_handle = janitor.clone().start();
self.components.push((
janitor as Arc<dyn ComponentLifecycle + Send + Sync>,
janitor_handle,
));
}
fn init_subscriber(
&mut self,
) -> Option<tokio::sync::mpsc::Receiver<crate::components::subscriber::SubscriptionEvent>> {
let mut subscriber = crate::components::subscriber::Subscriber::new(
Arc::clone(&self.broker),
SubscriberConfig::default(),
);
let event_rx = subscriber.take_receiver();
let subscriber = Arc::new(subscriber);
let subscriber_handle = subscriber.clone().start();
self.components.push((
subscriber as Arc<dyn ComponentLifecycle + Send + Sync>,
subscriber_handle,
));
event_rx
}
fn init_recoverer(&mut self) {
let recoverer_config = crate::components::recoverer::RecovererConfig {
interval: self.config.janitor_interval, queues: self
.config
.get_queues_with_prefix()
.keys()
.cloned()
.collect(),
};
let recoverer = Arc::new(crate::components::recoverer::Recoverer::new(
Arc::clone(&self.broker),
recoverer_config,
));
let recoverer_handle = recoverer.clone().start();
self.components.push((
recoverer as Arc<dyn ComponentLifecycle + Send + Sync>,
recoverer_handle,
));
}
fn init_forwarder(&mut self) {
let forwarder_config = crate::components::forwarder::ForwarderConfig {
interval: self.config.delayed_task_check_interval,
queues: self
.config
.get_queues_with_prefix()
.keys()
.cloned()
.collect(),
};
let forwarder = Arc::new(crate::components::forwarder::Forwarder::new(
Arc::clone(&self.broker),
forwarder_config,
));
let forwarder_handle = forwarder.clone().start();
self.components.push((
forwarder as Arc<dyn ComponentLifecycle + Send + Sync>,
forwarder_handle,
));
}
fn init_healthcheck(&mut self) {
let healthcheck_config = crate::components::healthcheck::HealthcheckConfig {
interval: self.config.health_check_interval,
};
let healthcheck = Arc::new(crate::components::healthcheck::Healthcheck::new(
Arc::clone(&self.broker),
healthcheck_config,
));
let healthcheck_handle = healthcheck.clone().start();
self.components.push((
healthcheck as Arc<dyn ComponentLifecycle + Send + Sync>,
healthcheck_handle,
));
}
fn init_aggregator(&mut self) {
if self.config.group_aggregator_enabled {
let aggregator_config = crate::components::aggregator::AggregatorConfig {
interval: Duration::from_secs(5),
queues: self
.config
.get_queues_with_prefix()
.keys()
.cloned()
.collect(),
grace_period: self.config.group_grace_period,
max_delay: self.config.group_max_delay,
max_size: self.config.group_max_size,
group_aggregator: self.group_aggregator.clone(),
};
let aggregator = Arc::new(crate::components::aggregator::Aggregator::new(
Arc::clone(&self.broker),
aggregator_config,
));
let aggregator_handle = aggregator.clone().start();
self.components.push((
aggregator as Arc<dyn ComponentLifecycle + Send + Sync>,
aggregator_handle,
));
}
}
#[allow(dead_code)]
fn is_task_expired(&self, task_msg: &crate::proto::TaskMessage) -> bool {
if task_msg.deadline <= 0 {
return false;
}
let now = chrono::Utc::now().timestamp();
now > task_msg.deadline
}
async fn wait_for_signal(&self) {
let _ = signal::ctrl_c().await;
tracing::info!("Received shutdown signal");
}
}
impl Drop for Server {
fn drop(&mut self) {
if self.state == ServerState::Closed {
return;
}
let host = self.host.clone();
let pid = self.pid;
let uuid = self.server_uuid.clone();
let tenant = self.config.acl_tenant.clone();
let broker = Arc::clone(&self.broker);
if let Ok(rt) = tokio::runtime::Handle::try_current() {
rt.spawn(async move {
if let Err(e) = broker
.clear_server_state(&host, pid, &uuid, tenant.as_deref())
.await
{
tracing::warn!(
"(Drop) Failed to clear server state {}:{}:{}: {}",
host,
pid,
uuid,
e
);
} else {
tracing::debug!("(Drop) Server state cleared {}:{}:{}", host, pid, uuid);
}
});
} else {
tracing::error!(
"[asynq] Drop without runtime; server keys will expire via TTL for {}:{}:{}",
host,
pid,
uuid
);
}
}
}
pub struct ServerBuilder {
redis_config: Option<RedisConnectionType>,
broker: Option<Arc<dyn Broker>>,
inspector: Option<Arc<dyn InspectorTrait>>,
config: ServerConfig,
}
impl ServerBuilder {
pub fn new() -> Self {
Self {
redis_config: None,
broker: None,
inspector: None,
config: ServerConfig::default(),
}
}
pub fn redis_config(mut self, config: RedisConnectionType) -> Self {
self.redis_config = Some(config);
self
}
#[cfg(feature = "postgres")]
pub fn postgres_broker(mut self, broker: Arc<crate::backend::pgdb::PostgresBroker>) -> Self {
let inspector = Arc::new(crate::backend::pgdb::PostgresInspector::from_broker(
broker.clone(),
));
self.broker = Some(broker);
self.inspector = Some(inspector);
self
}
pub fn inspector<I: InspectorTrait + 'static>(mut self, inspector: Arc<I>) -> Self {
self.inspector = Some(inspector);
self
}
pub fn server_config(mut self, config: ServerConfig) -> Self {
self.config = config;
self
}
pub fn concurrency(mut self, concurrency: usize) -> Self {
self.config = self.config.concurrency(concurrency);
self
}
pub fn add_queue<S: AsRef<str>>(mut self, name: S, priority: i32) -> Result<Self> {
self.config = self.config.add_queue(name, priority)?;
Ok(self)
}
pub async fn build(self) -> Result<Server> {
if let Some(broker) = self.broker {
if let Some(inspector) = self.inspector {
return Server::with_broker_and_inspector(broker, inspector, self.config).await;
}
return Err(Error::config(
"When providing a custom broker, you must also provide an inspector using .inspector(). Example: .broker(my_broker).inspector(my_inspector)",
));
}
let redis_config = self
.redis_config
.ok_or_else(|| Error::config("Redis configuration is required"))?;
Server::new(redis_config, self.config).await
}
}
impl Default for ServerBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::base::constants::DEFAULT_TIMEOUT;
#[tokio::test]
async fn test_handler_func() {
let handler = HandlerFunc::new(|task: Task| {
println!("Processing task: {}", task.get_type());
Ok(())
});
let task = Task::new("test", b"payload").unwrap();
let result = handler.process_task(task).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_async_handler_func() {
let handler = AsyncHandlerFunc::new(|task: Task| async move {
println!("Processing async task: {}", task.get_type());
Ok(())
});
let task = Task::new("test", b"payload").unwrap();
let result = handler.process_task(task).await;
assert!(result.is_ok());
}
#[test]
fn test_server_builder() {
let builder = ServerBuilder::new().concurrency(4);
assert_eq!(builder.config.concurrency, 4);
}
#[test]
fn test_timeout_calculation_logic() {
use std::time::Duration;
let now = chrono::Utc::now().timestamp();
let mut task_msg = crate::proto::TaskMessage {
deadline: now + 300,
..Default::default()
};
let timeout = if task_msg.timeout > 0 {
Some(Duration::from_secs(task_msg.timeout as u64))
} else if task_msg.deadline > 0 {
let remaining = task_msg.deadline - now;
if remaining > 0 {
Some(Duration::from_secs(remaining as u64))
} else {
None
}
} else {
Some(DEFAULT_TIMEOUT)
};
assert_eq!(timeout, Some(Duration::from_secs(300)));
task_msg.timeout = 0;
task_msg.deadline = now + 600;
let timeout = if task_msg.timeout > 0 {
Some(Duration::from_secs(task_msg.timeout as u64))
} else if task_msg.deadline > 0 {
let remaining = task_msg.deadline - now;
if remaining > 0 {
Some(Duration::from_secs(remaining as u64))
} else {
None
}
} else {
Some(DEFAULT_TIMEOUT)
};
assert!(timeout.is_some());
assert!(timeout.unwrap().as_secs() > 590);
}
#[test]
fn test_expiry_check_logic() {
let now = chrono::Utc::now().timestamp();
let mut task_msg = crate::proto::TaskMessage {
deadline: now + 300,
..Default::default()
};
let is_expired = task_msg.deadline > 0 && now > task_msg.deadline;
assert!(!is_expired);
task_msg.deadline = now - 300;
let is_expired = task_msg.deadline > 0 && now > task_msg.deadline;
assert!(is_expired);
task_msg.deadline = 0;
let is_expired = task_msg.deadline > 0 && now > task_msg.deadline;
assert!(!is_expired);
}
}