reifydb_sub_task/
coordinator.rs1use std::{cmp::Ordering, collections::BinaryHeap, error::Error, future, io, sync::Arc};
5
6use reifydb_core::interface::catalog::task::TaskId;
7use reifydb_engine::engine::StandardEngine;
8use reifydb_runtime::{SharedRuntime, context::clock::Instant};
9use tokio::{select, sync::mpsc, task::spawn_blocking, time};
10use tracing::{debug, error, info};
11
12use crate::{
13 context::TaskContext,
14 registry::{TaskEntry, TaskRegistry},
15 task::{ScheduledTask, TaskExecutor, TaskWork},
16};
17
18#[derive(Debug)]
19pub enum TaskCoordinatorMessage {
20 Register(ScheduledTask),
21
22 Unregister(TaskId),
23
24 TaskCompleted {
25 task_id: TaskId,
26 completed_at: Instant,
27 },
28
29 Shutdown,
30}
31
32#[derive(Debug, Clone, PartialEq, Eq)]
33struct HeapEntry {
34 next_execution: Instant,
35 task_id: TaskId,
36}
37
38impl Ord for HeapEntry {
39 fn cmp(&self, other: &Self) -> Ordering {
40 other.next_execution.cmp(&self.next_execution)
41 }
42}
43
44impl PartialOrd for HeapEntry {
45 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
46 Some(self.cmp(other))
47 }
48}
49
50pub async fn run_coordinator(
51 registry: TaskRegistry,
52 mut rx: mpsc::Receiver<TaskCoordinatorMessage>,
53 runtime: SharedRuntime,
54 engine: StandardEngine,
55) {
56 info!("Task coordinator started");
57
58 let (completion_tx, mut completion_rx) = mpsc::unbounded_channel();
59
60 let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::new();
61
62 for entry in registry.iter() {
63 heap.push(HeapEntry {
64 next_execution: entry.value().next_execution.clone(),
65 task_id: *entry.key(),
66 });
67 }
68
69 loop {
70 let sleep_duration = heap.peek().map(|entry| {
71 let now = runtime.clock().instant();
72 if entry.next_execution > now {
73 &entry.next_execution - &now
74 } else {
75 time::Duration::ZERO
76 }
77 });
78
79 select! {
80
81 _ = async {
82 match sleep_duration {
83 Some(duration) => time::sleep(duration).await,
84 None => future::pending::<()>().await,
85 }
86 } => {
87
88 if let Some(heap_entry) = heap.pop()
89 && let Some(entry) = registry.get(&heap_entry.task_id)
90 {
91 let task = entry.task.clone();
92 let task_id = heap_entry.task_id;
93 let task_name = task.name.clone();
94
95 spawn_task(
96 task_id,
97 task,
98 runtime.clone(),
99 engine.clone(),
100 completion_tx.clone(),
101 );
102
103 debug!("Spawned task: {}", task_name);
104 }
105 }
106
107
108 Some((task_id, completed_at)) = completion_rx.recv() => {
109
110 if let Some(mut entry) = registry.get_mut(&task_id) {
111 if let Some(next_exec) = entry.task.schedule.next_execution(completed_at) {
112
113 entry.next_execution = next_exec.clone();
114
115
116 heap.push(HeapEntry {
117 next_execution: next_exec,
118 task_id,
119 });
120
121 debug!("Rescheduled task: {}", entry.task.name);
122 } else {
123
124 let task_name = entry.task.name.clone();
125 drop(entry);
126 registry.remove(&task_id);
127 debug!("Completed one-shot task: {}", task_name);
128 }
129 }
130 }
131
132
133 Some(msg) = rx.recv() => {
134 match msg {
135 TaskCoordinatorMessage::Register(task) => {
136 let task_id = task.id;
137 let next_execution = runtime.clock().instant() + task.schedule.initial_delay();
138
139 info!("Registering task: {} (id: {})", task.name, task_id);
140
141
142 registry.insert(task_id, TaskEntry {
143 task: Arc::new(task),
144 next_execution: next_execution.clone(),
145 });
146
147
148 heap.push(HeapEntry {
149 next_execution,
150 task_id,
151 });
152 }
153
154 TaskCoordinatorMessage::Unregister(task_id) => {
155 info!("Unregistering task: {}", task_id);
156
157
158 registry.remove(&task_id);
159
160
161 heap.clear();
162 for entry in registry.iter() {
163 heap.push(HeapEntry {
164 next_execution: entry.value().next_execution.clone(),
165 task_id: *entry.key(),
166 });
167 }
168 }
169
170 TaskCoordinatorMessage::Shutdown => {
171 info!("Task coordinator shutting down");
172 break;
173 }
174 TaskCoordinatorMessage::TaskCompleted{ .. } => {}}
175 }
176
177 else => {
178
179 info!("Coordinator channel closed, shutting down");
180 break;
181 }
182 }
183 }
184
185 info!("Task coordinator stopped");
186}
187
188fn spawn_task(
189 task_id: TaskId,
190 task: Arc<ScheduledTask>,
191 runtime: SharedRuntime,
192 engine: StandardEngine,
193 completion_tx: mpsc::UnboundedSender<(TaskId, Instant)>,
194) {
195 let task_name = task.name.clone();
196 let executor = task.executor;
197 let work = task.work.clone();
198 let runtime_clone = runtime.clone();
199
200 runtime.spawn(async move {
201 let runtime = runtime_clone;
202 let start = runtime.clock().instant();
203 let ctx = TaskContext::new(engine);
204
205 let result = match (&work, executor) {
206 (TaskWork::Sync(f), TaskExecutor::ComputePool) => {
207 let f = f.clone();
208 let ctx_clone = ctx.clone();
209 spawn_blocking(move || f(ctx_clone))
210 .await
211 .map_err(|e| Box::new(e) as Box<dyn Error + Send>)
212 .and_then(|r| r)
213 }
214 (TaskWork::Async(f), TaskExecutor::Tokio) => f(ctx).await,
215 (TaskWork::Sync(_), TaskExecutor::Tokio) => Err(Box::new(io::Error::new(
216 io::ErrorKind::InvalidInput,
217 "Sync work cannot be executed on Tokio executor",
218 )) as Box<dyn Error + Send>),
219 (TaskWork::Async(_), TaskExecutor::ComputePool) => Err(Box::new(io::Error::new(
220 io::ErrorKind::InvalidInput,
221 "Async work cannot be executed on ComputePool executor",
222 )) as Box<dyn Error + Send>),
223 };
224
225 let duration = start.elapsed();
226 let completed_at = runtime.clock().instant();
227
228 match result {
229 Ok(()) => {
230 debug!("Task '{}' completed successfully in {:?}", task_name, duration);
231 }
232 Err(e) => {
233 error!("Task '{}' failed after {:?}: {}", task_name, duration, e);
234 }
235 }
236
237 let _ = completion_tx.send((task_id, completed_at));
238 });
239}