morok_device/device.rs
1//! Device abstraction following Tinygrad's architecture.
2//!
3//! This module provides a unified Device abstraction that owns:
4//! - **Renderer**: Transforms UOp graphs into source code (ProgramSpec)
5//! - **Compiler**: Transforms source code into executable bytes
6//! - **Runtime**: Creates executable Programs from compiled bytes
7//! - **Allocator**: Manages memory allocation for buffers
8//!
9//! This design allows multiple backends (LLVM, CUDA, Metal, WebGPU) to coexist
10//! and share compiled kernels via the method cache.
11
12use std::sync::Arc;
13
14use morok_dtype::DeviceSpec;
15use morok_ir::UOp;
16
17use crate::allocator::Allocator;
18use crate::error::Result;
19
20/// A compiled, executable kernel program.
21///
22/// This trait abstracts over different backend executors (LLVM JIT, CUDA, Metal, etc.).
23/// Each backend implements this to provide unified execution interface.
24///
25/// Note: This trait does not require Send + Sync because some backends (like LLVM JIT)
26/// use non-thread-safe types. Programs are typically executed on the same thread where
27/// they were compiled, and caching/sharing is handled at a higher level.
28///
29/// # Tinygrad Alignment
30///
31/// This trait follows Tinygrad's `Program` interface where variable values are
32/// passed as a positional tuple/array (`vals`) rather than a named HashMap.
33/// The order matches `var_names` in `CompiledSpec`.
34pub trait Program {
35 /// Execute the kernel with given buffers and variable values.
36 ///
37 /// # Arguments
38 ///
39 /// * `buffers` - Raw pointers to buffer data (input and output buffers)
40 /// * `vals` - Variable values in positional order (matches `var_names` in CompiledSpec)
41 /// * `global_size` - Global work size (for GPU backends, None for CPU)
42 /// * `local_size` - Local work size (for GPU backends, None for CPU)
43 ///
44 /// # Safety
45 ///
46 /// This is unsafe because:
47 /// - Buffer pointers must be valid and properly aligned
48 /// - Buffer sizes must match what the kernel expects
49 /// - Caller must ensure no data races during execution
50 unsafe fn execute(
51 &self,
52 buffers: &[*mut u8],
53 vals: &[i64],
54 global_size: Option<[usize; 3]>,
55 local_size: Option<[usize; 3]>,
56 ) -> Result<()>;
57
58 /// Get the kernel name (for debugging/profiling).
59 fn name(&self) -> &str;
60}
61
62/// Compilation result carrying source (JIT) or bytes (AOT).
63///
64/// Different backends need different information:
65/// - LLVM JIT: needs source code to compile during runtime
66/// - CUDA: needs PTX/CUBIN bytes to load
67/// - Metal: needs metallib bytes to load
68///
69/// This design allows the RuntimeFactory to access whatever it needs
70/// without requiring separate code paths for JIT vs AOT backends.
71#[derive(Debug, Clone)]
72pub struct CompiledSpec {
73 /// Entry point function name
74 pub name: String,
75
76 /// Source code (for JIT backends like LLVM)
77 /// Set to Some(...) for LLVM JIT, None for AOT backends
78 pub src: Option<String>,
79
80 /// Compiled bytes (for AOT backends like CUDA/Metal)
81 /// Empty for LLVM JIT, populated for AOT backends
82 pub bytes: Vec<u8>,
83
84 /// Original AST for cache key construction via hash consing
85 pub ast: Arc<UOp>,
86
87 /// Variable names in order for populating vars array at runtime.
88 /// Includes thread_id at the end if threading is enabled.
89 pub var_names: Vec<String>,
90
91 /// Global work size for dispatch (GPU backends, CPU threading)
92 /// For CPU threading: [thread_count, 1, 1]
93 pub global_size: Option<[usize; 3]>,
94
95 /// Local work size for dispatch (GPU backends)
96 pub local_size: Option<[usize; 3]>,
97
98 /// Number of buffer arguments (for CIF construction at compile time).
99 pub buf_count: usize,
100}
101
102impl CompiledSpec {
103 /// Create a new CompiledSpec for JIT backends (source-based).
104 pub fn from_source(name: String, src: String, ast: Arc<UOp>, buf_count: usize) -> Self {
105 Self {
106 name,
107 src: Some(src),
108 bytes: Vec::new(),
109 ast,
110 var_names: Vec::new(),
111 global_size: None,
112 local_size: None,
113 buf_count,
114 }
115 }
116
117 /// Create a new CompiledSpec for AOT backends (bytecode-based).
118 pub fn from_bytes(name: String, bytes: Vec<u8>, ast: Arc<UOp>) -> Self {
119 Self { name, src: None, bytes, ast, var_names: Vec::new(), global_size: None, local_size: None, buf_count: 0 }
120 }
121
122 /// Create a new CompiledSpec with work sizes for JIT backends.
123 pub fn from_source_with_sizes(
124 name: String,
125 src: String,
126 ast: Arc<UOp>,
127 global_size: Option<[usize; 3]>,
128 local_size: Option<[usize; 3]>,
129 buf_count: usize,
130 ) -> Self {
131 Self { name, src: Some(src), bytes: Vec::new(), ast, var_names: Vec::new(), global_size, local_size, buf_count }
132 }
133}
134
135/// A compiler that transforms source code into a compiled specification.
136///
137/// This trait abstracts over different compilation backends:
138/// - LLVM: IR validation (JIT compiles at runtime)
139/// - CUDA: CUDA C -> PTX/CUBIN
140/// - Metal: Metal Shading Language -> metallib
141/// - WebGPU: WGSL -> SPIR-V
142pub trait Compiler: Send + Sync {
143 /// Compile a program specification into executable form.
144 ///
145 /// # Arguments
146 ///
147 /// * `spec` - The program specification containing source code and metadata
148 ///
149 /// # Returns
150 ///
151 /// A CompiledSpec containing:
152 /// - For JIT backends (LLVM): source code in `src` field, empty `bytes`
153 /// - For AOT backends (CUDA/Metal): compiled bytes in `bytes` field, no `src`
154 ///
155 /// # Examples
156 ///
157 /// JIT backend (LLVM):
158 /// ```ignore
159 /// let compiled = compiler.compile(&spec)?;
160 /// assert!(compiled.src.is_some());
161 /// assert!(compiled.bytes.is_empty());
162 /// ```
163 ///
164 /// AOT backend (CUDA):
165 /// ```ignore
166 /// let compiled = compiler.compile(&spec)?;
167 /// assert!(compiled.src.is_none());
168 /// assert!(!compiled.bytes.is_empty());
169 /// ```
170 fn compile(&self, spec: &ProgramSpec) -> Result<CompiledSpec>;
171
172 /// Cache key identifying this compiler backend.
173 ///
174 /// Used to differentiate compiled artifacts when the same device type
175 /// can have multiple compiler backends (e.g., clang vs llvm-jit).
176 fn cache_key(&self) -> &'static str;
177}
178
179/// A renderer that transforms UOp graphs into source code.
180///
181/// This trait abstracts over different code generation backends:
182/// - LLVM IR generator
183/// - CUDA C generator
184/// - Metal Shading Language generator
185/// - WGSL generator
186pub trait Renderer: Send + Sync {
187 /// Render a UOp graph into source code.
188 ///
189 /// # Arguments
190 ///
191 /// * `ast` - The kernel AST (UOp graph rooted at KERNEL op)
192 /// * `name` - Optional kernel name for debugging (e.g., "r_g16l16R32u4").
193 /// Falls back to "kernel" if None.
194 ///
195 /// # Returns
196 ///
197 /// A ProgramSpec containing:
198 /// - Generated source code
199 /// - Entry point name
200 /// - Variable list
201 /// - Work sizes (for GPU backends)
202 fn render(&self, ast: &Arc<UOp>, name: Option<&str>) -> Result<ProgramSpec>;
203
204 /// Get the device spec for this renderer.
205 ///
206 /// This is used for cache key construction and device selection.
207 fn device(&self) -> &DeviceSpec;
208
209 /// Returns decomposition patterns for operations this backend doesn't support.
210 ///
211 /// This is used by the realization pass to decompose complex operations
212 /// into simpler primitives before rendering.
213 ///
214 /// # Default Implementation
215 ///
216 /// Returns `None`, meaning no decomposition is needed (backend supports all ops).
217 /// Backends that don't support certain operations (e.g., transcendentals)
218 /// should override this to return appropriate patterns.
219 fn decompositor(&self) -> Option<morok_ir::pattern::TypedPatternMatcher<()>> {
220 None
221 }
222}
223
224/// A factory function that creates executable Programs from a compiled specification.
225///
226/// This is a function pointer that wraps the backend-specific loader:
227/// - LLVM: Extract source from CompiledSpec and JIT compile
228/// - CUDA: Extract bytes from CompiledSpec and call cuModuleLoadData + cuModuleGetFunction
229/// - Metal: Extract bytes from CompiledSpec and call newLibraryWithData + newFunctionWithName
230/// - WebGPU: Extract bytes from CompiledSpec and call createShaderModule
231///
232/// The CompiledSpec contains either source (for JIT) or bytes (for AOT),
233/// allowing each backend to access what it needs.
234pub type RuntimeFactory = Arc<dyn Fn(&CompiledSpec) -> Result<Box<dyn Program>> + Send + Sync>;
235
236/// A (Renderer, Compiler) pair for a specific backend.
237///
238/// Devices can have multiple compiler pairs (e.g., different optimization levels).
239pub type CompilerPair = (Arc<dyn Renderer>, Arc<dyn Compiler>);
240
241/// A device that owns renderer, compiler, runtime, and allocator.
242///
243/// This follows Tinygrad's architecture where a Device is a complete
244/// compilation + execution unit for a specific backend.
245///
246/// # Example
247///
248/// ```ignore
249/// let cpu_device = create_cpu_device()?;
250/// let spec = cpu_device.renderer.render(&kernel_ast, Some("E_L3"))?;
251/// let compiled = cpu_device.compiler.compile(&spec)?;
252/// let program = (cpu_device.runtime)(&compiled)?;
253/// unsafe { program.execute(&buffers, &vals, None, None)?; }
254/// ```
255pub struct Device {
256 /// Device specification
257 pub device: DeviceSpec,
258
259 /// Memory allocator for this device
260 pub allocator: Arc<dyn Allocator>,
261
262 /// Available (renderer, compiler) pairs for this device
263 ///
264 /// Most devices have one pair, but some may have multiple
265 /// (e.g., different optimization levels or compilation modes).
266 pub compilers: Vec<CompilerPair>,
267
268 /// Primary renderer for this device
269 ///
270 /// This is typically compilers[0].0, stored separately for convenience.
271 pub renderer: Arc<dyn Renderer>,
272
273 /// Primary compiler for this device
274 ///
275 /// This is typically compilers[0].1, stored separately for convenience.
276 pub compiler: Arc<dyn Compiler>,
277
278 /// Runtime factory for creating executable programs
279 ///
280 /// Takes (entry_point, compiled_bytes) and returns a Program.
281 pub runtime: RuntimeFactory,
282}
283
284impl Device {
285 /// Create a new device with a single compiler pair.
286 ///
287 /// This is a convenience constructor for the common case where
288 /// a device has only one renderer/compiler combination.
289 pub fn new(
290 device: DeviceSpec,
291 allocator: Arc<dyn Allocator>,
292 renderer: Arc<dyn Renderer>,
293 compiler: Arc<dyn Compiler>,
294 runtime: RuntimeFactory,
295 ) -> Self {
296 let compilers = vec![(renderer.clone(), compiler.clone())];
297 Self { device, allocator, compilers, renderer, compiler, runtime }
298 }
299
300 /// Get the base device key (strips device ID).
301 ///
302 /// Used for compiled byte cache sharing across device instances.
303 /// Examples:
304 /// - DeviceSpec::Cpu -> "CPU"
305 /// - DeviceSpec::Cuda { device_id: 0 } -> "CUDA"
306 /// - DeviceSpec::Cuda { device_id: 1 } -> "CUDA"
307 /// - DeviceSpec::Metal { device_id: 0 } -> "Metal"
308 ///
309 /// This allows compiled CUDA kernels to be reused across CUDA:0 and CUDA:1.
310 pub fn base_device_key(&self) -> &'static str {
311 self.device.base_type()
312 }
313}
314
315/// Program specification containing source code and metadata.
316///
317/// This is returned by Renderer::render() and consumed by Compiler::compile().
318/// It bridges the gap between UOp graphs and compiled executables.
319///
320/// # Tinygrad Alignment
321///
322/// Buffer metadata (`globals`, `outs`, `ins`) matches Tinygrad's Program class:
323/// - `globals`: Buffer indices from PARAM ops
324/// - `outs`: Output buffer indices (written by STORE ops)
325/// - `ins`: Input buffer indices (read by LOAD ops)
326#[derive(Debug, Clone)]
327pub struct ProgramSpec {
328 /// Kernel name (for debugging/profiling)
329 pub name: String,
330
331 /// Generated source code (LLVM IR, CUDA C, Metal, WGSL, etc.)
332 pub src: String,
333
334 /// Device specification
335 pub device: DeviceSpec,
336
337 /// Original AST (for cache key construction via hash consing)
338 pub ast: Arc<UOp>,
339
340 /// Global work size (for GPU backends, None for CPU)
341 pub global_size: Option<[usize; 3]>,
342
343 /// Local work size (for GPU backends, None for CPU)
344 pub local_size: Option<[usize; 3]>,
345
346 /// Variable list (for symbolic shapes/strides)
347 pub vars: Vec<Variable>,
348
349 /// Variable names in order for populating vars array at runtime.
350 /// Includes thread_id at the end if threading is enabled.
351 pub var_names: Vec<String>,
352
353 /// Global buffer indices (from PARAM slot values).
354 /// Matches Tinygrad's `globals` field.
355 pub globals: Vec<usize>,
356
357 /// Output buffer indices (written by STORE ops).
358 /// Matches Tinygrad's `outs` field.
359 pub outs: Vec<usize>,
360
361 /// Input buffer indices (read by LOAD ops, excluding outputs).
362 /// Matches Tinygrad's `ins` field.
363 pub ins: Vec<usize>,
364
365 /// Number of buffer arguments (for CIF construction at compile time).
366 pub buf_count: usize,
367}
368
369impl ProgramSpec {
370 /// Create a new program specification.
371 pub fn new(name: String, src: String, device: DeviceSpec, ast: Arc<UOp>) -> Self {
372 Self {
373 name,
374 src,
375 device,
376 ast,
377 global_size: None,
378 local_size: None,
379 vars: Vec::new(),
380 var_names: Vec::new(),
381 globals: Vec::new(),
382 outs: Vec::new(),
383 ins: Vec::new(),
384 buf_count: 0,
385 }
386 }
387
388 /// Add a variable to the program.
389 pub fn add_var(&mut self, var: Variable) {
390 self.vars.push(var);
391 }
392
393 /// Set work sizes for GPU execution.
394 pub fn set_work_sizes(&mut self, global: [usize; 3], local: [usize; 3]) {
395 self.global_size = Some(global);
396 self.local_size = Some(local);
397 }
398
399 /// Set variable names for populating vars array at runtime.
400 pub fn set_var_names(&mut self, var_names: Vec<String>) {
401 self.var_names = var_names;
402 }
403
404 /// Set buffer metadata (globals, outs, ins).
405 pub fn set_buffer_metadata(&mut self, globals: Vec<usize>, outs: Vec<usize>, ins: Vec<usize>) {
406 self.globals = globals;
407 self.outs = outs;
408 self.ins = ins;
409 }
410}
411
412/// A variable in the kernel (for symbolic shapes/strides).
413///
414/// Variables represent symbolic values that are bound at kernel execution time.
415/// Examples:
416/// - Shape dimensions that vary per input
417/// - Stride values computed from shapes
418/// - Loop bounds determined by input sizes
419#[derive(Debug, Clone)]
420pub struct Variable {
421 /// Variable name (must be unique within the kernel)
422 pub name: String,
423
424 /// Minimum value (for range validation)
425 pub min: i64,
426
427 /// Maximum value (for range validation)
428 pub max: i64,
429}
430
431impl Variable {
432 /// Create a new variable.
433 pub fn new(name: String, min: i64, max: i64) -> Self {
434 Self { name, min, max }
435 }
436}