reifydb_sub_task/
coordinator.rs1use std::{cmp::Ordering, collections::BinaryHeap, error::Error, future, io, sync::Arc, time::Instant};
5
6use reifydb_engine::engine::StandardEngine;
7use reifydb_runtime::SharedRuntime;
8use tokio::{select, sync::mpsc, time};
9use tracing::{debug, error, info};
10
11use crate::{
12 context::TaskContext,
13 registry::{TaskEntry, TaskRegistry},
14 task::{ScheduledTask, TaskExecutor, TaskId, TaskWork},
15};
16
17#[derive(Debug)]
19pub enum CoordinatorMessage {
20 Register(ScheduledTask),
22 Unregister(TaskId),
24 TaskCompleted {
26 task_id: TaskId,
27 completed_at: Instant,
28 },
29 Shutdown,
31}
32
33#[derive(Debug, Clone, PartialEq, Eq)]
35struct HeapEntry {
36 next_execution: Instant,
37 task_id: TaskId,
38}
39
40impl Ord for HeapEntry {
41 fn cmp(&self, other: &Self) -> Ordering {
42 other.next_execution.cmp(&self.next_execution)
44 }
45}
46
47impl PartialOrd for HeapEntry {
48 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
49 Some(self.cmp(other))
50 }
51}
52
53pub async fn run_coordinator(
55 registry: TaskRegistry,
56 mut rx: mpsc::Receiver<CoordinatorMessage>,
57 runtime: SharedRuntime,
58 engine: StandardEngine,
59) {
60 info!("Task coordinator started");
61
62 let (completion_tx, mut completion_rx) = mpsc::unbounded_channel();
64
65 let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::new();
67
68 for entry in registry.iter() {
70 heap.push(HeapEntry {
71 next_execution: entry.value().next_execution,
72 task_id: *entry.key(),
73 });
74 }
75
76 loop {
77 let sleep_duration = heap.peek().map(|entry| {
79 let now = Instant::now();
80 if entry.next_execution > now {
81 entry.next_execution - now
82 } else {
83 time::Duration::ZERO
84 }
85 });
86
87 select! {
88 _ = async {
90 match sleep_duration {
91 Some(duration) => time::sleep(duration).await,
92 None => future::pending::<()>().await, }
94 } => {
95 if let Some(heap_entry) = heap.pop() {
97 if let Some(entry) = registry.get(&heap_entry.task_id) {
99 let task = entry.task.clone();
100 let task_id = heap_entry.task_id;
101 let task_name = task.name.clone();
102
103 spawn_task(
105 task_id,
106 task,
107 runtime.clone(),
108 engine.clone(),
109 completion_tx.clone(),
110 );
111
112 debug!("Spawned task: {}", task_name);
113 }
114 }
115 }
116
117 Some((task_id, completed_at)) = completion_rx.recv() => {
119 if let Some(mut entry) = registry.get_mut(&task_id) {
121 if let Some(next_exec) = entry.task.schedule.next_execution(completed_at) {
122 entry.next_execution = next_exec;
124
125 heap.push(HeapEntry {
127 next_execution: next_exec,
128 task_id,
129 });
130
131 debug!("Rescheduled task: {}", entry.task.name);
132 } else {
133 let task_name = entry.task.name.clone();
135 drop(entry); registry.remove(&task_id);
137 debug!("Completed one-shot task: {}", task_name);
138 }
139 }
140 }
141
142 Some(msg) = rx.recv() => {
144 match msg {
145 CoordinatorMessage::Register(task) => {
146 let task_id = task.id;
147 let next_execution = Instant::now() + task.schedule.initial_delay();
148
149 info!("Registering task: {} (id: {})", task.name, task_id);
150
151 registry.insert(task_id, TaskEntry {
153 task: Arc::new(task),
154 next_execution,
155 });
156
157 heap.push(HeapEntry {
159 next_execution,
160 task_id,
161 });
162 }
163
164 CoordinatorMessage::Unregister(task_id) => {
165 info!("Unregistering task: {}", task_id);
166
167 registry.remove(&task_id);
169
170 heap.clear();
172 for entry in registry.iter() {
173 heap.push(HeapEntry {
174 next_execution: entry.value().next_execution,
175 task_id: *entry.key(),
176 });
177 }
178 }
179
180 CoordinatorMessage::Shutdown => {
181 info!("Task coordinator shutting down");
182 break;
183 }
184 CoordinatorMessage::TaskCompleted{ .. } => {}}
185 }
186
187 else => {
188 info!("Coordinator channel closed, shutting down");
190 break;
191 }
192 }
193 }
194
195 info!("Task coordinator stopped");
196}
197
198fn spawn_task(
200 task_id: TaskId,
201 task: Arc<ScheduledTask>,
202 runtime: SharedRuntime,
203 engine: StandardEngine,
204 completion_tx: mpsc::UnboundedSender<(TaskId, Instant)>,
205) {
206 let task_name = task.name.clone();
207 let executor = task.executor;
208 let work = task.work.clone();
209 let runtime_clone = runtime.clone();
210
211 runtime.spawn(async move {
212 let runtime = runtime_clone;
213 let start = Instant::now();
214 let ctx = TaskContext::new(engine);
215
216 let result = match (&work, executor) {
218 (TaskWork::Sync(f), TaskExecutor::ComputePool) => {
219 let f = f.clone();
220 let ctx_clone = ctx.clone();
221 runtime.actor_system()
222 .compute(move || f(ctx_clone))
223 .await
224 .map_err(|e| Box::new(e) as Box<dyn Error + Send>)
225 .and_then(|r| r)
226 }
227 (TaskWork::Async(f), TaskExecutor::Tokio) => f(ctx).await,
228 (TaskWork::Sync(_), TaskExecutor::Tokio) => Err(Box::new(io::Error::new(
229 io::ErrorKind::InvalidInput,
230 "Sync work cannot be executed on Tokio executor",
231 )) as Box<dyn Error + Send>),
232 (TaskWork::Async(_), TaskExecutor::ComputePool) => Err(Box::new(io::Error::new(
233 io::ErrorKind::InvalidInput,
234 "Async work cannot be executed on ComputePool executor",
235 )) as Box<dyn Error + Send>),
236 };
237
238 let duration = start.elapsed();
239 let completed_at = Instant::now();
240
241 match result {
243 Ok(()) => {
244 debug!("Task '{}' completed successfully in {:?}", task_name, duration);
245 }
246 Err(e) => {
247 error!("Task '{}' failed after {:?}: {}", task_name, duration, e);
248 }
249 }
250
251 let _ = completion_tx.send((task_id, completed_at));
253 });
254}