Skip to main content

rocketmq_rust/schedule/
executor.rs

1// Copyright 2023 The RocketMQ Rust Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashMap;
16use std::sync::Arc;
17use std::time::Duration;
18use std::time::SystemTime;
19
20use tokio::sync::RwLock;
21use tokio::sync::Semaphore;
22use tokio::time::timeout;
23use tracing::error;
24use tracing::info;
25use tracing::warn;
26use uuid::Uuid;
27
28use crate::schedule::task::TaskExecution;
29use crate::schedule::SchedulerError;
30use crate::schedule::Task;
31use crate::schedule::TaskContext;
32use crate::schedule::TaskResult;
33use crate::schedule::TaskStatus;
34
35/// Task executor pool configuration
36#[derive(Debug, Clone)]
37pub struct ExecutorConfig {
38    pub max_concurrent_tasks: usize,
39    pub default_timeout: Duration,
40    pub enable_metrics: bool,
41}
42
43impl Default for ExecutorConfig {
44    fn default() -> Self {
45        Self {
46            max_concurrent_tasks: 10,
47            default_timeout: Duration::from_secs(300), // 5 minutes
48            enable_metrics: true,
49        }
50    }
51}
52
53/// Task execution metrics
54#[derive(Debug, Clone, Default)]
55pub struct ExecutionMetrics {
56    pub total_executions: u64,
57    pub successful_executions: u64,
58    pub failed_executions: u64,
59    pub cancelled_executions: u64,
60    pub average_execution_time: Duration,
61    pub max_execution_time: Duration,
62}
63
64/// Task executor responsible for running tasks
65pub struct TaskExecutor {
66    config: ExecutorConfig,
67    semaphore: Arc<Semaphore>,
68    running_tasks: Arc<RwLock<HashMap<String, tokio::task::JoinHandle<()>>>>,
69    metrics: Arc<RwLock<ExecutionMetrics>>,
70    executions: Arc<RwLock<HashMap<String, TaskExecution>>>,
71}
72
73impl TaskExecutor {
74    pub fn new(config: ExecutorConfig) -> Self {
75        let semaphore = Arc::new(Semaphore::new(config.max_concurrent_tasks));
76
77        Self {
78            config,
79            semaphore,
80            running_tasks: Arc::new(RwLock::new(HashMap::new())),
81            metrics: Arc::new(RwLock::new(ExecutionMetrics::default())),
82            executions: Arc::new(RwLock::new(HashMap::new())),
83        }
84    }
85
86    /// Execute a task asynchronously
87    pub async fn execute_task(&self, task: Arc<Task>, scheduled_time: SystemTime) -> Result<String, SchedulerError> {
88        let execution_id = Uuid::new_v4().to_string();
89        let mut execution = TaskExecution::new(task.id.clone(), scheduled_time);
90        execution.execution_id = execution_id.clone();
91
92        // Store execution record
93        {
94            let mut executions = self.executions.write().await;
95            executions.insert(execution_id.clone(), execution.clone());
96        }
97
98        // Clone necessary data for the task
99        let executor = self.clone_for_task();
100        let task_clone = task.clone();
101        let execution_id_clone = execution_id.clone();
102
103        // Spawn the task execution
104        let handle = tokio::spawn(async move {
105            executor
106                .run_task_internal(task_clone, execution_id_clone, scheduled_time)
107                .await;
108        });
109
110        // Store the handle
111        {
112            let mut running_tasks = self.running_tasks.write().await;
113            running_tasks.insert(execution_id.clone(), handle);
114        }
115
116        info!("Task scheduled for execution: {} ({})", task.name, execution_id);
117        Ok(execution_id)
118    }
119
120    /// Cancel a running task
121    pub async fn cancel_task(&self, execution_id: &str) -> Result<(), SchedulerError> {
122        let handle = {
123            let mut running_tasks = self.running_tasks.write().await;
124            running_tasks.remove(execution_id)
125        };
126
127        if let Some(handle) = handle {
128            handle.abort();
129
130            // Update execution record
131            if let Some(mut execution) = self.get_execution(execution_id).await {
132                execution.cancel();
133                let mut executions = self.executions.write().await;
134                executions.insert(execution_id.to_string(), execution);
135            }
136
137            info!("Task cancelled: {}", execution_id);
138            Ok(())
139        } else {
140            Err(SchedulerError::TaskNotFound(execution_id.to_string()))
141        }
142    }
143
144    /// Get task execution status
145    pub async fn get_execution(&self, execution_id: &str) -> Option<TaskExecution> {
146        let executions = self.executions.read().await;
147        executions.get(execution_id).cloned()
148    }
149
150    /// Get all executions for a task
151    pub async fn get_task_executions(&self, task_id: &str) -> Vec<TaskExecution> {
152        let executions = self.executions.read().await;
153        executions.values().filter(|e| e.task_id == task_id).cloned().collect()
154    }
155
156    /// Get current metrics
157    pub async fn get_metrics(&self) -> ExecutionMetrics {
158        let metrics = self.metrics.read().await;
159        metrics.clone()
160    }
161
162    /// Get number of running tasks
163    pub async fn running_task_count(&self) -> usize {
164        let running_tasks = self.running_tasks.read().await;
165        running_tasks.len()
166    }
167
168    /// Clean up completed executions older than the specified duration
169    pub async fn cleanup_old_executions(&self, older_than: Duration) {
170        let cutoff_time = SystemTime::now() - older_than;
171        let mut executions = self.executions.write().await;
172
173        executions.retain(|_, execution| match execution.status {
174            TaskStatus::Pending | TaskStatus::Running => true,
175            _ => execution.end_time.is_none_or(|end| end > cutoff_time),
176        });
177    }
178
179    // Internal methods
180
181    fn clone_for_task(&self) -> TaskExecutorInternal {
182        TaskExecutorInternal {
183            config: self.config.clone(),
184            semaphore: self.semaphore.clone(),
185            running_tasks: self.running_tasks.clone(),
186            metrics: self.metrics.clone(),
187            executions: self.executions.clone(),
188        }
189    }
190
191    async fn run_task_internal(&self, task: Arc<Task>, execution_id: String, scheduled_time: SystemTime) {
192        let internal = self.clone_for_task();
193        internal.run_task_internal(task, execution_id, scheduled_time).await;
194    }
195
196    pub async fn execute_task_with_delay(
197        &self,
198        task: Arc<Task>,
199        scheduled_time: SystemTime,
200        execution_delay: Option<Duration>,
201    ) -> Result<String, SchedulerError> {
202        let execution_id = Uuid::new_v4().to_string();
203        let mut execution = TaskExecution::new(task.id.clone(), scheduled_time);
204        execution.execution_id = execution_id.clone();
205
206        // Calculate actual execution time considering delay
207        let actual_execution_time = if let Some(delay) = execution_delay.or(task.execution_delay) {
208            scheduled_time + delay
209        } else {
210            scheduled_time
211        };
212
213        // Store execution record
214        {
215            let mut executions = self.executions.write().await;
216            executions.insert(execution_id.clone(), execution.clone());
217        }
218
219        // Clone necessary data for the task
220        let executor = self.clone_for_task();
221        let task_clone = task.clone();
222        let execution_id_clone = execution_id.clone();
223
224        // Spawn the delayed task execution
225        let handle = tokio::spawn(async move {
226            // Wait until actual execution time
227            let now = SystemTime::now();
228            if actual_execution_time > now {
229                if let Ok(delay_duration) = actual_execution_time.duration_since(now) {
230                    tokio::time::sleep(delay_duration).await;
231                }
232            }
233
234            executor
235                .run_task_internal(task_clone, execution_id_clone, actual_execution_time)
236                .await;
237        });
238
239        // Store the handle
240        {
241            let mut running_tasks = self.running_tasks.write().await;
242            running_tasks.insert(execution_id.clone(), handle);
243        }
244
245        info!(
246            "Task scheduled for delayed execution: {} ({}) - delay: {:?}",
247            task.name,
248            execution_id,
249            execution_delay.or(task.execution_delay)
250        );
251        Ok(execution_id)
252    }
253}
254
255// Internal executor for task execution
256#[derive(Clone)]
257struct TaskExecutorInternal {
258    config: ExecutorConfig,
259    semaphore: Arc<Semaphore>,
260    running_tasks: Arc<RwLock<HashMap<String, tokio::task::JoinHandle<()>>>>,
261    metrics: Arc<RwLock<ExecutionMetrics>>,
262    executions: Arc<RwLock<HashMap<String, TaskExecution>>>,
263}
264
265impl TaskExecutorInternal {
266    async fn run_task_internal(&self, task: Arc<Task>, execution_id: String, scheduled_time: SystemTime) {
267        // Acquire semaphore permit
268        let _permit = match self.semaphore.acquire().await {
269            Ok(permit) => permit,
270            Err(_) => {
271                error!("Failed to acquire semaphore permit for task: {}", task.id);
272                return;
273            }
274        };
275
276        // Update execution record - start
277        let mut context = TaskContext::new(task.id.clone(), scheduled_time);
278        context.execution_id = execution_id.clone();
279        context.mark_started();
280
281        {
282            let mut executions = self.executions.write().await;
283            if let Some(execution) = executions.get_mut(&execution_id) {
284                execution.start();
285            }
286        }
287
288        info!("Starting task execution: {} ({})", task.name, execution_id);
289
290        // Execute the task with timeout
291        let task_timeout = task.timeout.unwrap_or(self.config.default_timeout);
292        let result = match timeout(task_timeout, task.execute(context)).await {
293            Ok(result) => result,
294            Err(_) => {
295                warn!("Task execution timed out: {} ({})", task.name, execution_id);
296                TaskResult::Failed("Task execution timed out".to_string())
297            }
298        };
299
300        // Update execution record - complete
301        {
302            let mut executions = self.executions.write().await;
303            if let Some(execution) = executions.get_mut(&execution_id) {
304                execution.complete(result.clone());
305            }
306        }
307
308        // Update metrics
309        if self.config.enable_metrics {
310            self.update_metrics(&result).await;
311        }
312
313        // Remove from running tasks
314        {
315            let mut running_tasks = self.running_tasks.write().await;
316            running_tasks.remove(&execution_id);
317        }
318
319        match result {
320            TaskResult::Success(msg) => {
321                info!(
322                    "Task completed successfully: {} ({}) - {:?}",
323                    task.name, execution_id, msg
324                );
325            }
326            TaskResult::Failed(err) => {
327                error!("Task failed: {} ({}) - {}", task.name, execution_id, err);
328            }
329            TaskResult::Skipped(reason) => {
330                info!("Task skipped: {} ({}) - {}", task.name, execution_id, reason);
331            }
332        }
333    }
334
335    async fn update_metrics(&self, result: &TaskResult) {
336        let mut metrics = self.metrics.write().await;
337        metrics.total_executions += 1;
338
339        match result {
340            TaskResult::Success(_) => metrics.successful_executions += 1,
341            TaskResult::Failed(_) => metrics.failed_executions += 1,
342            TaskResult::Skipped(_) => metrics.successful_executions += 1,
343        }
344    }
345}
346
347/// Executor pool for managing multiple executors
348pub struct ExecutorPool {
349    executors: Vec<Arc<TaskExecutor>>,
350    current_index: Arc<RwLock<usize>>,
351}
352
353impl ExecutorPool {
354    pub fn new(pool_size: usize, config: ExecutorConfig) -> Self {
355        let executors = (0..pool_size)
356            .map(|_| Arc::new(TaskExecutor::new(config.clone())))
357            .collect();
358
359        Self {
360            executors,
361            current_index: Arc::new(RwLock::new(0)),
362        }
363    }
364
365    /// Get the next executor using round-robin
366    pub async fn get_executor(&self) -> Arc<TaskExecutor> {
367        let mut index = self.current_index.write().await;
368        let executor = self.executors[*index].clone();
369        *index = (*index + 1) % self.executors.len();
370        executor
371    }
372
373    /// Get total number of running tasks across all executors
374    pub async fn total_running_tasks(&self) -> usize {
375        let mut total = 0;
376        for executor in &self.executors {
377            total += executor.running_task_count().await;
378        }
379        total
380    }
381
382    /// Get combined metrics from all executors
383    pub async fn combined_metrics(&self) -> ExecutionMetrics {
384        let mut combined = ExecutionMetrics::default();
385
386        for executor in &self.executors {
387            let metrics = executor.get_metrics().await;
388            combined.total_executions += metrics.total_executions;
389            combined.successful_executions += metrics.successful_executions;
390            combined.failed_executions += metrics.failed_executions;
391            combined.cancelled_executions += metrics.cancelled_executions;
392
393            if metrics.max_execution_time > combined.max_execution_time {
394                combined.max_execution_time = metrics.max_execution_time;
395            }
396        }
397
398        combined
399    }
400}