Skip to main content

morok_device/
device.rs

1//! Device abstraction following Tinygrad's architecture.
2//!
3//! This module provides a unified Device abstraction that owns:
4//! - **Renderer**: Transforms UOp graphs into source code (ProgramSpec)
5//! - **Compiler**: Transforms source code into executable bytes
6//! - **Runtime**: Creates executable Programs from compiled bytes
7//! - **Allocator**: Manages memory allocation for buffers
8//!
9//! This design allows multiple backends (LLVM, CUDA, Metal, WebGPU) to coexist
10//! and share compiled kernels via the method cache.
11
12use std::sync::Arc;
13
14use morok_dtype::DeviceSpec;
15use morok_ir::UOp;
16
17use crate::allocator::Allocator;
18use crate::error::Result;
19
20/// A compiled, executable kernel program.
21///
22/// This trait abstracts over different backend executors (LLVM JIT, CUDA, Metal, etc.).
23/// Each backend implements this to provide unified execution interface.
24///
25/// Note: This trait does not require Send + Sync because some backends (like LLVM JIT)
26/// use non-thread-safe types. Programs are typically executed on the same thread where
27/// they were compiled, and caching/sharing is handled at a higher level.
28///
29/// # Tinygrad Alignment
30///
31/// This trait follows Tinygrad's `Program` interface where variable values are
32/// passed as a positional tuple/array (`vals`) rather than a named HashMap.
33/// The order matches `var_names` in `CompiledSpec`.
34pub trait Program {
35    /// Execute the kernel with given buffers and variable values.
36    ///
37    /// # Arguments
38    ///
39    /// * `buffers` - Raw pointers to buffer data (input and output buffers)
40    /// * `vals` - Variable values in positional order (matches `var_names` in CompiledSpec)
41    /// * `global_size` - Global work size (for GPU backends, None for CPU)
42    /// * `local_size` - Local work size (for GPU backends, None for CPU)
43    ///
44    /// # Safety
45    ///
46    /// This is unsafe because:
47    /// - Buffer pointers must be valid and properly aligned
48    /// - Buffer sizes must match what the kernel expects
49    /// - Caller must ensure no data races during execution
50    unsafe fn execute(
51        &self,
52        buffers: &[*mut u8],
53        vals: &[i64],
54        global_size: Option<[usize; 3]>,
55        local_size: Option<[usize; 3]>,
56    ) -> Result<()>;
57
58    /// Get the kernel name (for debugging/profiling).
59    fn name(&self) -> &str;
60}
61
62/// Compilation result carrying source (JIT) or bytes (AOT).
63///
64/// Different backends need different information:
65/// - LLVM JIT: needs source code to compile during runtime
66/// - CUDA: needs PTX/CUBIN bytes to load
67/// - Metal: needs metallib bytes to load
68///
69/// This design allows the RuntimeFactory to access whatever it needs
70/// without requiring separate code paths for JIT vs AOT backends.
71#[derive(Debug, Clone)]
72pub struct CompiledSpec {
73    /// Entry point function name
74    pub name: String,
75
76    /// Source code (for JIT backends like LLVM)
77    /// Set to Some(...) for LLVM JIT, None for AOT backends
78    pub src: Option<String>,
79
80    /// Compiled bytes (for AOT backends like CUDA/Metal)
81    /// Empty for LLVM JIT, populated for AOT backends
82    pub bytes: Vec<u8>,
83
84    /// Original AST for cache key construction via hash consing
85    pub ast: Arc<UOp>,
86
87    /// Variable names in order for populating vars array at runtime.
88    /// Includes thread_id at the end if threading is enabled.
89    pub var_names: Vec<String>,
90
91    /// Global work size for dispatch (GPU backends, CPU threading)
92    /// For CPU threading: [thread_count, 1, 1]
93    pub global_size: Option<[usize; 3]>,
94
95    /// Local work size for dispatch (GPU backends)
96    pub local_size: Option<[usize; 3]>,
97
98    /// Number of buffer arguments (for CIF construction at compile time).
99    pub buf_count: usize,
100}
101
102impl CompiledSpec {
103    /// Create a new CompiledSpec for JIT backends (source-based).
104    pub fn from_source(name: String, src: String, ast: Arc<UOp>, buf_count: usize) -> Self {
105        Self {
106            name,
107            src: Some(src),
108            bytes: Vec::new(),
109            ast,
110            var_names: Vec::new(),
111            global_size: None,
112            local_size: None,
113            buf_count,
114        }
115    }
116
117    /// Create a new CompiledSpec for AOT backends (bytecode-based).
118    pub fn from_bytes(name: String, bytes: Vec<u8>, ast: Arc<UOp>) -> Self {
119        Self { name, src: None, bytes, ast, var_names: Vec::new(), global_size: None, local_size: None, buf_count: 0 }
120    }
121
122    /// Create a new CompiledSpec with work sizes for JIT backends.
123    pub fn from_source_with_sizes(
124        name: String,
125        src: String,
126        ast: Arc<UOp>,
127        global_size: Option<[usize; 3]>,
128        local_size: Option<[usize; 3]>,
129        buf_count: usize,
130    ) -> Self {
131        Self { name, src: Some(src), bytes: Vec::new(), ast, var_names: Vec::new(), global_size, local_size, buf_count }
132    }
133}
134
135/// A compiler that transforms source code into a compiled specification.
136///
137/// This trait abstracts over different compilation backends:
138/// - LLVM: IR validation (JIT compiles at runtime)
139/// - CUDA: CUDA C -> PTX/CUBIN
140/// - Metal: Metal Shading Language -> metallib
141/// - WebGPU: WGSL -> SPIR-V
142pub trait Compiler: Send + Sync {
143    /// Compile a program specification into executable form.
144    ///
145    /// # Arguments
146    ///
147    /// * `spec` - The program specification containing source code and metadata
148    ///
149    /// # Returns
150    ///
151    /// A CompiledSpec containing:
152    /// - For JIT backends (LLVM): source code in `src` field, empty `bytes`
153    /// - For AOT backends (CUDA/Metal): compiled bytes in `bytes` field, no `src`
154    ///
155    /// # Examples
156    ///
157    /// JIT backend (LLVM):
158    /// ```ignore
159    /// let compiled = compiler.compile(&spec)?;
160    /// assert!(compiled.src.is_some());
161    /// assert!(compiled.bytes.is_empty());
162    /// ```
163    ///
164    /// AOT backend (CUDA):
165    /// ```ignore
166    /// let compiled = compiler.compile(&spec)?;
167    /// assert!(compiled.src.is_none());
168    /// assert!(!compiled.bytes.is_empty());
169    /// ```
170    fn compile(&self, spec: &ProgramSpec) -> Result<CompiledSpec>;
171
172    /// Optional cache key for this compiler configuration.
173    ///
174    /// Used to differentiate compiled artifacts when the same device type
175    /// can have multiple compiler configurations (e.g., different optimization levels).
176    ///
177    /// Returns None if all instances of this compiler produce identical output.
178    fn cache_key(&self) -> Option<&str> {
179        None
180    }
181}
182
183/// A renderer that transforms UOp graphs into source code.
184///
185/// This trait abstracts over different code generation backends:
186/// - LLVM IR generator
187/// - CUDA C generator
188/// - Metal Shading Language generator
189/// - WGSL generator
190pub trait Renderer: Send + Sync {
191    /// Render a UOp graph into source code.
192    ///
193    /// # Arguments
194    ///
195    /// * `ast` - The kernel AST (UOp graph rooted at KERNEL op)
196    /// * `name` - Optional kernel name for debugging (e.g., "r_g16l16R32u4").
197    ///   Falls back to "kernel" if None.
198    ///
199    /// # Returns
200    ///
201    /// A ProgramSpec containing:
202    /// - Generated source code
203    /// - Entry point name
204    /// - Variable list
205    /// - Work sizes (for GPU backends)
206    fn render(&self, ast: &Arc<UOp>, name: Option<&str>) -> Result<ProgramSpec>;
207
208    /// Get the device spec for this renderer.
209    ///
210    /// This is used for cache key construction and device selection.
211    fn device(&self) -> &DeviceSpec;
212
213    /// Returns decomposition patterns for operations this backend doesn't support.
214    ///
215    /// This is used by the realization pass to decompose complex operations
216    /// into simpler primitives before rendering.
217    ///
218    /// # Default Implementation
219    ///
220    /// Returns `None`, meaning no decomposition is needed (backend supports all ops).
221    /// Backends that don't support certain operations (e.g., transcendentals)
222    /// should override this to return appropriate patterns.
223    fn decompositor(&self) -> Option<morok_ir::pattern::TypedPatternMatcher<()>> {
224        None
225    }
226}
227
228/// A factory function that creates executable Programs from a compiled specification.
229///
230/// This is a function pointer that wraps the backend-specific loader:
231/// - LLVM: Extract source from CompiledSpec and JIT compile
232/// - CUDA: Extract bytes from CompiledSpec and call cuModuleLoadData + cuModuleGetFunction
233/// - Metal: Extract bytes from CompiledSpec and call newLibraryWithData + newFunctionWithName
234/// - WebGPU: Extract bytes from CompiledSpec and call createShaderModule
235///
236/// The CompiledSpec contains either source (for JIT) or bytes (for AOT),
237/// allowing each backend to access what it needs.
238pub type RuntimeFactory = Arc<dyn Fn(&CompiledSpec) -> Result<Box<dyn Program>> + Send + Sync>;
239
240/// A (Renderer, Compiler) pair for a specific backend.
241///
242/// Devices can have multiple compiler pairs (e.g., different optimization levels).
243pub type CompilerPair = (Arc<dyn Renderer>, Arc<dyn Compiler>);
244
245/// A device that owns renderer, compiler, runtime, and allocator.
246///
247/// This follows Tinygrad's architecture where a Device is a complete
248/// compilation + execution unit for a specific backend.
249///
250/// # Example
251///
252/// ```ignore
253/// let cpu_device = create_cpu_device()?;
254/// let spec = cpu_device.renderer.render(&kernel_ast, Some("E_L3"))?;
255/// let compiled = cpu_device.compiler.compile(&spec)?;
256/// let program = (cpu_device.runtime)(&compiled)?;
257/// unsafe { program.execute(&buffers, &vals, None, None)?; }
258/// ```
259pub struct Device {
260    /// Device specification
261    pub device: DeviceSpec,
262
263    /// Memory allocator for this device
264    pub allocator: Arc<dyn Allocator>,
265
266    /// Available (renderer, compiler) pairs for this device
267    ///
268    /// Most devices have one pair, but some may have multiple
269    /// (e.g., different optimization levels or compilation modes).
270    pub compilers: Vec<CompilerPair>,
271
272    /// Primary renderer for this device
273    ///
274    /// This is typically compilers[0].0, stored separately for convenience.
275    pub renderer: Arc<dyn Renderer>,
276
277    /// Primary compiler for this device
278    ///
279    /// This is typically compilers[0].1, stored separately for convenience.
280    pub compiler: Arc<dyn Compiler>,
281
282    /// Runtime factory for creating executable programs
283    ///
284    /// Takes (entry_point, compiled_bytes) and returns a Program.
285    pub runtime: RuntimeFactory,
286}
287
288impl Device {
289    /// Create a new device with a single compiler pair.
290    ///
291    /// This is a convenience constructor for the common case where
292    /// a device has only one renderer/compiler combination.
293    pub fn new(
294        device: DeviceSpec,
295        allocator: Arc<dyn Allocator>,
296        renderer: Arc<dyn Renderer>,
297        compiler: Arc<dyn Compiler>,
298        runtime: RuntimeFactory,
299    ) -> Self {
300        let compilers = vec![(renderer.clone(), compiler.clone())];
301        Self { device, allocator, compilers, renderer, compiler, runtime }
302    }
303
304    /// Get the base device key (strips device ID).
305    ///
306    /// Used for compiled byte cache sharing across device instances.
307    /// Examples:
308    /// - DeviceSpec::Cpu -> "CPU"
309    /// - DeviceSpec::Cuda { device_id: 0 } -> "CUDA"
310    /// - DeviceSpec::Cuda { device_id: 1 } -> "CUDA"
311    /// - DeviceSpec::Metal { device_id: 0 } -> "Metal"
312    ///
313    /// This allows compiled CUDA kernels to be reused across CUDA:0 and CUDA:1.
314    pub fn base_device_key(&self) -> &'static str {
315        self.device.base_type()
316    }
317}
318
319/// Program specification containing source code and metadata.
320///
321/// This is returned by Renderer::render() and consumed by Compiler::compile().
322/// It bridges the gap between UOp graphs and compiled executables.
323///
324/// # Tinygrad Alignment
325///
326/// Buffer metadata (`globals`, `outs`, `ins`) matches Tinygrad's Program class:
327/// - `globals`: Buffer indices from DefineGlobal ops
328/// - `outs`: Output buffer indices (written by STORE ops)
329/// - `ins`: Input buffer indices (read by LOAD ops)
330#[derive(Debug, Clone)]
331pub struct ProgramSpec {
332    /// Kernel name (for debugging/profiling)
333    pub name: String,
334
335    /// Generated source code (LLVM IR, CUDA C, Metal, WGSL, etc.)
336    pub src: String,
337
338    /// Device specification
339    pub device: DeviceSpec,
340
341    /// Original AST (for cache key construction via hash consing)
342    pub ast: Arc<UOp>,
343
344    /// Global work size (for GPU backends, None for CPU)
345    pub global_size: Option<[usize; 3]>,
346
347    /// Local work size (for GPU backends, None for CPU)
348    pub local_size: Option<[usize; 3]>,
349
350    /// Variable list (for symbolic shapes/strides)
351    pub vars: Vec<Variable>,
352
353    /// Variable names in order for populating vars array at runtime.
354    /// Includes thread_id at the end if threading is enabled.
355    pub var_names: Vec<String>,
356
357    /// Global buffer indices (from DefineGlobal argument values).
358    /// Matches Tinygrad's `globals` field.
359    pub globals: Vec<usize>,
360
361    /// Output buffer indices (written by STORE ops).
362    /// Matches Tinygrad's `outs` field.
363    pub outs: Vec<usize>,
364
365    /// Input buffer indices (read by LOAD ops, excluding outputs).
366    /// Matches Tinygrad's `ins` field.
367    pub ins: Vec<usize>,
368
369    /// Number of buffer arguments (for CIF construction at compile time).
370    pub buf_count: usize,
371}
372
373impl ProgramSpec {
374    /// Create a new program specification.
375    pub fn new(name: String, src: String, device: DeviceSpec, ast: Arc<UOp>) -> Self {
376        Self {
377            name,
378            src,
379            device,
380            ast,
381            global_size: None,
382            local_size: None,
383            vars: Vec::new(),
384            var_names: Vec::new(),
385            globals: Vec::new(),
386            outs: Vec::new(),
387            ins: Vec::new(),
388            buf_count: 0,
389        }
390    }
391
392    /// Add a variable to the program.
393    pub fn add_var(&mut self, var: Variable) {
394        self.vars.push(var);
395    }
396
397    /// Set work sizes for GPU execution.
398    pub fn set_work_sizes(&mut self, global: [usize; 3], local: [usize; 3]) {
399        self.global_size = Some(global);
400        self.local_size = Some(local);
401    }
402
403    /// Set variable names for populating vars array at runtime.
404    pub fn set_var_names(&mut self, var_names: Vec<String>) {
405        self.var_names = var_names;
406    }
407
408    /// Set buffer metadata (globals, outs, ins).
409    pub fn set_buffer_metadata(&mut self, globals: Vec<usize>, outs: Vec<usize>, ins: Vec<usize>) {
410        self.globals = globals;
411        self.outs = outs;
412        self.ins = ins;
413    }
414}
415
416/// A variable in the kernel (for symbolic shapes/strides).
417///
418/// Variables represent symbolic values that are bound at kernel execution time.
419/// Examples:
420/// - Shape dimensions that vary per input
421/// - Stride values computed from shapes
422/// - Loop bounds determined by input sizes
423#[derive(Debug, Clone)]
424pub struct Variable {
425    /// Variable name (must be unique within the kernel)
426    pub name: String,
427
428    /// Minimum value (for range validation)
429    pub min: i64,
430
431    /// Maximum value (for range validation)
432    pub max: i64,
433}
434
435impl Variable {
436    /// Create a new variable.
437    pub fn new(name: String, min: i64, max: i64) -> Self {
438        Self { name, min, max }
439    }
440}