Skip to main content

reifydb_sub_task/
coordinator.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Copyright (c) 2026 ReifyDB
3
4use std::{cmp::Ordering, collections::BinaryHeap, error::Error, future, io, sync::Arc};
5
6use reifydb_core::interface::catalog::task::TaskId;
7use reifydb_engine::engine::StandardEngine;
8use reifydb_runtime::{SharedRuntime, context::clock::Instant};
9use tokio::{select, sync::mpsc, task::spawn_blocking, time};
10use tracing::{debug, error, info};
11
12#[cfg(reifydb_assertions)]
13use crate::schedule::Schedule;
14use crate::{
15	context::TaskContext,
16	registry::{TaskEntry, TaskRegistry},
17	task::{ScheduledTask, TaskExecutor, TaskWork},
18};
19
20#[derive(Debug)]
21pub enum TaskCoordinatorMessage {
22	Register(ScheduledTask),
23
24	Unregister(TaskId),
25
26	TaskCompleted {
27		task_id: TaskId,
28		completed_at: Instant,
29	},
30
31	Shutdown,
32}
33
34#[derive(Debug, Clone, PartialEq, Eq)]
35struct HeapEntry {
36	next_execution: Instant,
37	task_id: TaskId,
38}
39
40impl Ord for HeapEntry {
41	fn cmp(&self, other: &Self) -> Ordering {
42		other.next_execution.cmp(&self.next_execution)
43	}
44}
45
46impl PartialOrd for HeapEntry {
47	fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
48		Some(self.cmp(other))
49	}
50}
51
52pub async fn run_coordinator(
53	registry: TaskRegistry,
54	mut rx: mpsc::Receiver<TaskCoordinatorMessage>,
55	runtime: SharedRuntime,
56	engine: StandardEngine,
57) {
58	info!("Task coordinator started");
59
60	let (completion_tx, mut completion_rx) = mpsc::unbounded_channel();
61
62	let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::new();
63
64	for entry in registry.iter() {
65		heap.push(HeapEntry {
66			next_execution: entry.value().next_execution.clone(),
67			task_id: *entry.key(),
68		});
69	}
70
71	loop {
72		let sleep_duration = heap.peek().map(|entry| {
73			let now = runtime.clock().instant();
74			if entry.next_execution > now {
75				&entry.next_execution - &now
76			} else {
77				time::Duration::ZERO
78			}
79		});
80
81		select! {
82
83		    _ = async {
84			match sleep_duration {
85			    Some(duration) => time::sleep(duration).await,
86			    None => future::pending::<()>().await,
87			}
88		    } => {
89
90			if let Some(heap_entry) = heap.pop()
91			    && let Some(entry) = registry.get(&heap_entry.task_id)
92			{
93				let task = entry.task.clone();
94				let task_id = heap_entry.task_id;
95				let task_name = task.name.clone();
96
97				spawn_task(
98				    task_id,
99				    task,
100				    runtime.clone(),
101				    engine.clone(),
102				    completion_tx.clone(),
103				);
104
105				debug!("Spawned task: {}", task_name);
106			}
107		    }
108
109
110		    Some((task_id, completed_at)) = completion_rx.recv() => {
111
112			if let Some(mut entry) = registry.get_mut(&task_id) {
113			    if let Some(next_exec) = entry.task.schedule.next_execution(completed_at) {
114				#[cfg(reifydb_assertions)]
115				{
116					assert!(
117						!matches!(entry.task.schedule, Schedule::Once(_)),
118						"a Schedule::Once task entered the reschedule path after completing, so a one-shot task would run repeatedly and duplicate its side effects (task={})",
119						entry.task.name
120					);
121				}
122
123				entry.next_execution = next_exec.clone();
124
125
126				heap.push(HeapEntry {
127				    next_execution: next_exec,
128				    task_id,
129				});
130
131				debug!("Rescheduled task: {}", entry.task.name);
132			    } else {
133
134				let task_name = entry.task.name.clone();
135				drop(entry);
136				registry.remove(&task_id);
137				debug!("Completed one-shot task: {}", task_name);
138			    }
139			}
140		    }
141
142
143		    Some(msg) = rx.recv() => {
144			match msg {
145			    TaskCoordinatorMessage::Register(task) => {
146				let task_id = task.id;
147				let next_execution = runtime.clock().instant() + task.schedule.initial_delay();
148
149				info!("Registering task: {} (id: {})", task.name, task_id);
150
151
152				registry.insert(task_id, TaskEntry {
153				    task: Arc::new(task),
154				    next_execution: next_execution.clone(),
155				});
156
157
158				heap.push(HeapEntry {
159				    next_execution,
160				    task_id,
161				});
162			    }
163
164			    TaskCoordinatorMessage::Unregister(task_id) => {
165				info!("Unregistering task: {}", task_id);
166
167
168				registry.remove(&task_id);
169
170
171				heap.clear();
172				for entry in registry.iter() {
173				    heap.push(HeapEntry {
174					next_execution: entry.value().next_execution.clone(),
175					task_id: *entry.key(),
176				    });
177				}
178			    }
179
180			    TaskCoordinatorMessage::Shutdown => {
181				info!("Task coordinator shutting down");
182				break;
183			    }
184			TaskCoordinatorMessage::TaskCompleted{ .. } => {}}
185		    }
186
187		    else => {
188
189			info!("Coordinator channel closed, shutting down");
190			break;
191		    }
192		}
193	}
194
195	info!("Task coordinator stopped");
196}
197
198fn spawn_task(
199	task_id: TaskId,
200	task: Arc<ScheduledTask>,
201	runtime: SharedRuntime,
202	engine: StandardEngine,
203	completion_tx: mpsc::UnboundedSender<(TaskId, Instant)>,
204) {
205	let task_name = task.name.clone();
206	let executor = task.executor;
207	let work = task.work.clone();
208	let runtime_clone = runtime.clone();
209
210	runtime.spawn(async move {
211		let runtime = runtime_clone;
212		let start = runtime.clock().instant();
213		let ctx = TaskContext::new(engine);
214
215		let result = match (&work, executor) {
216			(TaskWork::Sync(f), TaskExecutor::ComputePool) => {
217				let f = f.clone();
218				let ctx_clone = ctx.clone();
219				spawn_blocking(move || f(ctx_clone))
220					.await
221					.map_err(|e| Box::new(e) as Box<dyn Error + Send>)
222					.and_then(|r| r)
223			}
224			(TaskWork::Async(f), TaskExecutor::Tokio) => f(ctx).await,
225			(TaskWork::Sync(_), TaskExecutor::Tokio) => Err(Box::new(io::Error::new(
226				io::ErrorKind::InvalidInput,
227				"Sync work cannot be executed on Tokio executor",
228			)) as Box<dyn Error + Send>),
229			(TaskWork::Async(_), TaskExecutor::ComputePool) => Err(Box::new(io::Error::new(
230				io::ErrorKind::InvalidInput,
231				"Async work cannot be executed on ComputePool executor",
232			)) as Box<dyn Error + Send>),
233		};
234
235		let duration = start.elapsed();
236		let completed_at = runtime.clock().instant();
237
238		match result {
239			Ok(()) => {
240				debug!("Task '{}' completed successfully in {:?}", task_name, duration);
241			}
242			Err(e) => {
243				error!("Task '{}' failed after {:?}: {}", task_name, duration, e);
244			}
245		}
246
247		let _ = completion_tx.send((task_id, completed_at));
248	});
249}