morok_device/device.rs
1//! Device abstraction following Tinygrad's architecture.
2//!
3//! This module provides a unified Device abstraction that owns:
4//! - **Renderer**: Transforms UOp graphs into source code (ProgramSpec)
5//! - **Compiler**: Transforms source code into executable bytes
6//! - **Runtime**: Creates executable Programs from compiled bytes
7//! - **Allocator**: Manages memory allocation for buffers
8//!
9//! This design allows multiple backends (LLVM, CUDA, Metal, WebGPU) to coexist
10//! and share compiled kernels via the method cache.
11
12use std::sync::Arc;
13
14use morok_dtype::DeviceSpec;
15use morok_ir::UOp;
16
17use crate::allocator::Allocator;
18use crate::error::Result;
19
20/// A compiled, executable kernel program.
21///
22/// This trait abstracts over different backend executors (LLVM JIT, CUDA, Metal, etc.).
23/// Each backend implements this to provide unified execution interface.
24///
25/// Note: This trait does not require Send + Sync because some backends (like LLVM JIT)
26/// use non-thread-safe types. Programs are typically executed on the same thread where
27/// they were compiled, and caching/sharing is handled at a higher level.
28///
29/// # Tinygrad Alignment
30///
31/// This trait follows Tinygrad's `Program` interface where variable values are
32/// passed as a positional tuple/array (`vals`) rather than a named HashMap.
33/// The order matches `var_names` in `CompiledSpec`.
34pub trait Program {
35 /// Execute the kernel with given buffers and variable values.
36 ///
37 /// # Arguments
38 ///
39 /// * `buffers` - Raw pointers to buffer data (input and output buffers)
40 /// * `vals` - Variable values in positional order (matches `var_names` in CompiledSpec)
41 /// * `global_size` - Global work size (for GPU backends, None for CPU)
42 /// * `local_size` - Local work size (for GPU backends, None for CPU)
43 ///
44 /// # Safety
45 ///
46 /// This is unsafe because:
47 /// - Buffer pointers must be valid and properly aligned
48 /// - Buffer sizes must match what the kernel expects
49 /// - Caller must ensure no data races during execution
50 unsafe fn execute(
51 &self,
52 buffers: &[*mut u8],
53 vals: &[i64],
54 global_size: Option<[usize; 3]>,
55 local_size: Option<[usize; 3]>,
56 ) -> Result<()>;
57
58 /// Get the kernel name (for debugging/profiling).
59 fn name(&self) -> &str;
60}
61
62/// Compilation result carrying source (JIT) or bytes (AOT).
63///
64/// Different backends need different information:
65/// - LLVM JIT: needs source code to compile during runtime
66/// - CUDA: needs PTX/CUBIN bytes to load
67/// - Metal: needs metallib bytes to load
68///
69/// This design allows the RuntimeFactory to access whatever it needs
70/// without requiring separate code paths for JIT vs AOT backends.
71#[derive(Debug, Clone)]
72pub struct CompiledSpec {
73 /// Entry point function name
74 pub name: String,
75
76 /// Source code (for JIT backends like LLVM)
77 /// Set to Some(...) for LLVM JIT, None for AOT backends
78 pub src: Option<String>,
79
80 /// Compiled bytes (for AOT backends like CUDA/Metal)
81 /// Empty for LLVM JIT, populated for AOT backends
82 pub bytes: Vec<u8>,
83
84 /// Original AST for cache key construction via hash consing
85 pub ast: Arc<UOp>,
86
87 /// Variable names in order for populating vars array at runtime.
88 /// Includes thread_id at the end if threading is enabled.
89 pub var_names: Vec<String>,
90
91 /// Global work size for dispatch (GPU backends, CPU threading)
92 /// For CPU threading: [thread_count, 1, 1]
93 pub global_size: Option<[usize; 3]>,
94
95 /// Local work size for dispatch (GPU backends)
96 pub local_size: Option<[usize; 3]>,
97}
98
99impl CompiledSpec {
100 /// Create a new CompiledSpec for JIT backends (source-based).
101 pub fn from_source(name: String, src: String, ast: Arc<UOp>) -> Self {
102 Self {
103 name,
104 src: Some(src),
105 bytes: Vec::new(),
106 ast,
107 var_names: Vec::new(),
108 global_size: None,
109 local_size: None,
110 }
111 }
112
113 /// Create a new CompiledSpec for AOT backends (bytecode-based).
114 pub fn from_bytes(name: String, bytes: Vec<u8>, ast: Arc<UOp>) -> Self {
115 Self { name, src: None, bytes, ast, var_names: Vec::new(), global_size: None, local_size: None }
116 }
117
118 /// Create a new CompiledSpec with work sizes for JIT backends.
119 pub fn from_source_with_sizes(
120 name: String,
121 src: String,
122 ast: Arc<UOp>,
123 global_size: Option<[usize; 3]>,
124 local_size: Option<[usize; 3]>,
125 ) -> Self {
126 Self { name, src: Some(src), bytes: Vec::new(), ast, var_names: Vec::new(), global_size, local_size }
127 }
128}
129
130/// A compiler that transforms source code into a compiled specification.
131///
132/// This trait abstracts over different compilation backends:
133/// - LLVM: IR validation (JIT compiles at runtime)
134/// - CUDA: CUDA C -> PTX/CUBIN
135/// - Metal: Metal Shading Language -> metallib
136/// - WebGPU: WGSL -> SPIR-V
137pub trait Compiler: Send + Sync {
138 /// Compile a program specification into executable form.
139 ///
140 /// # Arguments
141 ///
142 /// * `spec` - The program specification containing source code and metadata
143 ///
144 /// # Returns
145 ///
146 /// A CompiledSpec containing:
147 /// - For JIT backends (LLVM): source code in `src` field, empty `bytes`
148 /// - For AOT backends (CUDA/Metal): compiled bytes in `bytes` field, no `src`
149 ///
150 /// # Examples
151 ///
152 /// JIT backend (LLVM):
153 /// ```ignore
154 /// let compiled = compiler.compile(&spec)?;
155 /// assert!(compiled.src.is_some());
156 /// assert!(compiled.bytes.is_empty());
157 /// ```
158 ///
159 /// AOT backend (CUDA):
160 /// ```ignore
161 /// let compiled = compiler.compile(&spec)?;
162 /// assert!(compiled.src.is_none());
163 /// assert!(!compiled.bytes.is_empty());
164 /// ```
165 fn compile(&self, spec: &ProgramSpec) -> Result<CompiledSpec>;
166
167 /// Optional cache key for this compiler configuration.
168 ///
169 /// Used to differentiate compiled artifacts when the same device type
170 /// can have multiple compiler configurations (e.g., different optimization levels).
171 ///
172 /// Returns None if all instances of this compiler produce identical output.
173 fn cache_key(&self) -> Option<&str> {
174 None
175 }
176}
177
178/// A renderer that transforms UOp graphs into source code.
179///
180/// This trait abstracts over different code generation backends:
181/// - LLVM IR generator
182/// - CUDA C generator
183/// - Metal Shading Language generator
184/// - WGSL generator
185pub trait Renderer: Send + Sync {
186 /// Render a UOp graph into source code.
187 ///
188 /// # Arguments
189 ///
190 /// * `ast` - The kernel AST (UOp graph rooted at KERNEL op)
191 ///
192 /// # Returns
193 ///
194 /// A ProgramSpec containing:
195 /// - Generated source code
196 /// - Entry point name
197 /// - Variable list
198 /// - Work sizes (for GPU backends)
199 fn render(&self, ast: &Arc<UOp>) -> Result<ProgramSpec>;
200
201 /// Get the device spec for this renderer.
202 ///
203 /// This is used for cache key construction and device selection.
204 fn device(&self) -> &DeviceSpec;
205
206 /// Returns decomposition patterns for operations this backend doesn't support.
207 ///
208 /// This is used by the realization pass to decompose complex operations
209 /// into simpler primitives before rendering.
210 ///
211 /// # Default Implementation
212 ///
213 /// Returns `None`, meaning no decomposition is needed (backend supports all ops).
214 /// Backends that don't support certain operations (e.g., transcendentals)
215 /// should override this to return appropriate patterns.
216 fn decompositor(&self) -> Option<morok_ir::pattern::TypedPatternMatcher<()>> {
217 None
218 }
219}
220
221/// A factory function that creates executable Programs from a compiled specification.
222///
223/// This is a function pointer that wraps the backend-specific loader:
224/// - LLVM: Extract source from CompiledSpec and JIT compile
225/// - CUDA: Extract bytes from CompiledSpec and call cuModuleLoadData + cuModuleGetFunction
226/// - Metal: Extract bytes from CompiledSpec and call newLibraryWithData + newFunctionWithName
227/// - WebGPU: Extract bytes from CompiledSpec and call createShaderModule
228///
229/// The CompiledSpec contains either source (for JIT) or bytes (for AOT),
230/// allowing each backend to access what it needs.
231pub type RuntimeFactory = Arc<dyn Fn(&CompiledSpec) -> Result<Box<dyn Program>> + Send + Sync>;
232
233/// A (Renderer, Compiler) pair for a specific backend.
234///
235/// Devices can have multiple compiler pairs (e.g., different optimization levels).
236pub type CompilerPair = (Arc<dyn Renderer>, Arc<dyn Compiler>);
237
238/// A device that owns renderer, compiler, runtime, and allocator.
239///
240/// This follows Tinygrad's architecture where a Device is a complete
241/// compilation + execution unit for a specific backend.
242///
243/// # Example
244///
245/// ```ignore
246/// let cpu_device = create_cpu_device()?;
247/// let spec = cpu_device.renderer.render(&kernel_ast)?;
248/// let compiled = cpu_device.compiler.compile(&spec)?;
249/// let program = (cpu_device.runtime)(&compiled)?;
250/// unsafe { program.execute(&buffers, &vals, None, None)?; }
251/// ```
252pub struct Device {
253 /// Device specification
254 pub device: DeviceSpec,
255
256 /// Memory allocator for this device
257 pub allocator: Arc<dyn Allocator>,
258
259 /// Available (renderer, compiler) pairs for this device
260 ///
261 /// Most devices have one pair, but some may have multiple
262 /// (e.g., different optimization levels or compilation modes).
263 pub compilers: Vec<CompilerPair>,
264
265 /// Primary renderer for this device
266 ///
267 /// This is typically compilers[0].0, stored separately for convenience.
268 pub renderer: Arc<dyn Renderer>,
269
270 /// Primary compiler for this device
271 ///
272 /// This is typically compilers[0].1, stored separately for convenience.
273 pub compiler: Arc<dyn Compiler>,
274
275 /// Runtime factory for creating executable programs
276 ///
277 /// Takes (entry_point, compiled_bytes) and returns a Program.
278 pub runtime: RuntimeFactory,
279}
280
281impl Device {
282 /// Create a new device with a single compiler pair.
283 ///
284 /// This is a convenience constructor for the common case where
285 /// a device has only one renderer/compiler combination.
286 pub fn new(
287 device: DeviceSpec,
288 allocator: Arc<dyn Allocator>,
289 renderer: Arc<dyn Renderer>,
290 compiler: Arc<dyn Compiler>,
291 runtime: RuntimeFactory,
292 ) -> Self {
293 let compilers = vec![(renderer.clone(), compiler.clone())];
294 Self { device, allocator, compilers, renderer, compiler, runtime }
295 }
296
297 /// Get the base device key (strips device ID).
298 ///
299 /// Used for compiled byte cache sharing across device instances.
300 /// Examples:
301 /// - DeviceSpec::Cpu -> "CPU"
302 /// - DeviceSpec::Cuda { device_id: 0 } -> "CUDA"
303 /// - DeviceSpec::Cuda { device_id: 1 } -> "CUDA"
304 /// - DeviceSpec::Metal { device_id: 0 } -> "Metal"
305 ///
306 /// This allows compiled CUDA kernels to be reused across CUDA:0 and CUDA:1.
307 pub fn base_device_key(&self) -> &'static str {
308 self.device.base_type()
309 }
310}
311
312/// Program specification containing source code and metadata.
313///
314/// This is returned by Renderer::render() and consumed by Compiler::compile().
315/// It bridges the gap between UOp graphs and compiled executables.
316///
317/// # Tinygrad Alignment
318///
319/// Buffer metadata (`globals`, `outs`, `ins`) matches Tinygrad's Program class:
320/// - `globals`: Buffer indices from DefineGlobal ops
321/// - `outs`: Output buffer indices (written by STORE ops)
322/// - `ins`: Input buffer indices (read by LOAD ops)
323#[derive(Debug, Clone)]
324pub struct ProgramSpec {
325 /// Kernel name (for debugging/profiling)
326 pub name: String,
327
328 /// Generated source code (LLVM IR, CUDA C, Metal, WGSL, etc.)
329 pub src: String,
330
331 /// Device specification
332 pub device: DeviceSpec,
333
334 /// Original AST (for cache key construction via hash consing)
335 pub ast: Arc<UOp>,
336
337 /// Global work size (for GPU backends, None for CPU)
338 pub global_size: Option<[usize; 3]>,
339
340 /// Local work size (for GPU backends, None for CPU)
341 pub local_size: Option<[usize; 3]>,
342
343 /// Variable list (for symbolic shapes/strides)
344 pub vars: Vec<Variable>,
345
346 /// Variable names in order for populating vars array at runtime.
347 /// Includes thread_id at the end if threading is enabled.
348 pub var_names: Vec<String>,
349
350 /// Global buffer indices (from DefineGlobal argument values).
351 /// Matches Tinygrad's `globals` field.
352 pub globals: Vec<usize>,
353
354 /// Output buffer indices (written by STORE ops).
355 /// Matches Tinygrad's `outs` field.
356 pub outs: Vec<usize>,
357
358 /// Input buffer indices (read by LOAD ops, excluding outputs).
359 /// Matches Tinygrad's `ins` field.
360 pub ins: Vec<usize>,
361}
362
363impl ProgramSpec {
364 /// Create a new program specification.
365 pub fn new(name: String, src: String, device: DeviceSpec, ast: Arc<UOp>) -> Self {
366 Self {
367 name,
368 src,
369 device,
370 ast,
371 global_size: None,
372 local_size: None,
373 vars: Vec::new(),
374 var_names: Vec::new(),
375 globals: Vec::new(),
376 outs: Vec::new(),
377 ins: Vec::new(),
378 }
379 }
380
381 /// Add a variable to the program.
382 pub fn add_var(&mut self, var: Variable) {
383 self.vars.push(var);
384 }
385
386 /// Set work sizes for GPU execution.
387 pub fn set_work_sizes(&mut self, global: [usize; 3], local: [usize; 3]) {
388 self.global_size = Some(global);
389 self.local_size = Some(local);
390 }
391
392 /// Set variable names for populating vars array at runtime.
393 pub fn set_var_names(&mut self, var_names: Vec<String>) {
394 self.var_names = var_names;
395 }
396
397 /// Set buffer metadata (globals, outs, ins).
398 pub fn set_buffer_metadata(&mut self, globals: Vec<usize>, outs: Vec<usize>, ins: Vec<usize>) {
399 self.globals = globals;
400 self.outs = outs;
401 self.ins = ins;
402 }
403}
404
405/// A variable in the kernel (for symbolic shapes/strides).
406///
407/// Variables represent symbolic values that are bound at kernel execution time.
408/// Examples:
409/// - Shape dimensions that vary per input
410/// - Stride values computed from shapes
411/// - Loop bounds determined by input sizes
412#[derive(Debug, Clone)]
413pub struct Variable {
414 /// Variable name (must be unique within the kernel)
415 pub name: String,
416
417 /// Minimum value (for range validation)
418 pub min: i64,
419
420 /// Maximum value (for range validation)
421 pub max: i64,
422}
423
424impl Variable {
425 /// Create a new variable.
426 pub fn new(name: String, min: i64, max: i64) -> Self {
427 Self { name, min, max }
428 }
429}