Skip to main content

reifydb_sub_task/
coordinator.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::{cmp::Ordering, collections::BinaryHeap, error::Error, future, io, sync::Arc};
5
6use reifydb_core::interface::catalog::task::TaskId;
7use reifydb_engine::engine::StandardEngine;
8use reifydb_runtime::{SharedRuntime, context::clock::Instant};
9use tokio::{select, sync::mpsc, task::spawn_blocking, time};
10use tracing::{debug, error, info};
11
12use crate::{
13	context::TaskContext,
14	registry::{TaskEntry, TaskRegistry},
15	task::{ScheduledTask, TaskExecutor, TaskWork},
16};
17
18#[derive(Debug)]
19pub enum TaskCoordinatorMessage {
20	Register(ScheduledTask),
21
22	Unregister(TaskId),
23
24	TaskCompleted {
25		task_id: TaskId,
26		completed_at: Instant,
27	},
28
29	Shutdown,
30}
31
32#[derive(Debug, Clone, PartialEq, Eq)]
33struct HeapEntry {
34	next_execution: Instant,
35	task_id: TaskId,
36}
37
38impl Ord for HeapEntry {
39	fn cmp(&self, other: &Self) -> Ordering {
40		other.next_execution.cmp(&self.next_execution)
41	}
42}
43
44impl PartialOrd for HeapEntry {
45	fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
46		Some(self.cmp(other))
47	}
48}
49
50pub async fn run_coordinator(
51	registry: TaskRegistry,
52	mut rx: mpsc::Receiver<TaskCoordinatorMessage>,
53	runtime: SharedRuntime,
54	engine: StandardEngine,
55) {
56	info!("Task coordinator started");
57
58	let (completion_tx, mut completion_rx) = mpsc::unbounded_channel();
59
60	let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::new();
61
62	for entry in registry.iter() {
63		heap.push(HeapEntry {
64			next_execution: entry.value().next_execution.clone(),
65			task_id: *entry.key(),
66		});
67	}
68
69	loop {
70		let sleep_duration = heap.peek().map(|entry| {
71			let now = runtime.clock().instant();
72			if entry.next_execution > now {
73				&entry.next_execution - &now
74			} else {
75				time::Duration::ZERO
76			}
77		});
78
79		select! {
80
81		    _ = async {
82			match sleep_duration {
83			    Some(duration) => time::sleep(duration).await,
84			    None => future::pending::<()>().await,
85			}
86		    } => {
87
88			if let Some(heap_entry) = heap.pop()
89			    && let Some(entry) = registry.get(&heap_entry.task_id)
90			{
91				let task = entry.task.clone();
92				let task_id = heap_entry.task_id;
93				let task_name = task.name.clone();
94
95				spawn_task(
96				    task_id,
97				    task,
98				    runtime.clone(),
99				    engine.clone(),
100				    completion_tx.clone(),
101				);
102
103				debug!("Spawned task: {}", task_name);
104			}
105		    }
106
107
108		    Some((task_id, completed_at)) = completion_rx.recv() => {
109
110			if let Some(mut entry) = registry.get_mut(&task_id) {
111			    if let Some(next_exec) = entry.task.schedule.next_execution(completed_at) {
112
113				entry.next_execution = next_exec.clone();
114
115
116				heap.push(HeapEntry {
117				    next_execution: next_exec,
118				    task_id,
119				});
120
121				debug!("Rescheduled task: {}", entry.task.name);
122			    } else {
123
124				let task_name = entry.task.name.clone();
125				drop(entry);
126				registry.remove(&task_id);
127				debug!("Completed one-shot task: {}", task_name);
128			    }
129			}
130		    }
131
132
133		    Some(msg) = rx.recv() => {
134			match msg {
135			    TaskCoordinatorMessage::Register(task) => {
136				let task_id = task.id;
137				let next_execution = runtime.clock().instant() + task.schedule.initial_delay();
138
139				info!("Registering task: {} (id: {})", task.name, task_id);
140
141
142				registry.insert(task_id, TaskEntry {
143				    task: Arc::new(task),
144				    next_execution: next_execution.clone(),
145				});
146
147
148				heap.push(HeapEntry {
149				    next_execution,
150				    task_id,
151				});
152			    }
153
154			    TaskCoordinatorMessage::Unregister(task_id) => {
155				info!("Unregistering task: {}", task_id);
156
157
158				registry.remove(&task_id);
159
160
161				heap.clear();
162				for entry in registry.iter() {
163				    heap.push(HeapEntry {
164					next_execution: entry.value().next_execution.clone(),
165					task_id: *entry.key(),
166				    });
167				}
168			    }
169
170			    TaskCoordinatorMessage::Shutdown => {
171				info!("Task coordinator shutting down");
172				break;
173			    }
174			TaskCoordinatorMessage::TaskCompleted{ .. } => {}}
175		    }
176
177		    else => {
178
179			info!("Coordinator channel closed, shutting down");
180			break;
181		    }
182		}
183	}
184
185	info!("Task coordinator stopped");
186}
187
188fn spawn_task(
189	task_id: TaskId,
190	task: Arc<ScheduledTask>,
191	runtime: SharedRuntime,
192	engine: StandardEngine,
193	completion_tx: mpsc::UnboundedSender<(TaskId, Instant)>,
194) {
195	let task_name = task.name.clone();
196	let executor = task.executor;
197	let work = task.work.clone();
198	let runtime_clone = runtime.clone();
199
200	runtime.spawn(async move {
201		let runtime = runtime_clone;
202		let start = runtime.clock().instant();
203		let ctx = TaskContext::new(engine);
204
205		let result = match (&work, executor) {
206			(TaskWork::Sync(f), TaskExecutor::ComputePool) => {
207				let f = f.clone();
208				let ctx_clone = ctx.clone();
209				spawn_blocking(move || f(ctx_clone))
210					.await
211					.map_err(|e| Box::new(e) as Box<dyn Error + Send>)
212					.and_then(|r| r)
213			}
214			(TaskWork::Async(f), TaskExecutor::Tokio) => f(ctx).await,
215			(TaskWork::Sync(_), TaskExecutor::Tokio) => Err(Box::new(io::Error::new(
216				io::ErrorKind::InvalidInput,
217				"Sync work cannot be executed on Tokio executor",
218			)) as Box<dyn Error + Send>),
219			(TaskWork::Async(_), TaskExecutor::ComputePool) => Err(Box::new(io::Error::new(
220				io::ErrorKind::InvalidInput,
221				"Async work cannot be executed on ComputePool executor",
222			)) as Box<dyn Error + Send>),
223		};
224
225		let duration = start.elapsed();
226		let completed_at = runtime.clock().instant();
227
228		match result {
229			Ok(()) => {
230				debug!("Task '{}' completed successfully in {:?}", task_name, duration);
231			}
232			Err(e) => {
233				error!("Task '{}' failed after {:?}: {}", task_name, duration, e);
234			}
235		}
236
237		let _ = completion_tx.send((task_id, completed_at));
238	});
239}