morok_device/queue.rs
1//! Hardware command queue abstraction for parallel execution.
2//!
3//! This module provides a device-agnostic interface for command queues,
4//! abstracting over CUDA streams, Metal command buffers, CPU task queues, etc.
5//!
6//! # Design
7//!
8//! The `HardwareQueue` trait uses a builder pattern for chaining operations:
9//!
10//! ```ignore
11//! queue
12//! .wait(&signal, 1) // Wait for dependency
13//! .exec(&kernel, &bufs) // Execute kernel
14//! .signal(&signal, 2) // Signal completion
15//! .submit()?; // Submit to hardware
16//! ```
17//!
18//! # Queue Types
19//!
20//! Most devices have two queue types:
21//! - **Compute queue**: For kernel execution
22//! - **Copy queue**: For DMA transfers (optional, some devices share)
23
24use std::sync::Arc;
25
26use morok_dtype::DeviceSpec;
27
28use crate::buffer::Buffer;
29use crate::error::Result;
30use crate::sync::TimelineSignal;
31
32/// Kernel execution parameters.
33#[derive(Debug, Clone)]
34pub struct ExecParams {
35 /// Global work size (total number of work items per dimension).
36 pub global_size: [usize; 3],
37 /// Local work size (work group size per dimension).
38 pub local_size: [usize; 3],
39}
40
41impl ExecParams {
42 /// Create 1D execution parameters.
43 pub fn new_1d(global: usize, local: usize) -> Self {
44 Self { global_size: [global, 1, 1], local_size: [local, 1, 1] }
45 }
46
47 /// Create 2D execution parameters.
48 pub fn new_2d(global: [usize; 2], local: [usize; 2]) -> Self {
49 Self { global_size: [global[0], global[1], 1], local_size: [local[0], local[1], 1] }
50 }
51
52 /// Create 3D execution parameters.
53 pub fn new_3d(global: [usize; 3], local: [usize; 3]) -> Self {
54 Self { global_size: global, local_size: local }
55 }
56}
57
58impl Default for ExecParams {
59 fn default() -> Self {
60 Self { global_size: [1, 1, 1], local_size: [1, 1, 1] }
61 }
62}
63
64/// Compiled program that can be executed on a queue.
65///
66/// This is a thin wrapper around device-specific program handles
67/// (JIT function pointers, CUDA modules, etc.).
68pub trait Program: Send + Sync + std::fmt::Debug {
69 /// Get the device this program is compiled for.
70 fn device(&self) -> &DeviceSpec;
71
72 /// Get the program name (for debugging).
73 fn name(&self) -> &str;
74}
75
76/// Hardware command queue for submitting operations to a device.
77///
78/// Queues batch operations and submit them to hardware atomically.
79/// All operations are non-blocking until `submit()` is called.
80///
81/// # Thread Safety
82///
83/// Queues are `Send` but not necessarily `Sync`. Each queue should be
84/// owned by a single thread/task at a time.
85pub trait HardwareQueue: Send + std::fmt::Debug {
86 /// The timeline signal type used by this queue.
87 type Signal: TimelineSignal;
88
89 /// Wait for a signal to reach a value before executing subsequent operations.
90 ///
91 /// This creates a dependency: operations after this call won't start
92 /// until the signal reaches `value`.
93 fn wait(&mut self, signal: &Self::Signal, value: u64) -> &mut Self;
94
95 /// Signal a value after all previous operations complete.
96 ///
97 /// Operations submitted after this call may start before the signal is set.
98 fn signal(&mut self, signal: &Self::Signal, value: u64) -> &mut Self;
99
100 /// Execute a compiled program with the given buffers and parameters.
101 ///
102 /// # Arguments
103 ///
104 /// * `program` - The compiled program to execute
105 /// * `buffers` - Buffer arguments (raw pointers extracted internally)
106 /// * `params` - Execution parameters (grid size, etc.)
107 ///
108 /// # Safety
109 ///
110 /// Caller must ensure:
111 /// - All buffers are allocated
112 /// - No conflicting buffer accesses (handled by executor)
113 fn exec(&mut self, program: &dyn Program, buffers: &[&Buffer], params: &ExecParams) -> &mut Self;
114
115 /// Copy data between buffers.
116 ///
117 /// Both buffers must be accessible from this queue's device.
118 /// For cross-device copies, use the executor's transfer mechanism.
119 fn copy(&mut self, dst: &Buffer, src: &Buffer) -> &mut Self;
120
121 /// Insert a memory barrier.
122 ///
123 /// Ensures all previous memory operations are visible to subsequent operations.
124 /// Mostly needed for CPU and some GPU memory models.
125 fn memory_barrier(&mut self) -> &mut Self;
126
127 /// Submit all batched operations to the hardware.
128 ///
129 /// This is the only blocking point - it submits work but doesn't wait
130 /// for completion. Use signals to synchronize.
131 fn submit(&mut self) -> Result<()>;
132
133 /// Get the device this queue belongs to.
134 fn device(&self) -> &DeviceSpec;
135}
136
137/// Factory for creating hardware queues.
138///
139/// Each device implementation provides a factory that creates queues
140/// for that device type.
141pub trait QueueFactory: Send + Sync + std::fmt::Debug {
142 /// The queue type produced by this factory.
143 type Queue: HardwareQueue;
144
145 /// The signal type used by queues from this factory.
146 type Signal: TimelineSignal;
147
148 /// Create a new compute queue.
149 fn create_compute_queue(&self) -> Result<Self::Queue>;
150
151 /// Create a new copy/DMA queue if supported.
152 ///
153 /// Returns `None` if the device doesn't support separate copy queues.
154 fn create_copy_queue(&self) -> Result<Option<Self::Queue>>;
155
156 /// Create a new timeline signal.
157 fn create_signal(&self) -> Result<Arc<Self::Signal>>;
158
159 /// Get the device specification.
160 fn device(&self) -> &DeviceSpec;
161}
162
163/// Type-erased queue for use in the unified executor.
164///
165/// This wraps a concrete `HardwareQueue` implementation and provides
166/// a common interface that doesn't require knowing the signal type.
167pub struct DynQueue {
168 inner: Box<dyn DynQueueInner>,
169}
170
171impl std::fmt::Debug for DynQueue {
172 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173 f.debug_struct("DynQueue").field("device", &self.inner.device()).finish()
174 }
175}
176
177impl DynQueue {
178 /// Create a new type-erased queue from a concrete implementation.
179 pub fn new<Q: HardwareQueue + 'static>(queue: Q) -> Self
180 where
181 Q::Signal: 'static,
182 {
183 Self { inner: Box::new(DynQueueWrapper { queue, _phantom: std::marker::PhantomData }) }
184 }
185
186 /// Wait for a type-erased signal.
187 pub fn wait(&mut self, signal: &dyn TimelineSignal, value: u64) -> &mut Self {
188 self.inner.wait_dyn(signal, value);
189 self
190 }
191
192 /// Signal completion.
193 pub fn signal(&mut self, signal: &dyn TimelineSignal, value: u64) -> &mut Self {
194 self.inner.signal_dyn(signal, value);
195 self
196 }
197
198 /// Execute a program.
199 pub fn exec(&mut self, program: &dyn Program, buffers: &[&Buffer], params: &ExecParams) -> &mut Self {
200 self.inner.exec_dyn(program, buffers, params);
201 self
202 }
203
204 /// Copy between buffers.
205 pub fn copy(&mut self, dst: &Buffer, src: &Buffer) -> &mut Self {
206 self.inner.copy_dyn(dst, src);
207 self
208 }
209
210 /// Insert memory barrier.
211 pub fn memory_barrier(&mut self) -> &mut Self {
212 self.inner.memory_barrier_dyn();
213 self
214 }
215
216 /// Submit to hardware.
217 pub fn submit(&mut self) -> Result<()> {
218 self.inner.submit_dyn()
219 }
220
221 /// Get the device.
222 pub fn device(&self) -> &DeviceSpec {
223 self.inner.device()
224 }
225}
226
227/// Internal trait for type erasure.
228trait DynQueueInner: Send + std::fmt::Debug {
229 fn wait_dyn(&mut self, signal: &dyn TimelineSignal, value: u64);
230 fn signal_dyn(&mut self, signal: &dyn TimelineSignal, value: u64);
231 fn exec_dyn(&mut self, program: &dyn Program, buffers: &[&Buffer], params: &ExecParams);
232 fn copy_dyn(&mut self, dst: &Buffer, src: &Buffer);
233 fn memory_barrier_dyn(&mut self);
234 fn submit_dyn(&mut self) -> Result<()>;
235 fn device(&self) -> &DeviceSpec;
236}
237
238/// Wrapper for concrete queue types.
239struct DynQueueWrapper<Q: HardwareQueue> {
240 queue: Q,
241 _phantom: std::marker::PhantomData<Q::Signal>,
242}
243
244impl<Q: HardwareQueue> std::fmt::Debug for DynQueueWrapper<Q> {
245 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
246 f.debug_struct("DynQueueWrapper").field("queue", &self.queue).finish()
247 }
248}
249
250impl<Q: HardwareQueue + 'static> DynQueueInner for DynQueueWrapper<Q>
251where
252 Q::Signal: 'static,
253{
254 fn wait_dyn(&mut self, _signal: &dyn TimelineSignal, _value: u64) {
255 // Note: In a full implementation, we'd need to downcast the signal
256 // to the concrete type. For now, this is a placeholder that demonstrates
257 // the interface. The real implementation will use concrete types
258 // in the executor where the signal type is known.
259 //
260 // The type-erased DynQueue is mainly for heterogeneous collections;
261 // most code paths will use concrete types directly.
262 }
263
264 fn signal_dyn(&mut self, _signal: &dyn TimelineSignal, _value: u64) {
265 // See wait_dyn comment
266 }
267
268 fn exec_dyn(&mut self, program: &dyn Program, buffers: &[&Buffer], params: &ExecParams) {
269 self.queue.exec(program, buffers, params);
270 }
271
272 fn copy_dyn(&mut self, dst: &Buffer, src: &Buffer) {
273 self.queue.copy(dst, src);
274 }
275
276 fn memory_barrier_dyn(&mut self) {
277 self.queue.memory_barrier();
278 }
279
280 fn submit_dyn(&mut self) -> Result<()> {
281 self.queue.submit()
282 }
283
284 fn device(&self) -> &DeviceSpec {
285 self.queue.device()
286 }
287}