1#![warn(missing_docs)]
26
27pub mod dsl_common;
28
29use std::collections::HashMap;
30use thiserror::Error;
31
32#[derive(Error, Debug)]
34pub enum CodegenError {
35 #[error("template error: {0}")]
37 TemplateError(String),
38
39 #[error("unsupported target: {0}")]
41 UnsupportedTarget(String),
42
43 #[error("invalid kernel: {0}")]
45 InvalidKernel(String),
46}
47
48pub type Result<T> = std::result::Result<T, CodegenError>;
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
53pub enum Target {
54 Cuda,
56 Metal,
58 Wgsl,
60}
61
62impl Target {
63 pub fn extension(&self) -> &'static str {
65 match self {
66 Target::Cuda => "ptx",
67 Target::Metal => "metal",
68 Target::Wgsl => "wgsl",
69 }
70 }
71
72 pub fn name(&self) -> &'static str {
74 match self {
75 Target::Cuda => "CUDA",
76 Target::Metal => "Metal",
77 Target::Wgsl => "WebGPU",
78 }
79 }
80}
81
82#[derive(Debug, Clone)]
84pub struct KernelConfig {
85 pub id: String,
87 pub grid_size: u32,
89 pub block_size: u32,
91 pub shared_memory: usize,
93 pub input_types: Vec<String>,
95 pub output_types: Vec<String>,
97}
98
99impl Default for KernelConfig {
100 fn default() -> Self {
101 Self {
102 id: "kernel".to_string(),
103 grid_size: 1,
104 block_size: 256,
105 shared_memory: 0,
106 input_types: vec![],
107 output_types: vec![],
108 }
109 }
110}
111
112pub struct CodeGenerator {
114 variables: HashMap<String, String>,
116}
117
118impl CodeGenerator {
119 pub fn new() -> Self {
121 Self {
122 variables: HashMap::new(),
123 }
124 }
125
126 pub fn set_variable(&mut self, key: impl Into<String>, value: impl Into<String>) {
128 self.variables.insert(key.into(), value.into());
129 }
130
131 pub fn generate_kernel_source(
133 &self,
134 kernel_id: &str,
135 user_code: &str,
136 target: Target,
137 ) -> Result<String> {
138 let template = self.get_template(target);
139 let source = self.substitute_template(template, kernel_id, user_code);
140 Ok(source)
141 }
142
143 pub fn generate_kernel_file(
145 &self,
146 config: &KernelConfig,
147 user_code: &str,
148 target: Target,
149 ) -> Result<GeneratedFile> {
150 let source = self.generate_kernel_source(&config.id, user_code, target)?;
151 Ok(GeneratedFile {
152 filename: format!("{}.{}", config.id, target.extension()),
153 content: source,
154 target,
155 })
156 }
157
158 pub fn generate_all_targets(
160 &self,
161 config: &KernelConfig,
162 user_code: &str,
163 ) -> Result<Vec<GeneratedFile>> {
164 let targets = [Target::Cuda, Target::Metal, Target::Wgsl];
165 let mut files = Vec::with_capacity(targets.len());
166
167 for target in targets {
168 files.push(self.generate_kernel_file(config, user_code, target)?);
169 }
170
171 Ok(files)
172 }
173
174 fn get_template(&self, target: Target) -> &'static str {
175 match target {
176 Target::Cuda => include_str!("templates/cuda.ptx.template"),
177 Target::Metal => include_str!("templates/metal.msl.template"),
178 Target::Wgsl => include_str!("templates/wgsl.template"),
179 }
180 }
181
182 fn substitute_template(&self, template: &str, kernel_id: &str, user_code: &str) -> String {
183 let mut result = template.to_string();
184 result = result.replace("{{KERNEL_ID}}", kernel_id);
185 result = result.replace("{{USER_CODE}}", user_code);
186 result = result.replace("// USER_KERNEL_CODE", user_code);
187
188 for (key, value) in &self.variables {
190 result = result.replace(&format!("{{{{{}}}}}", key), value);
191 }
192
193 result
194 }
195}
196
197impl Default for CodeGenerator {
198 fn default() -> Self {
199 Self::new()
200 }
201}
202
203#[derive(Debug, Clone)]
205pub struct GeneratedFile {
206 pub filename: String,
208 pub content: String,
210 pub target: Target,
212}
213
214#[derive(Debug, Clone)]
216pub struct IntrinsicMap {
217 pub rust_name: String,
219 pub cuda: String,
221 pub metal: String,
223 pub wgsl: String,
225}
226
227impl IntrinsicMap {
228 pub fn get(&self, target: Target) -> &str {
230 match target {
231 Target::Cuda => &self.cuda,
232 Target::Metal => &self.metal,
233 Target::Wgsl => &self.wgsl,
234 }
235 }
236}
237
238pub fn standard_intrinsics() -> Vec<IntrinsicMap> {
240 vec![
241 IntrinsicMap {
242 rust_name: "sync_threads".to_string(),
243 cuda: "__syncthreads()".to_string(),
244 metal: "threadgroup_barrier(mem_flags::mem_threadgroup)".to_string(),
245 wgsl: "workgroupBarrier()".to_string(),
246 },
247 IntrinsicMap {
248 rust_name: "thread_fence_block".to_string(),
249 cuda: "__threadfence_block()".to_string(),
250 metal: "threadgroup_barrier(mem_flags::mem_device)".to_string(),
251 wgsl: "storageBarrier()".to_string(),
252 },
253 IntrinsicMap {
254 rust_name: "thread_fence".to_string(),
255 cuda: "__threadfence()".to_string(),
256 metal: "threadgroup_barrier(mem_flags::mem_device)".to_string(),
257 wgsl: "storageBarrier()".to_string(),
258 },
259 IntrinsicMap {
260 rust_name: "atomic_add".to_string(),
261 cuda: "atomicAdd".to_string(),
262 metal: "atomic_fetch_add_explicit".to_string(),
263 wgsl: "atomicAdd".to_string(),
264 },
265 IntrinsicMap {
266 rust_name: "atomic_cas".to_string(),
267 cuda: "atomicCAS".to_string(),
268 metal: "atomic_compare_exchange_weak_explicit".to_string(),
269 wgsl: "atomicCompareExchangeWeak".to_string(),
270 },
271 ]
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 #[test]
279 fn test_code_generator() {
280 let gen = CodeGenerator::new();
281 let source = gen
282 .generate_kernel_source("test_kernel", "// test code", Target::Cuda)
283 .unwrap();
284
285 assert!(source.contains("test_kernel") || source.contains("ring_kernel"));
286 }
287
288 #[test]
289 fn test_target_extension() {
290 assert_eq!(Target::Cuda.extension(), "ptx");
291 assert_eq!(Target::Metal.extension(), "metal");
292 assert_eq!(Target::Wgsl.extension(), "wgsl");
293 }
294
295 #[test]
296 fn test_intrinsic_mapping() {
297 let intrinsics = standard_intrinsics();
298 let sync = intrinsics
299 .iter()
300 .find(|i| i.rust_name == "sync_threads")
301 .unwrap();
302
303 assert_eq!(sync.get(Target::Cuda), "__syncthreads()");
304 assert!(sync.get(Target::Metal).contains("barrier"));
305 }
306}