Skip to main content

svod_runtime/
executor.rs

1//! Unified parallel execution for heterogeneous devices.
2//!
3//! The `UnifiedExecutor` handles kernel execution across any mix of devices
4//! (CPU, CUDA, Metal, etc.) with proper synchronization and dependency tracking.
5//!
6//! # Design Principles
7//!
8//! 1. **Single abstraction** - One executor handles any device mix
9//! 2. **Device-agnostic sync** - Timeline signals abstract over device-specific primitives
10//! 3. **Zero overhead for single-device** - Fast path skips synchronization when possible
11//! 4. **Buffer dependency tracking** - Following Tinygrad's `_access_resources()` pattern
12//!
13//! # Example
14//!
15//! ```ignore
16//! let mut executor = UnifiedExecutor::new();
17//! executor.add_device(DeviceSpec::Cpu)?;
18//!
19//! // Execute schedule - handles dependencies automatically
20//! let output_id = executor.execute(&schedule)?;
21//! ```
22//!
23//! # Execution Graph
24//!
25//! For complex schedules with multiple devices, the executor builds an execution
26//! graph (DAG) where nodes are kernel operations and edges are buffer dependencies.
27//! Independent kernels on the same device can be batched, and kernels on different
28//! devices can run in parallel (with appropriate synchronization).
29
30use std::collections::{HashMap, HashSet};
31use std::sync::Arc;
32use std::sync::atomic::{AtomicU64, Ordering};
33
34use snafu::ResultExt;
35use svod_device::device::Device;
36use svod_device::registry::DeviceRegistry;
37use svod_device::{Allocator, Buffer, BufferId, CpuTimelineSignal, TimelineSignal};
38use svod_dtype::DeviceSpec;
39
40use crate::error::Result;
41
42/// Per-device execution context.
43///
44/// Each device has its own timeline signal, queue, and allocator.
45/// This enables parallel execution across devices with proper synchronization.
46pub struct DeviceContext {
47    /// Device specification (CPU, CUDA:0, etc.).
48    pub device: DeviceSpec,
49    /// Device abstraction for rendering/compiling/executing.
50    pub device_handle: Arc<Device>,
51    /// Timeline signal for this device's operations.
52    pub signal: Arc<dyn TimelineSignal>,
53    /// Current timeline value (monotonically increasing).
54    pub timeline: AtomicU64,
55    /// Allocator for this device.
56    pub allocator: Arc<dyn Allocator>,
57}
58
59impl std::fmt::Debug for DeviceContext {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        f.debug_struct("DeviceContext")
62            .field("device", &self.device)
63            .field("timeline", &self.timeline.load(Ordering::Relaxed))
64            .finish()
65    }
66}
67
68impl DeviceContext {
69    /// Create a new device context.
70    pub fn new(device: Arc<Device>, signal: Arc<dyn TimelineSignal>) -> Self {
71        let allocator = device.allocator.clone();
72        let device_spec = device.device.clone();
73        Self { device: device_spec, device_handle: device, signal, timeline: AtomicU64::new(0), allocator }
74    }
75
76    /// Get the next timeline value and increment.
77    pub fn next_timeline(&self) -> u64 {
78        self.timeline.fetch_add(1, Ordering::Relaxed) + 1
79    }
80
81    /// Get the current timeline value.
82    pub fn current_timeline(&self) -> u64 {
83        self.timeline.load(Ordering::Relaxed)
84    }
85
86    /// Signal that operations up to the given timeline value are complete.
87    pub fn signal_completion(&self, value: u64) {
88        self.signal.set(value);
89    }
90
91    /// Wait for operations up to the given timeline value to complete.
92    pub fn wait_for(&self, value: u64) -> Result<()> {
93        self.signal.wait(value, 0).context(crate::error::DeviceSnafu)?;
94        Ok(())
95    }
96}
97
98/// Cross-device synchronization strategy.
99#[derive(Debug, Clone, Copy, PartialEq, Eq)]
100pub enum SyncStrategy {
101    /// Same device - no synchronization needed (operations are ordered).
102    None,
103    /// Same device type, different instance (e.g., CUDA:0 → CUDA:1).
104    /// Use peer-to-peer events if available.
105    PeerToPeer,
106    /// Different device types (e.g., CUDA → CPU).
107    /// Use CPU-mediated polling.
108    CpuMediated,
109}
110
111/// A node in the execution graph representing a kernel or transfer operation.
112#[derive(Debug, Clone)]
113pub struct ExecutionNode {
114    /// Unique identifier for this node (typically the kernel AST ID).
115    pub id: u64,
116    /// Device this operation executes on.
117    pub device: DeviceSpec,
118    /// Buffer IDs read by this operation.
119    pub inputs: Vec<BufferId>,
120    /// Buffer IDs written by this operation.
121    pub outputs: Vec<BufferId>,
122    /// IDs of nodes that must complete before this one (dependencies).
123    pub predecessors: Vec<u64>,
124    /// Whether this is a data transfer (COPY) or a computational kernel.
125    pub is_transfer: bool,
126    /// Buffer access information for parallel execution.
127    /// Contains the full buffer list and output indices for dependency tracking.
128    pub buffer_access: Option<KernelBufferAccess>,
129}
130
131/// Buffer access information for parallel kernel execution.
132///
133/// This struct captures which buffers a kernel accesses and which are outputs,
134/// enabling precise dependency tracking in `execute_parallel_group`.
135#[derive(Debug, Clone)]
136pub struct KernelBufferAccess {
137    /// All buffer IDs accessed by this kernel (inputs and outputs).
138    pub buffers: Vec<BufferId>,
139    /// Indices into `buffers` that are outputs (written by the kernel).
140    /// Other indices are inputs (read-only).
141    pub output_indices: Vec<usize>,
142}
143
144/// Execution graph representing a DAG of kernel operations.
145///
146/// The graph is built from a schedule and captures buffer dependencies
147/// between kernels. Independent kernels can be executed in parallel.
148#[derive(Debug, Default)]
149pub struct ExecutionGraph {
150    /// Nodes in the graph, indexed by ID.
151    nodes: HashMap<u64, ExecutionNode>,
152    /// Execution order (topologically sorted node IDs).
153    execution_order: Vec<u64>,
154    /// Nodes grouped by device for batched execution.
155    device_groups: HashMap<DeviceSpec, Vec<u64>>,
156}
157
158impl ExecutionGraph {
159    /// Create a new empty execution graph.
160    pub fn new() -> Self {
161        Self::default()
162    }
163
164    /// Add a node to the graph.
165    pub fn add_node(&mut self, node: ExecutionNode) {
166        let id = node.id;
167        let device = node.device.clone();
168        self.nodes.insert(id, node);
169        self.device_groups.entry(device).or_default().push(id);
170    }
171
172    /// Get a node by ID.
173    pub fn node(&self, id: u64) -> Option<&ExecutionNode> {
174        self.nodes.get(&id)
175    }
176
177    /// Get all nodes.
178    pub fn nodes(&self) -> impl Iterator<Item = &ExecutionNode> {
179        self.nodes.values()
180    }
181
182    /// Compute topological order and find parallelizable groups.
183    ///
184    /// Returns groups of nodes that can be executed in parallel.
185    /// Each group contains nodes with no dependencies on each other.
186    pub fn compute_parallel_groups(&mut self) -> Vec<Vec<u64>> {
187        // Build in-degree map. Predecessor lists are deduplicated via
188        // sort-and-dedup on a local SmallVec so a node that lists the same
189        // predecessor twice doesn't decrement in-degree more than once during
190        // Kahn's traversal. Successor sets remain HashSet to keep insertion
191        // O(1) when many nodes share predecessors.
192        let mut in_degree: HashMap<u64, usize> = HashMap::new();
193        let mut successors: HashMap<u64, HashSet<u64>> = HashMap::new();
194
195        for node in self.nodes.values() {
196            in_degree.entry(node.id).or_insert(0);
197            let mut preds: smallvec::SmallVec<[u64; 8]> = node.predecessors.iter().copied().collect();
198            preds.sort_unstable();
199            preds.dedup();
200            for &pred in &preds {
201                successors.entry(pred).or_default().insert(node.id);
202                *in_degree.entry(node.id).or_insert(0) += 1;
203            }
204        }
205
206        // Kahn's algorithm with level grouping
207        let mut groups = Vec::new();
208        let mut ready: Vec<u64> = in_degree.iter().filter(|&(_, deg)| *deg == 0).map(|(&id, _)| id).collect();
209
210        while !ready.is_empty() {
211            // All nodes in ready can be executed in parallel
212            groups.push(ready.clone());
213
214            // Record execution order
215            self.execution_order.extend(ready.iter().copied());
216
217            // Find next batch
218            let mut next_ready = Vec::new();
219            for id in ready {
220                if let Some(succs) = successors.get(&id) {
221                    for &succ in succs {
222                        let deg = in_degree.get_mut(&succ).unwrap();
223                        *deg -= 1;
224                        if *deg == 0 {
225                            next_ready.push(succ);
226                        }
227                    }
228                }
229            }
230            ready = next_ready;
231        }
232
233        groups
234    }
235
236    /// Get nodes grouped by device.
237    pub fn device_groups(&self) -> &HashMap<DeviceSpec, Vec<u64>> {
238        &self.device_groups
239    }
240
241    /// Check if all nodes have been visited (no cycles).
242    pub fn is_valid(&self) -> bool {
243        self.execution_order.len() == self.nodes.len()
244    }
245}
246
247/// Unified executor for heterogeneous device execution.
248///
249/// Manages device contexts to enable parallel execution across any mix of devices.
250/// Uses timeline signals for cross-device synchronization.
251///
252/// # Stateless Execution Model (Tinygrad-Aligned)
253///
254/// The executor follows Tinygrad's stateless execution model where:
255/// - Dependencies are computed at schedule time, not runtime
256/// - ExecutionPlan pre-computes kernel order via topological sort
257/// - No runtime dependency tracking is needed (zero memory accumulation)
258/// - Timeline signals handle cross-device synchronization only
259pub struct UnifiedExecutor {
260    /// Per-device execution contexts.
261    contexts: HashMap<DeviceSpec, DeviceContext>,
262
263    /// Device registry for looking up allocators.
264    registry: &'static DeviceRegistry,
265}
266
267impl std::fmt::Debug for UnifiedExecutor {
268    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
269        f.debug_struct("UnifiedExecutor").field("contexts", &self.contexts.keys().collect::<Vec<_>>()).finish()
270    }
271}
272
273impl UnifiedExecutor {
274    /// Create a new unified executor.
275    pub fn new(registry: &'static DeviceRegistry) -> Self {
276        Self { contexts: HashMap::new(), registry }
277    }
278
279    /// Add a device to the executor.
280    ///
281    /// Creates the device context with timeline signal and queues.
282    pub fn add_device(&mut self, device_spec: DeviceSpec) -> Result<()> {
283        if self.contexts.contains_key(&device_spec) {
284            return Ok(()); // Already added
285        }
286
287        // Create device handle
288        let device = crate::DEVICE_FACTORIES.device(&device_spec, self.registry)?;
289
290        // Create timeline signal based on device type
291        let signal: Arc<dyn TimelineSignal> = match &device_spec {
292            DeviceSpec::Cpu => Arc::new(CpuTimelineSignal::new()),
293            #[cfg(feature = "cuda")]
294            DeviceSpec::Cuda { .. } => {
295                // TODO: Create CUDA timeline signal with events
296                // For now, fall back to CPU signal (works, but less efficient)
297                Arc::new(CpuTimelineSignal::new())
298            }
299            _ => Arc::new(CpuTimelineSignal::new()),
300        };
301
302        let ctx = DeviceContext::new(device, signal);
303        self.contexts.insert(device_spec, ctx);
304
305        Ok(())
306    }
307
308    /// Get the device context for a device specification.
309    pub fn context(&self, device: &DeviceSpec) -> Option<&DeviceContext> {
310        self.contexts.get(device)
311    }
312
313    /// Get the device context mutably.
314    pub fn context_mut(&mut self, device: &DeviceSpec) -> Option<&mut DeviceContext> {
315        self.contexts.get_mut(device)
316    }
317
318    /// Determine the synchronization strategy between two devices.
319    pub fn sync_strategy(from: &DeviceSpec, to: &DeviceSpec) -> SyncStrategy {
320        if from == to {
321            SyncStrategy::None
322        } else if std::mem::discriminant(from) == std::mem::discriminant(to) {
323            // Same device type (e.g., both CUDA)
324            SyncStrategy::PeerToPeer
325        } else {
326            // Different device types
327            SyncStrategy::CpuMediated
328        }
329    }
330
331    /// Check if all operations on a single device.
332    ///
333    /// Returns `Some(device)` if all buffers are on the same device,
334    /// enabling the fast single-device path.
335    pub fn single_device_check(&self, buffers: &[&Buffer]) -> Option<DeviceSpec> {
336        if buffers.is_empty() {
337            return None;
338        }
339
340        let first_device = buffers[0].allocator().device_spec();
341
342        for buffer in buffers.iter().skip(1) {
343            if buffer.allocator().device_spec() != first_device {
344                return None;
345            }
346        }
347
348        Some(first_device)
349    }
350
351    /// Synchronize all devices.
352    ///
353    /// Waits for all pending operations to complete on all devices.
354    pub fn synchronize_all(&self) -> Result<()> {
355        for ctx in self.contexts.values() {
356            let current = ctx.current_timeline();
357            if current > 0 {
358                ctx.wait_for(current)?;
359            }
360        }
361        Ok(())
362    }
363
364    /// Execute a kernel (sequential execution).
365    ///
366    /// ExecutionPlan pre-computes kernel order at schedule time, so no runtime
367    /// dependency tracking is needed. This follows Tinygrad's stateless execution model.
368    ///
369    /// # Arguments
370    ///
371    /// * `device` - Device to execute on
372    /// * `execute_fn` - Function that performs the actual kernel execution
373    ///
374    /// # Returns
375    ///
376    /// The timeline value for this execution (can be used for cross-device sync).
377    pub fn execute_kernel<F>(&mut self, device: &DeviceSpec, execute_fn: F) -> Result<u64>
378    where
379        F: FnOnce() -> Result<()>,
380    {
381        // 1. Ensure device context exists
382        if !self.contexts.contains_key(device) {
383            self.add_device(device.clone())?;
384        }
385
386        // 2. Get next timeline value for this execution
387        let timeline = self.contexts.get(device).unwrap().next_timeline();
388
389        // 3. Execute the kernel
390        execute_fn()?;
391
392        // 4. Signal completion (for cross-device synchronization)
393        if let Some(ctx) = self.contexts.get(device) {
394            ctx.signal_completion(timeline);
395        }
396
397        Ok(timeline)
398    }
399
400    /// Execute a buffer transfer (COPY operation).
401    ///
402    /// Handles cross-device transfers with appropriate synchronization:
403    /// - Same device: Direct copy using device's copy queue
404    /// - Same vendor (e.g., CUDA:0 → CUDA:1): Peer-to-peer transfer
405    /// - Different vendors (e.g., CUDA → CPU): Stage through host memory
406    ///
407    /// # Arguments
408    ///
409    /// * `src` - Source buffer
410    /// * `dst` - Destination buffer (must be pre-allocated)
411    /// * `src_device` - Device the source buffer is on
412    /// * `dst_device` - Device the destination buffer is on
413    ///
414    /// # Returns
415    ///
416    /// The timeline value for this transfer operation.
417    pub fn execute_transfer(
418        &mut self,
419        src: &Buffer,
420        dst: &mut Buffer,
421        src_device: &DeviceSpec,
422        dst_device: &DeviceSpec,
423    ) -> Result<u64> {
424        // Ensure both device contexts exist
425        if !self.contexts.contains_key(src_device) {
426            self.add_device(src_device.clone())?;
427        }
428        if !self.contexts.contains_key(dst_device) {
429            self.add_device(dst_device.clone())?;
430        }
431
432        // Get timeline for destination device (where the result will be used)
433        let timeline = self.contexts.get(dst_device).unwrap().next_timeline();
434
435        // Perform the transfer based on sync strategy
436        match Self::sync_strategy(src_device, dst_device) {
437            SyncStrategy::None => {
438                // Same device - direct copy
439                dst.copy_from(src).context(crate::error::DeviceSnafu)?;
440            }
441            SyncStrategy::PeerToPeer => {
442                // Same vendor (e.g., both CUDA) - use peer-to-peer if available
443                // For now, fall back to copy_from which handles this
444                dst.copy_from(src).context(crate::error::DeviceSnafu)?;
445            }
446            SyncStrategy::CpuMediated => {
447                // Different vendors - stage through CPU
448                // First, wait for source device operations to complete
449                if let Some(src_ctx) = self.contexts.get(src_device) {
450                    let src_timeline = src_ctx.current_timeline();
451                    if src_timeline > 0 {
452                        src_ctx.wait_for(src_timeline)?;
453                    }
454                }
455
456                // copy_from handles the staging internally for cross-device copies
457                dst.copy_from(src).context(crate::error::DeviceSnafu)?;
458
459                // Wait for destination device operations if needed
460                if let Some(dst_ctx) = self.contexts.get(dst_device) {
461                    let dst_timeline = dst_ctx.current_timeline();
462                    if dst_timeline > 0 {
463                        dst_ctx.wait_for(dst_timeline)?;
464                    }
465                }
466            }
467        }
468
469        // Signal completion (for cross-device synchronization)
470        if let Some(ctx) = self.contexts.get(dst_device) {
471            ctx.signal_completion(timeline);
472        }
473
474        Ok(timeline)
475    }
476}
477
478/// Global executor instance.
479///
480/// For most use cases, a single global executor is sufficient.
481/// Thread-safety is handled by timeline signals and dependency tracking.
482static EXECUTOR: once_cell::sync::Lazy<parking_lot::Mutex<UnifiedExecutor>> =
483    once_cell::sync::Lazy::new(|| parking_lot::Mutex::new(UnifiedExecutor::new(svod_device::registry::registry())));
484
485/// Get access to the global executor.
486pub fn global_executor() -> parking_lot::MutexGuard<'static, UnifiedExecutor> {
487    EXECUTOR.lock()
488}
489
490#[cfg(test)]
491#[path = "test/unit/executor.rs"]
492mod tests;