use crate::components::ComponentLifecycle;
use crate::error::Result;
use crate::scheduler::{PeriodicTask, Scheduler};
use async_trait::async_trait;
use std::collections::{HashMap, HashSet};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
#[async_trait]
pub trait PeriodicTaskConfigProvider: Send + Sync {
async fn get_configs(&self) -> Result<Vec<PeriodicTaskConfig>>;
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PeriodicTaskConfig {
pub task: String,
pub cron: String,
pub payload: Vec<u8>,
pub queue: String,
}
impl PeriodicTaskConfig {
pub fn new(task: String, cron: String, payload: Vec<u8>, queue: String) -> Self {
Self {
task,
cron,
payload,
queue,
}
}
pub fn to_periodic_task(&self) -> Result<PeriodicTask> {
PeriodicTask::new(
self.task.clone(),
self.cron.clone(),
self.payload.clone(),
self.queue.clone(),
)
.map_err(|e| crate::error::Error::other(format!("Failed to create PeriodicTask: {e}")))
}
pub fn config_key(&self) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
self.payload.hash(&mut hasher);
let payload_hash = hasher.finish();
format!(
"{}:{}:{}:{}",
self.task, self.cron, self.queue, payload_hash
)
}
}
#[derive(Debug, Clone)]
pub struct PeriodicTaskManagerConfig {
pub sync_interval: std::time::Duration,
}
impl Default for PeriodicTaskManagerConfig {
fn default() -> Self {
Self {
sync_interval: std::time::Duration::from_secs(60),
}
}
}
pub struct PeriodicTaskManager {
scheduler: Arc<Scheduler>,
config: PeriodicTaskManagerConfig,
config_provider: Arc<dyn PeriodicTaskConfigProvider>,
registered_tasks: Arc<Mutex<HashMap<String, String>>>, done: Arc<AtomicBool>,
}
impl PeriodicTaskManager {
pub fn new(
scheduler: Arc<Scheduler>,
config: PeriodicTaskManagerConfig,
config_provider: Arc<dyn PeriodicTaskConfigProvider>,
) -> Self {
Self {
scheduler,
config,
config_provider,
registered_tasks: Arc::new(Mutex::new(HashMap::new())),
done: Arc::new(AtomicBool::new(false)),
}
}
pub fn start(self: Arc<Self>) -> JoinHandle<()> {
let scheduler = self.scheduler.clone();
tokio::spawn(async move {
scheduler.start().await;
});
tokio::spawn(async move {
let mut interval = tokio::time::interval(self.config.sync_interval);
loop {
interval.tick().await;
if self.done.load(Ordering::Relaxed) {
tracing::debug!("PeriodicTaskManager: shutting down");
break;
}
if let Err(e) = self.sync_tasks().await {
tracing::error!("PeriodicTaskManager sync error: {}", e);
}
}
})
}
async fn sync_tasks(&self) -> Result<()> {
let new_configs = self.config_provider.get_configs().await?;
let new_config_keys: HashSet<String> = new_configs.iter().map(|c| c.config_key()).collect();
let mut registered = self.registered_tasks.lock().await;
let current_keys: HashSet<String> = registered.keys().cloned().collect();
let to_add: Vec<_> = new_configs
.iter()
.filter(|c| !current_keys.contains(&c.config_key()))
.collect();
let to_remove: Vec<_> = current_keys.difference(&new_config_keys).cloned().collect();
for config_key in to_remove {
if let Some(entry_id) = registered.remove(&config_key) {
if let Err(e) = self.scheduler.unregister(&entry_id).await {
tracing::error!("Failed to unregister task {}: {}", config_key, e);
} else {
tracing::info!("PeriodicTaskManager: unregistered task {}", config_key);
}
}
}
for config in to_add {
let periodic_task = match config.to_periodic_task() {
Ok(task) => task,
Err(e) => {
tracing::error!("Failed to create PeriodicTask from config: {}", e);
continue;
}
};
match self.scheduler.register(periodic_task, &config.queue).await {
Ok(entry_id) => {
let config_key = config.config_key();
registered.insert(config_key.clone(), entry_id);
tracing::info!(
"PeriodicTaskManager: registered task {} with cron '{}'",
config_key,
config.cron
);
}
Err(e) => {
tracing::error!("Failed to register task {}: {}", config.task, e);
}
}
}
Ok(())
}
pub fn shutdown(&self) {
self.done.store(true, Ordering::Relaxed);
let scheduler = self.scheduler.clone();
tokio::spawn(async move {
scheduler.stop().await;
});
}
pub fn is_done(&self) -> bool {
self.done.load(Ordering::Relaxed)
}
}
impl ComponentLifecycle for PeriodicTaskManager {
fn start(self: Arc<Self>) -> JoinHandle<()> {
PeriodicTaskManager::start(self)
}
fn shutdown(&self) {
PeriodicTaskManager::shutdown(self)
}
fn is_done(&self) -> bool {
PeriodicTaskManager::is_done(self)
}
}
#[cfg(feature = "default")]
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::RedisConnectionType;
use crate::client::Client;
struct TestConfigProvider {
configs: Arc<Mutex<Vec<PeriodicTaskConfig>>>,
}
impl TestConfigProvider {
fn new(configs: Vec<PeriodicTaskConfig>) -> Self {
Self {
configs: Arc::new(Mutex::new(configs)),
}
}
#[allow(dead_code)]
async fn set_configs(&self, configs: Vec<PeriodicTaskConfig>) {
let mut c = self.configs.lock().await;
*c = configs;
}
}
#[async_trait]
impl PeriodicTaskConfigProvider for TestConfigProvider {
async fn get_configs(&self) -> Result<Vec<PeriodicTaskConfig>> {
let configs = self.configs.lock().await;
Ok(configs.clone())
}
}
#[test]
fn test_periodic_task_manager_config_default() {
let config = PeriodicTaskManagerConfig::default();
assert_eq!(config.sync_interval, std::time::Duration::from_secs(60));
}
#[test]
fn test_periodic_task_config() {
let config = PeriodicTaskConfig::new(
"test:task".to_string(),
"* * * * * *".to_string(),
b"test payload".to_vec(),
"default".to_string(),
);
assert_eq!(config.task, "test:task");
assert_eq!(config.cron, "* * * * * *");
assert_eq!(config.queue, "default");
}
#[tokio::test]
#[ignore] async fn test_periodic_task_manager_sync() {
let redis_connection_config = RedisConnectionType::single("redis://localhost:6379").unwrap();
let client = Arc::new(Client::new(redis_connection_config).await.unwrap());
let scheduler = Arc::new(Scheduler::new(client, None).await.unwrap());
let initial_configs = vec![PeriodicTaskConfig::new(
"task1".to_string(),
"* * * * * *".to_string(),
b"payload1".to_vec(),
"default".to_string(),
)];
let provider = Arc::new(TestConfigProvider::new(initial_configs));
let config = PeriodicTaskManagerConfig::default();
let manager = PeriodicTaskManager::new(scheduler, config, provider.clone());
manager.sync_tasks().await.unwrap();
let registered = manager.registered_tasks.lock().await;
assert_eq!(registered.len(), 1);
}
#[tokio::test]
#[ignore] async fn test_periodic_task_manager_shutdown() {
let redis_connection_config = RedisConnectionType::single("redis://localhost:6379").unwrap();
let client = Arc::new(Client::new(redis_connection_config).await.unwrap());
let scheduler = Arc::new(Scheduler::new(client, None).await.unwrap());
let provider = Arc::new(TestConfigProvider::new(vec![]));
let config = PeriodicTaskManagerConfig::default();
let manager = PeriodicTaskManager::new(scheduler, config, provider);
assert!(!manager.is_done());
manager.shutdown();
assert!(manager.is_done());
}
}