Skip to main content

morok_device/
device.rs

1//! Device abstraction following Tinygrad's architecture.
2//!
3//! This module provides a unified Device abstraction that owns:
4//! - **Renderer**: Transforms UOp graphs into source code (ProgramSpec)
5//! - **Compiler**: Transforms source code into executable bytes
6//! - **Runtime**: Creates executable Programs from compiled bytes
7//! - **Allocator**: Manages memory allocation for buffers
8//!
9//! This design allows multiple backends (LLVM, CUDA, Metal, WebGPU) to coexist
10//! and share compiled kernels via the method cache.
11
12use std::sync::Arc;
13
14use morok_dtype::DeviceSpec;
15use morok_ir::UOp;
16
17use crate::allocator::Allocator;
18use crate::error::Result;
19
20/// A compiled, executable kernel program.
21///
22/// This trait abstracts over different backend executors (LLVM JIT, CUDA, Metal, etc.).
23/// Each backend implements this to provide unified execution interface.
24///
25/// Note: This trait does not require Send + Sync because some backends (like LLVM JIT)
26/// use non-thread-safe types. Programs are typically executed on the same thread where
27/// they were compiled, and caching/sharing is handled at a higher level.
28///
29/// # Tinygrad Alignment
30///
31/// This trait follows Tinygrad's `Program` interface where variable values are
32/// passed as a positional tuple/array (`vals`) rather than a named HashMap.
33/// The order matches `var_names` in `CompiledSpec`.
34pub trait Program {
35    /// Execute the kernel with given buffers and variable values.
36    ///
37    /// # Arguments
38    ///
39    /// * `buffers` - Raw pointers to buffer data (input and output buffers)
40    /// * `vals` - Variable values in positional order (matches `var_names` in CompiledSpec)
41    /// * `global_size` - Global work size (for GPU backends, None for CPU)
42    /// * `local_size` - Local work size (for GPU backends, None for CPU)
43    ///
44    /// # Safety
45    ///
46    /// This is unsafe because:
47    /// - Buffer pointers must be valid and properly aligned
48    /// - Buffer sizes must match what the kernel expects
49    /// - Caller must ensure no data races during execution
50    unsafe fn execute(
51        &self,
52        buffers: &[*mut u8],
53        vals: &[i64],
54        global_size: Option<[usize; 3]>,
55        local_size: Option<[usize; 3]>,
56    ) -> Result<()>;
57
58    /// Get the kernel name (for debugging/profiling).
59    fn name(&self) -> &str;
60}
61
62/// Compilation result carrying source (JIT) or bytes (AOT).
63///
64/// Different backends need different information:
65/// - LLVM JIT: needs source code to compile during runtime
66/// - CUDA: needs PTX/CUBIN bytes to load
67/// - Metal: needs metallib bytes to load
68///
69/// This design allows the RuntimeFactory to access whatever it needs
70/// without requiring separate code paths for JIT vs AOT backends.
71#[derive(Debug, Clone)]
72pub struct CompiledSpec {
73    /// Entry point function name
74    pub name: String,
75
76    /// Source code (for JIT backends like LLVM)
77    /// Set to Some(...) for LLVM JIT, None for AOT backends
78    pub src: Option<String>,
79
80    /// Compiled bytes (for AOT backends like CUDA/Metal)
81    /// Empty for LLVM JIT, populated for AOT backends
82    pub bytes: Vec<u8>,
83
84    /// Original AST for cache key construction via hash consing
85    pub ast: Arc<UOp>,
86
87    /// Variable names in order for populating vars array at runtime.
88    /// Includes thread_id at the end if threading is enabled.
89    pub var_names: Vec<String>,
90
91    /// Global work size for dispatch (GPU backends, CPU threading)
92    /// For CPU threading: [thread_count, 1, 1]
93    pub global_size: Option<[usize; 3]>,
94
95    /// Local work size for dispatch (GPU backends)
96    pub local_size: Option<[usize; 3]>,
97}
98
99impl CompiledSpec {
100    /// Create a new CompiledSpec for JIT backends (source-based).
101    pub fn from_source(name: String, src: String, ast: Arc<UOp>) -> Self {
102        Self {
103            name,
104            src: Some(src),
105            bytes: Vec::new(),
106            ast,
107            var_names: Vec::new(),
108            global_size: None,
109            local_size: None,
110        }
111    }
112
113    /// Create a new CompiledSpec for AOT backends (bytecode-based).
114    pub fn from_bytes(name: String, bytes: Vec<u8>, ast: Arc<UOp>) -> Self {
115        Self { name, src: None, bytes, ast, var_names: Vec::new(), global_size: None, local_size: None }
116    }
117
118    /// Create a new CompiledSpec with work sizes for JIT backends.
119    pub fn from_source_with_sizes(
120        name: String,
121        src: String,
122        ast: Arc<UOp>,
123        global_size: Option<[usize; 3]>,
124        local_size: Option<[usize; 3]>,
125    ) -> Self {
126        Self { name, src: Some(src), bytes: Vec::new(), ast, var_names: Vec::new(), global_size, local_size }
127    }
128}
129
130/// A compiler that transforms source code into a compiled specification.
131///
132/// This trait abstracts over different compilation backends:
133/// - LLVM: IR validation (JIT compiles at runtime)
134/// - CUDA: CUDA C -> PTX/CUBIN
135/// - Metal: Metal Shading Language -> metallib
136/// - WebGPU: WGSL -> SPIR-V
137pub trait Compiler: Send + Sync {
138    /// Compile a program specification into executable form.
139    ///
140    /// # Arguments
141    ///
142    /// * `spec` - The program specification containing source code and metadata
143    ///
144    /// # Returns
145    ///
146    /// A CompiledSpec containing:
147    /// - For JIT backends (LLVM): source code in `src` field, empty `bytes`
148    /// - For AOT backends (CUDA/Metal): compiled bytes in `bytes` field, no `src`
149    ///
150    /// # Examples
151    ///
152    /// JIT backend (LLVM):
153    /// ```ignore
154    /// let compiled = compiler.compile(&spec)?;
155    /// assert!(compiled.src.is_some());
156    /// assert!(compiled.bytes.is_empty());
157    /// ```
158    ///
159    /// AOT backend (CUDA):
160    /// ```ignore
161    /// let compiled = compiler.compile(&spec)?;
162    /// assert!(compiled.src.is_none());
163    /// assert!(!compiled.bytes.is_empty());
164    /// ```
165    fn compile(&self, spec: &ProgramSpec) -> Result<CompiledSpec>;
166
167    /// Optional cache key for this compiler configuration.
168    ///
169    /// Used to differentiate compiled artifacts when the same device type
170    /// can have multiple compiler configurations (e.g., different optimization levels).
171    ///
172    /// Returns None if all instances of this compiler produce identical output.
173    fn cache_key(&self) -> Option<&str> {
174        None
175    }
176}
177
178/// A renderer that transforms UOp graphs into source code.
179///
180/// This trait abstracts over different code generation backends:
181/// - LLVM IR generator
182/// - CUDA C generator
183/// - Metal Shading Language generator
184/// - WGSL generator
185pub trait Renderer: Send + Sync {
186    /// Render a UOp graph into source code.
187    ///
188    /// # Arguments
189    ///
190    /// * `ast` - The kernel AST (UOp graph rooted at KERNEL op)
191    ///
192    /// # Returns
193    ///
194    /// A ProgramSpec containing:
195    /// - Generated source code
196    /// - Entry point name
197    /// - Variable list
198    /// - Work sizes (for GPU backends)
199    fn render(&self, ast: &Arc<UOp>) -> Result<ProgramSpec>;
200
201    /// Get the device spec for this renderer.
202    ///
203    /// This is used for cache key construction and device selection.
204    fn device(&self) -> &DeviceSpec;
205
206    /// Returns decomposition patterns for operations this backend doesn't support.
207    ///
208    /// This is used by the realization pass to decompose complex operations
209    /// into simpler primitives before rendering.
210    ///
211    /// # Default Implementation
212    ///
213    /// Returns `None`, meaning no decomposition is needed (backend supports all ops).
214    /// Backends that don't support certain operations (e.g., transcendentals)
215    /// should override this to return appropriate patterns.
216    fn decompositor(&self) -> Option<morok_ir::pattern::TypedPatternMatcher<()>> {
217        None
218    }
219}
220
221/// A factory function that creates executable Programs from a compiled specification.
222///
223/// This is a function pointer that wraps the backend-specific loader:
224/// - LLVM: Extract source from CompiledSpec and JIT compile
225/// - CUDA: Extract bytes from CompiledSpec and call cuModuleLoadData + cuModuleGetFunction
226/// - Metal: Extract bytes from CompiledSpec and call newLibraryWithData + newFunctionWithName
227/// - WebGPU: Extract bytes from CompiledSpec and call createShaderModule
228///
229/// The CompiledSpec contains either source (for JIT) or bytes (for AOT),
230/// allowing each backend to access what it needs.
231pub type RuntimeFactory = Arc<dyn Fn(&CompiledSpec) -> Result<Box<dyn Program>> + Send + Sync>;
232
233/// A (Renderer, Compiler) pair for a specific backend.
234///
235/// Devices can have multiple compiler pairs (e.g., different optimization levels).
236pub type CompilerPair = (Arc<dyn Renderer>, Arc<dyn Compiler>);
237
238/// A device that owns renderer, compiler, runtime, and allocator.
239///
240/// This follows Tinygrad's architecture where a Device is a complete
241/// compilation + execution unit for a specific backend.
242///
243/// # Example
244///
245/// ```ignore
246/// let cpu_device = create_cpu_device()?;
247/// let spec = cpu_device.renderer.render(&kernel_ast)?;
248/// let compiled = cpu_device.compiler.compile(&spec)?;
249/// let program = (cpu_device.runtime)(&compiled)?;
250/// unsafe { program.execute(&buffers, &vals, None, None)?; }
251/// ```
252pub struct Device {
253    /// Device specification
254    pub device: DeviceSpec,
255
256    /// Memory allocator for this device
257    pub allocator: Arc<dyn Allocator>,
258
259    /// Available (renderer, compiler) pairs for this device
260    ///
261    /// Most devices have one pair, but some may have multiple
262    /// (e.g., different optimization levels or compilation modes).
263    pub compilers: Vec<CompilerPair>,
264
265    /// Primary renderer for this device
266    ///
267    /// This is typically compilers[0].0, stored separately for convenience.
268    pub renderer: Arc<dyn Renderer>,
269
270    /// Primary compiler for this device
271    ///
272    /// This is typically compilers[0].1, stored separately for convenience.
273    pub compiler: Arc<dyn Compiler>,
274
275    /// Runtime factory for creating executable programs
276    ///
277    /// Takes (entry_point, compiled_bytes) and returns a Program.
278    pub runtime: RuntimeFactory,
279}
280
281impl Device {
282    /// Create a new device with a single compiler pair.
283    ///
284    /// This is a convenience constructor for the common case where
285    /// a device has only one renderer/compiler combination.
286    pub fn new(
287        device: DeviceSpec,
288        allocator: Arc<dyn Allocator>,
289        renderer: Arc<dyn Renderer>,
290        compiler: Arc<dyn Compiler>,
291        runtime: RuntimeFactory,
292    ) -> Self {
293        let compilers = vec![(renderer.clone(), compiler.clone())];
294        Self { device, allocator, compilers, renderer, compiler, runtime }
295    }
296
297    /// Get the base device key (strips device ID).
298    ///
299    /// Used for compiled byte cache sharing across device instances.
300    /// Examples:
301    /// - DeviceSpec::Cpu -> "CPU"
302    /// - DeviceSpec::Cuda { device_id: 0 } -> "CUDA"
303    /// - DeviceSpec::Cuda { device_id: 1 } -> "CUDA"
304    /// - DeviceSpec::Metal { device_id: 0 } -> "Metal"
305    ///
306    /// This allows compiled CUDA kernels to be reused across CUDA:0 and CUDA:1.
307    pub fn base_device_key(&self) -> &'static str {
308        self.device.base_type()
309    }
310}
311
312/// Program specification containing source code and metadata.
313///
314/// This is returned by Renderer::render() and consumed by Compiler::compile().
315/// It bridges the gap between UOp graphs and compiled executables.
316///
317/// # Tinygrad Alignment
318///
319/// Buffer metadata (`globals`, `outs`, `ins`) matches Tinygrad's Program class:
320/// - `globals`: Buffer indices from DefineGlobal ops
321/// - `outs`: Output buffer indices (written by STORE ops)
322/// - `ins`: Input buffer indices (read by LOAD ops)
323#[derive(Debug, Clone)]
324pub struct ProgramSpec {
325    /// Kernel name (for debugging/profiling)
326    pub name: String,
327
328    /// Generated source code (LLVM IR, CUDA C, Metal, WGSL, etc.)
329    pub src: String,
330
331    /// Device specification
332    pub device: DeviceSpec,
333
334    /// Original AST (for cache key construction via hash consing)
335    pub ast: Arc<UOp>,
336
337    /// Global work size (for GPU backends, None for CPU)
338    pub global_size: Option<[usize; 3]>,
339
340    /// Local work size (for GPU backends, None for CPU)
341    pub local_size: Option<[usize; 3]>,
342
343    /// Variable list (for symbolic shapes/strides)
344    pub vars: Vec<Variable>,
345
346    /// Variable names in order for populating vars array at runtime.
347    /// Includes thread_id at the end if threading is enabled.
348    pub var_names: Vec<String>,
349
350    /// Global buffer indices (from DefineGlobal argument values).
351    /// Matches Tinygrad's `globals` field.
352    pub globals: Vec<usize>,
353
354    /// Output buffer indices (written by STORE ops).
355    /// Matches Tinygrad's `outs` field.
356    pub outs: Vec<usize>,
357
358    /// Input buffer indices (read by LOAD ops, excluding outputs).
359    /// Matches Tinygrad's `ins` field.
360    pub ins: Vec<usize>,
361}
362
363impl ProgramSpec {
364    /// Create a new program specification.
365    pub fn new(name: String, src: String, device: DeviceSpec, ast: Arc<UOp>) -> Self {
366        Self {
367            name,
368            src,
369            device,
370            ast,
371            global_size: None,
372            local_size: None,
373            vars: Vec::new(),
374            var_names: Vec::new(),
375            globals: Vec::new(),
376            outs: Vec::new(),
377            ins: Vec::new(),
378        }
379    }
380
381    /// Add a variable to the program.
382    pub fn add_var(&mut self, var: Variable) {
383        self.vars.push(var);
384    }
385
386    /// Set work sizes for GPU execution.
387    pub fn set_work_sizes(&mut self, global: [usize; 3], local: [usize; 3]) {
388        self.global_size = Some(global);
389        self.local_size = Some(local);
390    }
391
392    /// Set variable names for populating vars array at runtime.
393    pub fn set_var_names(&mut self, var_names: Vec<String>) {
394        self.var_names = var_names;
395    }
396
397    /// Set buffer metadata (globals, outs, ins).
398    pub fn set_buffer_metadata(&mut self, globals: Vec<usize>, outs: Vec<usize>, ins: Vec<usize>) {
399        self.globals = globals;
400        self.outs = outs;
401        self.ins = ins;
402    }
403}
404
405/// A variable in the kernel (for symbolic shapes/strides).
406///
407/// Variables represent symbolic values that are bound at kernel execution time.
408/// Examples:
409/// - Shape dimensions that vary per input
410/// - Stride values computed from shapes
411/// - Loop bounds determined by input sizes
412#[derive(Debug, Clone)]
413pub struct Variable {
414    /// Variable name (must be unique within the kernel)
415    pub name: String,
416
417    /// Minimum value (for range validation)
418    pub min: i64,
419
420    /// Maximum value (for range validation)
421    pub max: i64,
422}
423
424impl Variable {
425    /// Create a new variable.
426    pub fn new(name: String, min: i64, max: i64) -> Self {
427        Self { name, min, max }
428    }
429}