pub mod analyzer;
pub mod cache;
pub mod optimizer;
use crate::error::{GpuAdvancedError, Result};
use blake3::Hash;
use naga::{Module, valid::Validator};
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
pub struct CompiledShader {
pub source: String,
pub module: Module,
pub entry_points: Vec<String>,
pub hash: Hash,
pub optimized: bool,
}
pub struct ShaderCompiler {
cache: Arc<cache::ShaderCache>,
optimizer: Arc<optimizer::ShaderOptimizer>,
stats: Arc<RwLock<CompilerStats>>,
}
#[derive(Debug, Default, Clone)]
pub struct CompilerStats {
pub total_compilations: u64,
pub cache_hits: u64,
pub cache_misses: u64,
pub optimizations: u64,
pub validation_failures: u64,
}
impl ShaderCompiler {
pub fn new() -> Self {
Self {
cache: Arc::new(cache::ShaderCache::new(1000)),
optimizer: Arc::new(optimizer::ShaderOptimizer::new()),
stats: Arc::new(RwLock::new(CompilerStats::default())),
}
}
pub fn compile(&self, source: &str) -> Result<CompiledShader> {
{
let mut stats = self.stats.write();
stats.total_compilations += 1;
}
let hash = blake3::hash(source.as_bytes());
if let Some(cached) = self.cache.get(&hash) {
let mut stats = self.stats.write();
stats.cache_hits += 1;
return Ok(cached);
}
{
let mut stats = self.stats.write();
stats.cache_misses += 1;
}
let module = naga::front::wgsl::parse_str(source).map_err(|e| {
GpuAdvancedError::ShaderCompilerError(format!("WGSL parse error: {:?}", e))
})?;
let mut validator = Validator::new(
naga::valid::ValidationFlags::all(),
naga::valid::Capabilities::all(),
);
let _module_info = validator.validate(&module).map_err(|e| {
let mut stats = self.stats.write();
stats.validation_failures += 1;
GpuAdvancedError::ShaderValidationError(format!("Validation error: {:?}", e))
})?;
let entry_points: Vec<String> = module
.entry_points
.iter()
.map(|ep| ep.name.clone())
.collect();
let compiled = CompiledShader {
source: source.to_string(),
module,
entry_points,
hash,
optimized: false,
};
self.cache.insert(hash, compiled.clone());
Ok(compiled)
}
pub fn compile_optimized(&self, source: &str) -> Result<CompiledShader> {
let mut compiled = self.compile(source)?;
compiled.module = self.optimizer.optimize(&compiled.module)?;
compiled.optimized = true;
{
let mut stats = self.stats.write();
stats.optimizations += 1;
}
Ok(compiled)
}
pub fn validate(&self, source: &str) -> Result<()> {
let module = naga::front::wgsl::parse_str(source).map_err(|e| {
GpuAdvancedError::ShaderCompilerError(format!("WGSL parse error: {:?}", e))
})?;
let mut validator = Validator::new(
naga::valid::ValidationFlags::all(),
naga::valid::Capabilities::all(),
);
validator.validate(&module).map_err(|e| {
GpuAdvancedError::ShaderValidationError(format!("Validation error: {:?}", e))
})?;
Ok(())
}
pub fn get_stats(&self) -> CompilerStats {
self.stats.read().clone()
}
pub fn print_stats(&self) {
let stats = self.stats.read();
println!("\nShader Compiler Statistics:");
println!(" Total compilations: {}", stats.total_compilations);
println!(
" Cache hits: {} ({:.1}%)",
stats.cache_hits,
if stats.total_compilations > 0 {
(stats.cache_hits as f64 / stats.total_compilations as f64) * 100.0
} else {
0.0
}
);
println!(" Cache misses: {}", stats.cache_misses);
println!(" Optimizations: {}", stats.optimizations);
println!(" Validation failures: {}", stats.validation_failures);
}
pub fn clear_cache(&self) {
self.cache.clear();
}
pub fn cache(&self) -> Arc<cache::ShaderCache> {
self.cache.clone()
}
pub fn optimizer(&self) -> Arc<optimizer::ShaderOptimizer> {
self.optimizer.clone()
}
}
impl Default for ShaderCompiler {
fn default() -> Self {
Self::new()
}
}
impl Clone for CompiledShader {
fn clone(&self) -> Self {
Self {
source: self.source.clone(),
module: self.module.clone(),
entry_points: self.entry_points.clone(),
hash: self.hash,
optimized: self.optimized,
}
}
}
pub struct ShaderPreprocessor {
defines: HashMap<String, String>,
}
impl ShaderPreprocessor {
pub fn new() -> Self {
Self {
defines: HashMap::new(),
}
}
pub fn define(&mut self, name: impl Into<String>, value: impl Into<String>) {
self.defines.insert(name.into(), value.into());
}
pub fn undefine(&mut self, name: &str) {
self.defines.remove(name);
}
pub fn preprocess(&self, source: &str) -> String {
let mut result = source.to_string();
for (name, value) in &self.defines {
let pattern = format!("${}", name);
result = result.replace(&pattern, value);
}
result
}
}
impl Default for ShaderPreprocessor {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shader_preprocessor() {
let mut preprocessor = ShaderPreprocessor::new();
preprocessor.define("WORKGROUP_SIZE", "64");
let source = "@compute @workgroup_size($WORKGROUP_SIZE, 1, 1)\nfn main() {}";
let result = preprocessor.preprocess(source);
assert!(result.contains("64"));
}
#[test]
fn test_compiler_creation() {
let compiler = ShaderCompiler::new();
let stats = compiler.get_stats();
assert_eq!(stats.total_compilations, 0);
}
#[test]
fn test_simple_shader_compilation() {
let compiler = ShaderCompiler::new();
let source = r#"
@compute @workgroup_size(1, 1, 1)
fn main() {
// Empty compute shader
}
"#;
let result = compiler.compile(source);
assert!(result.is_ok());
if let Ok(compiled) = result {
assert!(compiled.entry_points.contains(&"main".to_string()));
}
}
}