persistent_scheduler/core/context.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
use crate::core::cleaner::TaskCleaner;
use crate::core::cron::next_run;
use crate::core::handlers::TaskHandlers;
use crate::core::model::TaskMeta;
use crate::core::status_updater::TaskStatusUpdater;
use crate::core::store::TaskStore;
use crate::core::task::Task;
use crate::core::task_kind::TaskKind;
use crate::utc_now;
use ahash::AHashMap;
use std::sync::Arc;
use std::time::Duration;
use super::flow::TaskFlow;
pub struct TaskContext<S>
where
S: TaskStore + Send + Sync + Clone + 'static, // Ensures that S is a type that implements the TaskStore trait
{
queue_concurrency: AHashMap<String, usize>, // Stores the concurrency level for each task queue
handlers: TaskHandlers, // Collection of task handlers to process different task types
store: Arc<S>, // Arc wrapper around the task store, allowing shared ownership across threads
}
impl<S> TaskContext<S>
where
S: TaskStore + Send + Sync + Clone + 'static, // S must implement TaskStore, and be Sync and Send
{
/// Creates a new TaskContext with the provided store.
pub fn new(store: S) -> Self {
let store = Arc::new(store);
Self {
queue_concurrency: AHashMap::new(), // Initialize concurrency map as empty
handlers: TaskHandlers::new(), // Create a new TaskHandlers instance
store: store.clone(), // Wrap the store in an Arc for shared ownership
}
}
/// Registers a new task type in the context.
pub fn register<T>(mut self) -> Self
where
T: Task, // T must implement the Task trait
{
self.handlers.register::<T>(); // Register the task handler
self.queue_concurrency.insert(T::TASK_QUEUE.to_owned(), 4); // Set default concurrency for the task queue
self
}
/// Sets the concurrency level for a specified queue.
pub fn set_concurrency(mut self, queue: &str, count: usize) -> Self {
self.queue_concurrency.insert(queue.to_owned(), count); // Update the concurrency level for the queue
self
}
/// Starts the task cleaner to periodically clean up tasks.
fn start_task_cleaner(&self) {
let cleaner = Arc::new(TaskCleaner::new(self.store.clone())); // Create a new TaskCleaner
cleaner.start(Duration::from_secs(60 * 10)); // Start the cleaner to run every 10 minutes
}
/// Starts worker threads for processing tasks in each queue.
async fn start_flow(&self) {
let status_updater = Arc::new(TaskStatusUpdater::new(
self.store.clone(),
self.queue_concurrency.len(),
));
let flow = Arc::new(TaskFlow::new(
self.store.clone(),
&self.queue_concurrency,
Arc::new(self.handlers.clone()),
status_updater,
));
flow.start().await;
}
/// Starts the task context, including workers and the task cleaner.
pub async fn start(self) -> Self {
self.start_flow().await; // Start task workers
self.start_task_cleaner(); // Start the task cleaner
self
}
/// Adds a new task to the context for execution.
pub async fn add_task<T>(&self, task: T, delay_seconds: Option<u32>) -> Result<(), String>
where
T: Task + Send + Sync + 'static, // T must implement the Task trait and be thread-safe
{
let mut task_meta = task.new_meta(); // Create metadata for the new task
let next_run = match T::TASK_KIND {
TaskKind::Once | TaskKind::Repeat => {
let delay_seconds = delay_seconds.unwrap_or(task_meta.delay_seconds) * 1000;
utc_now!() + delay_seconds as i64
} // Set the next run time by adding a delay to the current time, allowing the task to run at a specified future time.
TaskKind::Cron => {
let schedule = T::SCHEDULE
.ok_or_else(|| "Cron schedule is required for TaskKind::Cron".to_string())?; // Ensure a cron schedule is provided
let timezone = T::TIMEZONE
.ok_or_else(|| "Timezone is required for TaskKind::Cron".to_string())?; // Ensure a timezone is provided
// Calculate the next run time based on the cron schedule and timezone
next_run(schedule, timezone, 0).ok_or_else(|| {
format!("Failed to calculate next run for cron task '{}': invalid schedule or timezone", T::TASK_KEY)
})?
}
};
task_meta.next_run = next_run;
task_meta.last_run = next_run;
self.store
.store_task(task_meta) // Store the task metadata in the task store
.await
.map_err(|e| e.to_string()) // Handle any errors during the store operation
}
/// Adds a new task to the context for execution.
pub async fn add_tasks<T>(&self, tasks: Vec<TaskAndDelay<T>>) -> Result<(), String>
where
T: Task + Send + Sync + 'static, // T must implement the Task trait and be thread-safe
{
let mut batch: Vec<TaskMeta> = Vec::new();
for task in tasks {
let mut task_meta = task.inner.new_meta(); // Create metadata for the new task
let next_run = match T::TASK_KIND {
TaskKind::Once | TaskKind::Repeat => {
let delay_seconds =
task.delay_seconds.unwrap_or(task_meta.delay_seconds) * 1000;
utc_now!() + delay_seconds as i64
} // Set the next run time by adding a delay to the current time, allowing the task to run at a specified future time.
TaskKind::Cron => {
let schedule = T::SCHEDULE.ok_or_else(|| {
"Cron schedule is required for TaskKind::Cron".to_string()
})?; // Ensure a cron schedule is provided
let timezone = T::TIMEZONE
.ok_or_else(|| "Timezone is required for TaskKind::Cron".to_string())?; // Ensure a timezone is provided
// Calculate the next run time based on the cron schedule and timezone
next_run(schedule, timezone, 0).ok_or_else(|| {
format!("Failed to calculate next run for cron task '{}': invalid schedule or timezone", T::TASK_KEY)
})?
}
};
task_meta.next_run = next_run;
task_meta.last_run = next_run;
batch.push(task_meta);
}
self.store
.store_tasks(batch) // Store the task metadata in the task store
.await
.map_err(|e| e.to_string()) // Handle any errors during the store operation
}
}
pub struct TaskAndDelay<T: Task> {
pub inner: T,
pub delay_seconds: Option<u32>,
}