use crate::{
storage::{Keys, RedisClient, dependencies},
aggregator::AggregatorManager,
Error, Result, Task,
};
use chrono::Utc;
use fred::prelude::{RedisKey, RedisValue};
use rmp_serde;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
const DEFAULT_AGGREGATION_SIZE: usize = 10;
pub struct Scheduler {
redis: RedisClient,
queues: Vec<String>,
shutdown: Arc<AtomicBool>,
aggregator: Option<Arc<AggregatorManager>>,
}
impl Scheduler {
#[must_use]
pub fn new(redis: RedisClient, queues: Vec<String>) -> Self {
Self {
redis,
queues,
shutdown: Arc::new(AtomicBool::new(false)),
aggregator: None,
}
}
#[must_use]
pub fn with_aggregator(redis: RedisClient, queues: Vec<String>, aggregator: Arc<AggregatorManager>) -> Self {
Self {
redis,
queues,
shutdown: Arc::new(AtomicBool::new(false)),
aggregator: Some(aggregator),
}
}
pub async fn run(self) -> Result<()> {
tracing::info!("Scheduler started for queues: {:?}", self.queues);
let mut tick_count = 0u64;
loop {
if self.shutdown.load(Ordering::Relaxed) {
tracing::info!("Scheduler stopped after {} ticks", tick_count);
return Ok(());
}
tick_count += 1;
if let Err(e) = self.check_retry_tasks().await {
tracing::error!("Retry check error: {}", e);
}
if tick_count % 5 == 0 {
if let Err(e) = self.check_delayed_tasks().await {
tracing::error!("Delayed check error: {}", e);
}
}
if tick_count % 60 == 0 {
if let Err(e) = self.check_cron_tasks().await {
tracing::error!("Cron check error: {}", e);
}
}
if tick_count % 2 == 0 {
if let Err(e) = self.check_aggregation().await {
tracing::error!("Aggregation check error: {}", e);
}
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
pub fn shutdown(&self) {
self.shutdown.store(true, Ordering::Relaxed);
}
async fn check_retry_tasks(&self) -> Result<()> {
let now = Utc::now().timestamp();
const BATCH_SIZE: usize = 100;
for queue in &self.queues {
let retry_key: RedisKey = Keys::retry(queue).into();
let queue_key: RedisKey = Keys::queue(queue).into();
match self.redis.move_expired_tasks_lua(
retry_key,
queue_key,
now,
BATCH_SIZE,
).await {
Ok(count) if count > 0 => {
tracing::debug!("Moved {} tasks from retry to queue {}", count, queue);
}
Ok(_) => {}
Err(e) => {
tracing::warn!("Failed to move retry tasks for queue {}: {}", queue, e);
}
}
}
Ok(())
}
async fn check_delayed_tasks(&self) -> Result<()> {
let now = Utc::now().timestamp();
const BATCH_SIZE: usize = 100;
for queue in &self.queues {
let delayed_key: RedisKey = Keys::delayed(queue).into();
let queue_key: RedisKey = Keys::queue(queue).into();
match self.redis.move_expired_tasks_lua(
delayed_key,
queue_key,
now,
BATCH_SIZE,
).await {
Ok(count) if count > 0 => {
tracing::debug!("Moved {} tasks from delayed to queue {}", count, queue);
}
Ok(_) => {}
Err(e) => {
tracing::warn!("Failed to move delayed tasks for queue {}: {}", queue, e);
}
}
}
Ok(())
}
async fn check_cron_tasks(&self) -> Result<()> {
let now = Utc::now().timestamp();
for queue in &self.queues {
let cron_key: RedisKey = Keys::cron_queue(queue).into();
let queue_key: RedisKey = Keys::queue(queue).into();
let task_ids = self.redis.zrangebyscore(cron_key.clone(), 0, now).await?;
for task_id in task_ids {
let task_key: RedisKey = Keys::task(&task_id).into();
if let Some(data) = self.redis.get(task_key).await? {
let bytes = data.as_bytes()
.ok_or_else(|| Error::Serialization("Task data is not bytes".into()))?;
let cron_task: Task = rmp_serde::from_slice(bytes)
.map_err(|e| Error::Serialization(e.to_string()))?;
let cron_expr = cron_task.options.cron.clone()
.ok_or_else(|| Error::Validation("Cron task missing cron expression".into()))?;
self.redis.zrem(cron_key.clone(), task_id.as_str().into()).await?;
let new_task = match Task::builder(cron_task.task_type.clone())
.queue(queue.clone())
.max_retry(cron_task.options.max_retry)
.timeout(cron_task.options.timeout)
.priority(cron_task.options.priority)
.raw_payload(cron_task.payload.clone())
.build()
{
Ok(task) => task,
Err(e) => {
tracing::error!("Failed to create cron task instance for {}: {}", task_id, e);
self.redis.zadd(cron_key.clone(), task_id.as_str().into(), now + 60).await?;
continue;
}
};
let new_task_key: RedisKey = Keys::task(&new_task.id).into();
let new_task_data = rmp_serde::to_vec(&new_task)
.map_err(|e| Error::Serialization(e.to_string()))?;
self.redis.set(new_task_key, RedisValue::Bytes(new_task_data.into())).await?;
self.redis.rpush(queue_key.clone(), new_task.id.as_str().into()).await?;
tracing::debug!("Cron task {} instantiated and queued", task_id);
if let Some(next_time) = self.calculate_next_cron_time(&cron_expr, now) {
self.redis.zadd(cron_key.clone(), task_id.as_str().into(), next_time).await?;
tracing::debug!("Cron task {} rescheduled for {}", task_id, next_time);
} else {
tracing::warn!("Could not calculate next time for cron task {}", task_id);
}
}
}
}
Ok(())
}
fn calculate_next_cron_time(&self, cron_expr: &str, from_timestamp: i64) -> Option<i64> {
use cron::Schedule;
let schedule = Schedule::try_from(cron_expr).ok()?;
let from_datetime = chrono::DateTime::from_timestamp(from_timestamp, 0)?;
let timezone = from_datetime.timezone();
schedule.upcoming(timezone).next().map(|dt| dt.timestamp())
}
async fn check_aggregation(&self) -> Result<()> {
let aggregator = match &self.aggregator {
Some(a) => a,
None => return Ok(()),
};
let now = Utc::now().timestamp();
let group_pattern = "rediq:meta:group:*";
let (mut cursor, keys) = self.redis.scan_match(0, group_pattern, 100).await?;
for meta_key in keys {
let group_name = meta_key.strip_prefix("rediq:meta:group:").unwrap_or(&meta_key);
let count: i64 = match self.redis.hget(meta_key.clone().into(), "count".into()).await? {
Some(v) => v.as_string().and_then(|s| s.parse().ok()).unwrap_or(0),
None => continue,
};
let config = aggregator.default_config();
let should_aggregate = count as usize >= config.max_size;
if should_aggregate {
let group_key: RedisKey = Keys::group(group_name).into();
let task_ids = self.redis.zrange(group_key.clone(), 0, -1).await?;
if task_ids.is_empty() {
continue;
}
let mut tasks = Vec::new();
for task_id in &task_ids {
let task_key: RedisKey = Keys::task(task_id).into();
if let Some(data) = self.redis.hget(task_key.clone(), "data".into()).await? {
if let Some(bytes) = data.as_bytes() {
if let Ok(task) = rmp_serde::from_slice::<Task>(bytes) {
tasks.push(task);
}
}
}
}
let aggregated_task = if let Some(agg) = aggregator.get(group_name) {
agg.aggregate(group_name, tasks)?
} else {
continue;
};
for task_id in &task_ids {
self.redis.zrem(group_key.clone(), task_id.as_str().into()).await?;
}
self.redis.hset(
meta_key.clone().into(),
vec![("count".into(), "0".into())],
).await?;
if let Some(new_task) = aggregated_task {
let new_task_key: RedisKey = Keys::task(&new_task.id).into();
let new_task_data = rmp_serde::to_vec(&new_task)
.map_err(|e| Error::Serialization(e.to_string()))?;
self.redis.hset(
new_task_key.clone(),
vec![
("data".into(), RedisValue::Bytes(new_task_data.into())),
("queue".into(), new_task.queue.as_str().into()),
],
).await?;
let queue_key: RedisKey = Keys::queue(&new_task.queue).into();
self.redis.rpush(queue_key, new_task.id.as_str().into()).await?;
tracing::debug!("Aggregated {} tasks from group {} into task {}",
task_ids.len(), group_name, new_task.id);
}
}
}
while cursor != 0 {
let (next_cursor, more_keys) = self.redis.scan_match(cursor, group_pattern, 100).await?;
cursor = next_cursor;
for _key in more_keys {
}
}
Ok(())
}
pub async fn register_dependencies(&self, task: &Task) -> Result<()> {
let deps = match &task.options.depends_on {
Some(d) if !d.is_empty() => d.clone(),
_ => return Ok(()),
};
dependencies::register(&self.redis, &task.id, &deps).await
}
pub async fn check_dependent_tasks(&self, completed_task_id: &str) -> Result<()> {
dependencies::check_dependents(&self.redis, completed_task_id).await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[ignore = "Requires Redis server"]
async fn test_scheduler_creation() {
let redis_url = std::env::var("REDIS_URL")
.unwrap_or_else(|_| "redis://localhost:6379".to_string());
let redis = RedisClient::from_url(&redis_url)
.await
.unwrap();
let scheduler = Scheduler::new(redis, vec!["default".to_string()]);
assert_eq!(scheduler.queues.len(), 1);
}
}