1#![warn(missing_docs)]
26#![warn(clippy::unwrap_used)]
27
28pub mod dsl_common;
29
30use std::collections::HashMap;
31use thiserror::Error;
32
33#[derive(Error, Debug)]
35pub enum CodegenError {
36 #[error("template error: {0}")]
38 TemplateError(String),
39
40 #[error("unsupported target: {0}")]
42 UnsupportedTarget(String),
43
44 #[error("invalid kernel: {0}")]
46 InvalidKernel(String),
47}
48
49pub type Result<T> = std::result::Result<T, CodegenError>;
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
54pub enum Target {
55 Cuda,
57 Metal,
59 Wgsl,
61}
62
63impl Target {
64 pub fn extension(&self) -> &'static str {
66 match self {
67 Target::Cuda => "ptx",
68 Target::Metal => "metal",
69 Target::Wgsl => "wgsl",
70 }
71 }
72
73 pub fn name(&self) -> &'static str {
75 match self {
76 Target::Cuda => "CUDA",
77 Target::Metal => "Metal",
78 Target::Wgsl => "WebGPU",
79 }
80 }
81}
82
83#[derive(Debug, Clone)]
85pub struct KernelConfig {
86 pub id: String,
88 pub grid_size: u32,
90 pub block_size: u32,
92 pub shared_memory: usize,
94 pub input_types: Vec<String>,
96 pub output_types: Vec<String>,
98}
99
100impl Default for KernelConfig {
101 fn default() -> Self {
102 Self {
103 id: "kernel".to_string(),
104 grid_size: 1,
105 block_size: 256,
106 shared_memory: 0,
107 input_types: vec![],
108 output_types: vec![],
109 }
110 }
111}
112
113pub struct CodeGenerator {
115 variables: HashMap<String, String>,
117}
118
119impl CodeGenerator {
120 pub fn new() -> Self {
122 Self {
123 variables: HashMap::new(),
124 }
125 }
126
127 pub fn set_variable(&mut self, key: impl Into<String>, value: impl Into<String>) {
129 self.variables.insert(key.into(), value.into());
130 }
131
132 pub fn generate_kernel_source(
134 &self,
135 kernel_id: &str,
136 user_code: &str,
137 target: Target,
138 ) -> Result<String> {
139 let template = self.get_template(target);
140 let source = self.substitute_template(template, kernel_id, user_code);
141 Ok(source)
142 }
143
144 pub fn generate_kernel_file(
146 &self,
147 config: &KernelConfig,
148 user_code: &str,
149 target: Target,
150 ) -> Result<GeneratedFile> {
151 let source = self.generate_kernel_source(&config.id, user_code, target)?;
152 Ok(GeneratedFile {
153 filename: format!("{}.{}", config.id, target.extension()),
154 content: source,
155 target,
156 })
157 }
158
159 pub fn generate_all_targets(
161 &self,
162 config: &KernelConfig,
163 user_code: &str,
164 ) -> Result<Vec<GeneratedFile>> {
165 let targets = [Target::Cuda, Target::Metal, Target::Wgsl];
166 let mut files = Vec::with_capacity(targets.len());
167
168 for target in targets {
169 files.push(self.generate_kernel_file(config, user_code, target)?);
170 }
171
172 Ok(files)
173 }
174
175 fn get_template(&self, target: Target) -> &'static str {
176 match target {
177 Target::Cuda => include_str!("templates/cuda.ptx.template"),
178 Target::Metal => include_str!("templates/metal.msl.template"),
179 Target::Wgsl => include_str!("templates/wgsl.template"),
180 }
181 }
182
183 fn substitute_template(&self, template: &str, kernel_id: &str, user_code: &str) -> String {
184 let mut result = template.to_string();
185 result = result.replace("{{KERNEL_ID}}", kernel_id);
186 result = result.replace("{{USER_CODE}}", user_code);
187 result = result.replace("// USER_KERNEL_CODE", user_code);
188
189 for (key, value) in &self.variables {
191 result = result.replace(&format!("{{{{{}}}}}", key), value);
192 }
193
194 result
195 }
196}
197
198impl Default for CodeGenerator {
199 fn default() -> Self {
200 Self::new()
201 }
202}
203
204#[derive(Debug, Clone)]
206pub struct GeneratedFile {
207 pub filename: String,
209 pub content: String,
211 pub target: Target,
213}
214
215#[derive(Debug, Clone)]
217pub struct IntrinsicMap {
218 pub rust_name: String,
220 pub cuda: String,
222 pub metal: String,
224 pub wgsl: String,
226}
227
228impl IntrinsicMap {
229 pub fn get(&self, target: Target) -> &str {
231 match target {
232 Target::Cuda => &self.cuda,
233 Target::Metal => &self.metal,
234 Target::Wgsl => &self.wgsl,
235 }
236 }
237}
238
239pub fn standard_intrinsics() -> Vec<IntrinsicMap> {
241 vec![
242 IntrinsicMap {
243 rust_name: "sync_threads".to_string(),
244 cuda: "__syncthreads()".to_string(),
245 metal: "threadgroup_barrier(mem_flags::mem_threadgroup)".to_string(),
246 wgsl: "workgroupBarrier()".to_string(),
247 },
248 IntrinsicMap {
249 rust_name: "thread_fence_block".to_string(),
250 cuda: "__threadfence_block()".to_string(),
251 metal: "threadgroup_barrier(mem_flags::mem_device)".to_string(),
252 wgsl: "storageBarrier()".to_string(),
253 },
254 IntrinsicMap {
255 rust_name: "thread_fence".to_string(),
256 cuda: "__threadfence()".to_string(),
257 metal: "threadgroup_barrier(mem_flags::mem_device)".to_string(),
258 wgsl: "storageBarrier()".to_string(),
259 },
260 IntrinsicMap {
261 rust_name: "atomic_add".to_string(),
262 cuda: "atomicAdd".to_string(),
263 metal: "atomic_fetch_add_explicit".to_string(),
264 wgsl: "atomicAdd".to_string(),
265 },
266 IntrinsicMap {
267 rust_name: "atomic_cas".to_string(),
268 cuda: "atomicCAS".to_string(),
269 metal: "atomic_compare_exchange_weak_explicit".to_string(),
270 wgsl: "atomicCompareExchangeWeak".to_string(),
271 },
272 ]
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 #[test]
280 fn test_code_generator() {
281 let gen = CodeGenerator::new();
282 let source = gen
283 .generate_kernel_source("test_kernel", "// test code", Target::Cuda)
284 .unwrap();
285
286 assert!(source.contains("test_kernel") || source.contains("ring_kernel"));
287 }
288
289 #[test]
290 fn test_target_extension() {
291 assert_eq!(Target::Cuda.extension(), "ptx");
292 assert_eq!(Target::Metal.extension(), "metal");
293 assert_eq!(Target::Wgsl.extension(), "wgsl");
294 }
295
296 #[test]
297 fn test_intrinsic_mapping() {
298 let intrinsics = standard_intrinsics();
299 let sync = intrinsics
300 .iter()
301 .find(|i| i.rust_name == "sync_threads")
302 .unwrap();
303
304 assert_eq!(sync.get(Target::Cuda), "__syncthreads()");
305 assert!(sync.get(Target::Metal).contains("barrier"));
306 }
307}