Skip to main content

morok_device/
device.rs

1//! Device abstraction following Tinygrad's architecture.
2//!
3//! This module provides a unified Device abstraction that owns:
4//! - **Renderer**: Transforms UOp graphs into source code (ProgramSpec)
5//! - **Compiler**: Transforms source code into executable bytes
6//! - **Runtime**: Creates executable Programs from compiled bytes
7//! - **Allocator**: Manages memory allocation for buffers
8//!
9//! This design allows multiple backends (LLVM, CUDA, Metal, WebGPU) to coexist
10//! and share compiled kernels via the method cache.
11
12use std::sync::Arc;
13
14use morok_dtype::DeviceSpec;
15use morok_ir::UOp;
16
17use crate::allocator::Allocator;
18use crate::error::Result;
19
20/// A compiled, executable kernel program.
21///
22/// This trait abstracts over different backend executors (LLVM JIT, CUDA, Metal, etc.).
23/// Each backend implements this to provide unified execution interface.
24///
25/// Note: This trait does not require Send + Sync because some backends (like LLVM JIT)
26/// use non-thread-safe types. Programs are typically executed on the same thread where
27/// they were compiled, and caching/sharing is handled at a higher level.
28///
29/// # Tinygrad Alignment
30///
31/// This trait follows Tinygrad's `Program` interface where variable values are
32/// passed as a positional tuple/array (`vals`) rather than a named HashMap.
33/// The order matches `var_names` in `CompiledSpec`.
34pub trait Program {
35    /// Execute the kernel with given buffers and variable values.
36    ///
37    /// # Arguments
38    ///
39    /// * `buffers` - Raw pointers to buffer data (input and output buffers)
40    /// * `vals` - Variable values in positional order (matches `var_names` in CompiledSpec)
41    /// * `global_size` - Global work size (for GPU backends, None for CPU)
42    /// * `local_size` - Local work size (for GPU backends, None for CPU)
43    ///
44    /// # Safety
45    ///
46    /// This is unsafe because:
47    /// - Buffer pointers must be valid and properly aligned
48    /// - Buffer sizes must match what the kernel expects
49    /// - Caller must ensure no data races during execution
50    unsafe fn execute(
51        &self,
52        buffers: &[*mut u8],
53        vals: &[i64],
54        global_size: Option<[usize; 3]>,
55        local_size: Option<[usize; 3]>,
56    ) -> Result<()>;
57
58    /// Get the kernel name (for debugging/profiling).
59    fn name(&self) -> &str;
60}
61
62/// Compilation result carrying source (JIT) or bytes (AOT).
63///
64/// Different backends need different information:
65/// - LLVM JIT: needs source code to compile during runtime
66/// - CUDA: needs PTX/CUBIN bytes to load
67/// - Metal: needs metallib bytes to load
68///
69/// This design allows the RuntimeFactory to access whatever it needs
70/// without requiring separate code paths for JIT vs AOT backends.
71#[derive(Debug, Clone)]
72pub struct CompiledSpec {
73    /// Entry point function name
74    pub name: String,
75
76    /// Source code (for JIT backends like LLVM)
77    /// Set to Some(...) for LLVM JIT, None for AOT backends
78    pub src: Option<String>,
79
80    /// Compiled bytes (for AOT backends like CUDA/Metal)
81    /// Empty for LLVM JIT, populated for AOT backends
82    pub bytes: Vec<u8>,
83
84    /// Original AST for cache key construction via hash consing
85    pub ast: Arc<UOp>,
86
87    /// Variable names in order for populating vars array at runtime.
88    /// Includes thread_id at the end if threading is enabled.
89    pub var_names: Vec<String>,
90
91    /// Global work size for dispatch (GPU backends, CPU threading)
92    /// For CPU threading: [thread_count, 1, 1]
93    pub global_size: Option<[usize; 3]>,
94
95    /// Local work size for dispatch (GPU backends)
96    pub local_size: Option<[usize; 3]>,
97
98    /// Number of buffer arguments (for CIF construction at compile time).
99    pub buf_count: usize,
100}
101
102impl CompiledSpec {
103    /// Create a new CompiledSpec for JIT backends (source-based).
104    pub fn from_source(name: String, src: String, ast: Arc<UOp>, buf_count: usize) -> Self {
105        Self {
106            name,
107            src: Some(src),
108            bytes: Vec::new(),
109            ast,
110            var_names: Vec::new(),
111            global_size: None,
112            local_size: None,
113            buf_count,
114        }
115    }
116
117    /// Create a new CompiledSpec for AOT backends (bytecode-based).
118    pub fn from_bytes(name: String, bytes: Vec<u8>, ast: Arc<UOp>) -> Self {
119        Self { name, src: None, bytes, ast, var_names: Vec::new(), global_size: None, local_size: None, buf_count: 0 }
120    }
121
122    /// Create a new CompiledSpec with work sizes for JIT backends.
123    pub fn from_source_with_sizes(
124        name: String,
125        src: String,
126        ast: Arc<UOp>,
127        global_size: Option<[usize; 3]>,
128        local_size: Option<[usize; 3]>,
129        buf_count: usize,
130    ) -> Self {
131        Self { name, src: Some(src), bytes: Vec::new(), ast, var_names: Vec::new(), global_size, local_size, buf_count }
132    }
133}
134
135/// A compiler that transforms source code into a compiled specification.
136///
137/// This trait abstracts over different compilation backends:
138/// - LLVM: IR validation (JIT compiles at runtime)
139/// - CUDA: CUDA C -> PTX/CUBIN
140/// - Metal: Metal Shading Language -> metallib
141/// - WebGPU: WGSL -> SPIR-V
142pub trait Compiler: Send + Sync {
143    /// Compile a program specification into executable form.
144    ///
145    /// # Arguments
146    ///
147    /// * `spec` - The program specification containing source code and metadata
148    ///
149    /// # Returns
150    ///
151    /// A CompiledSpec containing:
152    /// - For JIT backends (LLVM): source code in `src` field, empty `bytes`
153    /// - For AOT backends (CUDA/Metal): compiled bytes in `bytes` field, no `src`
154    ///
155    /// # Examples
156    ///
157    /// JIT backend (LLVM):
158    /// ```ignore
159    /// let compiled = compiler.compile(&spec)?;
160    /// assert!(compiled.src.is_some());
161    /// assert!(compiled.bytes.is_empty());
162    /// ```
163    ///
164    /// AOT backend (CUDA):
165    /// ```ignore
166    /// let compiled = compiler.compile(&spec)?;
167    /// assert!(compiled.src.is_none());
168    /// assert!(!compiled.bytes.is_empty());
169    /// ```
170    fn compile(&self, spec: &ProgramSpec) -> Result<CompiledSpec>;
171
172    /// Cache key identifying this compiler backend.
173    ///
174    /// Used to differentiate compiled artifacts when the same device type
175    /// can have multiple compiler backends (e.g., clang vs llvm-jit).
176    fn cache_key(&self) -> &'static str;
177}
178
179/// A renderer that transforms UOp graphs into source code.
180///
181/// This trait abstracts over different code generation backends:
182/// - LLVM IR generator
183/// - CUDA C generator
184/// - Metal Shading Language generator
185/// - WGSL generator
186pub trait Renderer: Send + Sync {
187    /// Render a UOp graph into source code.
188    ///
189    /// # Arguments
190    ///
191    /// * `ast` - The kernel AST (UOp graph rooted at KERNEL op)
192    /// * `name` - Optional kernel name for debugging (e.g., "r_g16l16R32u4").
193    ///   Falls back to "kernel" if None.
194    ///
195    /// # Returns
196    ///
197    /// A ProgramSpec containing:
198    /// - Generated source code
199    /// - Entry point name
200    /// - Variable list
201    /// - Work sizes (for GPU backends)
202    fn render(&self, ast: &Arc<UOp>, name: Option<&str>) -> Result<ProgramSpec>;
203
204    /// Get the device spec for this renderer.
205    ///
206    /// This is used for cache key construction and device selection.
207    fn device(&self) -> &DeviceSpec;
208
209    /// Returns decomposition patterns for operations this backend doesn't support.
210    ///
211    /// This is used by the realization pass to decompose complex operations
212    /// into simpler primitives before rendering.
213    ///
214    /// # Default Implementation
215    ///
216    /// Returns `None`, meaning no decomposition is needed (backend supports all ops).
217    /// Backends that don't support certain operations (e.g., transcendentals)
218    /// should override this to return appropriate patterns.
219    fn decompositor(&self) -> Option<morok_ir::pattern::TypedPatternMatcher<()>> {
220        None
221    }
222}
223
224/// A factory function that creates executable Programs from a compiled specification.
225///
226/// This is a function pointer that wraps the backend-specific loader:
227/// - LLVM: Extract source from CompiledSpec and JIT compile
228/// - CUDA: Extract bytes from CompiledSpec and call cuModuleLoadData + cuModuleGetFunction
229/// - Metal: Extract bytes from CompiledSpec and call newLibraryWithData + newFunctionWithName
230/// - WebGPU: Extract bytes from CompiledSpec and call createShaderModule
231///
232/// The CompiledSpec contains either source (for JIT) or bytes (for AOT),
233/// allowing each backend to access what it needs.
234pub type RuntimeFactory = Arc<dyn Fn(&CompiledSpec) -> Result<Box<dyn Program>> + Send + Sync>;
235
236/// A (Renderer, Compiler) pair for a specific backend.
237///
238/// Devices can have multiple compiler pairs (e.g., different optimization levels).
239pub type CompilerPair = (Arc<dyn Renderer>, Arc<dyn Compiler>);
240
241/// A device that owns renderer, compiler, runtime, and allocator.
242///
243/// This follows Tinygrad's architecture where a Device is a complete
244/// compilation + execution unit for a specific backend.
245///
246/// # Example
247///
248/// ```ignore
249/// let cpu_device = create_cpu_device()?;
250/// let spec = cpu_device.renderer.render(&kernel_ast, Some("E_L3"))?;
251/// let compiled = cpu_device.compiler.compile(&spec)?;
252/// let program = (cpu_device.runtime)(&compiled)?;
253/// unsafe { program.execute(&buffers, &vals, None, None)?; }
254/// ```
255pub struct Device {
256    /// Device specification
257    pub device: DeviceSpec,
258
259    /// Memory allocator for this device
260    pub allocator: Arc<dyn Allocator>,
261
262    /// Available (renderer, compiler) pairs for this device
263    ///
264    /// Most devices have one pair, but some may have multiple
265    /// (e.g., different optimization levels or compilation modes).
266    pub compilers: Vec<CompilerPair>,
267
268    /// Primary renderer for this device
269    ///
270    /// This is typically compilers[0].0, stored separately for convenience.
271    pub renderer: Arc<dyn Renderer>,
272
273    /// Primary compiler for this device
274    ///
275    /// This is typically compilers[0].1, stored separately for convenience.
276    pub compiler: Arc<dyn Compiler>,
277
278    /// Runtime factory for creating executable programs
279    ///
280    /// Takes (entry_point, compiled_bytes) and returns a Program.
281    pub runtime: RuntimeFactory,
282}
283
284impl Device {
285    /// Create a new device with a single compiler pair.
286    ///
287    /// This is a convenience constructor for the common case where
288    /// a device has only one renderer/compiler combination.
289    pub fn new(
290        device: DeviceSpec,
291        allocator: Arc<dyn Allocator>,
292        renderer: Arc<dyn Renderer>,
293        compiler: Arc<dyn Compiler>,
294        runtime: RuntimeFactory,
295    ) -> Self {
296        let compilers = vec![(renderer.clone(), compiler.clone())];
297        Self { device, allocator, compilers, renderer, compiler, runtime }
298    }
299
300    /// Get the base device key (strips device ID).
301    ///
302    /// Used for compiled byte cache sharing across device instances.
303    /// Examples:
304    /// - DeviceSpec::Cpu -> "CPU"
305    /// - DeviceSpec::Cuda { device_id: 0 } -> "CUDA"
306    /// - DeviceSpec::Cuda { device_id: 1 } -> "CUDA"
307    /// - DeviceSpec::Metal { device_id: 0 } -> "Metal"
308    ///
309    /// This allows compiled CUDA kernels to be reused across CUDA:0 and CUDA:1.
310    pub fn base_device_key(&self) -> &'static str {
311        self.device.base_type()
312    }
313}
314
315/// Program specification containing source code and metadata.
316///
317/// This is returned by Renderer::render() and consumed by Compiler::compile().
318/// It bridges the gap between UOp graphs and compiled executables.
319///
320/// # Tinygrad Alignment
321///
322/// Buffer metadata (`globals`, `outs`, `ins`) matches Tinygrad's Program class:
323/// - `globals`: Buffer indices from PARAM ops
324/// - `outs`: Output buffer indices (written by STORE ops)
325/// - `ins`: Input buffer indices (read by LOAD ops)
326#[derive(Debug, Clone)]
327pub struct ProgramSpec {
328    /// Kernel name (for debugging/profiling)
329    pub name: String,
330
331    /// Generated source code (LLVM IR, CUDA C, Metal, WGSL, etc.)
332    pub src: String,
333
334    /// Device specification
335    pub device: DeviceSpec,
336
337    /// Original AST (for cache key construction via hash consing)
338    pub ast: Arc<UOp>,
339
340    /// Global work size (for GPU backends, None for CPU)
341    pub global_size: Option<[usize; 3]>,
342
343    /// Local work size (for GPU backends, None for CPU)
344    pub local_size: Option<[usize; 3]>,
345
346    /// Variable list (for symbolic shapes/strides)
347    pub vars: Vec<Variable>,
348
349    /// Variable names in order for populating vars array at runtime.
350    /// Includes thread_id at the end if threading is enabled.
351    pub var_names: Vec<String>,
352
353    /// Global buffer indices (from PARAM slot values).
354    /// Matches Tinygrad's `globals` field.
355    pub globals: Vec<usize>,
356
357    /// Output buffer indices (written by STORE ops).
358    /// Matches Tinygrad's `outs` field.
359    pub outs: Vec<usize>,
360
361    /// Input buffer indices (read by LOAD ops, excluding outputs).
362    /// Matches Tinygrad's `ins` field.
363    pub ins: Vec<usize>,
364
365    /// Number of buffer arguments (for CIF construction at compile time).
366    pub buf_count: usize,
367}
368
369impl ProgramSpec {
370    /// Create a new program specification.
371    pub fn new(name: String, src: String, device: DeviceSpec, ast: Arc<UOp>) -> Self {
372        Self {
373            name,
374            src,
375            device,
376            ast,
377            global_size: None,
378            local_size: None,
379            vars: Vec::new(),
380            var_names: Vec::new(),
381            globals: Vec::new(),
382            outs: Vec::new(),
383            ins: Vec::new(),
384            buf_count: 0,
385        }
386    }
387
388    /// Add a variable to the program.
389    pub fn add_var(&mut self, var: Variable) {
390        self.vars.push(var);
391    }
392
393    /// Set work sizes for GPU execution.
394    pub fn set_work_sizes(&mut self, global: [usize; 3], local: [usize; 3]) {
395        self.global_size = Some(global);
396        self.local_size = Some(local);
397    }
398
399    /// Set variable names for populating vars array at runtime.
400    pub fn set_var_names(&mut self, var_names: Vec<String>) {
401        self.var_names = var_names;
402    }
403
404    /// Set buffer metadata (globals, outs, ins).
405    pub fn set_buffer_metadata(&mut self, globals: Vec<usize>, outs: Vec<usize>, ins: Vec<usize>) {
406        self.globals = globals;
407        self.outs = outs;
408        self.ins = ins;
409    }
410}
411
412/// A variable in the kernel (for symbolic shapes/strides).
413///
414/// Variables represent symbolic values that are bound at kernel execution time.
415/// Examples:
416/// - Shape dimensions that vary per input
417/// - Stride values computed from shapes
418/// - Loop bounds determined by input sizes
419#[derive(Debug, Clone)]
420pub struct Variable {
421    /// Variable name (must be unique within the kernel)
422    pub name: String,
423
424    /// Minimum value (for range validation)
425    pub min: i64,
426
427    /// Maximum value (for range validation)
428    pub max: i64,
429}
430
431impl Variable {
432    /// Create a new variable.
433    pub fn new(name: String, min: i64, max: i64) -> Self {
434        Self { name, min, max }
435    }
436}