1use std::collections::HashMap;
16use std::sync::Arc;
17use std::time::Duration;
18use std::time::SystemTime;
19
20use tokio::sync::RwLock;
21use tokio::sync::Semaphore;
22use tokio::time::timeout;
23use tracing::error;
24use tracing::info;
25use tracing::warn;
26use uuid::Uuid;
27
28use crate::schedule::task::TaskExecution;
29use crate::schedule::SchedulerError;
30use crate::schedule::Task;
31use crate::schedule::TaskContext;
32use crate::schedule::TaskResult;
33use crate::schedule::TaskStatus;
34
35#[derive(Debug, Clone)]
37pub struct ExecutorConfig {
38 pub max_concurrent_tasks: usize,
39 pub default_timeout: Duration,
40 pub enable_metrics: bool,
41}
42
43impl Default for ExecutorConfig {
44 fn default() -> Self {
45 Self {
46 max_concurrent_tasks: 10,
47 default_timeout: Duration::from_secs(300), enable_metrics: true,
49 }
50 }
51}
52
53#[derive(Debug, Clone, Default)]
55pub struct ExecutionMetrics {
56 pub total_executions: u64,
57 pub successful_executions: u64,
58 pub failed_executions: u64,
59 pub cancelled_executions: u64,
60 pub average_execution_time: Duration,
61 pub max_execution_time: Duration,
62}
63
64pub struct TaskExecutor {
66 config: ExecutorConfig,
67 semaphore: Arc<Semaphore>,
68 running_tasks: Arc<RwLock<HashMap<String, tokio::task::JoinHandle<()>>>>,
69 metrics: Arc<RwLock<ExecutionMetrics>>,
70 executions: Arc<RwLock<HashMap<String, TaskExecution>>>,
71}
72
73impl TaskExecutor {
74 pub fn new(config: ExecutorConfig) -> Self {
75 let semaphore = Arc::new(Semaphore::new(config.max_concurrent_tasks));
76
77 Self {
78 config,
79 semaphore,
80 running_tasks: Arc::new(RwLock::new(HashMap::new())),
81 metrics: Arc::new(RwLock::new(ExecutionMetrics::default())),
82 executions: Arc::new(RwLock::new(HashMap::new())),
83 }
84 }
85
86 pub async fn execute_task(&self, task: Arc<Task>, scheduled_time: SystemTime) -> Result<String, SchedulerError> {
88 let execution_id = Uuid::new_v4().to_string();
89 let mut execution = TaskExecution::new(task.id.clone(), scheduled_time);
90 execution.execution_id = execution_id.clone();
91
92 {
94 let mut executions = self.executions.write().await;
95 executions.insert(execution_id.clone(), execution.clone());
96 }
97
98 let executor = self.clone_for_task();
100 let task_clone = task.clone();
101 let execution_id_clone = execution_id.clone();
102
103 let handle = tokio::spawn(async move {
105 executor
106 .run_task_internal(task_clone, execution_id_clone, scheduled_time)
107 .await;
108 });
109
110 {
112 let mut running_tasks = self.running_tasks.write().await;
113 running_tasks.insert(execution_id.clone(), handle);
114 }
115
116 info!("Task scheduled for execution: {} ({})", task.name, execution_id);
117 Ok(execution_id)
118 }
119
120 pub async fn cancel_task(&self, execution_id: &str) -> Result<(), SchedulerError> {
122 let handle = {
123 let mut running_tasks = self.running_tasks.write().await;
124 running_tasks.remove(execution_id)
125 };
126
127 if let Some(handle) = handle {
128 handle.abort();
129
130 if let Some(mut execution) = self.get_execution(execution_id).await {
132 execution.cancel();
133 let mut executions = self.executions.write().await;
134 executions.insert(execution_id.to_string(), execution);
135 }
136
137 info!("Task cancelled: {}", execution_id);
138 Ok(())
139 } else {
140 Err(SchedulerError::TaskNotFound(execution_id.to_string()))
141 }
142 }
143
144 pub async fn get_execution(&self, execution_id: &str) -> Option<TaskExecution> {
146 let executions = self.executions.read().await;
147 executions.get(execution_id).cloned()
148 }
149
150 pub async fn get_task_executions(&self, task_id: &str) -> Vec<TaskExecution> {
152 let executions = self.executions.read().await;
153 executions.values().filter(|e| e.task_id == task_id).cloned().collect()
154 }
155
156 pub async fn get_metrics(&self) -> ExecutionMetrics {
158 let metrics = self.metrics.read().await;
159 metrics.clone()
160 }
161
162 pub async fn running_task_count(&self) -> usize {
164 let running_tasks = self.running_tasks.read().await;
165 running_tasks.len()
166 }
167
168 pub async fn cleanup_old_executions(&self, older_than: Duration) {
170 let cutoff_time = SystemTime::now() - older_than;
171 let mut executions = self.executions.write().await;
172
173 executions.retain(|_, execution| match execution.status {
174 TaskStatus::Pending | TaskStatus::Running => true,
175 _ => execution.end_time.is_none_or(|end| end > cutoff_time),
176 });
177 }
178
179 fn clone_for_task(&self) -> TaskExecutorInternal {
182 TaskExecutorInternal {
183 config: self.config.clone(),
184 semaphore: self.semaphore.clone(),
185 running_tasks: self.running_tasks.clone(),
186 metrics: self.metrics.clone(),
187 executions: self.executions.clone(),
188 }
189 }
190
191 async fn run_task_internal(&self, task: Arc<Task>, execution_id: String, scheduled_time: SystemTime) {
192 let internal = self.clone_for_task();
193 internal.run_task_internal(task, execution_id, scheduled_time).await;
194 }
195
196 pub async fn execute_task_with_delay(
197 &self,
198 task: Arc<Task>,
199 scheduled_time: SystemTime,
200 execution_delay: Option<Duration>,
201 ) -> Result<String, SchedulerError> {
202 let execution_id = Uuid::new_v4().to_string();
203 let mut execution = TaskExecution::new(task.id.clone(), scheduled_time);
204 execution.execution_id = execution_id.clone();
205
206 let actual_execution_time = if let Some(delay) = execution_delay.or(task.execution_delay) {
208 scheduled_time + delay
209 } else {
210 scheduled_time
211 };
212
213 {
215 let mut executions = self.executions.write().await;
216 executions.insert(execution_id.clone(), execution.clone());
217 }
218
219 let executor = self.clone_for_task();
221 let task_clone = task.clone();
222 let execution_id_clone = execution_id.clone();
223
224 let handle = tokio::spawn(async move {
226 let now = SystemTime::now();
228 if actual_execution_time > now {
229 if let Ok(delay_duration) = actual_execution_time.duration_since(now) {
230 tokio::time::sleep(delay_duration).await;
231 }
232 }
233
234 executor
235 .run_task_internal(task_clone, execution_id_clone, actual_execution_time)
236 .await;
237 });
238
239 {
241 let mut running_tasks = self.running_tasks.write().await;
242 running_tasks.insert(execution_id.clone(), handle);
243 }
244
245 info!(
246 "Task scheduled for delayed execution: {} ({}) - delay: {:?}",
247 task.name,
248 execution_id,
249 execution_delay.or(task.execution_delay)
250 );
251 Ok(execution_id)
252 }
253}
254
255#[derive(Clone)]
257struct TaskExecutorInternal {
258 config: ExecutorConfig,
259 semaphore: Arc<Semaphore>,
260 running_tasks: Arc<RwLock<HashMap<String, tokio::task::JoinHandle<()>>>>,
261 metrics: Arc<RwLock<ExecutionMetrics>>,
262 executions: Arc<RwLock<HashMap<String, TaskExecution>>>,
263}
264
265impl TaskExecutorInternal {
266 async fn run_task_internal(&self, task: Arc<Task>, execution_id: String, scheduled_time: SystemTime) {
267 let _permit = match self.semaphore.acquire().await {
269 Ok(permit) => permit,
270 Err(_) => {
271 error!("Failed to acquire semaphore permit for task: {}", task.id);
272 return;
273 }
274 };
275
276 let mut context = TaskContext::new(task.id.clone(), scheduled_time);
278 context.execution_id = execution_id.clone();
279 context.mark_started();
280
281 {
282 let mut executions = self.executions.write().await;
283 if let Some(execution) = executions.get_mut(&execution_id) {
284 execution.start();
285 }
286 }
287
288 info!("Starting task execution: {} ({})", task.name, execution_id);
289
290 let task_timeout = task.timeout.unwrap_or(self.config.default_timeout);
292 let result = match timeout(task_timeout, task.execute(context)).await {
293 Ok(result) => result,
294 Err(_) => {
295 warn!("Task execution timed out: {} ({})", task.name, execution_id);
296 TaskResult::Failed("Task execution timed out".to_string())
297 }
298 };
299
300 {
302 let mut executions = self.executions.write().await;
303 if let Some(execution) = executions.get_mut(&execution_id) {
304 execution.complete(result.clone());
305 }
306 }
307
308 if self.config.enable_metrics {
310 self.update_metrics(&result).await;
311 }
312
313 {
315 let mut running_tasks = self.running_tasks.write().await;
316 running_tasks.remove(&execution_id);
317 }
318
319 match result {
320 TaskResult::Success(msg) => {
321 info!(
322 "Task completed successfully: {} ({}) - {:?}",
323 task.name, execution_id, msg
324 );
325 }
326 TaskResult::Failed(err) => {
327 error!("Task failed: {} ({}) - {}", task.name, execution_id, err);
328 }
329 TaskResult::Skipped(reason) => {
330 info!("Task skipped: {} ({}) - {}", task.name, execution_id, reason);
331 }
332 }
333 }
334
335 async fn update_metrics(&self, result: &TaskResult) {
336 let mut metrics = self.metrics.write().await;
337 metrics.total_executions += 1;
338
339 match result {
340 TaskResult::Success(_) => metrics.successful_executions += 1,
341 TaskResult::Failed(_) => metrics.failed_executions += 1,
342 TaskResult::Skipped(_) => metrics.successful_executions += 1,
343 }
344 }
345}
346
347pub struct ExecutorPool {
349 executors: Vec<Arc<TaskExecutor>>,
350 current_index: Arc<RwLock<usize>>,
351}
352
353impl ExecutorPool {
354 pub fn new(pool_size: usize, config: ExecutorConfig) -> Self {
355 let executors = (0..pool_size)
356 .map(|_| Arc::new(TaskExecutor::new(config.clone())))
357 .collect();
358
359 Self {
360 executors,
361 current_index: Arc::new(RwLock::new(0)),
362 }
363 }
364
365 pub async fn get_executor(&self) -> Arc<TaskExecutor> {
367 let mut index = self.current_index.write().await;
368 let executor = self.executors[*index].clone();
369 *index = (*index + 1) % self.executors.len();
370 executor
371 }
372
373 pub async fn total_running_tasks(&self) -> usize {
375 let mut total = 0;
376 for executor in &self.executors {
377 total += executor.running_task_count().await;
378 }
379 total
380 }
381
382 pub async fn combined_metrics(&self) -> ExecutionMetrics {
384 let mut combined = ExecutionMetrics::default();
385
386 for executor in &self.executors {
387 let metrics = executor.get_metrics().await;
388 combined.total_executions += metrics.total_executions;
389 combined.successful_executions += metrics.successful_executions;
390 combined.failed_executions += metrics.failed_executions;
391 combined.cancelled_executions += metrics.cancelled_executions;
392
393 if metrics.max_execution_time > combined.max_execution_time {
394 combined.max_execution_time = metrics.max_execution_time;
395 }
396 }
397
398 combined
399 }
400}