Skip to main content

reifydb_sub_task/
coordinator.rs

1use std::{cmp::Ordering, collections::BinaryHeap, error::Error, future, io, sync::Arc, time::Instant};
2
3use reifydb_engine::engine::StandardEngine;
4use reifydb_runtime::SharedRuntime;
5use tokio::{select, sync::mpsc, time};
6use tracing::{debug, error, info};
7
8use crate::{
9	context::TaskContext,
10	registry::{TaskEntry, TaskRegistry},
11	task::{ScheduledTask, TaskExecutor, TaskId, TaskWork},
12};
13
14/// Messages sent to the coordinator task
15#[derive(Debug)]
16pub enum CoordinatorMessage {
17	/// Register a new task
18	Register(ScheduledTask),
19	/// Unregister a task by ID
20	Unregister(TaskId),
21	/// A task has completed execution
22	TaskCompleted {
23		task_id: TaskId,
24		completed_at: Instant,
25	},
26	/// Request immediate shutdown
27	Shutdown,
28}
29
30/// Entry in the scheduling heap
31#[derive(Debug, Clone, PartialEq, Eq)]
32struct HeapEntry {
33	next_execution: Instant,
34	task_id: TaskId,
35}
36
37impl Ord for HeapEntry {
38	fn cmp(&self, other: &Self) -> Ordering {
39		// Reverse ordering to make BinaryHeap a min-heap
40		other.next_execution.cmp(&self.next_execution)
41	}
42}
43
44impl PartialOrd for HeapEntry {
45	fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
46		Some(self.cmp(other))
47	}
48}
49
50/// Run the coordinator loop
51pub async fn run_coordinator(
52	registry: TaskRegistry,
53	mut rx: mpsc::Receiver<CoordinatorMessage>,
54	runtime: SharedRuntime,
55	engine: StandardEngine,
56) {
57	info!("Task coordinator started");
58
59	// Create a channel for task completion notifications
60	let (completion_tx, mut completion_rx) = mpsc::unbounded_channel();
61
62	// Min-heap of tasks ordered by next execution time
63	let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::new();
64
65	// Build initial heap from registry
66	for entry in registry.iter() {
67		heap.push(HeapEntry {
68			next_execution: entry.value().next_execution,
69			task_id: *entry.key(),
70		});
71	}
72
73	loop {
74		// Calculate sleep duration until next task
75		let sleep_duration = heap.peek().map(|entry| {
76			let now = Instant::now();
77			if entry.next_execution > now {
78				entry.next_execution - now
79			} else {
80				time::Duration::ZERO
81			}
82		});
83
84		select! {
85		    // Next task is due
86		    _ = async {
87			match sleep_duration {
88			    Some(duration) => time::sleep(duration).await,
89			    None => future::pending::<()>().await, // No tasks, wait forever
90			}
91		    } => {
92			// Pop the task from the heap
93			if let Some(heap_entry) = heap.pop() {
94			    // Get task from registry
95			    if let Some(entry) = registry.get(&heap_entry.task_id) {
96				let task = entry.task.clone();
97				let task_id = heap_entry.task_id;
98				let task_name = task.name.clone();
99
100				// Spawn task execution
101				spawn_task(
102				    task_id,
103				    task,
104				    runtime.clone(),
105				    engine.clone(),
106				    completion_tx.clone(),
107				);
108
109				debug!("Spawned task: {}", task_name);
110			    }
111			}
112		    }
113
114		    // Handle task completion notifications
115		    Some((task_id, completed_at)) = completion_rx.recv() => {
116			// Check if task should be rescheduled
117			if let Some(mut entry) = registry.get_mut(&task_id) {
118			    if let Some(next_exec) = entry.task.schedule.next_execution(completed_at) {
119				// Update next execution time
120				entry.next_execution = next_exec;
121
122				// Add back to heap
123				heap.push(HeapEntry {
124				    next_execution: next_exec,
125				    task_id,
126				});
127
128				debug!("Rescheduled task: {}", entry.task.name);
129			    } else {
130				// One-shot task, remove from registry
131				let task_name = entry.task.name.clone();
132				drop(entry); // Release the lock
133				registry.remove(&task_id);
134				debug!("Completed one-shot task: {}", task_name);
135			    }
136			}
137		    }
138
139		    // Handle coordinator messages
140		    Some(msg) = rx.recv() => {
141			match msg {
142			    CoordinatorMessage::Register(task) => {
143				let task_id = task.id;
144				let next_execution = Instant::now() + task.schedule.initial_delay();
145
146				info!("Registering task: {} (id: {})", task.name, task_id);
147
148				// Add to registry
149				registry.insert(task_id, TaskEntry {
150				    task: Arc::new(task),
151				    next_execution,
152				});
153
154				// Add to heap
155				heap.push(HeapEntry {
156				    next_execution,
157				    task_id,
158				});
159			    }
160
161			    CoordinatorMessage::Unregister(task_id) => {
162				info!("Unregistering task: {}", task_id);
163
164				// Remove from registry
165				registry.remove(&task_id);
166
167				// Rebuild heap (simplest approach for now)
168				heap.clear();
169				for entry in registry.iter() {
170				    heap.push(HeapEntry {
171					next_execution: entry.value().next_execution,
172					task_id: *entry.key(),
173				    });
174				}
175			    }
176
177			    CoordinatorMessage::Shutdown => {
178				info!("Task coordinator shutting down");
179				break;
180			    }
181			CoordinatorMessage::TaskCompleted{ .. } => {}}
182		    }
183
184		    else => {
185			// Channel closed, shutdown
186			info!("Coordinator channel closed, shutting down");
187			break;
188		    }
189		}
190	}
191
192	info!("Task coordinator stopped");
193}
194
195/// Spawn a task execution
196fn spawn_task(
197	task_id: TaskId,
198	task: Arc<ScheduledTask>,
199	runtime: SharedRuntime,
200	engine: StandardEngine,
201	completion_tx: mpsc::UnboundedSender<(TaskId, Instant)>,
202) {
203	let task_name = task.name.clone();
204	let executor = task.executor;
205	let work = task.work.clone();
206	let runtime_clone = runtime.clone();
207
208	runtime.spawn(async move {
209		let runtime = runtime_clone;
210		let start = Instant::now();
211		let ctx = TaskContext::new(engine);
212
213		// Execute the work
214		let result = match (&work, executor) {
215			(TaskWork::Sync(f), TaskExecutor::ComputePool) => {
216				let f = f.clone();
217				let ctx_clone = ctx.clone();
218				runtime.actor_system()
219					.compute(move || f(ctx_clone))
220					.await
221					.map_err(|e| Box::new(e) as Box<dyn Error + Send>)
222					.and_then(|r| r)
223			}
224			(TaskWork::Async(f), TaskExecutor::Tokio) => f(ctx).await,
225			(TaskWork::Sync(_), TaskExecutor::Tokio) => Err(Box::new(io::Error::new(
226				io::ErrorKind::InvalidInput,
227				"Sync work cannot be executed on Tokio executor",
228			)) as Box<dyn Error + Send>),
229			(TaskWork::Async(_), TaskExecutor::ComputePool) => Err(Box::new(io::Error::new(
230				io::ErrorKind::InvalidInput,
231				"Async work cannot be executed on ComputePool executor",
232			)) as Box<dyn Error + Send>),
233		};
234
235		let duration = start.elapsed();
236		let completed_at = Instant::now();
237
238		// Log result
239		match result {
240			Ok(()) => {
241				debug!("Task '{}' completed successfully in {:?}", task_name, duration);
242			}
243			Err(e) => {
244				error!("Task '{}' failed after {:?}: {}", task_name, duration, e);
245			}
246		}
247
248		// Send completion notification to coordinator
249		let _ = completion_tx.send((task_id, completed_at));
250	});
251}