1#![warn(missing_docs)]
26
27use std::collections::HashMap;
28use thiserror::Error;
29
30#[derive(Error, Debug)]
32pub enum CodegenError {
33 #[error("template error: {0}")]
35 TemplateError(String),
36
37 #[error("unsupported target: {0}")]
39 UnsupportedTarget(String),
40
41 #[error("invalid kernel: {0}")]
43 InvalidKernel(String),
44}
45
46pub type Result<T> = std::result::Result<T, CodegenError>;
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
51pub enum Target {
52 Cuda,
54 Metal,
56 Wgsl,
58}
59
60impl Target {
61 pub fn extension(&self) -> &'static str {
63 match self {
64 Target::Cuda => "ptx",
65 Target::Metal => "metal",
66 Target::Wgsl => "wgsl",
67 }
68 }
69
70 pub fn name(&self) -> &'static str {
72 match self {
73 Target::Cuda => "CUDA",
74 Target::Metal => "Metal",
75 Target::Wgsl => "WebGPU",
76 }
77 }
78}
79
80#[derive(Debug, Clone)]
82pub struct KernelConfig {
83 pub id: String,
85 pub grid_size: u32,
87 pub block_size: u32,
89 pub shared_memory: usize,
91 pub input_types: Vec<String>,
93 pub output_types: Vec<String>,
95}
96
97impl Default for KernelConfig {
98 fn default() -> Self {
99 Self {
100 id: "kernel".to_string(),
101 grid_size: 1,
102 block_size: 256,
103 shared_memory: 0,
104 input_types: vec![],
105 output_types: vec![],
106 }
107 }
108}
109
110pub struct CodeGenerator {
112 variables: HashMap<String, String>,
114}
115
116impl CodeGenerator {
117 pub fn new() -> Self {
119 Self {
120 variables: HashMap::new(),
121 }
122 }
123
124 pub fn set_variable(&mut self, key: impl Into<String>, value: impl Into<String>) {
126 self.variables.insert(key.into(), value.into());
127 }
128
129 pub fn generate_kernel_source(
131 &self,
132 kernel_id: &str,
133 user_code: &str,
134 target: Target,
135 ) -> Result<String> {
136 let template = self.get_template(target);
137 let source = self.substitute_template(template, kernel_id, user_code);
138 Ok(source)
139 }
140
141 pub fn generate_kernel_file(
143 &self,
144 config: &KernelConfig,
145 user_code: &str,
146 target: Target,
147 ) -> Result<GeneratedFile> {
148 let source = self.generate_kernel_source(&config.id, user_code, target)?;
149 Ok(GeneratedFile {
150 filename: format!("{}.{}", config.id, target.extension()),
151 content: source,
152 target,
153 })
154 }
155
156 pub fn generate_all_targets(
158 &self,
159 config: &KernelConfig,
160 user_code: &str,
161 ) -> Result<Vec<GeneratedFile>> {
162 let targets = [Target::Cuda, Target::Metal, Target::Wgsl];
163 let mut files = Vec::with_capacity(targets.len());
164
165 for target in targets {
166 files.push(self.generate_kernel_file(config, user_code, target)?);
167 }
168
169 Ok(files)
170 }
171
172 fn get_template(&self, target: Target) -> &'static str {
173 match target {
174 Target::Cuda => include_str!("templates/cuda.ptx.template"),
175 Target::Metal => include_str!("templates/metal.msl.template"),
176 Target::Wgsl => include_str!("templates/wgsl.template"),
177 }
178 }
179
180 fn substitute_template(&self, template: &str, kernel_id: &str, user_code: &str) -> String {
181 let mut result = template.to_string();
182 result = result.replace("{{KERNEL_ID}}", kernel_id);
183 result = result.replace("{{USER_CODE}}", user_code);
184 result = result.replace("// USER_KERNEL_CODE", user_code);
185
186 for (key, value) in &self.variables {
188 result = result.replace(&format!("{{{{{}}}}}", key), value);
189 }
190
191 result
192 }
193}
194
195impl Default for CodeGenerator {
196 fn default() -> Self {
197 Self::new()
198 }
199}
200
201#[derive(Debug, Clone)]
203pub struct GeneratedFile {
204 pub filename: String,
206 pub content: String,
208 pub target: Target,
210}
211
212#[derive(Debug, Clone)]
214pub struct IntrinsicMap {
215 pub rust_name: String,
217 pub cuda: String,
219 pub metal: String,
221 pub wgsl: String,
223}
224
225impl IntrinsicMap {
226 pub fn get(&self, target: Target) -> &str {
228 match target {
229 Target::Cuda => &self.cuda,
230 Target::Metal => &self.metal,
231 Target::Wgsl => &self.wgsl,
232 }
233 }
234}
235
236pub fn standard_intrinsics() -> Vec<IntrinsicMap> {
238 vec![
239 IntrinsicMap {
240 rust_name: "sync_threads".to_string(),
241 cuda: "__syncthreads()".to_string(),
242 metal: "threadgroup_barrier(mem_flags::mem_threadgroup)".to_string(),
243 wgsl: "workgroupBarrier()".to_string(),
244 },
245 IntrinsicMap {
246 rust_name: "thread_fence_block".to_string(),
247 cuda: "__threadfence_block()".to_string(),
248 metal: "threadgroup_barrier(mem_flags::mem_device)".to_string(),
249 wgsl: "storageBarrier()".to_string(),
250 },
251 IntrinsicMap {
252 rust_name: "thread_fence".to_string(),
253 cuda: "__threadfence()".to_string(),
254 metal: "threadgroup_barrier(mem_flags::mem_device)".to_string(),
255 wgsl: "storageBarrier()".to_string(),
256 },
257 IntrinsicMap {
258 rust_name: "atomic_add".to_string(),
259 cuda: "atomicAdd".to_string(),
260 metal: "atomic_fetch_add_explicit".to_string(),
261 wgsl: "atomicAdd".to_string(),
262 },
263 IntrinsicMap {
264 rust_name: "atomic_cas".to_string(),
265 cuda: "atomicCAS".to_string(),
266 metal: "atomic_compare_exchange_weak_explicit".to_string(),
267 wgsl: "atomicCompareExchangeWeak".to_string(),
268 },
269 ]
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn test_code_generator() {
278 let gen = CodeGenerator::new();
279 let source = gen
280 .generate_kernel_source("test_kernel", "// test code", Target::Cuda)
281 .unwrap();
282
283 assert!(source.contains("test_kernel") || source.contains("ring_kernel"));
284 }
285
286 #[test]
287 fn test_target_extension() {
288 assert_eq!(Target::Cuda.extension(), "ptx");
289 assert_eq!(Target::Metal.extension(), "metal");
290 assert_eq!(Target::Wgsl.extension(), "wgsl");
291 }
292
293 #[test]
294 fn test_intrinsic_mapping() {
295 let intrinsics = standard_intrinsics();
296 let sync = intrinsics
297 .iter()
298 .find(|i| i.rust_name == "sync_threads")
299 .unwrap();
300
301 assert_eq!(sync.get(Target::Cuda), "__syncthreads()");
302 assert!(sync.get(Target::Metal).contains("barrier"));
303 }
304}