use crate::common::model::CronConfig;
use crate::common::model::entity::{
account, module, platform, rel_account_platform, rel_module_account, rel_module_platform,
};
use crate::common::model::message::TaskEvent;
use crate::common::state::State;
use crate::engine::task::TaskManager;
use crate::engine::task::task_dispatch_adapter::build_task_dispatch;
use crate::queue::{QueueManager, QueuedItem};
use chrono::{DateTime, TimeZone, Utc};
use cron::Schedule;
use dashmap::DashMap;
use futures::StreamExt;
use log::{error, info, warn};
use sea_orm::prelude::Expr;
use sea_orm::{ColumnTrait, EntityTrait, JoinType, QueryFilter, QuerySelect, RelationTrait};
use std::collections::{HashSet, hash_map::DefaultHasher};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::time::{Duration, sleep};
use crate::sync::LeadershipGate;
use metrics::{counter, histogram};
use tokio::sync::broadcast;
struct ActiveCronJobGuard<'a> {
counter: &'a AtomicU64,
}
impl<'a> ActiveCronJobGuard<'a> {
fn new(counter: &'a AtomicU64) -> Self {
counter.fetch_add(1, Ordering::Relaxed);
Self { counter }
}
}
impl Drop for ActiveCronJobGuard<'_> {
fn drop(&mut self) {
self.counter.fetch_sub(1, Ordering::Relaxed);
}
}
pub struct CronScheduler {
task_manager: Arc<TaskManager>,
state: Arc<State>,
queue_manager: Arc<QueueManager>,
schedule_cache: DashMap<String, (Arc<Schedule>, Arc<Vec<(String, String)>>)>,
cron_config_cache: DashMap<String, Option<CronConfig>>,
right_now_context_hash: DashMap<String, u64>,
shutdown_rx: broadcast::Receiver<()>,
last_version: AtomicU64,
last_module_hash: AtomicU64,
last_refresh_at_ms: AtomicU64,
active_context_jobs: AtomicU64,
leadership_gate: Arc<dyn LeadershipGate>,
}
pub struct CronSchedulerConfig {
pub task_manager: Arc<TaskManager>,
pub state: Arc<State>,
pub queue_manager: Arc<QueueManager>,
pub shutdown_rx: broadcast::Receiver<()>,
pub leadership_gate: Arc<dyn LeadershipGate>,
}
impl CronScheduler {
pub async fn new(config: CronSchedulerConfig) -> Self {
Self {
task_manager: config.task_manager,
state: config.state,
queue_manager: config.queue_manager,
schedule_cache: DashMap::new(),
cron_config_cache: DashMap::new(),
right_now_context_hash: DashMap::new(),
shutdown_rx: config.shutdown_rx,
last_version: AtomicU64::new(0),
last_module_hash: AtomicU64::new(0),
last_refresh_at_ms: AtomicU64::new(0),
active_context_jobs: AtomicU64::new(0),
leadership_gate: config.leadership_gate,
}
}
pub fn has_running_tasks(&self) -> bool {
self.active_context_jobs.load(Ordering::Relaxed) > 0
}
async fn get_misfire_tolerance(&self) -> i64 {
let config = self.state.config.read().await;
config
.scheduler
.as_ref()
.and_then(|s| s.misfire_tolerance_secs)
.unwrap_or(300)
}
pub fn start(self: Arc<Self>) {
let this = self.clone();
tokio::spawn(async move {
this.refresh_loop().await;
});
tokio::spawn(async move {
self.run().await;
});
}
async fn refresh_loop(&self) {
info!("CronScheduler refresh loop started");
let mut shutdown = self.shutdown_rx.resubscribe();
loop {
self.refresh_cache().await;
let refresh_interval_secs = self
.state
.config
.read()
.await
.scheduler
.as_ref()
.and_then(|s| s.refresh_interval_secs)
.unwrap_or(60);
tokio::select! {
_ = shutdown.recv() => {
info!("CronScheduler refresh loop received shutdown signal");
break;
}
_ = sleep(Duration::from_secs(refresh_interval_secs)) => {}
}
}
}
async fn refresh_cache(&self) {
let namespace = self.state.cache_service.namespace();
let redis_version_key = if namespace.is_empty() {
"scheduler:config_version".to_string()
} else {
format!("{namespace}:scheduler:config_version")
};
let remote_version_bytes = self
.state
.cache_service
.get(&redis_version_key)
.await
.ok()
.flatten();
let remote_version: u64 = if let Some(bytes) = remote_version_bytes {
String::from_utf8(bytes)
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(0)
} else {
0
};
let local_version = self.last_version.load(Ordering::Relaxed);
let modules = self.task_manager.get_all_modules().await;
let module_signatures: Vec<(String, i32)> = modules
.iter()
.map(|m| (m.name().to_string(), m.version()))
.collect();
let module_names: Vec<String> = module_signatures
.iter()
.map(|(name, _)| name.clone())
.collect();
let module_hash = Self::hash_module_signatures(&module_signatures);
let local_module_hash = self.last_module_hash.load(Ordering::Relaxed);
let now_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
let last_refresh_at_ms = self.last_refresh_at_ms.load(Ordering::Relaxed);
let max_staleness_secs = self
.state
.config
.read()
.await
.scheduler
.as_ref()
.and_then(|s| s.max_staleness_secs)
.unwrap_or(120);
let staleness_exceeded =
Self::staleness_exceeded(now_ms, last_refresh_at_ms, max_staleness_secs);
if !staleness_exceeded
&& remote_version > 0
&& remote_version == local_version
&& module_hash == local_module_hash
{
return;
}
if module_hash != local_module_hash {
self.cron_config_cache.clear();
}
let module_set: HashSet<String> = module_names.iter().cloned().collect();
let start = std::time::Instant::now();
match self.fetch_all_enabled_contexts().await {
Ok(contexts) => {
let fetch_duration = start.elapsed();
let context_count = contexts.len();
let mut context_map: std::collections::HashMap<String, Vec<(String, String)>> =
std::collections::HashMap::new();
for (m, a, p) in contexts {
context_map.entry(m).or_default().push((a, p));
}
for (module, name) in modules.into_iter().zip(module_names.iter()) {
let cron_config = if let Some(entry) = self.cron_config_cache.get(name) {
entry.value().clone()
} else {
let config = module.cron();
self.cron_config_cache.insert(name.clone(), config.clone());
config
};
if let Some(cron_config) = cron_config {
if !cron_config.enable {
self.schedule_cache.remove(name);
self.right_now_context_hash.remove(name);
continue;
}
let contexts = context_map.remove(name).unwrap_or_default();
if cron_config.right_now || cron_config.run_now_and_schedule {
if contexts.is_empty() {
self.schedule_cache.remove(name);
self.right_now_context_hash.remove(name);
continue;
}
let context_hash = Self::hash_contexts(&contexts);
let last_hash = self
.right_now_context_hash
.get(name)
.map(|entry| *entry.value());
if last_hash != Some(context_hash) {
self.right_now_context_hash
.insert(name.clone(), context_hash);
let now = Utc::now();
self.process_module_contexts(name, &contexts, now).await;
}
if cron_config.right_now {
self.schedule_cache.remove(name);
continue;
}
}
if !contexts.is_empty() {
let schedule = Arc::new(cron_config.schedule.clone());
self.schedule_cache
.insert(name.clone(), (schedule, Arc::new(contexts)));
} else {
self.schedule_cache.remove(name);
}
} else {
self.schedule_cache.remove(name);
}
}
let stale_keys: Vec<String> = self
.schedule_cache
.iter()
.filter_map(|entry| {
if module_set.contains(entry.key()) {
None
} else {
Some(entry.key().clone())
}
})
.collect();
for key in stale_keys {
self.schedule_cache.remove(&key);
}
Self::remove_stale_keys(&self.cron_config_cache, &module_set);
Self::remove_stale_keys(&self.right_now_context_hash, &module_set);
if remote_version > 0 {
self.last_version.store(remote_version, Ordering::Relaxed);
}
self.last_module_hash.store(module_hash, Ordering::Relaxed);
self.last_refresh_at_ms.store(now_ms, Ordering::Relaxed);
let process_duration = start.elapsed() - fetch_duration;
info!(
"CronScheduler cache refreshed in {:?}. Fetch: {:?}, Process: {:?}. Total contexts: {}. Active scheduled modules: {}",
start.elapsed(),
fetch_duration,
process_duration,
context_count,
self.schedule_cache.len()
);
}
Err(e) => {
error!("Failed to refresh cron contexts: {}", e);
}
}
}
fn hash_module_signatures(signatures: &[(String, i32)]) -> u64 {
let mut sorted = signatures.to_vec();
sorted.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
let mut hasher = DefaultHasher::new();
for (name, version) in sorted {
name.hash(&mut hasher);
version.hash(&mut hasher);
}
hasher.finish()
}
fn hash_contexts(contexts: &[(String, String)]) -> u64 {
let mut sorted: Vec<(String, String)> = contexts.to_vec();
sorted.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
let mut hasher = DefaultHasher::new();
for (account, platform) in sorted {
account.hash(&mut hasher);
platform.hash(&mut hasher);
}
hasher.finish()
}
fn remove_stale_keys<V>(map: &DashMap<String, V>, allowed: &HashSet<String>) {
let stale: Vec<String> = map
.iter()
.filter_map(|entry| {
if allowed.contains(entry.key()) {
None
} else {
Some(entry.key().clone())
}
})
.collect();
for key in stale {
map.remove(&key);
}
}
async fn run(self: Arc<Self>) {
{
let leadership_gate = self.leadership_gate.clone();
tokio::spawn(async move {
leadership_gate.start().await;
});
}
info!("CronScheduler started (Leader Election Mode)");
let mut last_tick: Option<DateTime<Utc>> = None;
loop {
let now = Utc::now();
let current_second_ts = now.timestamp();
if let Some(current_second) = Utc.timestamp_opt(current_second_ts, 0).single() {
if self.leadership_gate.is_leader() {
if let Some(last_run) = last_tick {
let diff = current_second.signed_duration_since(last_run).num_seconds();
if diff > 1 {
let misfire_tolerance = self.get_misfire_tolerance().await;
if diff <= misfire_tolerance {
info!(
"Detected missed ticks. Catching up from {} to {}",
last_run, current_second
);
let mut cursor = last_run + chrono::Duration::seconds(1);
while cursor <= current_second {
self.clone().process_tick(cursor).await;
cursor += chrono::Duration::seconds(1);
}
last_tick = Some(current_second);
} else {
warn!(
"Missed ticks gap ({}) exceeds tolerance ({}). Skipping catch-up, setting last_tick to now.",
diff, misfire_tolerance
);
last_tick = Some(current_second);
self.clone().process_tick(current_second).await;
}
} else if diff > 0 {
last_tick = Some(current_second);
self.clone().process_tick(current_second).await;
}
} else {
last_tick = Some(current_second);
self.clone().process_tick(current_second).await;
}
} else {
last_tick = Some(current_second);
}
}
let now = Utc::now();
let next_second = now.timestamp() + 1;
let sleep_secs = (next_second - now.timestamp()).max(0) as u64;
let sleep_duration = std::time::Duration::from_secs(sleep_secs);
let mut shutdown = self.shutdown_rx.resubscribe();
tokio::select! {
_ = shutdown.recv() => {
info!("CronScheduler main loop received shutdown signal");
break;
}
_ = sleep(sleep_duration + Duration::from_millis(10)) => {}
}
}
}
async fn process_tick(self: Arc<Self>, current_tick: DateTime<Utc>) {
let mut tasks = Vec::new();
for r in self.schedule_cache.iter() {
let (module_name, (schedule, contexts)) = r.pair();
if Self::is_schedule_match(schedule, current_tick) {
tasks.push((module_name.clone(), contexts.clone()));
}
}
let start = std::time::Instant::now();
let mut total_triggered = 0;
for (module_name, contexts) in tasks {
let this = self.clone();
total_triggered += contexts.len();
tokio::spawn(async move {
this.process_module_contexts(&module_name, &contexts, current_tick)
.await;
});
}
let duration = start.elapsed().as_secs_f64();
histogram!("mocra_scheduler_tick_duration_seconds").record(duration);
if total_triggered > 0 {
info!(
"Scheduler tick processed {} potential tasks in {:.4}s",
total_triggered, duration
);
}
}
async fn process_module_contexts(
&self,
module_name: &str,
contexts: &[(String, String)],
current_tick: DateTime<Utc>,
) {
let _active_guard = ActiveCronJobGuard::new(&self.active_context_jobs);
let concurrency = self
.state
.config
.read()
.await
.scheduler
.as_ref()
.and_then(|s| s.concurrency)
.unwrap_or(100);
let timestamp = current_tick.timestamp();
let namespace_prefix = {
let ns = self.state.cache_service.namespace();
if ns.is_empty() {
None
} else {
Some(ns.to_string())
}
};
let namespace_prefix = Arc::new(namespace_prefix);
let batch_size = 500;
futures::stream::iter(contexts.chunks(batch_size))
.for_each_concurrent(Some(concurrency.div_ceil(batch_size)), |batch| {
let namespace_prefix = Arc::clone(&namespace_prefix);
async move {
let mut keys = Vec::with_capacity(batch.len());
let mut batch_items = Vec::with_capacity(batch.len());
for (account, platform) in batch {
let key = if let Some(prefix) = namespace_prefix.as_ref() {
format!(
"{prefix}:cron:{}:{}:{}:{}",
module_name, account, platform, timestamp
)
} else {
format!(
"cron:{}:{}:{}:{}",
module_name, account, platform, timestamp
)
};
keys.push(key);
batch_items.push((account, platform));
}
let key_refs: Vec<&str> = keys.iter().map(|s| s.as_str()).collect();
let lock_start = std::time::Instant::now();
match self
.state
.cache_service
.set_nx_batch(&key_refs, b"1", Some(Duration::from_secs(600)))
.await
{
Ok(results) => {
let lock_duration = lock_start.elapsed().as_secs_f64();
histogram!("mocra_scheduler_lock_acquisition_seconds")
.record(lock_duration);
for (success, (account, platform)) in
results.into_iter().zip(batch_items.iter())
{
if success {
info!(
"Triggering cron task for module: {} [{}@{}] at {}",
module_name, account, platform, current_tick
);
self.trigger_single_task(module_name, account, platform)
.await;
}
}
}
Err(e) => {
error!("Failed to acquire batch locks for cron task: {}", e);
}
}
}
})
.await;
}
fn is_schedule_match(schedule: &Schedule, target: DateTime<Utc>) -> bool {
let check_time = target - chrono::Duration::seconds(1);
if let Some(next) = schedule.after(&check_time).next() {
return next == target;
}
false
}
async fn trigger_single_task(&self, module_name: &str, account: &str, platform: &str) {
counter!("mocra_scheduled_tasks_total", "module" => module_name.to_string()).increment(1);
let task = TaskEvent {
account: account.to_string(),
platform: platform.to_string(),
module: Some(vec![module_name.to_string()]),
run_id: uuid::Uuid::now_v7(),
priority: crate::common::model::Priority::Normal,
};
let sender = self.queue_manager.get_task_push_channel();
let dispatch = match build_task_dispatch(&task, self.queue_manager.namespace.clone()) {
Ok(dispatch) => dispatch,
Err(e) => {
counter!("mocra_scheduler_task_drops_total", "module" => module_name.to_string(), "reason" => "dispatch_encode_failed").increment(1);
error!(
"Failed to build task dispatch for module {} [{}@{}]: {}",
module_name, account, platform, e
);
return;
}
};
match sender.try_send(QueuedItem::new(dispatch.clone())) {
Ok(()) => {}
Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
counter!("mocra_scheduler_task_drops_total", "module" => module_name.to_string(), "reason" => "channel_full").increment(1);
warn!(
"Task queue full, blocking send for module {} [{}@{}]",
module_name, account, platform
);
if let Err(e) = sender.send(QueuedItem::new(dispatch)).await {
counter!("mocra_scheduler_task_drops_total", "module" => module_name.to_string(), "reason" => "send_failed").increment(1);
error!(
"Failed to push cron task to queue for module {} [{}@{}]: {}",
module_name, account, platform, e
);
}
}
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
counter!("mocra_scheduler_task_drops_total", "module" => module_name.to_string(), "reason" => "channel_closed").increment(1);
error!(
"Task queue channel closed for module {} [{}@{}]",
module_name, account, platform
);
}
}
}
async fn fetch_all_enabled_contexts(
&self,
) -> Result<Vec<(String, String, String)>, sea_orm::DbErr> {
let results: Vec<(String, String, String)> = module::Entity::find()
.join(
JoinType::InnerJoin,
module::Relation::RelModuleAccount.def(),
)
.join(
JoinType::InnerJoin,
rel_module_account::Relation::Account.def(),
)
.join(
JoinType::InnerJoin,
account::Relation::RelAccountPlatform.def(),
)
.join(
JoinType::InnerJoin,
rel_account_platform::Relation::Platform.def(),
)
.join(
JoinType::InnerJoin,
platform::Relation::RelModulePlatform.def(),
)
.filter(
Expr::col((
rel_module_platform::Entity,
rel_module_platform::Column::ModuleId,
))
.eq(Expr::col((module::Entity, module::Column::Id))),
)
.filter(rel_module_account::Column::Enabled.eq(true))
.filter(rel_account_platform::Column::Enabled.eq(true))
.filter(rel_module_platform::Column::Enabled.eq(true))
.filter(module::Column::Enabled.eq(true))
.select_only()
.column(module::Column::Name)
.column(account::Column::Name)
.column(platform::Column::Name)
.into_tuple()
.all(&*self.state.db)
.await?;
Ok(results)
}
fn staleness_exceeded(now_ms: u64, last_refresh_at_ms: u64, max_staleness_secs: u64) -> bool {
if last_refresh_at_ms == 0 {
return false;
}
now_ms.saturating_sub(last_refresh_at_ms) > max_staleness_secs.saturating_mul(1000)
}
}
#[cfg(test)]
mod staleness_tests {
use super::CronScheduler;
#[test]
fn test_staleness_exceeded() {
let now_ms = 10_000;
assert!(!CronScheduler::staleness_exceeded(now_ms, 0, 120));
assert!(!CronScheduler::staleness_exceeded(now_ms, 9_500, 1));
assert!(CronScheduler::staleness_exceeded(now_ms, 8_000, 1));
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::TimeZone;
use std::str::FromStr;
#[test]
fn test_is_schedule_match() {
let schedule = Schedule::from_str("* * * * * *").unwrap();
let target = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap();
assert!(CronScheduler::is_schedule_match(&schedule, target));
let target_sec = Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 1).unwrap();
assert!(CronScheduler::is_schedule_match(&schedule, target_sec));
}
#[test]
fn test_specific_schedule_match() {
let schedule = Schedule::from_str("0 5 * * * *").unwrap();
let match_time = Utc.with_ymd_and_hms(2024, 1, 1, 10, 5, 0).unwrap();
assert!(CronScheduler::is_schedule_match(&schedule, match_time));
let no_match_time = Utc.with_ymd_and_hms(2024, 1, 1, 10, 6, 0).unwrap();
assert!(!CronScheduler::is_schedule_match(&schedule, no_match_time));
}
}