Skip to main content

reifydb_sub_task/
coordinator.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::{cmp::Ordering, collections::BinaryHeap, error::Error, future, io, sync::Arc, time::Instant};
5
6use reifydb_engine::engine::StandardEngine;
7use reifydb_runtime::SharedRuntime;
8use tokio::{select, sync::mpsc, time};
9use tracing::{debug, error, info};
10
11use crate::{
12	context::TaskContext,
13	registry::{TaskEntry, TaskRegistry},
14	task::{ScheduledTask, TaskExecutor, TaskId, TaskWork},
15};
16
17/// Messages sent to the coordinator task
18#[derive(Debug)]
19pub enum CoordinatorMessage {
20	/// Register a new task
21	Register(ScheduledTask),
22	/// Unregister a task by ID
23	Unregister(TaskId),
24	/// A task has completed execution
25	TaskCompleted {
26		task_id: TaskId,
27		completed_at: Instant,
28	},
29	/// Request immediate shutdown
30	Shutdown,
31}
32
33/// Entry in the scheduling heap
34#[derive(Debug, Clone, PartialEq, Eq)]
35struct HeapEntry {
36	next_execution: Instant,
37	task_id: TaskId,
38}
39
40impl Ord for HeapEntry {
41	fn cmp(&self, other: &Self) -> Ordering {
42		// Reverse ordering to make BinaryHeap a min-heap
43		other.next_execution.cmp(&self.next_execution)
44	}
45}
46
47impl PartialOrd for HeapEntry {
48	fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
49		Some(self.cmp(other))
50	}
51}
52
53/// Run the coordinator loop
54pub async fn run_coordinator(
55	registry: TaskRegistry,
56	mut rx: mpsc::Receiver<CoordinatorMessage>,
57	runtime: SharedRuntime,
58	engine: StandardEngine,
59) {
60	info!("Task coordinator started");
61
62	// Create a channel for task completion notifications
63	let (completion_tx, mut completion_rx) = mpsc::unbounded_channel();
64
65	// Min-heap of tasks ordered by next execution time
66	let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::new();
67
68	// Build initial heap from registry
69	for entry in registry.iter() {
70		heap.push(HeapEntry {
71			next_execution: entry.value().next_execution,
72			task_id: *entry.key(),
73		});
74	}
75
76	loop {
77		// Calculate sleep duration until next task
78		let sleep_duration = heap.peek().map(|entry| {
79			let now = Instant::now();
80			if entry.next_execution > now {
81				entry.next_execution - now
82			} else {
83				time::Duration::ZERO
84			}
85		});
86
87		select! {
88		    // Next task is due
89		    _ = async {
90			match sleep_duration {
91			    Some(duration) => time::sleep(duration).await,
92			    None => future::pending::<()>().await, // No tasks, wait forever
93			}
94		    } => {
95			// Pop the task from the heap
96			if let Some(heap_entry) = heap.pop() {
97			    // Get task from registry
98			    if let Some(entry) = registry.get(&heap_entry.task_id) {
99				let task = entry.task.clone();
100				let task_id = heap_entry.task_id;
101				let task_name = task.name.clone();
102
103				// Spawn task execution
104				spawn_task(
105				    task_id,
106				    task,
107				    runtime.clone(),
108				    engine.clone(),
109				    completion_tx.clone(),
110				);
111
112				debug!("Spawned task: {}", task_name);
113			    }
114			}
115		    }
116
117		    // Handle task completion notifications
118		    Some((task_id, completed_at)) = completion_rx.recv() => {
119			// Check if task should be rescheduled
120			if let Some(mut entry) = registry.get_mut(&task_id) {
121			    if let Some(next_exec) = entry.task.schedule.next_execution(completed_at) {
122				// Update next execution time
123				entry.next_execution = next_exec;
124
125				// Add back to heap
126				heap.push(HeapEntry {
127				    next_execution: next_exec,
128				    task_id,
129				});
130
131				debug!("Rescheduled task: {}", entry.task.name);
132			    } else {
133				// One-shot task, remove from registry
134				let task_name = entry.task.name.clone();
135				drop(entry); // Release the lock
136				registry.remove(&task_id);
137				debug!("Completed one-shot task: {}", task_name);
138			    }
139			}
140		    }
141
142		    // Handle coordinator messages
143		    Some(msg) = rx.recv() => {
144			match msg {
145			    CoordinatorMessage::Register(task) => {
146				let task_id = task.id;
147				let next_execution = Instant::now() + task.schedule.initial_delay();
148
149				info!("Registering task: {} (id: {})", task.name, task_id);
150
151				// Add to registry
152				registry.insert(task_id, TaskEntry {
153				    task: Arc::new(task),
154				    next_execution,
155				});
156
157				// Add to heap
158				heap.push(HeapEntry {
159				    next_execution,
160				    task_id,
161				});
162			    }
163
164			    CoordinatorMessage::Unregister(task_id) => {
165				info!("Unregistering task: {}", task_id);
166
167				// Remove from registry
168				registry.remove(&task_id);
169
170				// Rebuild heap (simplest approach for now)
171				heap.clear();
172				for entry in registry.iter() {
173				    heap.push(HeapEntry {
174					next_execution: entry.value().next_execution,
175					task_id: *entry.key(),
176				    });
177				}
178			    }
179
180			    CoordinatorMessage::Shutdown => {
181				info!("Task coordinator shutting down");
182				break;
183			    }
184			CoordinatorMessage::TaskCompleted{ .. } => {}}
185		    }
186
187		    else => {
188			// Channel closed, shutdown
189			info!("Coordinator channel closed, shutting down");
190			break;
191		    }
192		}
193	}
194
195	info!("Task coordinator stopped");
196}
197
198/// Spawn a task execution
199fn spawn_task(
200	task_id: TaskId,
201	task: Arc<ScheduledTask>,
202	runtime: SharedRuntime,
203	engine: StandardEngine,
204	completion_tx: mpsc::UnboundedSender<(TaskId, Instant)>,
205) {
206	let task_name = task.name.clone();
207	let executor = task.executor;
208	let work = task.work.clone();
209	let runtime_clone = runtime.clone();
210
211	runtime.spawn(async move {
212		let runtime = runtime_clone;
213		let start = Instant::now();
214		let ctx = TaskContext::new(engine);
215
216		// Execute the work
217		let result = match (&work, executor) {
218			(TaskWork::Sync(f), TaskExecutor::ComputePool) => {
219				let f = f.clone();
220				let ctx_clone = ctx.clone();
221				runtime.actor_system()
222					.compute(move || f(ctx_clone))
223					.await
224					.map_err(|e| Box::new(e) as Box<dyn Error + Send>)
225					.and_then(|r| r)
226			}
227			(TaskWork::Async(f), TaskExecutor::Tokio) => f(ctx).await,
228			(TaskWork::Sync(_), TaskExecutor::Tokio) => Err(Box::new(io::Error::new(
229				io::ErrorKind::InvalidInput,
230				"Sync work cannot be executed on Tokio executor",
231			)) as Box<dyn Error + Send>),
232			(TaskWork::Async(_), TaskExecutor::ComputePool) => Err(Box::new(io::Error::new(
233				io::ErrorKind::InvalidInput,
234				"Async work cannot be executed on ComputePool executor",
235			)) as Box<dyn Error + Send>),
236		};
237
238		let duration = start.elapsed();
239		let completed_at = Instant::now();
240
241		// Log result
242		match result {
243			Ok(()) => {
244				debug!("Task '{}' completed successfully in {:?}", task_name, duration);
245			}
246			Err(e) => {
247				error!("Task '{}' failed after {:?}: {}", task_name, duration, e);
248			}
249		}
250
251		// Send completion notification to coordinator
252		let _ = completion_tx.send((task_id, completed_at));
253	});
254}