morok_device/device.rs
1//! Device abstraction following Tinygrad's architecture.
2//!
3//! This module provides a unified Device abstraction that owns:
4//! - **Renderer**: Transforms UOp graphs into source code (ProgramSpec)
5//! - **Compiler**: Transforms source code into executable bytes
6//! - **Runtime**: Creates executable Programs from compiled bytes
7//! - **Allocator**: Manages memory allocation for buffers
8//!
9//! This design allows multiple backends (LLVM, CUDA, Metal, WebGPU) to coexist
10//! and share compiled kernels via the method cache.
11
12use std::sync::Arc;
13
14use morok_dtype::DeviceSpec;
15use morok_ir::UOp;
16
17use crate::allocator::Allocator;
18use crate::error::Result;
19
20/// A compiled, executable kernel program.
21///
22/// This trait abstracts over different backend executors (LLVM JIT, CUDA, Metal, etc.).
23/// Each backend implements this to provide unified execution interface.
24///
25/// Note: This trait does not require Send + Sync because some backends (like LLVM JIT)
26/// use non-thread-safe types. Programs are typically executed on the same thread where
27/// they were compiled, and caching/sharing is handled at a higher level.
28///
29/// # Tinygrad Alignment
30///
31/// This trait follows Tinygrad's `Program` interface where variable values are
32/// passed as a positional tuple/array (`vals`) rather than a named HashMap.
33/// The order matches `var_names` in `CompiledSpec`.
34pub trait Program {
35 /// Execute the kernel with given buffers and variable values.
36 ///
37 /// # Arguments
38 ///
39 /// * `buffers` - Raw pointers to buffer data (input and output buffers)
40 /// * `vals` - Variable values in positional order (matches `var_names` in CompiledSpec)
41 /// * `global_size` - Global work size (for GPU backends, None for CPU)
42 /// * `local_size` - Local work size (for GPU backends, None for CPU)
43 ///
44 /// # Safety
45 ///
46 /// This is unsafe because:
47 /// - Buffer pointers must be valid and properly aligned
48 /// - Buffer sizes must match what the kernel expects
49 /// - Caller must ensure no data races during execution
50 unsafe fn execute(
51 &self,
52 buffers: &[*mut u8],
53 vals: &[i64],
54 global_size: Option<[usize; 3]>,
55 local_size: Option<[usize; 3]>,
56 ) -> Result<()>;
57
58 /// Get the kernel name (for debugging/profiling).
59 fn name(&self) -> &str;
60}
61
62/// Compilation result carrying source (JIT) or bytes (AOT).
63///
64/// Different backends need different information:
65/// - LLVM JIT: needs source code to compile during runtime
66/// - CUDA: needs PTX/CUBIN bytes to load
67/// - Metal: needs metallib bytes to load
68///
69/// This design allows the RuntimeFactory to access whatever it needs
70/// without requiring separate code paths for JIT vs AOT backends.
71#[derive(Debug, Clone)]
72pub struct CompiledSpec {
73 /// Entry point function name
74 pub name: String,
75
76 /// Source code (for JIT backends like LLVM)
77 /// Set to Some(...) for LLVM JIT, None for AOT backends
78 pub src: Option<String>,
79
80 /// Compiled bytes (for AOT backends like CUDA/Metal)
81 /// Empty for LLVM JIT, populated for AOT backends
82 pub bytes: Vec<u8>,
83
84 /// Original AST for cache key construction via hash consing
85 pub ast: Arc<UOp>,
86
87 /// Variable names in order for populating vars array at runtime.
88 /// Includes thread_id at the end if threading is enabled.
89 pub var_names: Vec<String>,
90
91 /// Global work size for dispatch (GPU backends, CPU threading)
92 /// For CPU threading: [thread_count, 1, 1]
93 pub global_size: Option<[usize; 3]>,
94
95 /// Local work size for dispatch (GPU backends)
96 pub local_size: Option<[usize; 3]>,
97
98 /// Number of buffer arguments (for CIF construction at compile time).
99 pub buf_count: usize,
100}
101
102impl CompiledSpec {
103 /// Create a new CompiledSpec for JIT backends (source-based).
104 pub fn from_source(name: String, src: String, ast: Arc<UOp>, buf_count: usize) -> Self {
105 Self {
106 name,
107 src: Some(src),
108 bytes: Vec::new(),
109 ast,
110 var_names: Vec::new(),
111 global_size: None,
112 local_size: None,
113 buf_count,
114 }
115 }
116
117 /// Create a new CompiledSpec for AOT backends (bytecode-based).
118 pub fn from_bytes(name: String, bytes: Vec<u8>, ast: Arc<UOp>) -> Self {
119 Self { name, src: None, bytes, ast, var_names: Vec::new(), global_size: None, local_size: None, buf_count: 0 }
120 }
121
122 /// Create a new CompiledSpec with work sizes for JIT backends.
123 pub fn from_source_with_sizes(
124 name: String,
125 src: String,
126 ast: Arc<UOp>,
127 global_size: Option<[usize; 3]>,
128 local_size: Option<[usize; 3]>,
129 buf_count: usize,
130 ) -> Self {
131 Self { name, src: Some(src), bytes: Vec::new(), ast, var_names: Vec::new(), global_size, local_size, buf_count }
132 }
133}
134
135/// A compiler that transforms source code into a compiled specification.
136///
137/// This trait abstracts over different compilation backends:
138/// - LLVM: IR validation (JIT compiles at runtime)
139/// - CUDA: CUDA C -> PTX/CUBIN
140/// - Metal: Metal Shading Language -> metallib
141/// - WebGPU: WGSL -> SPIR-V
142pub trait Compiler: Send + Sync {
143 /// Compile a program specification into executable form.
144 ///
145 /// # Arguments
146 ///
147 /// * `spec` - The program specification containing source code and metadata
148 ///
149 /// # Returns
150 ///
151 /// A CompiledSpec containing:
152 /// - For JIT backends (LLVM): source code in `src` field, empty `bytes`
153 /// - For AOT backends (CUDA/Metal): compiled bytes in `bytes` field, no `src`
154 ///
155 /// # Examples
156 ///
157 /// JIT backend (LLVM):
158 /// ```ignore
159 /// let compiled = compiler.compile(&spec)?;
160 /// assert!(compiled.src.is_some());
161 /// assert!(compiled.bytes.is_empty());
162 /// ```
163 ///
164 /// AOT backend (CUDA):
165 /// ```ignore
166 /// let compiled = compiler.compile(&spec)?;
167 /// assert!(compiled.src.is_none());
168 /// assert!(!compiled.bytes.is_empty());
169 /// ```
170 fn compile(&self, spec: &ProgramSpec) -> Result<CompiledSpec>;
171
172 /// Optional cache key for this compiler configuration.
173 ///
174 /// Used to differentiate compiled artifacts when the same device type
175 /// can have multiple compiler configurations (e.g., different optimization levels).
176 ///
177 /// Returns None if all instances of this compiler produce identical output.
178 fn cache_key(&self) -> Option<&str> {
179 None
180 }
181}
182
183/// A renderer that transforms UOp graphs into source code.
184///
185/// This trait abstracts over different code generation backends:
186/// - LLVM IR generator
187/// - CUDA C generator
188/// - Metal Shading Language generator
189/// - WGSL generator
190pub trait Renderer: Send + Sync {
191 /// Render a UOp graph into source code.
192 ///
193 /// # Arguments
194 ///
195 /// * `ast` - The kernel AST (UOp graph rooted at KERNEL op)
196 /// * `name` - Optional kernel name for debugging (e.g., "r_g16l16R32u4").
197 /// Falls back to "kernel" if None.
198 ///
199 /// # Returns
200 ///
201 /// A ProgramSpec containing:
202 /// - Generated source code
203 /// - Entry point name
204 /// - Variable list
205 /// - Work sizes (for GPU backends)
206 fn render(&self, ast: &Arc<UOp>, name: Option<&str>) -> Result<ProgramSpec>;
207
208 /// Get the device spec for this renderer.
209 ///
210 /// This is used for cache key construction and device selection.
211 fn device(&self) -> &DeviceSpec;
212
213 /// Returns decomposition patterns for operations this backend doesn't support.
214 ///
215 /// This is used by the realization pass to decompose complex operations
216 /// into simpler primitives before rendering.
217 ///
218 /// # Default Implementation
219 ///
220 /// Returns `None`, meaning no decomposition is needed (backend supports all ops).
221 /// Backends that don't support certain operations (e.g., transcendentals)
222 /// should override this to return appropriate patterns.
223 fn decompositor(&self) -> Option<morok_ir::pattern::TypedPatternMatcher<()>> {
224 None
225 }
226}
227
228/// A factory function that creates executable Programs from a compiled specification.
229///
230/// This is a function pointer that wraps the backend-specific loader:
231/// - LLVM: Extract source from CompiledSpec and JIT compile
232/// - CUDA: Extract bytes from CompiledSpec and call cuModuleLoadData + cuModuleGetFunction
233/// - Metal: Extract bytes from CompiledSpec and call newLibraryWithData + newFunctionWithName
234/// - WebGPU: Extract bytes from CompiledSpec and call createShaderModule
235///
236/// The CompiledSpec contains either source (for JIT) or bytes (for AOT),
237/// allowing each backend to access what it needs.
238pub type RuntimeFactory = Arc<dyn Fn(&CompiledSpec) -> Result<Box<dyn Program>> + Send + Sync>;
239
240/// A (Renderer, Compiler) pair for a specific backend.
241///
242/// Devices can have multiple compiler pairs (e.g., different optimization levels).
243pub type CompilerPair = (Arc<dyn Renderer>, Arc<dyn Compiler>);
244
245/// A device that owns renderer, compiler, runtime, and allocator.
246///
247/// This follows Tinygrad's architecture where a Device is a complete
248/// compilation + execution unit for a specific backend.
249///
250/// # Example
251///
252/// ```ignore
253/// let cpu_device = create_cpu_device()?;
254/// let spec = cpu_device.renderer.render(&kernel_ast, Some("E_L3"))?;
255/// let compiled = cpu_device.compiler.compile(&spec)?;
256/// let program = (cpu_device.runtime)(&compiled)?;
257/// unsafe { program.execute(&buffers, &vals, None, None)?; }
258/// ```
259pub struct Device {
260 /// Device specification
261 pub device: DeviceSpec,
262
263 /// Memory allocator for this device
264 pub allocator: Arc<dyn Allocator>,
265
266 /// Available (renderer, compiler) pairs for this device
267 ///
268 /// Most devices have one pair, but some may have multiple
269 /// (e.g., different optimization levels or compilation modes).
270 pub compilers: Vec<CompilerPair>,
271
272 /// Primary renderer for this device
273 ///
274 /// This is typically compilers[0].0, stored separately for convenience.
275 pub renderer: Arc<dyn Renderer>,
276
277 /// Primary compiler for this device
278 ///
279 /// This is typically compilers[0].1, stored separately for convenience.
280 pub compiler: Arc<dyn Compiler>,
281
282 /// Runtime factory for creating executable programs
283 ///
284 /// Takes (entry_point, compiled_bytes) and returns a Program.
285 pub runtime: RuntimeFactory,
286}
287
288impl Device {
289 /// Create a new device with a single compiler pair.
290 ///
291 /// This is a convenience constructor for the common case where
292 /// a device has only one renderer/compiler combination.
293 pub fn new(
294 device: DeviceSpec,
295 allocator: Arc<dyn Allocator>,
296 renderer: Arc<dyn Renderer>,
297 compiler: Arc<dyn Compiler>,
298 runtime: RuntimeFactory,
299 ) -> Self {
300 let compilers = vec![(renderer.clone(), compiler.clone())];
301 Self { device, allocator, compilers, renderer, compiler, runtime }
302 }
303
304 /// Get the base device key (strips device ID).
305 ///
306 /// Used for compiled byte cache sharing across device instances.
307 /// Examples:
308 /// - DeviceSpec::Cpu -> "CPU"
309 /// - DeviceSpec::Cuda { device_id: 0 } -> "CUDA"
310 /// - DeviceSpec::Cuda { device_id: 1 } -> "CUDA"
311 /// - DeviceSpec::Metal { device_id: 0 } -> "Metal"
312 ///
313 /// This allows compiled CUDA kernels to be reused across CUDA:0 and CUDA:1.
314 pub fn base_device_key(&self) -> &'static str {
315 self.device.base_type()
316 }
317}
318
319/// Program specification containing source code and metadata.
320///
321/// This is returned by Renderer::render() and consumed by Compiler::compile().
322/// It bridges the gap between UOp graphs and compiled executables.
323///
324/// # Tinygrad Alignment
325///
326/// Buffer metadata (`globals`, `outs`, `ins`) matches Tinygrad's Program class:
327/// - `globals`: Buffer indices from DefineGlobal ops
328/// - `outs`: Output buffer indices (written by STORE ops)
329/// - `ins`: Input buffer indices (read by LOAD ops)
330#[derive(Debug, Clone)]
331pub struct ProgramSpec {
332 /// Kernel name (for debugging/profiling)
333 pub name: String,
334
335 /// Generated source code (LLVM IR, CUDA C, Metal, WGSL, etc.)
336 pub src: String,
337
338 /// Device specification
339 pub device: DeviceSpec,
340
341 /// Original AST (for cache key construction via hash consing)
342 pub ast: Arc<UOp>,
343
344 /// Global work size (for GPU backends, None for CPU)
345 pub global_size: Option<[usize; 3]>,
346
347 /// Local work size (for GPU backends, None for CPU)
348 pub local_size: Option<[usize; 3]>,
349
350 /// Variable list (for symbolic shapes/strides)
351 pub vars: Vec<Variable>,
352
353 /// Variable names in order for populating vars array at runtime.
354 /// Includes thread_id at the end if threading is enabled.
355 pub var_names: Vec<String>,
356
357 /// Global buffer indices (from DefineGlobal argument values).
358 /// Matches Tinygrad's `globals` field.
359 pub globals: Vec<usize>,
360
361 /// Output buffer indices (written by STORE ops).
362 /// Matches Tinygrad's `outs` field.
363 pub outs: Vec<usize>,
364
365 /// Input buffer indices (read by LOAD ops, excluding outputs).
366 /// Matches Tinygrad's `ins` field.
367 pub ins: Vec<usize>,
368
369 /// Number of buffer arguments (for CIF construction at compile time).
370 pub buf_count: usize,
371}
372
373impl ProgramSpec {
374 /// Create a new program specification.
375 pub fn new(name: String, src: String, device: DeviceSpec, ast: Arc<UOp>) -> Self {
376 Self {
377 name,
378 src,
379 device,
380 ast,
381 global_size: None,
382 local_size: None,
383 vars: Vec::new(),
384 var_names: Vec::new(),
385 globals: Vec::new(),
386 outs: Vec::new(),
387 ins: Vec::new(),
388 buf_count: 0,
389 }
390 }
391
392 /// Add a variable to the program.
393 pub fn add_var(&mut self, var: Variable) {
394 self.vars.push(var);
395 }
396
397 /// Set work sizes for GPU execution.
398 pub fn set_work_sizes(&mut self, global: [usize; 3], local: [usize; 3]) {
399 self.global_size = Some(global);
400 self.local_size = Some(local);
401 }
402
403 /// Set variable names for populating vars array at runtime.
404 pub fn set_var_names(&mut self, var_names: Vec<String>) {
405 self.var_names = var_names;
406 }
407
408 /// Set buffer metadata (globals, outs, ins).
409 pub fn set_buffer_metadata(&mut self, globals: Vec<usize>, outs: Vec<usize>, ins: Vec<usize>) {
410 self.globals = globals;
411 self.outs = outs;
412 self.ins = ins;
413 }
414}
415
416/// A variable in the kernel (for symbolic shapes/strides).
417///
418/// Variables represent symbolic values that are bound at kernel execution time.
419/// Examples:
420/// - Shape dimensions that vary per input
421/// - Stride values computed from shapes
422/// - Loop bounds determined by input sizes
423#[derive(Debug, Clone)]
424pub struct Variable {
425 /// Variable name (must be unique within the kernel)
426 pub name: String,
427
428 /// Minimum value (for range validation)
429 pub min: i64,
430
431 /// Maximum value (for range validation)
432 pub max: i64,
433}
434
435impl Variable {
436 /// Create a new variable.
437 pub fn new(name: String, min: i64, max: i64) -> Self {
438 Self { name, min, max }
439 }
440}