Skip to main content

ringkernel_codegen/
lib.rs

1//! Code Generation for RingKernel
2//!
3//! This crate generates GPU kernel source code (CUDA PTX, Metal MSL, WGSL)
4//! from Rust kernel definitions.
5//!
6//! # Supported Targets
7//!
8//! - CUDA PTX (sm_70+)
9//! - Metal MSL
10//! - WebGPU WGSL
11//!
12//! # Example
13//!
14//! ```
15//! use ringkernel_codegen::{CodeGenerator, Target};
16//!
17//! let generator = CodeGenerator::new();
18//! let source = generator.generate_kernel_source(
19//!     "my_kernel",
20//!     "// custom kernel code",
21//!     Target::Cuda,
22//! );
23//! ```
24
25#![warn(missing_docs)]
26
27pub mod dsl_common;
28
29use std::collections::HashMap;
30use thiserror::Error;
31
32/// Code generation errors.
33#[derive(Error, Debug)]
34pub enum CodegenError {
35    /// Template error.
36    #[error("template error: {0}")]
37    TemplateError(String),
38
39    /// Unsupported target.
40    #[error("unsupported target: {0}")]
41    UnsupportedTarget(String),
42
43    /// Invalid kernel definition.
44    #[error("invalid kernel: {0}")]
45    InvalidKernel(String),
46}
47
48/// Code generation result type.
49pub type Result<T> = std::result::Result<T, CodegenError>;
50
51/// Target GPU platform.
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
53pub enum Target {
54    /// NVIDIA CUDA (PTX).
55    Cuda,
56    /// Apple Metal (MSL).
57    Metal,
58    /// WebGPU (WGSL).
59    Wgsl,
60}
61
62impl Target {
63    /// Get file extension for the target.
64    pub fn extension(&self) -> &'static str {
65        match self {
66            Target::Cuda => "ptx",
67            Target::Metal => "metal",
68            Target::Wgsl => "wgsl",
69        }
70    }
71
72    /// Get target name.
73    pub fn name(&self) -> &'static str {
74        match self {
75            Target::Cuda => "CUDA",
76            Target::Metal => "Metal",
77            Target::Wgsl => "WebGPU",
78        }
79    }
80}
81
82/// Kernel configuration.
83#[derive(Debug, Clone)]
84pub struct KernelConfig {
85    /// Kernel identifier.
86    pub id: String,
87    /// Grid size (blocks).
88    pub grid_size: u32,
89    /// Block size (threads).
90    pub block_size: u32,
91    /// Shared memory size in bytes.
92    pub shared_memory: usize,
93    /// Input message types.
94    pub input_types: Vec<String>,
95    /// Output message types.
96    pub output_types: Vec<String>,
97}
98
99impl Default for KernelConfig {
100    fn default() -> Self {
101        Self {
102            id: "kernel".to_string(),
103            grid_size: 1,
104            block_size: 256,
105            shared_memory: 0,
106            input_types: vec![],
107            output_types: vec![],
108        }
109    }
110}
111
112/// Code generator for GPU kernels.
113pub struct CodeGenerator {
114    /// Template variables.
115    variables: HashMap<String, String>,
116}
117
118impl CodeGenerator {
119    /// Create a new code generator.
120    pub fn new() -> Self {
121        Self {
122            variables: HashMap::new(),
123        }
124    }
125
126    /// Set a template variable.
127    pub fn set_variable(&mut self, key: impl Into<String>, value: impl Into<String>) {
128        self.variables.insert(key.into(), value.into());
129    }
130
131    /// Generate kernel source code for the specified target.
132    pub fn generate_kernel_source(
133        &self,
134        kernel_id: &str,
135        user_code: &str,
136        target: Target,
137    ) -> Result<String> {
138        let template = self.get_template(target);
139        let source = self.substitute_template(template, kernel_id, user_code);
140        Ok(source)
141    }
142
143    /// Generate complete kernel file.
144    pub fn generate_kernel_file(
145        &self,
146        config: &KernelConfig,
147        user_code: &str,
148        target: Target,
149    ) -> Result<GeneratedFile> {
150        let source = self.generate_kernel_source(&config.id, user_code, target)?;
151        Ok(GeneratedFile {
152            filename: format!("{}.{}", config.id, target.extension()),
153            content: source,
154            target,
155        })
156    }
157
158    /// Generate for all targets.
159    pub fn generate_all_targets(
160        &self,
161        config: &KernelConfig,
162        user_code: &str,
163    ) -> Result<Vec<GeneratedFile>> {
164        let targets = [Target::Cuda, Target::Metal, Target::Wgsl];
165        let mut files = Vec::with_capacity(targets.len());
166
167        for target in targets {
168            files.push(self.generate_kernel_file(config, user_code, target)?);
169        }
170
171        Ok(files)
172    }
173
174    fn get_template(&self, target: Target) -> &'static str {
175        match target {
176            Target::Cuda => include_str!("templates/cuda.ptx.template"),
177            Target::Metal => include_str!("templates/metal.msl.template"),
178            Target::Wgsl => include_str!("templates/wgsl.template"),
179        }
180    }
181
182    fn substitute_template(&self, template: &str, kernel_id: &str, user_code: &str) -> String {
183        let mut result = template.to_string();
184        result = result.replace("{{KERNEL_ID}}", kernel_id);
185        result = result.replace("{{USER_CODE}}", user_code);
186        result = result.replace("// USER_KERNEL_CODE", user_code);
187
188        // Apply custom variables
189        for (key, value) in &self.variables {
190            result = result.replace(&format!("{{{{{}}}}}", key), value);
191        }
192
193        result
194    }
195}
196
197impl Default for CodeGenerator {
198    fn default() -> Self {
199        Self::new()
200    }
201}
202
203/// Generated kernel file.
204#[derive(Debug, Clone)]
205pub struct GeneratedFile {
206    /// Output filename.
207    pub filename: String,
208    /// Generated source code.
209    pub content: String,
210    /// Target platform.
211    pub target: Target,
212}
213
214/// Intrinsic mapping from Rust to GPU code.
215#[derive(Debug, Clone)]
216pub struct IntrinsicMap {
217    /// Rust function name.
218    pub rust_name: String,
219    /// CUDA equivalent.
220    pub cuda: String,
221    /// Metal equivalent.
222    pub metal: String,
223    /// WGSL equivalent.
224    pub wgsl: String,
225}
226
227impl IntrinsicMap {
228    /// Get intrinsic for the specified target.
229    pub fn get(&self, target: Target) -> &str {
230        match target {
231            Target::Cuda => &self.cuda,
232            Target::Metal => &self.metal,
233            Target::Wgsl => &self.wgsl,
234        }
235    }
236}
237
238/// Standard intrinsic mappings.
239pub fn standard_intrinsics() -> Vec<IntrinsicMap> {
240    vec![
241        IntrinsicMap {
242            rust_name: "sync_threads".to_string(),
243            cuda: "__syncthreads()".to_string(),
244            metal: "threadgroup_barrier(mem_flags::mem_threadgroup)".to_string(),
245            wgsl: "workgroupBarrier()".to_string(),
246        },
247        IntrinsicMap {
248            rust_name: "thread_fence_block".to_string(),
249            cuda: "__threadfence_block()".to_string(),
250            metal: "threadgroup_barrier(mem_flags::mem_device)".to_string(),
251            wgsl: "storageBarrier()".to_string(),
252        },
253        IntrinsicMap {
254            rust_name: "thread_fence".to_string(),
255            cuda: "__threadfence()".to_string(),
256            metal: "threadgroup_barrier(mem_flags::mem_device)".to_string(),
257            wgsl: "storageBarrier()".to_string(),
258        },
259        IntrinsicMap {
260            rust_name: "atomic_add".to_string(),
261            cuda: "atomicAdd".to_string(),
262            metal: "atomic_fetch_add_explicit".to_string(),
263            wgsl: "atomicAdd".to_string(),
264        },
265        IntrinsicMap {
266            rust_name: "atomic_cas".to_string(),
267            cuda: "atomicCAS".to_string(),
268            metal: "atomic_compare_exchange_weak_explicit".to_string(),
269            wgsl: "atomicCompareExchangeWeak".to_string(),
270        },
271    ]
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    #[test]
279    fn test_code_generator() {
280        let gen = CodeGenerator::new();
281        let source = gen
282            .generate_kernel_source("test_kernel", "// test code", Target::Cuda)
283            .unwrap();
284
285        assert!(source.contains("test_kernel") || source.contains("ring_kernel"));
286    }
287
288    #[test]
289    fn test_target_extension() {
290        assert_eq!(Target::Cuda.extension(), "ptx");
291        assert_eq!(Target::Metal.extension(), "metal");
292        assert_eq!(Target::Wgsl.extension(), "wgsl");
293    }
294
295    #[test]
296    fn test_intrinsic_mapping() {
297        let intrinsics = standard_intrinsics();
298        let sync = intrinsics
299            .iter()
300            .find(|i| i.rust_name == "sync_threads")
301            .unwrap();
302
303        assert_eq!(sync.get(Target::Cuda), "__syncthreads()");
304        assert!(sync.get(Target::Metal).contains("barrier"));
305    }
306}