oxigdal_gpu_advanced/shader_compiler/
mod.rs1pub mod analyzer;
4pub mod cache;
5pub mod optimizer;
6
7use crate::error::{GpuAdvancedError, Result};
8use blake3::Hash;
9use naga::{Module, valid::Validator};
10use parking_lot::RwLock;
11use std::collections::HashMap;
12use std::sync::Arc;
13
14pub struct CompiledShader {
16 pub source: String,
18 pub module: Module,
20 pub entry_points: Vec<String>,
22 pub hash: Hash,
24 pub optimized: bool,
26}
27
28pub struct ShaderCompiler {
30 cache: Arc<cache::ShaderCache>,
32 optimizer: Arc<optimizer::ShaderOptimizer>,
34 stats: Arc<RwLock<CompilerStats>>,
36}
37
38#[derive(Debug, Default, Clone)]
40pub struct CompilerStats {
41 pub total_compilations: u64,
43 pub cache_hits: u64,
45 pub cache_misses: u64,
47 pub optimizations: u64,
49 pub validation_failures: u64,
51}
52
53impl ShaderCompiler {
54 pub fn new() -> Self {
56 Self {
57 cache: Arc::new(cache::ShaderCache::new(1000)),
58 optimizer: Arc::new(optimizer::ShaderOptimizer::new()),
59 stats: Arc::new(RwLock::new(CompilerStats::default())),
60 }
61 }
62
63 pub fn compile(&self, source: &str) -> Result<CompiledShader> {
65 {
67 let mut stats = self.stats.write();
68 stats.total_compilations += 1;
69 }
70
71 let hash = blake3::hash(source.as_bytes());
73
74 if let Some(cached) = self.cache.get(&hash) {
76 let mut stats = self.stats.write();
77 stats.cache_hits += 1;
78 return Ok(cached);
79 }
80
81 {
83 let mut stats = self.stats.write();
84 stats.cache_misses += 1;
85 }
86
87 let module = naga::front::wgsl::parse_str(source).map_err(|e| {
89 GpuAdvancedError::ShaderCompilerError(format!("WGSL parse error: {:?}", e))
90 })?;
91
92 let mut validator = Validator::new(
94 naga::valid::ValidationFlags::all(),
95 naga::valid::Capabilities::all(),
96 );
97
98 let _module_info = validator.validate(&module).map_err(|e| {
99 let mut stats = self.stats.write();
100 stats.validation_failures += 1;
101 GpuAdvancedError::ShaderValidationError(format!("Validation error: {:?}", e))
102 })?;
103
104 let entry_points: Vec<String> = module
106 .entry_points
107 .iter()
108 .map(|ep| ep.name.clone())
109 .collect();
110
111 let compiled = CompiledShader {
112 source: source.to_string(),
113 module,
114 entry_points,
115 hash,
116 optimized: false,
117 };
118
119 self.cache.insert(hash, compiled.clone());
121
122 Ok(compiled)
123 }
124
125 pub fn compile_optimized(&self, source: &str) -> Result<CompiledShader> {
127 let mut compiled = self.compile(source)?;
128
129 compiled.module = self.optimizer.optimize(&compiled.module)?;
131 compiled.optimized = true;
132
133 {
135 let mut stats = self.stats.write();
136 stats.optimizations += 1;
137 }
138
139 Ok(compiled)
140 }
141
142 pub fn validate(&self, source: &str) -> Result<()> {
144 let module = naga::front::wgsl::parse_str(source).map_err(|e| {
145 GpuAdvancedError::ShaderCompilerError(format!("WGSL parse error: {:?}", e))
146 })?;
147
148 let mut validator = Validator::new(
149 naga::valid::ValidationFlags::all(),
150 naga::valid::Capabilities::all(),
151 );
152
153 validator.validate(&module).map_err(|e| {
154 GpuAdvancedError::ShaderValidationError(format!("Validation error: {:?}", e))
155 })?;
156
157 Ok(())
158 }
159
160 pub fn get_stats(&self) -> CompilerStats {
162 self.stats.read().clone()
163 }
164
165 pub fn print_stats(&self) {
167 let stats = self.stats.read();
168 println!("\nShader Compiler Statistics:");
169 println!(" Total compilations: {}", stats.total_compilations);
170 println!(
171 " Cache hits: {} ({:.1}%)",
172 stats.cache_hits,
173 if stats.total_compilations > 0 {
174 (stats.cache_hits as f64 / stats.total_compilations as f64) * 100.0
175 } else {
176 0.0
177 }
178 );
179 println!(" Cache misses: {}", stats.cache_misses);
180 println!(" Optimizations: {}", stats.optimizations);
181 println!(" Validation failures: {}", stats.validation_failures);
182 }
183
184 pub fn clear_cache(&self) {
186 self.cache.clear();
187 }
188
189 pub fn cache(&self) -> Arc<cache::ShaderCache> {
191 self.cache.clone()
192 }
193
194 pub fn optimizer(&self) -> Arc<optimizer::ShaderOptimizer> {
196 self.optimizer.clone()
197 }
198}
199
200impl Default for ShaderCompiler {
201 fn default() -> Self {
202 Self::new()
203 }
204}
205
206impl Clone for CompiledShader {
207 fn clone(&self) -> Self {
208 Self {
209 source: self.source.clone(),
210 module: self.module.clone(),
211 entry_points: self.entry_points.clone(),
212 hash: self.hash,
213 optimized: self.optimized,
214 }
215 }
216}
217
218pub struct ShaderPreprocessor {
220 defines: HashMap<String, String>,
222}
223
224impl ShaderPreprocessor {
225 pub fn new() -> Self {
227 Self {
228 defines: HashMap::new(),
229 }
230 }
231
232 pub fn define(&mut self, name: impl Into<String>, value: impl Into<String>) {
234 self.defines.insert(name.into(), value.into());
235 }
236
237 pub fn undefine(&mut self, name: &str) {
239 self.defines.remove(name);
240 }
241
242 pub fn preprocess(&self, source: &str) -> String {
244 let mut result = source.to_string();
245
246 for (name, value) in &self.defines {
248 let pattern = format!("${}", name);
249 result = result.replace(&pattern, value);
250 }
251
252 result
253 }
254}
255
256impl Default for ShaderPreprocessor {
257 fn default() -> Self {
258 Self::new()
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 #[test]
267 fn test_shader_preprocessor() {
268 let mut preprocessor = ShaderPreprocessor::new();
269 preprocessor.define("WORKGROUP_SIZE", "64");
270
271 let source = "@compute @workgroup_size($WORKGROUP_SIZE, 1, 1)\nfn main() {}";
272 let result = preprocessor.preprocess(source);
273
274 assert!(result.contains("64"));
275 }
276
277 #[test]
278 fn test_compiler_creation() {
279 let compiler = ShaderCompiler::new();
280 let stats = compiler.get_stats();
281 assert_eq!(stats.total_compilations, 0);
282 }
283
284 #[test]
285 fn test_simple_shader_compilation() {
286 let compiler = ShaderCompiler::new();
287 let source = r#"
288@compute @workgroup_size(1, 1, 1)
289fn main() {
290 // Empty compute shader
291}
292 "#;
293
294 let result = compiler.compile(source);
295 assert!(result.is_ok());
296
297 if let Ok(compiled) = result {
298 assert!(compiled.entry_points.contains(&"main".to_string()));
299 }
300 }
301}