reifydb_sub_task/
coordinator.rs1use std::{cmp::Ordering, collections::BinaryHeap, error::Error, future, io, sync::Arc};
5
6use reifydb_core::interface::catalog::task::TaskId;
7use reifydb_engine::engine::StandardEngine;
8use reifydb_runtime::{SharedRuntime, context::clock::Instant};
9use tokio::{select, sync::mpsc, task::spawn_blocking, time};
10use tracing::{debug, error, info};
11
12use crate::{
13 context::TaskContext,
14 registry::{TaskEntry, TaskRegistry},
15 task::{ScheduledTask, TaskExecutor, TaskWork},
16};
17
18#[derive(Debug)]
20pub enum TaskCoordinatorMessage {
21 Register(ScheduledTask),
23 Unregister(TaskId),
25 TaskCompleted {
27 task_id: TaskId,
28 completed_at: Instant,
29 },
30 Shutdown,
32}
33
34#[derive(Debug, Clone, PartialEq, Eq)]
36struct HeapEntry {
37 next_execution: Instant,
38 task_id: TaskId,
39}
40
41impl Ord for HeapEntry {
42 fn cmp(&self, other: &Self) -> Ordering {
43 other.next_execution.cmp(&self.next_execution)
45 }
46}
47
48impl PartialOrd for HeapEntry {
49 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
50 Some(self.cmp(other))
51 }
52}
53
54pub async fn run_coordinator(
56 registry: TaskRegistry,
57 mut rx: mpsc::Receiver<TaskCoordinatorMessage>,
58 runtime: SharedRuntime,
59 engine: StandardEngine,
60) {
61 info!("Task coordinator started");
62
63 let (completion_tx, mut completion_rx) = mpsc::unbounded_channel();
65
66 let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::new();
68
69 for entry in registry.iter() {
71 heap.push(HeapEntry {
72 next_execution: entry.value().next_execution.clone(),
73 task_id: *entry.key(),
74 });
75 }
76
77 loop {
78 let sleep_duration = heap.peek().map(|entry| {
80 let now = runtime.clock().instant();
81 if entry.next_execution > now {
82 &entry.next_execution - &now
83 } else {
84 time::Duration::ZERO
85 }
86 });
87
88 select! {
89 _ = async {
91 match sleep_duration {
92 Some(duration) => time::sleep(duration).await,
93 None => future::pending::<()>().await, }
95 } => {
96 if let Some(heap_entry) = heap.pop() {
98 if let Some(entry) = registry.get(&heap_entry.task_id) {
100 let task = entry.task.clone();
101 let task_id = heap_entry.task_id;
102 let task_name = task.name.clone();
103
104 spawn_task(
106 task_id,
107 task,
108 runtime.clone(),
109 engine.clone(),
110 completion_tx.clone(),
111 );
112
113 debug!("Spawned task: {}", task_name);
114 }
115 }
116 }
117
118 Some((task_id, completed_at)) = completion_rx.recv() => {
120 if let Some(mut entry) = registry.get_mut(&task_id) {
122 if let Some(next_exec) = entry.task.schedule.next_execution(completed_at) {
123 entry.next_execution = next_exec.clone();
125
126 heap.push(HeapEntry {
128 next_execution: next_exec,
129 task_id,
130 });
131
132 debug!("Rescheduled task: {}", entry.task.name);
133 } else {
134 let task_name = entry.task.name.clone();
136 drop(entry); registry.remove(&task_id);
138 debug!("Completed one-shot task: {}", task_name);
139 }
140 }
141 }
142
143 Some(msg) = rx.recv() => {
145 match msg {
146 TaskCoordinatorMessage::Register(task) => {
147 let task_id = task.id;
148 let next_execution = runtime.clock().instant() + task.schedule.initial_delay();
149
150 info!("Registering task: {} (id: {})", task.name, task_id);
151
152 registry.insert(task_id, TaskEntry {
154 task: Arc::new(task),
155 next_execution: next_execution.clone(),
156 });
157
158 heap.push(HeapEntry {
160 next_execution,
161 task_id,
162 });
163 }
164
165 TaskCoordinatorMessage::Unregister(task_id) => {
166 info!("Unregistering task: {}", task_id);
167
168 registry.remove(&task_id);
170
171 heap.clear();
173 for entry in registry.iter() {
174 heap.push(HeapEntry {
175 next_execution: entry.value().next_execution.clone(),
176 task_id: *entry.key(),
177 });
178 }
179 }
180
181 TaskCoordinatorMessage::Shutdown => {
182 info!("Task coordinator shutting down");
183 break;
184 }
185 TaskCoordinatorMessage::TaskCompleted{ .. } => {}}
186 }
187
188 else => {
189 info!("Coordinator channel closed, shutting down");
191 break;
192 }
193 }
194 }
195
196 info!("Task coordinator stopped");
197}
198
199fn spawn_task(
201 task_id: TaskId,
202 task: Arc<ScheduledTask>,
203 runtime: SharedRuntime,
204 engine: StandardEngine,
205 completion_tx: mpsc::UnboundedSender<(TaskId, Instant)>,
206) {
207 let task_name = task.name.clone();
208 let executor = task.executor;
209 let work = task.work.clone();
210 let runtime_clone = runtime.clone();
211
212 runtime.spawn(async move {
213 let runtime = runtime_clone;
214 let start = runtime.clock().instant();
215 let ctx = TaskContext::new(engine);
216
217 let result = match (&work, executor) {
219 (TaskWork::Sync(f), TaskExecutor::ComputePool) => {
220 let f = f.clone();
221 let ctx_clone = ctx.clone();
222 spawn_blocking(move || f(ctx_clone))
223 .await
224 .map_err(|e| Box::new(e) as Box<dyn Error + Send>)
225 .and_then(|r| r)
226 }
227 (TaskWork::Async(f), TaskExecutor::Tokio) => f(ctx).await,
228 (TaskWork::Sync(_), TaskExecutor::Tokio) => Err(Box::new(io::Error::new(
229 io::ErrorKind::InvalidInput,
230 "Sync work cannot be executed on Tokio executor",
231 )) as Box<dyn Error + Send>),
232 (TaskWork::Async(_), TaskExecutor::ComputePool) => Err(Box::new(io::Error::new(
233 io::ErrorKind::InvalidInput,
234 "Async work cannot be executed on ComputePool executor",
235 )) as Box<dyn Error + Send>),
236 };
237
238 let duration = start.elapsed();
239 let completed_at = runtime.clock().instant();
240
241 match result {
243 Ok(()) => {
244 debug!("Task '{}' completed successfully in {:?}", task_name, duration);
245 }
246 Err(e) => {
247 error!("Task '{}' failed after {:?}: {}", task_name, duration, e);
248 }
249 }
250
251 let _ = completion_tx.send((task_id, completed_at));
253 });
254}