1use std::{cmp::Ordering, collections::BinaryHeap, error::Error, future, io, sync::Arc};
5
6use reifydb_core::interface::catalog::task::TaskId;
7use reifydb_engine::engine::StandardEngine;
8use reifydb_runtime::{SharedRuntime, context::clock::Instant};
9use tokio::{select, sync::mpsc, task::spawn_blocking, time};
10use tracing::{debug, error, info};
11
12#[cfg(reifydb_assertions)]
13use crate::schedule::Schedule;
14use crate::{
15 context::TaskContext,
16 registry::{TaskEntry, TaskRegistry},
17 task::{ScheduledTask, TaskExecutor, TaskWork},
18};
19
20#[derive(Debug)]
21pub enum TaskCoordinatorMessage {
22 Register(ScheduledTask),
23
24 Unregister(TaskId),
25
26 TaskCompleted {
27 task_id: TaskId,
28 completed_at: Instant,
29 },
30
31 Shutdown,
32}
33
34#[derive(Debug, Clone, PartialEq, Eq)]
35struct HeapEntry {
36 next_execution: Instant,
37 task_id: TaskId,
38}
39
40impl Ord for HeapEntry {
41 fn cmp(&self, other: &Self) -> Ordering {
42 other.next_execution.cmp(&self.next_execution)
43 }
44}
45
46impl PartialOrd for HeapEntry {
47 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
48 Some(self.cmp(other))
49 }
50}
51
52pub async fn run_coordinator(
53 registry: TaskRegistry,
54 mut rx: mpsc::Receiver<TaskCoordinatorMessage>,
55 runtime: SharedRuntime,
56 engine: StandardEngine,
57) {
58 info!("Task coordinator started");
59
60 let (completion_tx, mut completion_rx) = mpsc::unbounded_channel();
61
62 let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::new();
63
64 for entry in registry.iter() {
65 heap.push(HeapEntry {
66 next_execution: entry.value().next_execution.clone(),
67 task_id: *entry.key(),
68 });
69 }
70
71 loop {
72 let sleep_duration = heap.peek().map(|entry| {
73 let now = runtime.clock().instant();
74 if entry.next_execution > now {
75 &entry.next_execution - &now
76 } else {
77 time::Duration::ZERO
78 }
79 });
80
81 select! {
82
83 _ = async {
84 match sleep_duration {
85 Some(duration) => time::sleep(duration).await,
86 None => future::pending::<()>().await,
87 }
88 } => {
89
90 if let Some(heap_entry) = heap.pop()
91 && let Some(entry) = registry.get(&heap_entry.task_id)
92 {
93 let task = entry.task.clone();
94 let task_id = heap_entry.task_id;
95 let task_name = task.name.clone();
96
97 spawn_task(
98 task_id,
99 task,
100 runtime.clone(),
101 engine.clone(),
102 completion_tx.clone(),
103 );
104
105 debug!("Spawned task: {}", task_name);
106 }
107 }
108
109
110 Some((task_id, completed_at)) = completion_rx.recv() => {
111
112 if let Some(mut entry) = registry.get_mut(&task_id) {
113 if let Some(next_exec) = entry.task.schedule.next_execution(completed_at) {
114 #[cfg(reifydb_assertions)]
115 {
116 assert!(
117 !matches!(entry.task.schedule, Schedule::Once(_)),
118 "a Schedule::Once task entered the reschedule path after completing, so a one-shot task would run repeatedly and duplicate its side effects (task={})",
119 entry.task.name
120 );
121 }
122
123 entry.next_execution = next_exec.clone();
124
125
126 heap.push(HeapEntry {
127 next_execution: next_exec,
128 task_id,
129 });
130
131 debug!("Rescheduled task: {}", entry.task.name);
132 } else {
133
134 let task_name = entry.task.name.clone();
135 drop(entry);
136 registry.remove(&task_id);
137 debug!("Completed one-shot task: {}", task_name);
138 }
139 }
140 }
141
142
143 Some(msg) = rx.recv() => {
144 match msg {
145 TaskCoordinatorMessage::Register(task) => {
146 let task_id = task.id;
147 let next_execution = runtime.clock().instant() + task.schedule.initial_delay();
148
149 info!("Registering task: {} (id: {})", task.name, task_id);
150
151
152 registry.insert(task_id, TaskEntry {
153 task: Arc::new(task),
154 next_execution: next_execution.clone(),
155 });
156
157
158 heap.push(HeapEntry {
159 next_execution,
160 task_id,
161 });
162 }
163
164 TaskCoordinatorMessage::Unregister(task_id) => {
165 info!("Unregistering task: {}", task_id);
166
167
168 registry.remove(&task_id);
169
170
171 heap.clear();
172 for entry in registry.iter() {
173 heap.push(HeapEntry {
174 next_execution: entry.value().next_execution.clone(),
175 task_id: *entry.key(),
176 });
177 }
178 }
179
180 TaskCoordinatorMessage::Shutdown => {
181 info!("Task coordinator shutting down");
182 break;
183 }
184 TaskCoordinatorMessage::TaskCompleted{ .. } => {}}
185 }
186
187 else => {
188
189 info!("Coordinator channel closed, shutting down");
190 break;
191 }
192 }
193 }
194
195 info!("Task coordinator stopped");
196}
197
198fn spawn_task(
199 task_id: TaskId,
200 task: Arc<ScheduledTask>,
201 runtime: SharedRuntime,
202 engine: StandardEngine,
203 completion_tx: mpsc::UnboundedSender<(TaskId, Instant)>,
204) {
205 let task_name = task.name.clone();
206 let executor = task.executor;
207 let work = task.work.clone();
208 let runtime_clone = runtime.clone();
209
210 runtime.spawn(async move {
211 let runtime = runtime_clone;
212 let start = runtime.clock().instant();
213 let ctx = TaskContext::new(engine);
214
215 let result = match (&work, executor) {
216 (TaskWork::Sync(f), TaskExecutor::ComputePool) => {
217 let f = f.clone();
218 let ctx_clone = ctx.clone();
219 spawn_blocking(move || f(ctx_clone))
220 .await
221 .map_err(|e| Box::new(e) as Box<dyn Error + Send>)
222 .and_then(|r| r)
223 }
224 (TaskWork::Async(f), TaskExecutor::Tokio) => f(ctx).await,
225 (TaskWork::Sync(_), TaskExecutor::Tokio) => Err(Box::new(io::Error::new(
226 io::ErrorKind::InvalidInput,
227 "Sync work cannot be executed on Tokio executor",
228 )) as Box<dyn Error + Send>),
229 (TaskWork::Async(_), TaskExecutor::ComputePool) => Err(Box::new(io::Error::new(
230 io::ErrorKind::InvalidInput,
231 "Async work cannot be executed on ComputePool executor",
232 )) as Box<dyn Error + Send>),
233 };
234
235 let duration = start.elapsed();
236 let completed_at = runtime.clock().instant();
237
238 match result {
239 Ok(()) => {
240 debug!("Task '{}' completed successfully in {:?}", task_name, duration);
241 }
242 Err(e) => {
243 error!("Task '{}' failed after {:?}: {}", task_name, duration, e);
244 }
245 }
246
247 let _ = completion_tx.send((task_id, completed_at));
248 });
249}