Skip to main content

cron_task_scheduler/actor/
worker.rs

1use crate::models::{ExecutionPolicy, ReactiveTask, SchedulingPolicy, TaskContext, TaskType};
2use chrono::{DateTime, Utc};
3use priority_queue::PriorityQueue;
4use std::collections::HashMap;
5use std::hash::{Hash, Hasher};
6use std::sync::Arc;
7use tokio::sync::{mpsc, Mutex, Semaphore};
8use tracing::{error, info, warn};
9use uuid::Uuid;
10
11pub enum WorkerMessage {
12    Execute {
13        task: Arc<dyn ReactiveTask>,
14        context: TaskContext,
15        execution_policy: ExecutionPolicy,
16        scheduling_policy: SchedulingPolicy,
17    },
18}
19
20struct PendingTask {
21    id: Uuid,
22    task: Arc<dyn ReactiveTask>,
23    context: TaskContext,
24    execution_policy: ExecutionPolicy,
25    scheduling_policy: SchedulingPolicy,
26    arrival_time: DateTime<Utc>,
27}
28
29impl PartialEq for PendingTask {
30    fn eq(&self, other: &Self) -> bool {
31        self.id == other.id
32    }
33}
34impl Eq for PendingTask {}
35impl Hash for PendingTask {
36    fn hash<H: Hasher>(&self, state: &mut H) {
37        self.id.hash(state);
38    }
39}
40
41pub struct WorkerActor {
42    receiver: mpsc::Receiver<WorkerMessage>,
43    task_locks: HashMap<String, Arc<Mutex<()>>>,
44}
45
46impl WorkerActor {
47    pub fn new(receiver: mpsc::Receiver<WorkerMessage>) -> Self {
48        Self {
49            receiver,
50            task_locks: HashMap::new(),
51        }
52    }
53
54    pub async fn run(mut self) {
55        info!("Worker actor started");
56        let mut async_queue = PriorityQueue::new();
57        let mut blocking_queue = PriorityQueue::new();
58
59        // Limit concurrency
60        let async_semaphore = Arc::new(Semaphore::new(100));
61        // Blocking tasks usually limited by CPU cores
62        let blocking_semaphore = Arc::new(Semaphore::new(num_cpus::get()));
63
64        let mut log_interval = tokio::time::interval(std::time::Duration::from_secs(5));
65
66        loop {
67            tokio::select! {
68                _ = log_interval.tick() => {
69                     info!("Async tasks queued: {}, Blocking tasks queued: {}", async_queue.len(), blocking_queue.len());
70                }
71                msg = self.receiver.recv() => {
72                    match msg {
73                        Some(WorkerMessage::Execute {
74                            task,
75                            context,
76                            execution_policy,
77                            scheduling_policy,
78                        }) => {
79                            let pending = PendingTask {
80                                id: Uuid::new_v4(),
81                                task: task.clone(),
82                                context,
83                                execution_policy,
84                                scheduling_policy,
85                                arrival_time: Utc::now(),
86                            };
87                            let priority = self.calculate_priority(&pending);
88                            match task.task_type() {
89                                TaskType::Async => { async_queue.push(pending, priority); }
90                                TaskType::Blocking => { blocking_queue.push(pending, priority); }
91                            }
92                        }
93                        None => {
94                            info!("Worker actor channel closed");
95                            break;
96                        }
97                    }
98                }
99                // Try to acquire a permit to execute an async task
100                permit = async_semaphore.clone().acquire_owned(), if !async_queue.is_empty() => {
101                    if let Ok(permit) = permit {
102                        if let Some((pending, _)) = async_queue.pop() {
103                            self.execute_pending_task(pending, permit).await;
104                        }
105                    }
106                }
107                // Try to acquire a permit to execute a blocking task
108                permit = blocking_semaphore.clone().acquire_owned(), if !blocking_queue.is_empty() => {
109                    if let Ok(permit) = permit {
110                        if let Some((pending, _)) = blocking_queue.pop() {
111                            self.execute_pending_task(pending, permit).await;
112                        }
113                    }
114                }
115            }
116        }
117        info!("Worker actor stopped");
118    }
119
120    // Priority is (class, value). Higher class = higher priority.
121    // Class 2: Priority (weight based)
122    // Class 1: FirstInFirstOut, Delayed, Fair (time based)
123    // Class 0: RateLimited
124    fn calculate_priority(&self, task: &PendingTask) -> (i8, i64) {
125        match task.scheduling_policy {
126            SchedulingPolicy::Priority => {
127                // Weight: -20 (high) to 19 (low).
128                // We want lower weight to have higher priority.
129                // -weight: 20 (high) to -19 (low).
130                (2, -(task.context.weight as i64))
131            }
132            SchedulingPolicy::FirstInFirstOut => {
133                // Earlier arrival = higher priority.
134                (1, -task.arrival_time.timestamp_nanos_opt().unwrap_or(50))
135            }
136            SchedulingPolicy::Delayed => {
137                // Earlier scheduled time = higher priority.
138                (1, -task.context.scheduled_time.timestamp_millis())
139            }
140            SchedulingPolicy::Fair => {
141                // Treat as FIFO for now
142                (1, -task.arrival_time.timestamp_nanos_opt().unwrap_or(50))
143            }
144            SchedulingPolicy::RateLimited => {
145                // Lowest priority
146                (0, 0)
147            }
148        }
149    }
150
151    async fn execute_pending_task(&mut self, pending: PendingTask, permit: tokio::sync::OwnedSemaphorePermit) {
152        let task = pending.task;
153        let context = pending.context;
154        let policy = pending.execution_policy;
155
156        let task_id = task.id().to_string();
157        let lock = self
158            .task_locks
159            .entry(task_id)
160            .or_insert_with(|| Arc::new(Mutex::new(())))
161            .clone();
162
163        let task_type = task.task_type();
164
165        // We spawn the execution logic so we don't block the actor loop
166        tokio::spawn(async move {
167            // permit is held until this task finishes
168            let _permit = permit;
169
170            match policy {
171                ExecutionPolicy::SkipIfRunning => {
172                    if let Ok(guard) = lock.try_lock_owned() {
173                        let _guard = guard;
174                        info!("Executing task (skip-if-running): {}", task.id());
175                        execute_task_by_type(task, context, task_type).await;
176                    } else {
177                        warn!("Task {} already running, skipping execution", task.id());
178                    }
179                }
180                ExecutionPolicy::Parallel => {
181                    info!("Executing task (parallel): {}", task.id());
182                    execute_task_by_type(task, context, task_type).await;
183                }
184                ExecutionPolicy::Sequential => {
185                    let _guard = lock.lock_owned().await;
186                    info!("Executing task (sequential): {}", task.id());
187                    execute_task_by_type(task, context, task_type).await;
188                }
189            }
190        });
191    }
192}
193
194async fn execute_task_by_type(task: Arc<dyn ReactiveTask>, context: TaskContext, task_type: TaskType) {
195    match task_type {
196        TaskType::Async => {
197            if let Err(e) = task.execute(context).await {
198                error!("Task {} failed: {}", task.id(), e);
199            } else {
200                info!("Task {} completed successfully", task.id());
201            }
202        }
203        TaskType::Blocking => {
204            let task_clone = task.clone();
205            let context_clone = context.clone();
206            let handle = tokio::task::spawn_blocking(move || {
207                futures::executor::block_on(task_clone.execute(context_clone))
208            });
209            match handle.await {
210                Ok(res) => {
211                    if let Err(e) = res {
212                        error!("Task {} failed: {}", task.id(), e);
213                    } else {
214                        info!("Task {} completed successfully", task.id());
215                    }
216                }
217                Err(e) => {
218                    error!("Task {} join error: {}", task.id(), e);
219                }
220            }
221        }
222    }
223}