Skip to main content

reifydb_sub_task/
coordinator.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::{cmp::Ordering, collections::BinaryHeap, error::Error, future, io, sync::Arc};
5
6use reifydb_core::interface::catalog::task::TaskId;
7use reifydb_engine::engine::StandardEngine;
8use reifydb_runtime::{SharedRuntime, context::clock::Instant};
9use tokio::{select, sync::mpsc, task::spawn_blocking, time};
10use tracing::{debug, error, info};
11
12use crate::{
13	context::TaskContext,
14	registry::{TaskEntry, TaskRegistry},
15	task::{ScheduledTask, TaskExecutor, TaskWork},
16};
17
18/// Messages sent to the coordinator task
19#[derive(Debug)]
20pub enum TaskCoordinatorMessage {
21	/// Register a new task
22	Register(ScheduledTask),
23	/// Unregister a task by ID
24	Unregister(TaskId),
25	/// A task has completed execution
26	TaskCompleted {
27		task_id: TaskId,
28		completed_at: Instant,
29	},
30	/// Request immediate shutdown
31	Shutdown,
32}
33
34/// Entry in the scheduling heap
35#[derive(Debug, Clone, PartialEq, Eq)]
36struct HeapEntry {
37	next_execution: Instant,
38	task_id: TaskId,
39}
40
41impl Ord for HeapEntry {
42	fn cmp(&self, other: &Self) -> Ordering {
43		// Reverse ordering to make BinaryHeap a min-heap
44		other.next_execution.cmp(&self.next_execution)
45	}
46}
47
48impl PartialOrd for HeapEntry {
49	fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
50		Some(self.cmp(other))
51	}
52}
53
54/// Run the coordinator loop
55pub async fn run_coordinator(
56	registry: TaskRegistry,
57	mut rx: mpsc::Receiver<TaskCoordinatorMessage>,
58	runtime: SharedRuntime,
59	engine: StandardEngine,
60) {
61	info!("Task coordinator started");
62
63	// Create a channel for task completion notifications
64	let (completion_tx, mut completion_rx) = mpsc::unbounded_channel();
65
66	// Min-heap of tasks ordered by next execution time
67	let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::new();
68
69	// Build initial heap from registry
70	for entry in registry.iter() {
71		heap.push(HeapEntry {
72			next_execution: entry.value().next_execution.clone(),
73			task_id: *entry.key(),
74		});
75	}
76
77	loop {
78		// Calculate sleep duration until next task
79		let sleep_duration = heap.peek().map(|entry| {
80			let now = runtime.clock().instant();
81			if entry.next_execution > now {
82				&entry.next_execution - &now
83			} else {
84				time::Duration::ZERO
85			}
86		});
87
88		select! {
89		    // Next task is due
90		    _ = async {
91			match sleep_duration {
92			    Some(duration) => time::sleep(duration).await,
93			    None => future::pending::<()>().await, // No tasks, wait forever
94			}
95		    } => {
96			// Pop the task from the heap
97			if let Some(heap_entry) = heap.pop() {
98			    // Get task from registry
99			    if let Some(entry) = registry.get(&heap_entry.task_id) {
100				let task = entry.task.clone();
101				let task_id = heap_entry.task_id;
102				let task_name = task.name.clone();
103
104				// Spawn task execution
105				spawn_task(
106				    task_id,
107				    task,
108				    runtime.clone(),
109				    engine.clone(),
110				    completion_tx.clone(),
111				);
112
113				debug!("Spawned task: {}", task_name);
114			    }
115			}
116		    }
117
118		    // Handle task completion notifications
119		    Some((task_id, completed_at)) = completion_rx.recv() => {
120			// Check if task should be rescheduled
121			if let Some(mut entry) = registry.get_mut(&task_id) {
122			    if let Some(next_exec) = entry.task.schedule.next_execution(completed_at) {
123				// Update next execution time
124				entry.next_execution = next_exec.clone();
125
126				// Add back to heap
127				heap.push(HeapEntry {
128				    next_execution: next_exec,
129				    task_id,
130				});
131
132				debug!("Rescheduled task: {}", entry.task.name);
133			    } else {
134				// One-shot task, remove from registry
135				let task_name = entry.task.name.clone();
136				drop(entry); // Release the lock
137				registry.remove(&task_id);
138				debug!("Completed one-shot task: {}", task_name);
139			    }
140			}
141		    }
142
143		    // Handle coordinator messages
144		    Some(msg) = rx.recv() => {
145			match msg {
146			    TaskCoordinatorMessage::Register(task) => {
147				let task_id = task.id;
148				let next_execution = runtime.clock().instant() + task.schedule.initial_delay();
149
150				info!("Registering task: {} (id: {})", task.name, task_id);
151
152				// Add to registry
153				registry.insert(task_id, TaskEntry {
154				    task: Arc::new(task),
155				    next_execution: next_execution.clone(),
156				});
157
158				// Add to heap
159				heap.push(HeapEntry {
160				    next_execution,
161				    task_id,
162				});
163			    }
164
165			    TaskCoordinatorMessage::Unregister(task_id) => {
166				info!("Unregistering task: {}", task_id);
167
168				// Remove from registry
169				registry.remove(&task_id);
170
171				// Rebuild heap (simplest approach for now)
172				heap.clear();
173				for entry in registry.iter() {
174				    heap.push(HeapEntry {
175					next_execution: entry.value().next_execution.clone(),
176					task_id: *entry.key(),
177				    });
178				}
179			    }
180
181			    TaskCoordinatorMessage::Shutdown => {
182				info!("Task coordinator shutting down");
183				break;
184			    }
185			TaskCoordinatorMessage::TaskCompleted{ .. } => {}}
186		    }
187
188		    else => {
189			// Channel closed, shutdown
190			info!("Coordinator channel closed, shutting down");
191			break;
192		    }
193		}
194	}
195
196	info!("Task coordinator stopped");
197}
198
199/// Spawn a task execution
200fn spawn_task(
201	task_id: TaskId,
202	task: Arc<ScheduledTask>,
203	runtime: SharedRuntime,
204	engine: StandardEngine,
205	completion_tx: mpsc::UnboundedSender<(TaskId, Instant)>,
206) {
207	let task_name = task.name.clone();
208	let executor = task.executor;
209	let work = task.work.clone();
210	let runtime_clone = runtime.clone();
211
212	runtime.spawn(async move {
213		let runtime = runtime_clone;
214		let start = runtime.clock().instant();
215		let ctx = TaskContext::new(engine);
216
217		// Execute the work
218		let result = match (&work, executor) {
219			(TaskWork::Sync(f), TaskExecutor::ComputePool) => {
220				let f = f.clone();
221				let ctx_clone = ctx.clone();
222				spawn_blocking(move || f(ctx_clone))
223					.await
224					.map_err(|e| Box::new(e) as Box<dyn Error + Send>)
225					.and_then(|r| r)
226			}
227			(TaskWork::Async(f), TaskExecutor::Tokio) => f(ctx).await,
228			(TaskWork::Sync(_), TaskExecutor::Tokio) => Err(Box::new(io::Error::new(
229				io::ErrorKind::InvalidInput,
230				"Sync work cannot be executed on Tokio executor",
231			)) as Box<dyn Error + Send>),
232			(TaskWork::Async(_), TaskExecutor::ComputePool) => Err(Box::new(io::Error::new(
233				io::ErrorKind::InvalidInput,
234				"Async work cannot be executed on ComputePool executor",
235			)) as Box<dyn Error + Send>),
236		};
237
238		let duration = start.elapsed();
239		let completed_at = runtime.clock().instant();
240
241		// Log result
242		match result {
243			Ok(()) => {
244				debug!("Task '{}' completed successfully in {:?}", task_name, duration);
245			}
246			Err(e) => {
247				error!("Task '{}' failed after {:?}: {}", task_name, duration, e);
248			}
249		}
250
251		// Send completion notification to coordinator
252		let _ = completion_tx.send((task_id, completed_at));
253	});
254}