Skip to main content

oximedia_gpu/
compiler.rs

1//! Runtime shader compilation and management
2//!
3//! This module provides utilities for compiling shaders at runtime,
4//! including error handling, preprocessor support, and optimization.
5
6use crate::{GpuDevice, GpuError, Result};
7use std::collections::HashMap;
8use wgpu::ShaderModule;
9
10/// Shader compilation error
11#[derive(Debug, Clone)]
12pub struct CompilationError {
13    /// Error message
14    pub message: String,
15    /// Line number where the error occurred
16    pub line: Option<usize>,
17    /// Column number where the error occurred
18    pub column: Option<usize>,
19}
20
21impl CompilationError {
22    /// Create a new compilation error
23    pub fn new(message: impl Into<String>) -> Self {
24        Self {
25            message: message.into(),
26            line: None,
27            column: None,
28        }
29    }
30
31    /// Set the line number
32    #[must_use]
33    pub fn with_line(mut self, line: usize) -> Self {
34        self.line = Some(line);
35        self
36    }
37
38    /// Set the column number
39    #[must_use]
40    pub fn with_column(mut self, column: usize) -> Self {
41        self.column = Some(column);
42        self
43    }
44}
45
46impl std::fmt::Display for CompilationError {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        if let (Some(line), Some(column)) = (self.line, self.column) {
49            write!(f, "{}:{}: {}", line, column, self.message)
50        } else if let Some(line) = self.line {
51            write!(f, "Line {}: {}", line, self.message)
52        } else {
53            write!(f, "{}", self.message)
54        }
55    }
56}
57
58/// Shader source type
59#[derive(Debug, Clone)]
60pub enum ShaderSourceType {
61    /// WGSL source code
62    WGSL,
63    /// SPIR-V binary
64    SPIRV,
65    /// GLSL source code (requires translation to WGSL)
66    GLSL,
67}
68
69/// Shader preprocessor for handling defines and includes
70pub struct ShaderPreprocessor {
71    defines: HashMap<String, String>,
72}
73
74impl ShaderPreprocessor {
75    /// Create a new shader preprocessor
76    #[must_use]
77    pub fn new() -> Self {
78        Self {
79            defines: HashMap::new(),
80        }
81    }
82
83    /// Add a preprocessor define
84    pub fn define(&mut self, name: impl Into<String>, value: impl Into<String>) {
85        self.defines.insert(name.into(), value.into());
86    }
87
88    /// Process shader source with preprocessor directives
89    pub fn process(&self, source: &str) -> Result<String> {
90        let mut output = String::new();
91        let mut lines = source.lines();
92
93        while let Some(line) = lines.next() {
94            let trimmed = line.trim();
95
96            // Handle #define directives
97            if trimmed.starts_with("#define") {
98                let parts: Vec<&str> = trimmed.split_whitespace().collect();
99                if parts.len() >= 2 {
100                    let name = parts[1];
101                    let _value = parts.get(2).unwrap_or(&"1");
102                    if let Some(defined_value) = self.defines.get(name) {
103                        output.push_str(&format!("#define {name} {defined_value}\n"));
104                    } else {
105                        output.push_str(line);
106                        output.push('\n');
107                    }
108                } else {
109                    output.push_str(line);
110                    output.push('\n');
111                }
112            }
113            // Handle #ifdef directives
114            else if trimmed.starts_with("#ifdef") {
115                let parts: Vec<&str> = trimmed.split_whitespace().collect();
116                if parts.len() >= 2 {
117                    let name = parts[1];
118                    if self.defines.contains_key(name) {
119                        // Include this block
120                        continue;
121                    }
122                    // Skip until #endif
123                    for inner_line in lines.by_ref() {
124                        if inner_line.trim().starts_with("#endif") {
125                            break;
126                        }
127                    }
128                    continue;
129                }
130                output.push_str(line);
131                output.push('\n');
132            }
133            // Pass through other lines
134            else {
135                output.push_str(line);
136                output.push('\n');
137            }
138        }
139
140        Ok(output)
141    }
142
143    /// Get all defines
144    #[must_use]
145    pub fn defines(&self) -> &HashMap<String, String> {
146        &self.defines
147    }
148}
149
150impl Default for ShaderPreprocessor {
151    fn default() -> Self {
152        Self::new()
153    }
154}
155
156/// Shader compiler with caching and optimization
157pub struct ShaderCompiler {
158    device: std::sync::Arc<wgpu::Device>,
159    preprocessor: ShaderPreprocessor,
160}
161
162impl ShaderCompiler {
163    /// Create a new shader compiler
164    #[must_use]
165    pub fn new(device: &GpuDevice) -> Self {
166        Self {
167            device: std::sync::Arc::clone(device.device()),
168            preprocessor: ShaderPreprocessor::new(),
169        }
170    }
171
172    /// Compile WGSL shader source
173    ///
174    /// # Arguments
175    ///
176    /// * `label` - Shader label for debugging
177    /// * `source` - WGSL source code
178    ///
179    /// # Errors
180    ///
181    /// Returns an error if compilation fails.
182    pub fn compile_wgsl(&self, label: &str, source: &str) -> Result<ShaderModule> {
183        // Process with preprocessor
184        let processed_source = self.preprocessor.process(source)?;
185
186        // Compile the shader
187        Ok(self
188            .device
189            .create_shader_module(wgpu::ShaderModuleDescriptor {
190                label: Some(label),
191                source: wgpu::ShaderSource::Wgsl(processed_source.into()),
192            }))
193    }
194
195    /// Compile shader from file
196    ///
197    /// # Arguments
198    ///
199    /// * `label` - Shader label for debugging
200    /// * `path` - Path to shader file
201    ///
202    /// # Errors
203    ///
204    /// Returns an error if file reading or compilation fails.
205    pub fn compile_file(
206        &self,
207        label: &str,
208        path: impl AsRef<std::path::Path>,
209    ) -> Result<ShaderModule> {
210        let source = std::fs::read_to_string(path.as_ref())
211            .map_err(|e| GpuError::ShaderCompilation(format!("Failed to read shader file: {e}")))?;
212
213        self.compile_wgsl(label, &source)
214    }
215
216    /// Get the preprocessor
217    #[must_use]
218    pub fn preprocessor(&self) -> &ShaderPreprocessor {
219        &self.preprocessor
220    }
221
222    /// Get a mutable reference to the preprocessor
223    pub fn preprocessor_mut(&mut self) -> &mut ShaderPreprocessor {
224        &mut self.preprocessor
225    }
226
227    /// Validate shader source without compiling
228    ///
229    /// # Arguments
230    ///
231    /// * `source` - Shader source code
232    ///
233    /// # Returns
234    ///
235    /// Ok(()) if the shader is valid, Err otherwise
236    pub fn validate(&self, source: &str) -> Result<()> {
237        // wgpu performs validation during shader module creation
238        // We can create a temporary shader module to validate
239        let _ = self
240            .device
241            .create_shader_module(wgpu::ShaderModuleDescriptor {
242                label: Some("Validation"),
243                source: wgpu::ShaderSource::Wgsl(source.into()),
244            });
245
246        Ok(())
247    }
248}
249
250/// Shader optimization level
251#[derive(Debug, Clone, Copy, PartialEq, Eq)]
252pub enum OptimizationLevel {
253    /// No optimization
254    None,
255    /// Basic optimization
256    Basic,
257    /// Full optimization
258    Full,
259}
260
261/// Shader compilation options
262pub struct CompilationOptions {
263    /// Optimization level
264    pub optimization: OptimizationLevel,
265    /// Enable debug information
266    pub debug_info: bool,
267    /// Preprocessor defines
268    pub defines: HashMap<String, String>,
269}
270
271impl Default for CompilationOptions {
272    fn default() -> Self {
273        Self {
274            optimization: OptimizationLevel::Basic,
275            debug_info: false,
276            defines: HashMap::new(),
277        }
278    }
279}
280
281impl CompilationOptions {
282    /// Create new compilation options
283    #[must_use]
284    pub fn new() -> Self {
285        Self::default()
286    }
287
288    /// Set optimization level
289    #[must_use]
290    pub fn with_optimization(mut self, level: OptimizationLevel) -> Self {
291        self.optimization = level;
292        self
293    }
294
295    /// Enable debug information
296    #[must_use]
297    pub fn with_debug_info(mut self, enabled: bool) -> Self {
298        self.debug_info = enabled;
299        self
300    }
301
302    /// Add a preprocessor define
303    pub fn with_define(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
304        self.defines.insert(name.into(), value.into());
305        self
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    #[test]
314    fn test_preprocessor() {
315        let mut preprocessor = ShaderPreprocessor::new();
316        preprocessor.define("WORKGROUP_SIZE", "256");
317
318        let source = "#define WORKGROUP_SIZE 64\nfn main() {}";
319        let processed = preprocessor.process(source).unwrap();
320
321        assert!(processed.contains("256"));
322    }
323
324    #[test]
325    fn test_compilation_error() {
326        let error = CompilationError::new("Syntax error")
327            .with_line(42)
328            .with_column(10);
329
330        assert_eq!(error.line, Some(42));
331        assert_eq!(error.column, Some(10));
332
333        let formatted = format!("{error}");
334        assert!(formatted.contains("42"));
335        assert!(formatted.contains("10"));
336    }
337
338    #[test]
339    fn test_compilation_options() {
340        let options = CompilationOptions::new()
341            .with_optimization(OptimizationLevel::Full)
342            .with_debug_info(true)
343            .with_define("TEST", "1");
344
345        assert_eq!(options.optimization, OptimizationLevel::Full);
346        assert!(options.debug_info);
347        assert_eq!(options.defines.get("TEST"), Some(&"1".to_string()));
348    }
349}