Skip to main content

morok_device/
queue.rs

1//! Hardware command queue abstraction for parallel execution.
2//!
3//! This module provides a device-agnostic interface for command queues,
4//! abstracting over CUDA streams, Metal command buffers, CPU task queues, etc.
5//!
6//! # Design
7//!
8//! The `HardwareQueue` trait uses a builder pattern for chaining operations:
9//!
10//! ```ignore
11//! queue
12//!     .wait(&signal, 1)       // Wait for dependency
13//!     .exec(&kernel, &bufs)   // Execute kernel
14//!     .signal(&signal, 2)     // Signal completion
15//!     .submit()?;             // Submit to hardware
16//! ```
17//!
18//! # Queue Types
19//!
20//! Most devices have two queue types:
21//! - **Compute queue**: For kernel execution
22//! - **Copy queue**: For DMA transfers (optional, some devices share)
23
24use std::sync::Arc;
25
26use morok_dtype::DeviceSpec;
27
28use crate::buffer::Buffer;
29use crate::error::Result;
30use crate::sync::TimelineSignal;
31
32/// Kernel execution parameters.
33#[derive(Debug, Clone)]
34pub struct ExecParams {
35    /// Global work size (total number of work items per dimension).
36    pub global_size: [usize; 3],
37    /// Local work size (work group size per dimension).
38    pub local_size: [usize; 3],
39}
40
41impl ExecParams {
42    /// Create 1D execution parameters.
43    pub fn new_1d(global: usize, local: usize) -> Self {
44        Self { global_size: [global, 1, 1], local_size: [local, 1, 1] }
45    }
46
47    /// Create 2D execution parameters.
48    pub fn new_2d(global: [usize; 2], local: [usize; 2]) -> Self {
49        Self { global_size: [global[0], global[1], 1], local_size: [local[0], local[1], 1] }
50    }
51
52    /// Create 3D execution parameters.
53    pub fn new_3d(global: [usize; 3], local: [usize; 3]) -> Self {
54        Self { global_size: global, local_size: local }
55    }
56}
57
58impl Default for ExecParams {
59    fn default() -> Self {
60        Self { global_size: [1, 1, 1], local_size: [1, 1, 1] }
61    }
62}
63
64/// Compiled program that can be executed on a queue.
65///
66/// This is a thin wrapper around device-specific program handles
67/// (JIT function pointers, CUDA modules, etc.).
68pub trait Program: Send + Sync + std::fmt::Debug {
69    /// Get the device this program is compiled for.
70    fn device(&self) -> &DeviceSpec;
71
72    /// Get the program name (for debugging).
73    fn name(&self) -> &str;
74}
75
76/// Hardware command queue for submitting operations to a device.
77///
78/// Queues batch operations and submit them to hardware atomically.
79/// All operations are non-blocking until `submit()` is called.
80///
81/// # Thread Safety
82///
83/// Queues are `Send` but not necessarily `Sync`. Each queue should be
84/// owned by a single thread/task at a time.
85pub trait HardwareQueue: Send + std::fmt::Debug {
86    /// The timeline signal type used by this queue.
87    type Signal: TimelineSignal;
88
89    /// Wait for a signal to reach a value before executing subsequent operations.
90    ///
91    /// This creates a dependency: operations after this call won't start
92    /// until the signal reaches `value`.
93    fn wait(&mut self, signal: &Self::Signal, value: u64) -> &mut Self;
94
95    /// Signal a value after all previous operations complete.
96    ///
97    /// Operations submitted after this call may start before the signal is set.
98    fn signal(&mut self, signal: &Self::Signal, value: u64) -> &mut Self;
99
100    /// Execute a compiled program with the given buffers and parameters.
101    ///
102    /// # Arguments
103    ///
104    /// * `program` - The compiled program to execute
105    /// * `buffers` - Buffer arguments (raw pointers extracted internally)
106    /// * `params` - Execution parameters (grid size, etc.)
107    ///
108    /// # Safety
109    ///
110    /// Caller must ensure:
111    /// - All buffers are allocated
112    /// - No conflicting buffer accesses (handled by executor)
113    fn exec(&mut self, program: &dyn Program, buffers: &[&Buffer], params: &ExecParams) -> &mut Self;
114
115    /// Copy data between buffers.
116    ///
117    /// Both buffers must be accessible from this queue's device.
118    /// For cross-device copies, use the executor's transfer mechanism.
119    fn copy(&mut self, dst: &Buffer, src: &Buffer) -> &mut Self;
120
121    /// Insert a memory barrier.
122    ///
123    /// Ensures all previous memory operations are visible to subsequent operations.
124    /// Mostly needed for CPU and some GPU memory models.
125    fn memory_barrier(&mut self) -> &mut Self;
126
127    /// Submit all batched operations to the hardware.
128    ///
129    /// This is the only blocking point - it submits work but doesn't wait
130    /// for completion. Use signals to synchronize.
131    fn submit(&mut self) -> Result<()>;
132
133    /// Get the device this queue belongs to.
134    fn device(&self) -> &DeviceSpec;
135}
136
137/// Factory for creating hardware queues.
138///
139/// Each device implementation provides a factory that creates queues
140/// for that device type.
141pub trait QueueFactory: Send + Sync + std::fmt::Debug {
142    /// The queue type produced by this factory.
143    type Queue: HardwareQueue;
144
145    /// The signal type used by queues from this factory.
146    type Signal: TimelineSignal;
147
148    /// Create a new compute queue.
149    fn create_compute_queue(&self) -> Result<Self::Queue>;
150
151    /// Create a new copy/DMA queue if supported.
152    ///
153    /// Returns `None` if the device doesn't support separate copy queues.
154    fn create_copy_queue(&self) -> Result<Option<Self::Queue>>;
155
156    /// Create a new timeline signal.
157    fn create_signal(&self) -> Result<Arc<Self::Signal>>;
158
159    /// Get the device specification.
160    fn device(&self) -> &DeviceSpec;
161}
162
163/// Type-erased queue for use in the unified executor.
164///
165/// This wraps a concrete `HardwareQueue` implementation and provides
166/// a common interface that doesn't require knowing the signal type.
167pub struct DynQueue {
168    inner: Box<dyn DynQueueInner>,
169}
170
171impl std::fmt::Debug for DynQueue {
172    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173        f.debug_struct("DynQueue").field("device", &self.inner.device()).finish()
174    }
175}
176
177impl DynQueue {
178    /// Create a new type-erased queue from a concrete implementation.
179    pub fn new<Q: HardwareQueue + 'static>(queue: Q) -> Self
180    where
181        Q::Signal: 'static,
182    {
183        Self { inner: Box::new(DynQueueWrapper { queue, _phantom: std::marker::PhantomData }) }
184    }
185
186    /// Wait for a type-erased signal.
187    pub fn wait(&mut self, signal: &dyn TimelineSignal, value: u64) -> &mut Self {
188        self.inner.wait_dyn(signal, value);
189        self
190    }
191
192    /// Signal completion.
193    pub fn signal(&mut self, signal: &dyn TimelineSignal, value: u64) -> &mut Self {
194        self.inner.signal_dyn(signal, value);
195        self
196    }
197
198    /// Execute a program.
199    pub fn exec(&mut self, program: &dyn Program, buffers: &[&Buffer], params: &ExecParams) -> &mut Self {
200        self.inner.exec_dyn(program, buffers, params);
201        self
202    }
203
204    /// Copy between buffers.
205    pub fn copy(&mut self, dst: &Buffer, src: &Buffer) -> &mut Self {
206        self.inner.copy_dyn(dst, src);
207        self
208    }
209
210    /// Insert memory barrier.
211    pub fn memory_barrier(&mut self) -> &mut Self {
212        self.inner.memory_barrier_dyn();
213        self
214    }
215
216    /// Submit to hardware.
217    pub fn submit(&mut self) -> Result<()> {
218        self.inner.submit_dyn()
219    }
220
221    /// Get the device.
222    pub fn device(&self) -> &DeviceSpec {
223        self.inner.device()
224    }
225}
226
227/// Internal trait for type erasure.
228trait DynQueueInner: Send + std::fmt::Debug {
229    fn wait_dyn(&mut self, signal: &dyn TimelineSignal, value: u64);
230    fn signal_dyn(&mut self, signal: &dyn TimelineSignal, value: u64);
231    fn exec_dyn(&mut self, program: &dyn Program, buffers: &[&Buffer], params: &ExecParams);
232    fn copy_dyn(&mut self, dst: &Buffer, src: &Buffer);
233    fn memory_barrier_dyn(&mut self);
234    fn submit_dyn(&mut self) -> Result<()>;
235    fn device(&self) -> &DeviceSpec;
236}
237
238/// Wrapper for concrete queue types.
239struct DynQueueWrapper<Q: HardwareQueue> {
240    queue: Q,
241    _phantom: std::marker::PhantomData<Q::Signal>,
242}
243
244impl<Q: HardwareQueue> std::fmt::Debug for DynQueueWrapper<Q> {
245    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
246        f.debug_struct("DynQueueWrapper").field("queue", &self.queue).finish()
247    }
248}
249
250impl<Q: HardwareQueue + 'static> DynQueueInner for DynQueueWrapper<Q>
251where
252    Q::Signal: 'static,
253{
254    fn wait_dyn(&mut self, _signal: &dyn TimelineSignal, _value: u64) {
255        // Note: In a full implementation, we'd need to downcast the signal
256        // to the concrete type. For now, this is a placeholder that demonstrates
257        // the interface. The real implementation will use concrete types
258        // in the executor where the signal type is known.
259        //
260        // The type-erased DynQueue is mainly for heterogeneous collections;
261        // most code paths will use concrete types directly.
262    }
263
264    fn signal_dyn(&mut self, _signal: &dyn TimelineSignal, _value: u64) {
265        // See wait_dyn comment
266    }
267
268    fn exec_dyn(&mut self, program: &dyn Program, buffers: &[&Buffer], params: &ExecParams) {
269        self.queue.exec(program, buffers, params);
270    }
271
272    fn copy_dyn(&mut self, dst: &Buffer, src: &Buffer) {
273        self.queue.copy(dst, src);
274    }
275
276    fn memory_barrier_dyn(&mut self) {
277        self.queue.memory_barrier();
278    }
279
280    fn submit_dyn(&mut self) -> Result<()> {
281        self.queue.submit()
282    }
283
284    fn device(&self) -> &DeviceSpec {
285        self.queue.device()
286    }
287}