reifydb_sub_task/
coordinator.rs1use std::{cmp::Ordering, collections::BinaryHeap, error::Error, future, io, sync::Arc, time::Instant};
2
3use reifydb_engine::engine::StandardEngine;
4use reifydb_runtime::SharedRuntime;
5use tokio::{select, sync::mpsc, time};
6use tracing::{debug, error, info};
7
8use crate::{
9 context::TaskContext,
10 registry::{TaskEntry, TaskRegistry},
11 task::{ScheduledTask, TaskExecutor, TaskId, TaskWork},
12};
13
14#[derive(Debug)]
16pub enum CoordinatorMessage {
17 Register(ScheduledTask),
19 Unregister(TaskId),
21 TaskCompleted {
23 task_id: TaskId,
24 completed_at: Instant,
25 },
26 Shutdown,
28}
29
30#[derive(Debug, Clone, PartialEq, Eq)]
32struct HeapEntry {
33 next_execution: Instant,
34 task_id: TaskId,
35}
36
37impl Ord for HeapEntry {
38 fn cmp(&self, other: &Self) -> Ordering {
39 other.next_execution.cmp(&self.next_execution)
41 }
42}
43
44impl PartialOrd for HeapEntry {
45 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
46 Some(self.cmp(other))
47 }
48}
49
50pub async fn run_coordinator(
52 registry: TaskRegistry,
53 mut rx: mpsc::Receiver<CoordinatorMessage>,
54 runtime: SharedRuntime,
55 engine: StandardEngine,
56) {
57 info!("Task coordinator started");
58
59 let (completion_tx, mut completion_rx) = mpsc::unbounded_channel();
61
62 let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::new();
64
65 for entry in registry.iter() {
67 heap.push(HeapEntry {
68 next_execution: entry.value().next_execution,
69 task_id: *entry.key(),
70 });
71 }
72
73 loop {
74 let sleep_duration = heap.peek().map(|entry| {
76 let now = Instant::now();
77 if entry.next_execution > now {
78 entry.next_execution - now
79 } else {
80 time::Duration::ZERO
81 }
82 });
83
84 select! {
85 _ = async {
87 match sleep_duration {
88 Some(duration) => time::sleep(duration).await,
89 None => future::pending::<()>().await, }
91 } => {
92 if let Some(heap_entry) = heap.pop() {
94 if let Some(entry) = registry.get(&heap_entry.task_id) {
96 let task = entry.task.clone();
97 let task_id = heap_entry.task_id;
98 let task_name = task.name.clone();
99
100 spawn_task(
102 task_id,
103 task,
104 runtime.clone(),
105 engine.clone(),
106 completion_tx.clone(),
107 );
108
109 debug!("Spawned task: {}", task_name);
110 }
111 }
112 }
113
114 Some((task_id, completed_at)) = completion_rx.recv() => {
116 if let Some(mut entry) = registry.get_mut(&task_id) {
118 if let Some(next_exec) = entry.task.schedule.next_execution(completed_at) {
119 entry.next_execution = next_exec;
121
122 heap.push(HeapEntry {
124 next_execution: next_exec,
125 task_id,
126 });
127
128 debug!("Rescheduled task: {}", entry.task.name);
129 } else {
130 let task_name = entry.task.name.clone();
132 drop(entry); registry.remove(&task_id);
134 debug!("Completed one-shot task: {}", task_name);
135 }
136 }
137 }
138
139 Some(msg) = rx.recv() => {
141 match msg {
142 CoordinatorMessage::Register(task) => {
143 let task_id = task.id;
144 let next_execution = Instant::now() + task.schedule.initial_delay();
145
146 info!("Registering task: {} (id: {})", task.name, task_id);
147
148 registry.insert(task_id, TaskEntry {
150 task: Arc::new(task),
151 next_execution,
152 });
153
154 heap.push(HeapEntry {
156 next_execution,
157 task_id,
158 });
159 }
160
161 CoordinatorMessage::Unregister(task_id) => {
162 info!("Unregistering task: {}", task_id);
163
164 registry.remove(&task_id);
166
167 heap.clear();
169 for entry in registry.iter() {
170 heap.push(HeapEntry {
171 next_execution: entry.value().next_execution,
172 task_id: *entry.key(),
173 });
174 }
175 }
176
177 CoordinatorMessage::Shutdown => {
178 info!("Task coordinator shutting down");
179 break;
180 }
181 CoordinatorMessage::TaskCompleted{ .. } => {}}
182 }
183
184 else => {
185 info!("Coordinator channel closed, shutting down");
187 break;
188 }
189 }
190 }
191
192 info!("Task coordinator stopped");
193}
194
195fn spawn_task(
197 task_id: TaskId,
198 task: Arc<ScheduledTask>,
199 runtime: SharedRuntime,
200 engine: StandardEngine,
201 completion_tx: mpsc::UnboundedSender<(TaskId, Instant)>,
202) {
203 let task_name = task.name.clone();
204 let executor = task.executor;
205 let work = task.work.clone();
206 let runtime_clone = runtime.clone();
207
208 runtime.spawn(async move {
209 let runtime = runtime_clone;
210 let start = Instant::now();
211 let ctx = TaskContext::new(engine);
212
213 let result = match (&work, executor) {
215 (TaskWork::Sync(f), TaskExecutor::ComputePool) => {
216 let f = f.clone();
217 let ctx_clone = ctx.clone();
218 runtime.actor_system()
219 .compute(move || f(ctx_clone))
220 .await
221 .map_err(|e| Box::new(e) as Box<dyn Error + Send>)
222 .and_then(|r| r)
223 }
224 (TaskWork::Async(f), TaskExecutor::Tokio) => f(ctx).await,
225 (TaskWork::Sync(_), TaskExecutor::Tokio) => Err(Box::new(io::Error::new(
226 io::ErrorKind::InvalidInput,
227 "Sync work cannot be executed on Tokio executor",
228 )) as Box<dyn Error + Send>),
229 (TaskWork::Async(_), TaskExecutor::ComputePool) => Err(Box::new(io::Error::new(
230 io::ErrorKind::InvalidInput,
231 "Async work cannot be executed on ComputePool executor",
232 )) as Box<dyn Error + Send>),
233 };
234
235 let duration = start.elapsed();
236 let completed_at = Instant::now();
237
238 match result {
240 Ok(()) => {
241 debug!("Task '{}' completed successfully in {:?}", task_name, duration);
242 }
243 Err(e) => {
244 error!("Task '{}' failed after {:?}: {}", task_name, duration, e);
245 }
246 }
247
248 let _ = completion_tx.send((task_id, completed_at));
250 });
251}