1use crate::{GpuDevice, GpuError, Result};
7use std::collections::HashMap;
8use wgpu::ShaderModule;
9
10#[derive(Debug, Clone)]
12pub struct CompilationError {
13 pub message: String,
15 pub line: Option<usize>,
17 pub column: Option<usize>,
19}
20
21impl CompilationError {
22 pub fn new(message: impl Into<String>) -> Self {
24 Self {
25 message: message.into(),
26 line: None,
27 column: None,
28 }
29 }
30
31 #[must_use]
33 pub fn with_line(mut self, line: usize) -> Self {
34 self.line = Some(line);
35 self
36 }
37
38 #[must_use]
40 pub fn with_column(mut self, column: usize) -> Self {
41 self.column = Some(column);
42 self
43 }
44}
45
46impl std::fmt::Display for CompilationError {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 if let (Some(line), Some(column)) = (self.line, self.column) {
49 write!(f, "{}:{}: {}", line, column, self.message)
50 } else if let Some(line) = self.line {
51 write!(f, "Line {}: {}", line, self.message)
52 } else {
53 write!(f, "{}", self.message)
54 }
55 }
56}
57
58#[derive(Debug, Clone)]
60pub enum ShaderSourceType {
61 WGSL,
63 SPIRV,
65 GLSL,
67}
68
69pub struct ShaderPreprocessor {
71 defines: HashMap<String, String>,
72}
73
74impl ShaderPreprocessor {
75 #[must_use]
77 pub fn new() -> Self {
78 Self {
79 defines: HashMap::new(),
80 }
81 }
82
83 pub fn define(&mut self, name: impl Into<String>, value: impl Into<String>) {
85 self.defines.insert(name.into(), value.into());
86 }
87
88 pub fn process(&self, source: &str) -> Result<String> {
90 let mut output = String::new();
91 let mut lines = source.lines();
92
93 while let Some(line) = lines.next() {
94 let trimmed = line.trim();
95
96 if trimmed.starts_with("#define") {
98 let parts: Vec<&str> = trimmed.split_whitespace().collect();
99 if parts.len() >= 2 {
100 let name = parts[1];
101 let _value = parts.get(2).unwrap_or(&"1");
102 if let Some(defined_value) = self.defines.get(name) {
103 output.push_str(&format!("#define {name} {defined_value}\n"));
104 } else {
105 output.push_str(line);
106 output.push('\n');
107 }
108 } else {
109 output.push_str(line);
110 output.push('\n');
111 }
112 }
113 else if trimmed.starts_with("#ifdef") {
115 let parts: Vec<&str> = trimmed.split_whitespace().collect();
116 if parts.len() >= 2 {
117 let name = parts[1];
118 if self.defines.contains_key(name) {
119 continue;
121 }
122 for inner_line in lines.by_ref() {
124 if inner_line.trim().starts_with("#endif") {
125 break;
126 }
127 }
128 continue;
129 }
130 output.push_str(line);
131 output.push('\n');
132 }
133 else {
135 output.push_str(line);
136 output.push('\n');
137 }
138 }
139
140 Ok(output)
141 }
142
143 #[must_use]
145 pub fn defines(&self) -> &HashMap<String, String> {
146 &self.defines
147 }
148}
149
150impl Default for ShaderPreprocessor {
151 fn default() -> Self {
152 Self::new()
153 }
154}
155
156pub struct ShaderCompiler {
158 device: std::sync::Arc<wgpu::Device>,
159 preprocessor: ShaderPreprocessor,
160}
161
162impl ShaderCompiler {
163 #[must_use]
165 pub fn new(device: &GpuDevice) -> Self {
166 Self {
167 device: std::sync::Arc::clone(device.device()),
168 preprocessor: ShaderPreprocessor::new(),
169 }
170 }
171
172 pub fn compile_wgsl(&self, label: &str, source: &str) -> Result<ShaderModule> {
183 let processed_source = self.preprocessor.process(source)?;
185
186 Ok(self
188 .device
189 .create_shader_module(wgpu::ShaderModuleDescriptor {
190 label: Some(label),
191 source: wgpu::ShaderSource::Wgsl(processed_source.into()),
192 }))
193 }
194
195 pub fn compile_file(
206 &self,
207 label: &str,
208 path: impl AsRef<std::path::Path>,
209 ) -> Result<ShaderModule> {
210 let source = std::fs::read_to_string(path.as_ref())
211 .map_err(|e| GpuError::ShaderCompilation(format!("Failed to read shader file: {e}")))?;
212
213 self.compile_wgsl(label, &source)
214 }
215
216 #[must_use]
218 pub fn preprocessor(&self) -> &ShaderPreprocessor {
219 &self.preprocessor
220 }
221
222 pub fn preprocessor_mut(&mut self) -> &mut ShaderPreprocessor {
224 &mut self.preprocessor
225 }
226
227 pub fn validate(&self, source: &str) -> Result<()> {
237 let _ = self
240 .device
241 .create_shader_module(wgpu::ShaderModuleDescriptor {
242 label: Some("Validation"),
243 source: wgpu::ShaderSource::Wgsl(source.into()),
244 });
245
246 Ok(())
247 }
248}
249
250#[derive(Debug, Clone, Copy, PartialEq, Eq)]
252pub enum OptimizationLevel {
253 None,
255 Basic,
257 Full,
259}
260
261pub struct CompilationOptions {
263 pub optimization: OptimizationLevel,
265 pub debug_info: bool,
267 pub defines: HashMap<String, String>,
269}
270
271impl Default for CompilationOptions {
272 fn default() -> Self {
273 Self {
274 optimization: OptimizationLevel::Basic,
275 debug_info: false,
276 defines: HashMap::new(),
277 }
278 }
279}
280
281impl CompilationOptions {
282 #[must_use]
284 pub fn new() -> Self {
285 Self::default()
286 }
287
288 #[must_use]
290 pub fn with_optimization(mut self, level: OptimizationLevel) -> Self {
291 self.optimization = level;
292 self
293 }
294
295 #[must_use]
297 pub fn with_debug_info(mut self, enabled: bool) -> Self {
298 self.debug_info = enabled;
299 self
300 }
301
302 pub fn with_define(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
304 self.defines.insert(name.into(), value.into());
305 self
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312
313 #[test]
314 fn test_preprocessor() {
315 let mut preprocessor = ShaderPreprocessor::new();
316 preprocessor.define("WORKGROUP_SIZE", "256");
317
318 let source = "#define WORKGROUP_SIZE 64\nfn main() {}";
319 let processed = preprocessor
320 .process(source)
321 .expect("preprocessing should succeed");
322
323 assert!(processed.contains("256"));
324 }
325
326 #[test]
327 fn test_compilation_error() {
328 let error = CompilationError::new("Syntax error")
329 .with_line(42)
330 .with_column(10);
331
332 assert_eq!(error.line, Some(42));
333 assert_eq!(error.column, Some(10));
334
335 let formatted = format!("{error}");
336 assert!(formatted.contains("42"));
337 assert!(formatted.contains("10"));
338 }
339
340 #[test]
341 fn test_compilation_options() {
342 let options = CompilationOptions::new()
343 .with_optimization(OptimizationLevel::Full)
344 .with_debug_info(true)
345 .with_define("TEST", "1");
346
347 assert_eq!(options.optimization, OptimizationLevel::Full);
348 assert!(options.debug_info);
349 assert_eq!(options.defines.get("TEST"), Some(&"1".to_string()));
350 }
351}