svod_runtime/executor.rs
1//! Unified parallel execution for heterogeneous devices.
2//!
3//! The `UnifiedExecutor` handles kernel execution across any mix of devices
4//! (CPU, CUDA, Metal, etc.) with proper synchronization and dependency tracking.
5//!
6//! # Design Principles
7//!
8//! 1. **Single abstraction** - One executor handles any device mix
9//! 2. **Device-agnostic sync** - Timeline signals abstract over device-specific primitives
10//! 3. **Zero overhead for single-device** - Fast path skips synchronization when possible
11//! 4. **Buffer dependency tracking** - Following Tinygrad's `_access_resources()` pattern
12//!
13//! # Example
14//!
15//! ```ignore
16//! let mut executor = UnifiedExecutor::new();
17//! executor.add_device(DeviceSpec::Cpu)?;
18//!
19//! // Execute schedule - handles dependencies automatically
20//! let output_id = executor.execute(&schedule)?;
21//! ```
22//!
23//! # Execution Graph
24//!
25//! For complex schedules with multiple devices, the executor builds an execution
26//! graph (DAG) where nodes are kernel operations and edges are buffer dependencies.
27//! Independent kernels on the same device can be batched, and kernels on different
28//! devices can run in parallel (with appropriate synchronization).
29
30use std::collections::{HashMap, HashSet};
31use std::sync::Arc;
32use std::sync::atomic::{AtomicU64, Ordering};
33
34use snafu::ResultExt;
35use svod_device::device::Device;
36use svod_device::registry::DeviceRegistry;
37use svod_device::{Allocator, Buffer, BufferId, CpuTimelineSignal, TimelineSignal};
38use svod_dtype::DeviceSpec;
39
40use crate::error::Result;
41
42/// Per-device execution context.
43///
44/// Each device has its own timeline signal, queue, and allocator.
45/// This enables parallel execution across devices with proper synchronization.
46pub struct DeviceContext {
47 /// Device specification (CPU, CUDA:0, etc.).
48 pub device: DeviceSpec,
49 /// Device abstraction for rendering/compiling/executing.
50 pub device_handle: Arc<Device>,
51 /// Timeline signal for this device's operations.
52 pub signal: Arc<dyn TimelineSignal>,
53 /// Current timeline value (monotonically increasing).
54 pub timeline: AtomicU64,
55 /// Allocator for this device.
56 pub allocator: Arc<dyn Allocator>,
57}
58
59impl std::fmt::Debug for DeviceContext {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 f.debug_struct("DeviceContext")
62 .field("device", &self.device)
63 .field("timeline", &self.timeline.load(Ordering::Relaxed))
64 .finish()
65 }
66}
67
68impl DeviceContext {
69 /// Create a new device context.
70 pub fn new(device: Arc<Device>, signal: Arc<dyn TimelineSignal>) -> Self {
71 let allocator = device.allocator.clone();
72 let device_spec = device.device.clone();
73 Self { device: device_spec, device_handle: device, signal, timeline: AtomicU64::new(0), allocator }
74 }
75
76 /// Get the next timeline value and increment.
77 pub fn next_timeline(&self) -> u64 {
78 self.timeline.fetch_add(1, Ordering::Relaxed) + 1
79 }
80
81 /// Get the current timeline value.
82 pub fn current_timeline(&self) -> u64 {
83 self.timeline.load(Ordering::Relaxed)
84 }
85
86 /// Signal that operations up to the given timeline value are complete.
87 pub fn signal_completion(&self, value: u64) {
88 self.signal.set(value);
89 }
90
91 /// Wait for operations up to the given timeline value to complete.
92 pub fn wait_for(&self, value: u64) -> Result<()> {
93 self.signal.wait(value, 0).context(crate::error::DeviceSnafu)?;
94 Ok(())
95 }
96}
97
98/// Cross-device synchronization strategy.
99#[derive(Debug, Clone, Copy, PartialEq, Eq)]
100pub enum SyncStrategy {
101 /// Same device - no synchronization needed (operations are ordered).
102 None,
103 /// Same device type, different instance (e.g., CUDA:0 → CUDA:1).
104 /// Use peer-to-peer events if available.
105 PeerToPeer,
106 /// Different device types (e.g., CUDA → CPU).
107 /// Use CPU-mediated polling.
108 CpuMediated,
109}
110
111/// A node in the execution graph representing a kernel or transfer operation.
112#[derive(Debug, Clone)]
113pub struct ExecutionNode {
114 /// Unique identifier for this node (typically the kernel AST ID).
115 pub id: u64,
116 /// Device this operation executes on.
117 pub device: DeviceSpec,
118 /// Buffer IDs read by this operation.
119 pub inputs: Vec<BufferId>,
120 /// Buffer IDs written by this operation.
121 pub outputs: Vec<BufferId>,
122 /// IDs of nodes that must complete before this one (dependencies).
123 pub predecessors: Vec<u64>,
124 /// Whether this is a data transfer (COPY) or a computational kernel.
125 pub is_transfer: bool,
126 /// Buffer access information for parallel execution.
127 /// Contains the full buffer list and output indices for dependency tracking.
128 pub buffer_access: Option<KernelBufferAccess>,
129}
130
131/// Buffer access information for parallel kernel execution.
132///
133/// This struct captures which buffers a kernel accesses and which are outputs,
134/// enabling precise dependency tracking in `execute_parallel_group`.
135#[derive(Debug, Clone)]
136pub struct KernelBufferAccess {
137 /// All buffer IDs accessed by this kernel (inputs and outputs).
138 pub buffers: Vec<BufferId>,
139 /// Indices into `buffers` that are outputs (written by the kernel).
140 /// Other indices are inputs (read-only).
141 pub output_indices: Vec<usize>,
142}
143
144/// Execution graph representing a DAG of kernel operations.
145///
146/// The graph is built from a schedule and captures buffer dependencies
147/// between kernels. Independent kernels can be executed in parallel.
148#[derive(Debug, Default)]
149pub struct ExecutionGraph {
150 /// Nodes in the graph, indexed by ID.
151 nodes: HashMap<u64, ExecutionNode>,
152 /// Execution order (topologically sorted node IDs).
153 execution_order: Vec<u64>,
154 /// Nodes grouped by device for batched execution.
155 device_groups: HashMap<DeviceSpec, Vec<u64>>,
156}
157
158impl ExecutionGraph {
159 /// Create a new empty execution graph.
160 pub fn new() -> Self {
161 Self::default()
162 }
163
164 /// Add a node to the graph.
165 pub fn add_node(&mut self, node: ExecutionNode) {
166 let id = node.id;
167 let device = node.device.clone();
168 self.nodes.insert(id, node);
169 self.device_groups.entry(device).or_default().push(id);
170 }
171
172 /// Get a node by ID.
173 pub fn node(&self, id: u64) -> Option<&ExecutionNode> {
174 self.nodes.get(&id)
175 }
176
177 /// Get all nodes.
178 pub fn nodes(&self) -> impl Iterator<Item = &ExecutionNode> {
179 self.nodes.values()
180 }
181
182 /// Compute topological order and find parallelizable groups.
183 ///
184 /// Returns groups of nodes that can be executed in parallel.
185 /// Each group contains nodes with no dependencies on each other.
186 pub fn compute_parallel_groups(&mut self) -> Vec<Vec<u64>> {
187 // Build in-degree map. Predecessor lists are deduplicated via
188 // sort-and-dedup on a local SmallVec so a node that lists the same
189 // predecessor twice doesn't decrement in-degree more than once during
190 // Kahn's traversal. Successor sets remain HashSet to keep insertion
191 // O(1) when many nodes share predecessors.
192 let mut in_degree: HashMap<u64, usize> = HashMap::new();
193 let mut successors: HashMap<u64, HashSet<u64>> = HashMap::new();
194
195 for node in self.nodes.values() {
196 in_degree.entry(node.id).or_insert(0);
197 let mut preds: smallvec::SmallVec<[u64; 8]> = node.predecessors.iter().copied().collect();
198 preds.sort_unstable();
199 preds.dedup();
200 for &pred in &preds {
201 successors.entry(pred).or_default().insert(node.id);
202 *in_degree.entry(node.id).or_insert(0) += 1;
203 }
204 }
205
206 // Kahn's algorithm with level grouping
207 let mut groups = Vec::new();
208 let mut ready: Vec<u64> = in_degree.iter().filter(|&(_, deg)| *deg == 0).map(|(&id, _)| id).collect();
209
210 while !ready.is_empty() {
211 // All nodes in ready can be executed in parallel
212 groups.push(ready.clone());
213
214 // Record execution order
215 self.execution_order.extend(ready.iter().copied());
216
217 // Find next batch
218 let mut next_ready = Vec::new();
219 for id in ready {
220 if let Some(succs) = successors.get(&id) {
221 for &succ in succs {
222 let deg = in_degree.get_mut(&succ).unwrap();
223 *deg -= 1;
224 if *deg == 0 {
225 next_ready.push(succ);
226 }
227 }
228 }
229 }
230 ready = next_ready;
231 }
232
233 groups
234 }
235
236 /// Get nodes grouped by device.
237 pub fn device_groups(&self) -> &HashMap<DeviceSpec, Vec<u64>> {
238 &self.device_groups
239 }
240
241 /// Check if all nodes have been visited (no cycles).
242 pub fn is_valid(&self) -> bool {
243 self.execution_order.len() == self.nodes.len()
244 }
245}
246
247/// Unified executor for heterogeneous device execution.
248///
249/// Manages device contexts to enable parallel execution across any mix of devices.
250/// Uses timeline signals for cross-device synchronization.
251///
252/// # Stateless Execution Model (Tinygrad-Aligned)
253///
254/// The executor follows Tinygrad's stateless execution model where:
255/// - Dependencies are computed at schedule time, not runtime
256/// - ExecutionPlan pre-computes kernel order via topological sort
257/// - No runtime dependency tracking is needed (zero memory accumulation)
258/// - Timeline signals handle cross-device synchronization only
259pub struct UnifiedExecutor {
260 /// Per-device execution contexts.
261 contexts: HashMap<DeviceSpec, DeviceContext>,
262
263 /// Device registry for looking up allocators.
264 registry: &'static DeviceRegistry,
265}
266
267impl std::fmt::Debug for UnifiedExecutor {
268 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
269 f.debug_struct("UnifiedExecutor").field("contexts", &self.contexts.keys().collect::<Vec<_>>()).finish()
270 }
271}
272
273impl UnifiedExecutor {
274 /// Create a new unified executor.
275 pub fn new(registry: &'static DeviceRegistry) -> Self {
276 Self { contexts: HashMap::new(), registry }
277 }
278
279 /// Add a device to the executor.
280 ///
281 /// Creates the device context with timeline signal and queues.
282 pub fn add_device(&mut self, device_spec: DeviceSpec) -> Result<()> {
283 if self.contexts.contains_key(&device_spec) {
284 return Ok(()); // Already added
285 }
286
287 // Create device handle
288 let device = crate::DEVICE_FACTORIES.device(&device_spec, self.registry)?;
289
290 // Create timeline signal based on device type
291 let signal: Arc<dyn TimelineSignal> = match &device_spec {
292 DeviceSpec::Cpu => Arc::new(CpuTimelineSignal::new()),
293 #[cfg(feature = "cuda")]
294 DeviceSpec::Cuda { .. } => {
295 // TODO: Create CUDA timeline signal with events
296 // For now, fall back to CPU signal (works, but less efficient)
297 Arc::new(CpuTimelineSignal::new())
298 }
299 _ => Arc::new(CpuTimelineSignal::new()),
300 };
301
302 let ctx = DeviceContext::new(device, signal);
303 self.contexts.insert(device_spec, ctx);
304
305 Ok(())
306 }
307
308 /// Get the device context for a device specification.
309 pub fn context(&self, device: &DeviceSpec) -> Option<&DeviceContext> {
310 self.contexts.get(device)
311 }
312
313 /// Get the device context mutably.
314 pub fn context_mut(&mut self, device: &DeviceSpec) -> Option<&mut DeviceContext> {
315 self.contexts.get_mut(device)
316 }
317
318 /// Determine the synchronization strategy between two devices.
319 pub fn sync_strategy(from: &DeviceSpec, to: &DeviceSpec) -> SyncStrategy {
320 if from == to {
321 SyncStrategy::None
322 } else if std::mem::discriminant(from) == std::mem::discriminant(to) {
323 // Same device type (e.g., both CUDA)
324 SyncStrategy::PeerToPeer
325 } else {
326 // Different device types
327 SyncStrategy::CpuMediated
328 }
329 }
330
331 /// Check if all operations on a single device.
332 ///
333 /// Returns `Some(device)` if all buffers are on the same device,
334 /// enabling the fast single-device path.
335 pub fn single_device_check(&self, buffers: &[&Buffer]) -> Option<DeviceSpec> {
336 if buffers.is_empty() {
337 return None;
338 }
339
340 let first_device = buffers[0].allocator().device_spec();
341
342 for buffer in buffers.iter().skip(1) {
343 if buffer.allocator().device_spec() != first_device {
344 return None;
345 }
346 }
347
348 Some(first_device)
349 }
350
351 /// Synchronize all devices.
352 ///
353 /// Waits for all pending operations to complete on all devices.
354 pub fn synchronize_all(&self) -> Result<()> {
355 for ctx in self.contexts.values() {
356 let current = ctx.current_timeline();
357 if current > 0 {
358 ctx.wait_for(current)?;
359 }
360 }
361 Ok(())
362 }
363
364 /// Execute a kernel (sequential execution).
365 ///
366 /// ExecutionPlan pre-computes kernel order at schedule time, so no runtime
367 /// dependency tracking is needed. This follows Tinygrad's stateless execution model.
368 ///
369 /// # Arguments
370 ///
371 /// * `device` - Device to execute on
372 /// * `execute_fn` - Function that performs the actual kernel execution
373 ///
374 /// # Returns
375 ///
376 /// The timeline value for this execution (can be used for cross-device sync).
377 pub fn execute_kernel<F>(&mut self, device: &DeviceSpec, execute_fn: F) -> Result<u64>
378 where
379 F: FnOnce() -> Result<()>,
380 {
381 // 1. Ensure device context exists
382 if !self.contexts.contains_key(device) {
383 self.add_device(device.clone())?;
384 }
385
386 // 2. Get next timeline value for this execution
387 let timeline = self.contexts.get(device).unwrap().next_timeline();
388
389 // 3. Execute the kernel
390 execute_fn()?;
391
392 // 4. Signal completion (for cross-device synchronization)
393 if let Some(ctx) = self.contexts.get(device) {
394 ctx.signal_completion(timeline);
395 }
396
397 Ok(timeline)
398 }
399
400 /// Execute a buffer transfer (COPY operation).
401 ///
402 /// Handles cross-device transfers with appropriate synchronization:
403 /// - Same device: Direct copy using device's copy queue
404 /// - Same vendor (e.g., CUDA:0 → CUDA:1): Peer-to-peer transfer
405 /// - Different vendors (e.g., CUDA → CPU): Stage through host memory
406 ///
407 /// # Arguments
408 ///
409 /// * `src` - Source buffer
410 /// * `dst` - Destination buffer (must be pre-allocated)
411 /// * `src_device` - Device the source buffer is on
412 /// * `dst_device` - Device the destination buffer is on
413 ///
414 /// # Returns
415 ///
416 /// The timeline value for this transfer operation.
417 pub fn execute_transfer(
418 &mut self,
419 src: &Buffer,
420 dst: &mut Buffer,
421 src_device: &DeviceSpec,
422 dst_device: &DeviceSpec,
423 ) -> Result<u64> {
424 // Ensure both device contexts exist
425 if !self.contexts.contains_key(src_device) {
426 self.add_device(src_device.clone())?;
427 }
428 if !self.contexts.contains_key(dst_device) {
429 self.add_device(dst_device.clone())?;
430 }
431
432 // Get timeline for destination device (where the result will be used)
433 let timeline = self.contexts.get(dst_device).unwrap().next_timeline();
434
435 // Perform the transfer based on sync strategy
436 match Self::sync_strategy(src_device, dst_device) {
437 SyncStrategy::None => {
438 // Same device - direct copy
439 dst.copy_from(src).context(crate::error::DeviceSnafu)?;
440 }
441 SyncStrategy::PeerToPeer => {
442 // Same vendor (e.g., both CUDA) - use peer-to-peer if available
443 // For now, fall back to copy_from which handles this
444 dst.copy_from(src).context(crate::error::DeviceSnafu)?;
445 }
446 SyncStrategy::CpuMediated => {
447 // Different vendors - stage through CPU
448 // First, wait for source device operations to complete
449 if let Some(src_ctx) = self.contexts.get(src_device) {
450 let src_timeline = src_ctx.current_timeline();
451 if src_timeline > 0 {
452 src_ctx.wait_for(src_timeline)?;
453 }
454 }
455
456 // copy_from handles the staging internally for cross-device copies
457 dst.copy_from(src).context(crate::error::DeviceSnafu)?;
458
459 // Wait for destination device operations if needed
460 if let Some(dst_ctx) = self.contexts.get(dst_device) {
461 let dst_timeline = dst_ctx.current_timeline();
462 if dst_timeline > 0 {
463 dst_ctx.wait_for(dst_timeline)?;
464 }
465 }
466 }
467 }
468
469 // Signal completion (for cross-device synchronization)
470 if let Some(ctx) = self.contexts.get(dst_device) {
471 ctx.signal_completion(timeline);
472 }
473
474 Ok(timeline)
475 }
476}
477
478/// Global executor instance.
479///
480/// For most use cases, a single global executor is sufficient.
481/// Thread-safety is handled by timeline signals and dependency tracking.
482static EXECUTOR: once_cell::sync::Lazy<parking_lot::Mutex<UnifiedExecutor>> =
483 once_cell::sync::Lazy::new(|| parking_lot::Mutex::new(UnifiedExecutor::new(svod_device::registry::registry())));
484
485/// Get access to the global executor.
486pub fn global_executor() -> parking_lot::MutexGuard<'static, UnifiedExecutor> {
487 EXECUTOR.lock()
488}
489
490#[cfg(test)]
491#[path = "test/unit/executor.rs"]
492mod tests;