rocketmq_rust/schedule/
executor.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18use std::collections::HashMap;
19use std::sync::Arc;
20use std::time::Duration;
21use std::time::SystemTime;
22
23use tokio::sync::RwLock;
24use tokio::sync::Semaphore;
25use tokio::time::timeout;
26use tracing::error;
27use tracing::info;
28use tracing::warn;
29use uuid::Uuid;
30
31use crate::schedule::task::TaskExecution;
32use crate::schedule::SchedulerError;
33use crate::schedule::Task;
34use crate::schedule::TaskContext;
35use crate::schedule::TaskResult;
36use crate::schedule::TaskStatus;
37
38/// Task executor pool configuration
39#[derive(Debug, Clone)]
40pub struct ExecutorConfig {
41    pub max_concurrent_tasks: usize,
42    pub default_timeout: Duration,
43    pub enable_metrics: bool,
44}
45
46impl Default for ExecutorConfig {
47    fn default() -> Self {
48        Self {
49            max_concurrent_tasks: 10,
50            default_timeout: Duration::from_secs(300), // 5 minutes
51            enable_metrics: true,
52        }
53    }
54}
55
56/// Task execution metrics
57#[derive(Debug, Clone, Default)]
58pub struct ExecutionMetrics {
59    pub total_executions: u64,
60    pub successful_executions: u64,
61    pub failed_executions: u64,
62    pub cancelled_executions: u64,
63    pub average_execution_time: Duration,
64    pub max_execution_time: Duration,
65}
66
67/// Task executor responsible for running tasks
68pub struct TaskExecutor {
69    config: ExecutorConfig,
70    semaphore: Arc<Semaphore>,
71    running_tasks: Arc<RwLock<HashMap<String, tokio::task::JoinHandle<()>>>>,
72    metrics: Arc<RwLock<ExecutionMetrics>>,
73    executions: Arc<RwLock<HashMap<String, TaskExecution>>>,
74}
75
76impl TaskExecutor {
77    pub fn new(config: ExecutorConfig) -> Self {
78        let semaphore = Arc::new(Semaphore::new(config.max_concurrent_tasks));
79
80        Self {
81            config,
82            semaphore,
83            running_tasks: Arc::new(RwLock::new(HashMap::new())),
84            metrics: Arc::new(RwLock::new(ExecutionMetrics::default())),
85            executions: Arc::new(RwLock::new(HashMap::new())),
86        }
87    }
88
89    /// Execute a task asynchronously
90    pub async fn execute_task(
91        &self,
92        task: Arc<Task>,
93        scheduled_time: SystemTime,
94    ) -> Result<String, SchedulerError> {
95        let execution_id = Uuid::new_v4().to_string();
96        let mut execution = TaskExecution::new(task.id.clone(), scheduled_time);
97        execution.execution_id = execution_id.clone();
98
99        // Store execution record
100        {
101            let mut executions = self.executions.write().await;
102            executions.insert(execution_id.clone(), execution.clone());
103        }
104
105        // Clone necessary data for the task
106        let executor = self.clone_for_task();
107        let task_clone = task.clone();
108        let execution_id_clone = execution_id.clone();
109
110        // Spawn the task execution
111        let handle = tokio::spawn(async move {
112            executor
113                .run_task_internal(task_clone, execution_id_clone, scheduled_time)
114                .await;
115        });
116
117        // Store the handle
118        {
119            let mut running_tasks = self.running_tasks.write().await;
120            running_tasks.insert(execution_id.clone(), handle);
121        }
122
123        info!(
124            "Task scheduled for execution: {} ({})",
125            task.name, execution_id
126        );
127        Ok(execution_id)
128    }
129
130    /// Cancel a running task
131    pub async fn cancel_task(&self, execution_id: &str) -> Result<(), SchedulerError> {
132        let handle = {
133            let mut running_tasks = self.running_tasks.write().await;
134            running_tasks.remove(execution_id)
135        };
136
137        if let Some(handle) = handle {
138            handle.abort();
139
140            // Update execution record
141            if let Some(mut execution) = self.get_execution(execution_id).await {
142                execution.cancel();
143                let mut executions = self.executions.write().await;
144                executions.insert(execution_id.to_string(), execution);
145            }
146
147            info!("Task cancelled: {}", execution_id);
148            Ok(())
149        } else {
150            Err(SchedulerError::TaskNotFound(execution_id.to_string()))
151        }
152    }
153
154    /// Get task execution status
155    pub async fn get_execution(&self, execution_id: &str) -> Option<TaskExecution> {
156        let executions = self.executions.read().await;
157        executions.get(execution_id).cloned()
158    }
159
160    /// Get all executions for a task
161    pub async fn get_task_executions(&self, task_id: &str) -> Vec<TaskExecution> {
162        let executions = self.executions.read().await;
163        executions
164            .values()
165            .filter(|e| e.task_id == task_id)
166            .cloned()
167            .collect()
168    }
169
170    /// Get current metrics
171    pub async fn get_metrics(&self) -> ExecutionMetrics {
172        let metrics = self.metrics.read().await;
173        metrics.clone()
174    }
175
176    /// Get number of running tasks
177    pub async fn running_task_count(&self) -> usize {
178        let running_tasks = self.running_tasks.read().await;
179        running_tasks.len()
180    }
181
182    /// Clean up completed executions older than the specified duration
183    pub async fn cleanup_old_executions(&self, older_than: Duration) {
184        let cutoff_time = SystemTime::now() - older_than;
185        let mut executions = self.executions.write().await;
186
187        executions.retain(|_, execution| match execution.status {
188            TaskStatus::Pending | TaskStatus::Running => true,
189            _ => execution.end_time.is_none_or(|end| end > cutoff_time),
190        });
191    }
192
193    // Internal methods
194
195    fn clone_for_task(&self) -> TaskExecutorInternal {
196        TaskExecutorInternal {
197            config: self.config.clone(),
198            semaphore: self.semaphore.clone(),
199            running_tasks: self.running_tasks.clone(),
200            metrics: self.metrics.clone(),
201            executions: self.executions.clone(),
202        }
203    }
204
205    async fn run_task_internal(
206        &self,
207        task: Arc<Task>,
208        execution_id: String,
209        scheduled_time: SystemTime,
210    ) {
211        let internal = self.clone_for_task();
212        internal
213            .run_task_internal(task, execution_id, scheduled_time)
214            .await;
215    }
216
217    pub async fn execute_task_with_delay(
218        &self,
219        task: Arc<Task>,
220        scheduled_time: SystemTime,
221        execution_delay: Option<Duration>,
222    ) -> Result<String, SchedulerError> {
223        let execution_id = Uuid::new_v4().to_string();
224        let mut execution = TaskExecution::new(task.id.clone(), scheduled_time);
225        execution.execution_id = execution_id.clone();
226
227        // Calculate actual execution time considering delay
228        let actual_execution_time = if let Some(delay) = execution_delay.or(task.execution_delay) {
229            scheduled_time + delay
230        } else {
231            scheduled_time
232        };
233
234        // Store execution record
235        {
236            let mut executions = self.executions.write().await;
237            executions.insert(execution_id.clone(), execution.clone());
238        }
239
240        // Clone necessary data for the task
241        let executor = self.clone_for_task();
242        let task_clone = task.clone();
243        let execution_id_clone = execution_id.clone();
244
245        // Spawn the delayed task execution
246        let handle = tokio::spawn(async move {
247            // Wait until actual execution time
248            let now = SystemTime::now();
249            if actual_execution_time > now {
250                if let Ok(delay_duration) = actual_execution_time.duration_since(now) {
251                    tokio::time::sleep(delay_duration).await;
252                }
253            }
254
255            executor
256                .run_task_internal(task_clone, execution_id_clone, actual_execution_time)
257                .await;
258        });
259
260        // Store the handle
261        {
262            let mut running_tasks = self.running_tasks.write().await;
263            running_tasks.insert(execution_id.clone(), handle);
264        }
265
266        info!(
267            "Task scheduled for delayed execution: {} ({}) - delay: {:?}",
268            task.name,
269            execution_id,
270            execution_delay.or(task.execution_delay)
271        );
272        Ok(execution_id)
273    }
274}
275
276// Internal executor for task execution
277#[derive(Clone)]
278struct TaskExecutorInternal {
279    config: ExecutorConfig,
280    semaphore: Arc<Semaphore>,
281    running_tasks: Arc<RwLock<HashMap<String, tokio::task::JoinHandle<()>>>>,
282    metrics: Arc<RwLock<ExecutionMetrics>>,
283    executions: Arc<RwLock<HashMap<String, TaskExecution>>>,
284}
285
286impl TaskExecutorInternal {
287    async fn run_task_internal(
288        &self,
289        task: Arc<Task>,
290        execution_id: String,
291        scheduled_time: SystemTime,
292    ) {
293        // Acquire semaphore permit
294        let _permit = match self.semaphore.acquire().await {
295            Ok(permit) => permit,
296            Err(_) => {
297                error!("Failed to acquire semaphore permit for task: {}", task.id);
298                return;
299            }
300        };
301
302        // Update execution record - start
303        let mut context = TaskContext::new(task.id.clone(), scheduled_time);
304        context.execution_id = execution_id.clone();
305        context.mark_started();
306
307        {
308            let mut executions = self.executions.write().await;
309            if let Some(execution) = executions.get_mut(&execution_id) {
310                execution.start();
311            }
312        }
313
314        info!("Starting task execution: {} ({})", task.name, execution_id);
315
316        // Execute the task with timeout
317        let task_timeout = task.timeout.unwrap_or(self.config.default_timeout);
318        let result = match timeout(task_timeout, task.execute(context)).await {
319            Ok(result) => result,
320            Err(_) => {
321                warn!("Task execution timed out: {} ({})", task.name, execution_id);
322                TaskResult::Failed("Task execution timed out".to_string())
323            }
324        };
325
326        // Update execution record - complete
327        {
328            let mut executions = self.executions.write().await;
329            if let Some(execution) = executions.get_mut(&execution_id) {
330                execution.complete(result.clone());
331            }
332        }
333
334        // Update metrics
335        if self.config.enable_metrics {
336            self.update_metrics(&result).await;
337        }
338
339        // Remove from running tasks
340        {
341            let mut running_tasks = self.running_tasks.write().await;
342            running_tasks.remove(&execution_id);
343        }
344
345        match result {
346            TaskResult::Success(msg) => {
347                info!(
348                    "Task completed successfully: {} ({}) - {:?}",
349                    task.name, execution_id, msg
350                );
351            }
352            TaskResult::Failed(err) => {
353                error!("Task failed: {} ({}) - {}", task.name, execution_id, err);
354            }
355            TaskResult::Skipped(reason) => {
356                info!(
357                    "Task skipped: {} ({}) - {}",
358                    task.name, execution_id, reason
359                );
360            }
361        }
362    }
363
364    async fn update_metrics(&self, result: &TaskResult) {
365        let mut metrics = self.metrics.write().await;
366        metrics.total_executions += 1;
367
368        match result {
369            TaskResult::Success(_) => metrics.successful_executions += 1,
370            TaskResult::Failed(_) => metrics.failed_executions += 1,
371            TaskResult::Skipped(_) => metrics.successful_executions += 1,
372        }
373    }
374}
375
376/// Executor pool for managing multiple executors
377pub struct ExecutorPool {
378    executors: Vec<Arc<TaskExecutor>>,
379    current_index: Arc<RwLock<usize>>,
380}
381
382impl ExecutorPool {
383    pub fn new(pool_size: usize, config: ExecutorConfig) -> Self {
384        let executors = (0..pool_size)
385            .map(|_| Arc::new(TaskExecutor::new(config.clone())))
386            .collect();
387
388        Self {
389            executors,
390            current_index: Arc::new(RwLock::new(0)),
391        }
392    }
393
394    /// Get the next executor using round-robin
395    pub async fn get_executor(&self) -> Arc<TaskExecutor> {
396        let mut index = self.current_index.write().await;
397        let executor = self.executors[*index].clone();
398        *index = (*index + 1) % self.executors.len();
399        executor
400    }
401
402    /// Get total number of running tasks across all executors
403    pub async fn total_running_tasks(&self) -> usize {
404        let mut total = 0;
405        for executor in &self.executors {
406            total += executor.running_task_count().await;
407        }
408        total
409    }
410
411    /// Get combined metrics from all executors
412    pub async fn combined_metrics(&self) -> ExecutionMetrics {
413        let mut combined = ExecutionMetrics::default();
414
415        for executor in &self.executors {
416            let metrics = executor.get_metrics().await;
417            combined.total_executions += metrics.total_executions;
418            combined.successful_executions += metrics.successful_executions;
419            combined.failed_executions += metrics.failed_executions;
420            combined.cancelled_executions += metrics.cancelled_executions;
421
422            if metrics.max_execution_time > combined.max_execution_time {
423                combined.max_execution_time = metrics.max_execution_time;
424            }
425        }
426
427        combined
428    }
429}