1use std::collections::HashMap;
19use std::sync::Arc;
20use std::time::Duration;
21use std::time::SystemTime;
22
23use tokio::sync::RwLock;
24use tokio::sync::Semaphore;
25use tokio::time::timeout;
26use tracing::error;
27use tracing::info;
28use tracing::warn;
29use uuid::Uuid;
30
31use crate::schedule::task::TaskExecution;
32use crate::schedule::SchedulerError;
33use crate::schedule::Task;
34use crate::schedule::TaskContext;
35use crate::schedule::TaskResult;
36use crate::schedule::TaskStatus;
37
38#[derive(Debug, Clone)]
40pub struct ExecutorConfig {
41 pub max_concurrent_tasks: usize,
42 pub default_timeout: Duration,
43 pub enable_metrics: bool,
44}
45
46impl Default for ExecutorConfig {
47 fn default() -> Self {
48 Self {
49 max_concurrent_tasks: 10,
50 default_timeout: Duration::from_secs(300), enable_metrics: true,
52 }
53 }
54}
55
56#[derive(Debug, Clone, Default)]
58pub struct ExecutionMetrics {
59 pub total_executions: u64,
60 pub successful_executions: u64,
61 pub failed_executions: u64,
62 pub cancelled_executions: u64,
63 pub average_execution_time: Duration,
64 pub max_execution_time: Duration,
65}
66
67pub struct TaskExecutor {
69 config: ExecutorConfig,
70 semaphore: Arc<Semaphore>,
71 running_tasks: Arc<RwLock<HashMap<String, tokio::task::JoinHandle<()>>>>,
72 metrics: Arc<RwLock<ExecutionMetrics>>,
73 executions: Arc<RwLock<HashMap<String, TaskExecution>>>,
74}
75
76impl TaskExecutor {
77 pub fn new(config: ExecutorConfig) -> Self {
78 let semaphore = Arc::new(Semaphore::new(config.max_concurrent_tasks));
79
80 Self {
81 config,
82 semaphore,
83 running_tasks: Arc::new(RwLock::new(HashMap::new())),
84 metrics: Arc::new(RwLock::new(ExecutionMetrics::default())),
85 executions: Arc::new(RwLock::new(HashMap::new())),
86 }
87 }
88
89 pub async fn execute_task(
91 &self,
92 task: Arc<Task>,
93 scheduled_time: SystemTime,
94 ) -> Result<String, SchedulerError> {
95 let execution_id = Uuid::new_v4().to_string();
96 let mut execution = TaskExecution::new(task.id.clone(), scheduled_time);
97 execution.execution_id = execution_id.clone();
98
99 {
101 let mut executions = self.executions.write().await;
102 executions.insert(execution_id.clone(), execution.clone());
103 }
104
105 let executor = self.clone_for_task();
107 let task_clone = task.clone();
108 let execution_id_clone = execution_id.clone();
109
110 let handle = tokio::spawn(async move {
112 executor
113 .run_task_internal(task_clone, execution_id_clone, scheduled_time)
114 .await;
115 });
116
117 {
119 let mut running_tasks = self.running_tasks.write().await;
120 running_tasks.insert(execution_id.clone(), handle);
121 }
122
123 info!(
124 "Task scheduled for execution: {} ({})",
125 task.name, execution_id
126 );
127 Ok(execution_id)
128 }
129
130 pub async fn cancel_task(&self, execution_id: &str) -> Result<(), SchedulerError> {
132 let handle = {
133 let mut running_tasks = self.running_tasks.write().await;
134 running_tasks.remove(execution_id)
135 };
136
137 if let Some(handle) = handle {
138 handle.abort();
139
140 if let Some(mut execution) = self.get_execution(execution_id).await {
142 execution.cancel();
143 let mut executions = self.executions.write().await;
144 executions.insert(execution_id.to_string(), execution);
145 }
146
147 info!("Task cancelled: {}", execution_id);
148 Ok(())
149 } else {
150 Err(SchedulerError::TaskNotFound(execution_id.to_string()))
151 }
152 }
153
154 pub async fn get_execution(&self, execution_id: &str) -> Option<TaskExecution> {
156 let executions = self.executions.read().await;
157 executions.get(execution_id).cloned()
158 }
159
160 pub async fn get_task_executions(&self, task_id: &str) -> Vec<TaskExecution> {
162 let executions = self.executions.read().await;
163 executions
164 .values()
165 .filter(|e| e.task_id == task_id)
166 .cloned()
167 .collect()
168 }
169
170 pub async fn get_metrics(&self) -> ExecutionMetrics {
172 let metrics = self.metrics.read().await;
173 metrics.clone()
174 }
175
176 pub async fn running_task_count(&self) -> usize {
178 let running_tasks = self.running_tasks.read().await;
179 running_tasks.len()
180 }
181
182 pub async fn cleanup_old_executions(&self, older_than: Duration) {
184 let cutoff_time = SystemTime::now() - older_than;
185 let mut executions = self.executions.write().await;
186
187 executions.retain(|_, execution| match execution.status {
188 TaskStatus::Pending | TaskStatus::Running => true,
189 _ => execution.end_time.is_none_or(|end| end > cutoff_time),
190 });
191 }
192
193 fn clone_for_task(&self) -> TaskExecutorInternal {
196 TaskExecutorInternal {
197 config: self.config.clone(),
198 semaphore: self.semaphore.clone(),
199 running_tasks: self.running_tasks.clone(),
200 metrics: self.metrics.clone(),
201 executions: self.executions.clone(),
202 }
203 }
204
205 async fn run_task_internal(
206 &self,
207 task: Arc<Task>,
208 execution_id: String,
209 scheduled_time: SystemTime,
210 ) {
211 let internal = self.clone_for_task();
212 internal
213 .run_task_internal(task, execution_id, scheduled_time)
214 .await;
215 }
216
217 pub async fn execute_task_with_delay(
218 &self,
219 task: Arc<Task>,
220 scheduled_time: SystemTime,
221 execution_delay: Option<Duration>,
222 ) -> Result<String, SchedulerError> {
223 let execution_id = Uuid::new_v4().to_string();
224 let mut execution = TaskExecution::new(task.id.clone(), scheduled_time);
225 execution.execution_id = execution_id.clone();
226
227 let actual_execution_time = if let Some(delay) = execution_delay.or(task.execution_delay) {
229 scheduled_time + delay
230 } else {
231 scheduled_time
232 };
233
234 {
236 let mut executions = self.executions.write().await;
237 executions.insert(execution_id.clone(), execution.clone());
238 }
239
240 let executor = self.clone_for_task();
242 let task_clone = task.clone();
243 let execution_id_clone = execution_id.clone();
244
245 let handle = tokio::spawn(async move {
247 let now = SystemTime::now();
249 if actual_execution_time > now {
250 if let Ok(delay_duration) = actual_execution_time.duration_since(now) {
251 tokio::time::sleep(delay_duration).await;
252 }
253 }
254
255 executor
256 .run_task_internal(task_clone, execution_id_clone, actual_execution_time)
257 .await;
258 });
259
260 {
262 let mut running_tasks = self.running_tasks.write().await;
263 running_tasks.insert(execution_id.clone(), handle);
264 }
265
266 info!(
267 "Task scheduled for delayed execution: {} ({}) - delay: {:?}",
268 task.name,
269 execution_id,
270 execution_delay.or(task.execution_delay)
271 );
272 Ok(execution_id)
273 }
274}
275
276#[derive(Clone)]
278struct TaskExecutorInternal {
279 config: ExecutorConfig,
280 semaphore: Arc<Semaphore>,
281 running_tasks: Arc<RwLock<HashMap<String, tokio::task::JoinHandle<()>>>>,
282 metrics: Arc<RwLock<ExecutionMetrics>>,
283 executions: Arc<RwLock<HashMap<String, TaskExecution>>>,
284}
285
286impl TaskExecutorInternal {
287 async fn run_task_internal(
288 &self,
289 task: Arc<Task>,
290 execution_id: String,
291 scheduled_time: SystemTime,
292 ) {
293 let _permit = match self.semaphore.acquire().await {
295 Ok(permit) => permit,
296 Err(_) => {
297 error!("Failed to acquire semaphore permit for task: {}", task.id);
298 return;
299 }
300 };
301
302 let mut context = TaskContext::new(task.id.clone(), scheduled_time);
304 context.execution_id = execution_id.clone();
305 context.mark_started();
306
307 {
308 let mut executions = self.executions.write().await;
309 if let Some(execution) = executions.get_mut(&execution_id) {
310 execution.start();
311 }
312 }
313
314 info!("Starting task execution: {} ({})", task.name, execution_id);
315
316 let task_timeout = task.timeout.unwrap_or(self.config.default_timeout);
318 let result = match timeout(task_timeout, task.execute(context)).await {
319 Ok(result) => result,
320 Err(_) => {
321 warn!("Task execution timed out: {} ({})", task.name, execution_id);
322 TaskResult::Failed("Task execution timed out".to_string())
323 }
324 };
325
326 {
328 let mut executions = self.executions.write().await;
329 if let Some(execution) = executions.get_mut(&execution_id) {
330 execution.complete(result.clone());
331 }
332 }
333
334 if self.config.enable_metrics {
336 self.update_metrics(&result).await;
337 }
338
339 {
341 let mut running_tasks = self.running_tasks.write().await;
342 running_tasks.remove(&execution_id);
343 }
344
345 match result {
346 TaskResult::Success(msg) => {
347 info!(
348 "Task completed successfully: {} ({}) - {:?}",
349 task.name, execution_id, msg
350 );
351 }
352 TaskResult::Failed(err) => {
353 error!("Task failed: {} ({}) - {}", task.name, execution_id, err);
354 }
355 TaskResult::Skipped(reason) => {
356 info!(
357 "Task skipped: {} ({}) - {}",
358 task.name, execution_id, reason
359 );
360 }
361 }
362 }
363
364 async fn update_metrics(&self, result: &TaskResult) {
365 let mut metrics = self.metrics.write().await;
366 metrics.total_executions += 1;
367
368 match result {
369 TaskResult::Success(_) => metrics.successful_executions += 1,
370 TaskResult::Failed(_) => metrics.failed_executions += 1,
371 TaskResult::Skipped(_) => metrics.successful_executions += 1,
372 }
373 }
374}
375
376pub struct ExecutorPool {
378 executors: Vec<Arc<TaskExecutor>>,
379 current_index: Arc<RwLock<usize>>,
380}
381
382impl ExecutorPool {
383 pub fn new(pool_size: usize, config: ExecutorConfig) -> Self {
384 let executors = (0..pool_size)
385 .map(|_| Arc::new(TaskExecutor::new(config.clone())))
386 .collect();
387
388 Self {
389 executors,
390 current_index: Arc::new(RwLock::new(0)),
391 }
392 }
393
394 pub async fn get_executor(&self) -> Arc<TaskExecutor> {
396 let mut index = self.current_index.write().await;
397 let executor = self.executors[*index].clone();
398 *index = (*index + 1) % self.executors.len();
399 executor
400 }
401
402 pub async fn total_running_tasks(&self) -> usize {
404 let mut total = 0;
405 for executor in &self.executors {
406 total += executor.running_task_count().await;
407 }
408 total
409 }
410
411 pub async fn combined_metrics(&self) -> ExecutionMetrics {
413 let mut combined = ExecutionMetrics::default();
414
415 for executor in &self.executors {
416 let metrics = executor.get_metrics().await;
417 combined.total_executions += metrics.total_executions;
418 combined.successful_executions += metrics.successful_executions;
419 combined.failed_executions += metrics.failed_executions;
420 combined.cancelled_executions += metrics.cancelled_executions;
421
422 if metrics.max_execution_time > combined.max_execution_time {
423 combined.max_execution_time = metrics.max_execution_time;
424 }
425 }
426
427 combined
428 }
429}