Skip to main content

oxigdal_gpu_advanced/shader_compiler/
mod.rs

1//! WGSL shader compiler and optimizer.
2
3pub mod analyzer;
4pub mod cache;
5pub mod optimizer;
6
7use crate::error::{GpuAdvancedError, Result};
8use blake3::Hash;
9use naga::{Module, valid::Validator};
10use parking_lot::RwLock;
11use std::collections::HashMap;
12use std::sync::Arc;
13
14/// Shader compilation result
15pub struct CompiledShader {
16    /// Original source code
17    pub source: String,
18    /// Naga module
19    pub module: Module,
20    /// Entry points
21    pub entry_points: Vec<String>,
22    /// Compilation hash
23    pub hash: Hash,
24    /// Whether shader was optimized
25    pub optimized: bool,
26}
27
28/// Shader compiler with optimization and caching
29pub struct ShaderCompiler {
30    /// Shader cache
31    cache: Arc<cache::ShaderCache>,
32    /// Optimizer
33    optimizer: Arc<optimizer::ShaderOptimizer>,
34    /// Compilation statistics
35    stats: Arc<RwLock<CompilerStats>>,
36}
37
38/// Compiler statistics
39#[derive(Debug, Default, Clone)]
40pub struct CompilerStats {
41    /// Total compilations
42    pub total_compilations: u64,
43    /// Cache hits
44    pub cache_hits: u64,
45    /// Cache misses
46    pub cache_misses: u64,
47    /// Optimizations performed
48    pub optimizations: u64,
49    /// Validation failures
50    pub validation_failures: u64,
51}
52
53impl ShaderCompiler {
54    /// Create a new shader compiler
55    pub fn new() -> Self {
56        Self {
57            cache: Arc::new(cache::ShaderCache::new(1000)),
58            optimizer: Arc::new(optimizer::ShaderOptimizer::new()),
59            stats: Arc::new(RwLock::new(CompilerStats::default())),
60        }
61    }
62
63    /// Compile WGSL source code
64    pub fn compile(&self, source: &str) -> Result<CompiledShader> {
65        // Update stats
66        {
67            let mut stats = self.stats.write();
68            stats.total_compilations += 1;
69        }
70
71        // Calculate hash
72        let hash = blake3::hash(source.as_bytes());
73
74        // Check cache
75        if let Some(cached) = self.cache.get(&hash) {
76            let mut stats = self.stats.write();
77            stats.cache_hits += 1;
78            return Ok(cached);
79        }
80
81        // Cache miss
82        {
83            let mut stats = self.stats.write();
84            stats.cache_misses += 1;
85        }
86
87        // Parse WGSL
88        let module = naga::front::wgsl::parse_str(source).map_err(|e| {
89            GpuAdvancedError::ShaderCompilerError(format!("WGSL parse error: {:?}", e))
90        })?;
91
92        // Validate
93        let mut validator = Validator::new(
94            naga::valid::ValidationFlags::all(),
95            naga::valid::Capabilities::all(),
96        );
97
98        let _module_info = validator.validate(&module).map_err(|e| {
99            let mut stats = self.stats.write();
100            stats.validation_failures += 1;
101            GpuAdvancedError::ShaderValidationError(format!("Validation error: {:?}", e))
102        })?;
103
104        // Extract entry points
105        let entry_points: Vec<String> = module
106            .entry_points
107            .iter()
108            .map(|ep| ep.name.clone())
109            .collect();
110
111        let compiled = CompiledShader {
112            source: source.to_string(),
113            module,
114            entry_points,
115            hash,
116            optimized: false,
117        };
118
119        // Cache the result
120        self.cache.insert(hash, compiled.clone());
121
122        Ok(compiled)
123    }
124
125    /// Compile and optimize
126    pub fn compile_optimized(&self, source: &str) -> Result<CompiledShader> {
127        let mut compiled = self.compile(source)?;
128
129        // Apply optimizations
130        compiled.module = self.optimizer.optimize(&compiled.module)?;
131        compiled.optimized = true;
132
133        // Update stats
134        {
135            let mut stats = self.stats.write();
136            stats.optimizations += 1;
137        }
138
139        Ok(compiled)
140    }
141
142    /// Validate shader without compilation
143    pub fn validate(&self, source: &str) -> Result<()> {
144        let module = naga::front::wgsl::parse_str(source).map_err(|e| {
145            GpuAdvancedError::ShaderCompilerError(format!("WGSL parse error: {:?}", e))
146        })?;
147
148        let mut validator = Validator::new(
149            naga::valid::ValidationFlags::all(),
150            naga::valid::Capabilities::all(),
151        );
152
153        validator.validate(&module).map_err(|e| {
154            GpuAdvancedError::ShaderValidationError(format!("Validation error: {:?}", e))
155        })?;
156
157        Ok(())
158    }
159
160    /// Get compiler statistics
161    pub fn get_stats(&self) -> CompilerStats {
162        self.stats.read().clone()
163    }
164
165    /// Print compiler statistics
166    pub fn print_stats(&self) {
167        let stats = self.stats.read();
168        println!("\nShader Compiler Statistics:");
169        println!("  Total compilations: {}", stats.total_compilations);
170        println!(
171            "  Cache hits: {} ({:.1}%)",
172            stats.cache_hits,
173            if stats.total_compilations > 0 {
174                (stats.cache_hits as f64 / stats.total_compilations as f64) * 100.0
175            } else {
176                0.0
177            }
178        );
179        println!("  Cache misses: {}", stats.cache_misses);
180        println!("  Optimizations: {}", stats.optimizations);
181        println!("  Validation failures: {}", stats.validation_failures);
182    }
183
184    /// Clear cache
185    pub fn clear_cache(&self) {
186        self.cache.clear();
187    }
188
189    /// Get cache
190    pub fn cache(&self) -> Arc<cache::ShaderCache> {
191        self.cache.clone()
192    }
193
194    /// Get optimizer
195    pub fn optimizer(&self) -> Arc<optimizer::ShaderOptimizer> {
196        self.optimizer.clone()
197    }
198}
199
200impl Default for ShaderCompiler {
201    fn default() -> Self {
202        Self::new()
203    }
204}
205
206impl Clone for CompiledShader {
207    fn clone(&self) -> Self {
208        Self {
209            source: self.source.clone(),
210            module: self.module.clone(),
211            entry_points: self.entry_points.clone(),
212            hash: self.hash,
213            optimized: self.optimized,
214        }
215    }
216}
217
218/// Shader preprocessor for macro expansion
219pub struct ShaderPreprocessor {
220    /// Defined macros
221    defines: HashMap<String, String>,
222}
223
224impl ShaderPreprocessor {
225    /// Create a new preprocessor
226    pub fn new() -> Self {
227        Self {
228            defines: HashMap::new(),
229        }
230    }
231
232    /// Define a macro
233    pub fn define(&mut self, name: impl Into<String>, value: impl Into<String>) {
234        self.defines.insert(name.into(), value.into());
235    }
236
237    /// Undefine a macro
238    pub fn undefine(&mut self, name: &str) {
239        self.defines.remove(name);
240    }
241
242    /// Preprocess source code
243    pub fn preprocess(&self, source: &str) -> String {
244        let mut result = source.to_string();
245
246        // Simple macro replacement
247        for (name, value) in &self.defines {
248            let pattern = format!("${}", name);
249            result = result.replace(&pattern, value);
250        }
251
252        result
253    }
254}
255
256impl Default for ShaderPreprocessor {
257    fn default() -> Self {
258        Self::new()
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265
266    #[test]
267    fn test_shader_preprocessor() {
268        let mut preprocessor = ShaderPreprocessor::new();
269        preprocessor.define("WORKGROUP_SIZE", "64");
270
271        let source = "@compute @workgroup_size($WORKGROUP_SIZE, 1, 1)\nfn main() {}";
272        let result = preprocessor.preprocess(source);
273
274        assert!(result.contains("64"));
275    }
276
277    #[test]
278    fn test_compiler_creation() {
279        let compiler = ShaderCompiler::new();
280        let stats = compiler.get_stats();
281        assert_eq!(stats.total_compilations, 0);
282    }
283
284    #[test]
285    fn test_simple_shader_compilation() {
286        let compiler = ShaderCompiler::new();
287        let source = r#"
288@compute @workgroup_size(1, 1, 1)
289fn main() {
290    // Empty compute shader
291}
292        "#;
293
294        let result = compiler.compile(source);
295        assert!(result.is_ok());
296
297        if let Ok(compiled) = result {
298            assert!(compiled.entry_points.contains(&"main".to_string()));
299        }
300    }
301}