ringkernel_codegen/
lib.rs

1//! Code Generation for RingKernel
2//!
3//! This crate generates GPU kernel source code (CUDA PTX, Metal MSL, WGSL)
4//! from Rust kernel definitions.
5//!
6//! # Supported Targets
7//!
8//! - CUDA PTX (sm_70+)
9//! - Metal MSL
10//! - WebGPU WGSL
11//!
12//! # Example
13//!
14//! ```
15//! use ringkernel_codegen::{CodeGenerator, Target};
16//!
17//! let generator = CodeGenerator::new();
18//! let source = generator.generate_kernel_source(
19//!     "my_kernel",
20//!     "// custom kernel code",
21//!     Target::Cuda,
22//! );
23//! ```
24
25#![warn(missing_docs)]
26
27use std::collections::HashMap;
28use thiserror::Error;
29
30/// Code generation errors.
31#[derive(Error, Debug)]
32pub enum CodegenError {
33    /// Template error.
34    #[error("template error: {0}")]
35    TemplateError(String),
36
37    /// Unsupported target.
38    #[error("unsupported target: {0}")]
39    UnsupportedTarget(String),
40
41    /// Invalid kernel definition.
42    #[error("invalid kernel: {0}")]
43    InvalidKernel(String),
44}
45
46/// Code generation result type.
47pub type Result<T> = std::result::Result<T, CodegenError>;
48
49/// Target GPU platform.
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
51pub enum Target {
52    /// NVIDIA CUDA (PTX).
53    Cuda,
54    /// Apple Metal (MSL).
55    Metal,
56    /// WebGPU (WGSL).
57    Wgsl,
58}
59
60impl Target {
61    /// Get file extension for the target.
62    pub fn extension(&self) -> &'static str {
63        match self {
64            Target::Cuda => "ptx",
65            Target::Metal => "metal",
66            Target::Wgsl => "wgsl",
67        }
68    }
69
70    /// Get target name.
71    pub fn name(&self) -> &'static str {
72        match self {
73            Target::Cuda => "CUDA",
74            Target::Metal => "Metal",
75            Target::Wgsl => "WebGPU",
76        }
77    }
78}
79
80/// Kernel configuration.
81#[derive(Debug, Clone)]
82pub struct KernelConfig {
83    /// Kernel identifier.
84    pub id: String,
85    /// Grid size (blocks).
86    pub grid_size: u32,
87    /// Block size (threads).
88    pub block_size: u32,
89    /// Shared memory size in bytes.
90    pub shared_memory: usize,
91    /// Input message types.
92    pub input_types: Vec<String>,
93    /// Output message types.
94    pub output_types: Vec<String>,
95}
96
97impl Default for KernelConfig {
98    fn default() -> Self {
99        Self {
100            id: "kernel".to_string(),
101            grid_size: 1,
102            block_size: 256,
103            shared_memory: 0,
104            input_types: vec![],
105            output_types: vec![],
106        }
107    }
108}
109
110/// Code generator for GPU kernels.
111pub struct CodeGenerator {
112    /// Template variables.
113    variables: HashMap<String, String>,
114}
115
116impl CodeGenerator {
117    /// Create a new code generator.
118    pub fn new() -> Self {
119        Self {
120            variables: HashMap::new(),
121        }
122    }
123
124    /// Set a template variable.
125    pub fn set_variable(&mut self, key: impl Into<String>, value: impl Into<String>) {
126        self.variables.insert(key.into(), value.into());
127    }
128
129    /// Generate kernel source code for the specified target.
130    pub fn generate_kernel_source(
131        &self,
132        kernel_id: &str,
133        user_code: &str,
134        target: Target,
135    ) -> Result<String> {
136        let template = self.get_template(target);
137        let source = self.substitute_template(template, kernel_id, user_code);
138        Ok(source)
139    }
140
141    /// Generate complete kernel file.
142    pub fn generate_kernel_file(
143        &self,
144        config: &KernelConfig,
145        user_code: &str,
146        target: Target,
147    ) -> Result<GeneratedFile> {
148        let source = self.generate_kernel_source(&config.id, user_code, target)?;
149        Ok(GeneratedFile {
150            filename: format!("{}.{}", config.id, target.extension()),
151            content: source,
152            target,
153        })
154    }
155
156    /// Generate for all targets.
157    pub fn generate_all_targets(
158        &self,
159        config: &KernelConfig,
160        user_code: &str,
161    ) -> Result<Vec<GeneratedFile>> {
162        let targets = [Target::Cuda, Target::Metal, Target::Wgsl];
163        let mut files = Vec::with_capacity(targets.len());
164
165        for target in targets {
166            files.push(self.generate_kernel_file(config, user_code, target)?);
167        }
168
169        Ok(files)
170    }
171
172    fn get_template(&self, target: Target) -> &'static str {
173        match target {
174            Target::Cuda => include_str!("templates/cuda.ptx.template"),
175            Target::Metal => include_str!("templates/metal.msl.template"),
176            Target::Wgsl => include_str!("templates/wgsl.template"),
177        }
178    }
179
180    fn substitute_template(&self, template: &str, kernel_id: &str, user_code: &str) -> String {
181        let mut result = template.to_string();
182        result = result.replace("{{KERNEL_ID}}", kernel_id);
183        result = result.replace("{{USER_CODE}}", user_code);
184        result = result.replace("// USER_KERNEL_CODE", user_code);
185
186        // Apply custom variables
187        for (key, value) in &self.variables {
188            result = result.replace(&format!("{{{{{}}}}}", key), value);
189        }
190
191        result
192    }
193}
194
195impl Default for CodeGenerator {
196    fn default() -> Self {
197        Self::new()
198    }
199}
200
201/// Generated kernel file.
202#[derive(Debug, Clone)]
203pub struct GeneratedFile {
204    /// Output filename.
205    pub filename: String,
206    /// Generated source code.
207    pub content: String,
208    /// Target platform.
209    pub target: Target,
210}
211
212/// Intrinsic mapping from Rust to GPU code.
213#[derive(Debug, Clone)]
214pub struct IntrinsicMap {
215    /// Rust function name.
216    pub rust_name: String,
217    /// CUDA equivalent.
218    pub cuda: String,
219    /// Metal equivalent.
220    pub metal: String,
221    /// WGSL equivalent.
222    pub wgsl: String,
223}
224
225impl IntrinsicMap {
226    /// Get intrinsic for the specified target.
227    pub fn get(&self, target: Target) -> &str {
228        match target {
229            Target::Cuda => &self.cuda,
230            Target::Metal => &self.metal,
231            Target::Wgsl => &self.wgsl,
232        }
233    }
234}
235
236/// Standard intrinsic mappings.
237pub fn standard_intrinsics() -> Vec<IntrinsicMap> {
238    vec![
239        IntrinsicMap {
240            rust_name: "sync_threads".to_string(),
241            cuda: "__syncthreads()".to_string(),
242            metal: "threadgroup_barrier(mem_flags::mem_threadgroup)".to_string(),
243            wgsl: "workgroupBarrier()".to_string(),
244        },
245        IntrinsicMap {
246            rust_name: "thread_fence_block".to_string(),
247            cuda: "__threadfence_block()".to_string(),
248            metal: "threadgroup_barrier(mem_flags::mem_device)".to_string(),
249            wgsl: "storageBarrier()".to_string(),
250        },
251        IntrinsicMap {
252            rust_name: "thread_fence".to_string(),
253            cuda: "__threadfence()".to_string(),
254            metal: "threadgroup_barrier(mem_flags::mem_device)".to_string(),
255            wgsl: "storageBarrier()".to_string(),
256        },
257        IntrinsicMap {
258            rust_name: "atomic_add".to_string(),
259            cuda: "atomicAdd".to_string(),
260            metal: "atomic_fetch_add_explicit".to_string(),
261            wgsl: "atomicAdd".to_string(),
262        },
263        IntrinsicMap {
264            rust_name: "atomic_cas".to_string(),
265            cuda: "atomicCAS".to_string(),
266            metal: "atomic_compare_exchange_weak_explicit".to_string(),
267            wgsl: "atomicCompareExchangeWeak".to_string(),
268        },
269    ]
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    #[test]
277    fn test_code_generator() {
278        let gen = CodeGenerator::new();
279        let source = gen
280            .generate_kernel_source("test_kernel", "// test code", Target::Cuda)
281            .unwrap();
282
283        assert!(source.contains("test_kernel") || source.contains("ring_kernel"));
284    }
285
286    #[test]
287    fn test_target_extension() {
288        assert_eq!(Target::Cuda.extension(), "ptx");
289        assert_eq!(Target::Metal.extension(), "metal");
290        assert_eq!(Target::Wgsl.extension(), "wgsl");
291    }
292
293    #[test]
294    fn test_intrinsic_mapping() {
295        let intrinsics = standard_intrinsics();
296        let sync = intrinsics
297            .iter()
298            .find(|i| i.rust_name == "sync_threads")
299            .unwrap();
300
301        assert_eq!(sync.get(Target::Cuda), "__syncthreads()");
302        assert!(sync.get(Target::Metal).contains("barrier"));
303    }
304}