Skip to main content

ringkernel_codegen/
lib.rs

1//! Code Generation for RingKernel
2//!
3//! This crate generates GPU kernel source code (CUDA PTX, Metal MSL, WGSL)
4//! from Rust kernel definitions.
5//!
6//! # Supported Targets
7//!
8//! - CUDA PTX (sm_70+)
9//! - Metal MSL
10//! - WebGPU WGSL
11//!
12//! # Example
13//!
14//! ```
15//! use ringkernel_codegen::{CodeGenerator, Target};
16//!
17//! let generator = CodeGenerator::new();
18//! let source = generator.generate_kernel_source(
19//!     "my_kernel",
20//!     "// custom kernel code",
21//!     Target::Cuda,
22//! );
23//! ```
24
25#![warn(missing_docs)]
26#![warn(clippy::unwrap_used)]
27
28pub mod dsl_common;
29
30use std::collections::HashMap;
31use thiserror::Error;
32
33/// Code generation errors.
34#[derive(Error, Debug)]
35pub enum CodegenError {
36    /// Template error.
37    #[error("template error: {0}")]
38    TemplateError(String),
39
40    /// Unsupported target.
41    #[error("unsupported target: {0}")]
42    UnsupportedTarget(String),
43
44    /// Invalid kernel definition.
45    #[error("invalid kernel: {0}")]
46    InvalidKernel(String),
47}
48
49/// Code generation result type.
50pub type Result<T> = std::result::Result<T, CodegenError>;
51
52/// Target GPU platform.
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
54pub enum Target {
55    /// NVIDIA CUDA (PTX).
56    Cuda,
57    /// Apple Metal (MSL).
58    Metal,
59    /// WebGPU (WGSL).
60    Wgsl,
61}
62
63impl Target {
64    /// Get file extension for the target.
65    pub fn extension(&self) -> &'static str {
66        match self {
67            Target::Cuda => "ptx",
68            Target::Metal => "metal",
69            Target::Wgsl => "wgsl",
70        }
71    }
72
73    /// Get target name.
74    pub fn name(&self) -> &'static str {
75        match self {
76            Target::Cuda => "CUDA",
77            Target::Metal => "Metal",
78            Target::Wgsl => "WebGPU",
79        }
80    }
81}
82
83/// Kernel configuration.
84#[derive(Debug, Clone)]
85pub struct KernelConfig {
86    /// Kernel identifier.
87    pub id: String,
88    /// Grid size (blocks).
89    pub grid_size: u32,
90    /// Block size (threads).
91    pub block_size: u32,
92    /// Shared memory size in bytes.
93    pub shared_memory: usize,
94    /// Input message types.
95    pub input_types: Vec<String>,
96    /// Output message types.
97    pub output_types: Vec<String>,
98}
99
100impl Default for KernelConfig {
101    fn default() -> Self {
102        Self {
103            id: "kernel".to_string(),
104            grid_size: 1,
105            block_size: 256,
106            shared_memory: 0,
107            input_types: vec![],
108            output_types: vec![],
109        }
110    }
111}
112
113/// Code generator for GPU kernels.
114pub struct CodeGenerator {
115    /// Template variables.
116    variables: HashMap<String, String>,
117}
118
119impl CodeGenerator {
120    /// Create a new code generator.
121    pub fn new() -> Self {
122        Self {
123            variables: HashMap::new(),
124        }
125    }
126
127    /// Set a template variable.
128    pub fn set_variable(&mut self, key: impl Into<String>, value: impl Into<String>) {
129        self.variables.insert(key.into(), value.into());
130    }
131
132    /// Generate kernel source code for the specified target.
133    pub fn generate_kernel_source(
134        &self,
135        kernel_id: &str,
136        user_code: &str,
137        target: Target,
138    ) -> Result<String> {
139        let template = self.get_template(target);
140        let source = self.substitute_template(template, kernel_id, user_code);
141        Ok(source)
142    }
143
144    /// Generate complete kernel file.
145    pub fn generate_kernel_file(
146        &self,
147        config: &KernelConfig,
148        user_code: &str,
149        target: Target,
150    ) -> Result<GeneratedFile> {
151        let source = self.generate_kernel_source(&config.id, user_code, target)?;
152        Ok(GeneratedFile {
153            filename: format!("{}.{}", config.id, target.extension()),
154            content: source,
155            target,
156        })
157    }
158
159    /// Generate for all targets.
160    pub fn generate_all_targets(
161        &self,
162        config: &KernelConfig,
163        user_code: &str,
164    ) -> Result<Vec<GeneratedFile>> {
165        let targets = [Target::Cuda, Target::Metal, Target::Wgsl];
166        let mut files = Vec::with_capacity(targets.len());
167
168        for target in targets {
169            files.push(self.generate_kernel_file(config, user_code, target)?);
170        }
171
172        Ok(files)
173    }
174
175    fn get_template(&self, target: Target) -> &'static str {
176        match target {
177            Target::Cuda => include_str!("templates/cuda.ptx.template"),
178            Target::Metal => include_str!("templates/metal.msl.template"),
179            Target::Wgsl => include_str!("templates/wgsl.template"),
180        }
181    }
182
183    fn substitute_template(&self, template: &str, kernel_id: &str, user_code: &str) -> String {
184        let mut result = template.to_string();
185        result = result.replace("{{KERNEL_ID}}", kernel_id);
186        result = result.replace("{{USER_CODE}}", user_code);
187        result = result.replace("// USER_KERNEL_CODE", user_code);
188
189        // Apply custom variables
190        for (key, value) in &self.variables {
191            result = result.replace(&format!("{{{{{}}}}}", key), value);
192        }
193
194        result
195    }
196}
197
198impl Default for CodeGenerator {
199    fn default() -> Self {
200        Self::new()
201    }
202}
203
204/// Generated kernel file.
205#[derive(Debug, Clone)]
206pub struct GeneratedFile {
207    /// Output filename.
208    pub filename: String,
209    /// Generated source code.
210    pub content: String,
211    /// Target platform.
212    pub target: Target,
213}
214
215/// Intrinsic mapping from Rust to GPU code.
216#[derive(Debug, Clone)]
217pub struct IntrinsicMap {
218    /// Rust function name.
219    pub rust_name: String,
220    /// CUDA equivalent.
221    pub cuda: String,
222    /// Metal equivalent.
223    pub metal: String,
224    /// WGSL equivalent.
225    pub wgsl: String,
226}
227
228impl IntrinsicMap {
229    /// Get intrinsic for the specified target.
230    pub fn get(&self, target: Target) -> &str {
231        match target {
232            Target::Cuda => &self.cuda,
233            Target::Metal => &self.metal,
234            Target::Wgsl => &self.wgsl,
235        }
236    }
237}
238
239/// Standard intrinsic mappings.
240pub fn standard_intrinsics() -> Vec<IntrinsicMap> {
241    vec![
242        IntrinsicMap {
243            rust_name: "sync_threads".to_string(),
244            cuda: "__syncthreads()".to_string(),
245            metal: "threadgroup_barrier(mem_flags::mem_threadgroup)".to_string(),
246            wgsl: "workgroupBarrier()".to_string(),
247        },
248        IntrinsicMap {
249            rust_name: "thread_fence_block".to_string(),
250            cuda: "__threadfence_block()".to_string(),
251            metal: "threadgroup_barrier(mem_flags::mem_device)".to_string(),
252            wgsl: "storageBarrier()".to_string(),
253        },
254        IntrinsicMap {
255            rust_name: "thread_fence".to_string(),
256            cuda: "__threadfence()".to_string(),
257            metal: "threadgroup_barrier(mem_flags::mem_device)".to_string(),
258            wgsl: "storageBarrier()".to_string(),
259        },
260        IntrinsicMap {
261            rust_name: "atomic_add".to_string(),
262            cuda: "atomicAdd".to_string(),
263            metal: "atomic_fetch_add_explicit".to_string(),
264            wgsl: "atomicAdd".to_string(),
265        },
266        IntrinsicMap {
267            rust_name: "atomic_cas".to_string(),
268            cuda: "atomicCAS".to_string(),
269            metal: "atomic_compare_exchange_weak_explicit".to_string(),
270            wgsl: "atomicCompareExchangeWeak".to_string(),
271        },
272    ]
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278
279    #[test]
280    fn test_code_generator() {
281        let gen = CodeGenerator::new();
282        let source = gen
283            .generate_kernel_source("test_kernel", "// test code", Target::Cuda)
284            .unwrap();
285
286        assert!(source.contains("test_kernel") || source.contains("ring_kernel"));
287    }
288
289    #[test]
290    fn test_target_extension() {
291        assert_eq!(Target::Cuda.extension(), "ptx");
292        assert_eq!(Target::Metal.extension(), "metal");
293        assert_eq!(Target::Wgsl.extension(), "wgsl");
294    }
295
296    #[test]
297    fn test_intrinsic_mapping() {
298        let intrinsics = standard_intrinsics();
299        let sync = intrinsics
300            .iter()
301            .find(|i| i.rust_name == "sync_threads")
302            .unwrap();
303
304        assert_eq!(sync.get(Target::Cuda), "__syncthreads()");
305        assert!(sync.get(Target::Metal).contains("barrier"));
306    }
307}