#[allow(unused_imports)]
use super::functions::*;
use std::collections::HashMap;
pub struct VariantCache {
pub capacity: usize,
pub(super) entries: Vec<(ShaderKey, CompiledVariant, u64)>,
pub(super) clock: u64,
}
impl VariantCache {
pub fn new(capacity: usize) -> Self {
Self {
capacity: capacity.max(1),
entries: Vec::new(),
clock: 0,
}
}
pub fn insert(&mut self, variant: CompiledVariant) {
if let Some(pos) = self.entries.iter().position(|(k, _, _)| *k == variant.key) {
self.clock += 1;
self.entries[pos] = (variant.key.clone(), variant, self.clock);
return;
}
if self.entries.len() >= self.capacity {
let lru_pos = self
.entries
.iter()
.enumerate()
.min_by_key(|(_, (_, _, t))| *t)
.map(|(i, _)| i)
.unwrap_or(0);
self.entries.swap_remove(lru_pos);
}
self.clock += 1;
self.entries
.push((variant.key.clone(), variant, self.clock));
}
pub fn get(&mut self, key: &ShaderKey) -> Option<&CompiledVariant> {
if let Some(pos) = self.entries.iter().position(|(k, _, _)| k == key) {
self.clock += 1;
self.entries[pos].2 = self.clock;
return Some(&self.entries[pos].1);
}
None
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn clear(&mut self) {
self.entries.clear();
self.clock = 0;
}
pub fn total_binary_bytes(&self) -> usize {
self.entries
.iter()
.map(|(_, v, _)| v.binary_size_bytes)
.sum()
}
}
#[derive(Debug, Clone)]
pub struct CompiledVariant {
pub key: ShaderKey,
pub wgsl: String,
pub workgroup_size: [u32; 3],
pub binary_size_bytes: usize,
}
impl CompiledVariant {
fn new(key: ShaderKey, wgsl: String, workgroup_size: [u32; 3]) -> Self {
let binary_size_bytes = wgsl.len();
Self {
key,
wgsl,
workgroup_size,
binary_size_bytes,
}
}
}
#[derive(Debug, Default)]
pub struct VariantProfileRegistry {
pub(super) profiles: HashMap<String, VariantProfile>,
}
impl VariantProfileRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register(&mut self, profile: VariantProfile) {
self.profiles.insert(profile.name.clone(), profile);
}
pub fn resolve(&self, name: &str) -> Option<HashMap<String, String>> {
let profile = self.profiles.get(name)?;
let mut merged = if let Some(base) = &profile.base {
self.resolve(base)?
} else {
HashMap::new()
};
for (k, v) in &profile.defines {
merged.insert(k.clone(), v.clone());
}
Some(merged)
}
pub fn profile_names(&self) -> Vec<&str> {
let mut names: Vec<&str> = self.profiles.keys().map(|s| s.as_str()).collect();
names.sort_unstable();
names
}
}
#[derive(Debug, Default)]
pub struct ShaderDependencyGraph {
pub(super) depends_on: HashMap<String, Vec<String>>,
}
impl ShaderDependencyGraph {
pub fn new() -> Self {
Self::default()
}
pub fn add_dependency(&mut self, dependent: impl Into<String>, dependency: impl Into<String>) {
self.depends_on
.entry(dependent.into())
.or_default()
.push(dependency.into());
}
pub fn direct_dependents(&self, dependency: &str) -> Vec<&str> {
self.depends_on
.iter()
.filter(|(_, deps)| deps.iter().any(|d| d == dependency))
.map(|(name, _)| name.as_str())
.collect()
}
pub fn transitive_dependents(&self, changed_shader: &str) -> Vec<String> {
let mut visited = std::collections::HashSet::new();
let mut queue = std::collections::VecDeque::new();
queue.push_back(changed_shader.to_string());
while let Some(current) = queue.pop_front() {
for (name, deps) in &self.depends_on {
if deps.contains(¤t) && !visited.contains(name.as_str()) {
visited.insert(name.clone());
queue.push_back(name.clone());
}
}
}
let mut result: Vec<String> = visited.into_iter().collect();
result.sort_unstable();
result
}
pub fn direct_dependencies(&self, shader: &str) -> &[String] {
self.depends_on
.get(shader)
.map(Vec::as_slice)
.unwrap_or(&[])
}
pub fn shaders_with_deps(&self) -> Vec<&str> {
let mut names: Vec<&str> = self.depends_on.keys().map(|s| s.as_str()).collect();
names.sort_unstable();
names
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ShaderKey {
pub name: String,
pub defines: Vec<(String, String)>,
}
impl ShaderKey {
pub fn new(name: impl Into<String>, defines: &HashMap<String, String>) -> Self {
let mut sorted: Vec<(String, String)> = defines
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
sorted.sort_by(|a, b| a.0.cmp(&b.0));
Self {
name: name.into(),
defines: sorted,
}
}
pub fn bare(name: impl Into<String>) -> Self {
Self {
name: name.into(),
defines: Vec::new(),
}
}
pub fn fingerprint(&self) -> String {
if self.defines.is_empty() {
return self.name.clone();
}
let suffix: String = self
.defines
.iter()
.map(|(k, v)| format!("{k}_{v}"))
.collect::<Vec<_>>()
.join("__");
format!("{}__{}", self.name, suffix)
}
}
#[derive(Debug, Clone)]
pub struct VariantProfile {
pub name: String,
pub defines: HashMap<String, String>,
pub base: Option<String>,
}
impl VariantProfile {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
defines: HashMap::new(),
base: None,
}
}
pub fn with_base(name: impl Into<String>, base: impl Into<String>) -> Self {
Self {
name: name.into(),
defines: HashMap::new(),
base: Some(base.into()),
}
}
pub fn set(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.defines.insert(key.into(), value.into());
self
}
}
#[derive(Debug, Default)]
pub struct HotReloadTracker {
pub(super) timestamps: HashMap<String, u64>,
pub(super) compiled_at: HashMap<String, u64>,
}
impl HotReloadTracker {
pub fn new() -> Self {
Self::default()
}
pub fn touch(&mut self, name: impl Into<String>, time: u64) {
self.timestamps.insert(name.into(), time);
}
pub fn record_compile(&mut self, name: impl Into<String>, time: u64) {
self.compiled_at.insert(name.into(), time);
}
pub fn needs_recompile(&self, name: &str) -> bool {
let modified = self.timestamps.get(name).copied().unwrap_or(0);
let compiled = self.compiled_at.get(name).copied().unwrap_or(0);
modified > compiled
}
pub fn stale_shaders(&self) -> Vec<&str> {
self.timestamps
.keys()
.filter(|n| self.needs_recompile(n))
.map(|s| s.as_str())
.collect()
}
}
impl HotReloadTracker {
pub fn touch_batch(&mut self, names: &[&str], time: u64) {
for &name in names {
self.touch(name, time);
}
}
pub fn flush_stale(&mut self, time: u64) {
let stale: Vec<String> = self
.timestamps
.keys()
.filter(|n| self.needs_recompile(n))
.cloned()
.collect();
for name in stale {
self.compiled_at.insert(name, time);
}
}
pub fn never_compiled(&self) -> Vec<&str> {
self.timestamps
.keys()
.filter(|n| !self.compiled_at.contains_key(n.as_str()))
.map(|s| s.as_str())
.collect()
}
}
pub struct ShaderRegistry {
pub(super) sources: HashMap<String, ShaderSource>,
pub(super) cache: VariantCache,
pub cache_hits: u64,
pub compilations: u64,
}
impl ShaderRegistry {
pub fn new(cache_capacity: usize) -> Self {
Self {
sources: HashMap::new(),
cache: VariantCache::new(cache_capacity),
cache_hits: 0,
compilations: 0,
}
}
pub fn register(&mut self, source: ShaderSource) {
let name = source.name.clone();
self.sources.insert(name.clone(), source);
self.cache.entries.retain(|(k, _, _)| k.name != name);
}
pub fn get_or_compile(
&mut self,
name: &str,
defines: &HashMap<String, String>,
) -> Result<&CompiledVariant, RegistryError> {
let key = ShaderKey::new(name, defines);
if self.cache.get(&key).is_some() {
self.cache_hits += 1;
return Ok(self
.cache
.get(&key)
.expect("cache entry must exist after insertion"));
}
let source = self
.sources
.get(name)
.ok_or_else(|| RegistryError::UnknownShader(name.to_string()))?;
for ph in &source.placeholders {
if !defines.contains_key(ph.as_str()) {
return Err(RegistryError::MissingDefine {
shader: name.to_string(),
define: ph.clone(),
});
}
}
let wgsl = source.instantiate(defines);
let workgroup_size = source.workgroup_size;
let variant = CompiledVariant::new(key.clone(), wgsl, workgroup_size);
self.compilations += 1;
self.cache.insert(variant);
Ok(self
.cache
.get(&key)
.expect("cache entry must exist after insertion"))
}
pub fn shader_names(&self) -> Vec<&str> {
self.sources.keys().map(|s| s.as_str()).collect()
}
pub fn cached_count(&self) -> usize {
self.cache.len()
}
pub fn invalidate(&mut self, name: &str) {
self.cache.entries.retain(|(k, _, _)| k.name != name);
}
pub fn invalidate_all(&mut self) {
self.cache.clear();
}
}
impl ShaderRegistry {
pub fn compile_variant(
&mut self,
name: &str,
defines: &HashMap<String, String>,
opts: &ShaderCompileOptions,
) -> Result<CompiledVariant, RegistryError> {
let source = self
.sources
.get(name)
.ok_or_else(|| RegistryError::UnknownShader(name.to_string()))?;
let mut merged = defines.clone();
for (k, v) in &opts.extra_defines {
merged.entry(k.clone()).or_insert_with(|| v.clone());
}
for ph in &source.placeholders {
if !merged.contains_key(ph.as_str()) {
return Err(RegistryError::MissingDefine {
shader: name.to_string(),
define: ph.clone(),
});
}
}
let wgsl = source.instantiate(&merged);
if opts.max_source_bytes > 0 && wgsl.len() > opts.max_source_bytes {
return Err(RegistryError::SourceTooLarge {
shader: name.to_string(),
size: wgsl.len(),
limit: opts.max_source_bytes,
});
}
let workgroup_size = opts.workgroup_size;
let key = ShaderKey::new(name, &merged);
self.compilations += 1;
Ok(CompiledVariant::new(key, wgsl, workgroup_size))
}
}
impl ShaderRegistry {
pub fn apply_hot_reload(&mut self, tracker: &HotReloadTracker) -> Vec<String> {
let stale: Vec<String> = tracker
.stale_shaders()
.into_iter()
.map(|s| s.to_string())
.collect();
for name in &stale {
self.invalidate(name);
}
stale
}
pub fn registered_count(&self) -> usize {
self.sources.len()
}
pub fn source_bytes(&self, name: &str) -> usize {
self.sources.get(name).map(|s| s.wgsl.len()).unwrap_or(0)
}
}
#[derive(Debug, Clone)]
pub struct ShaderSource {
pub name: String,
pub wgsl: String,
pub workgroup_size: [u32; 3],
pub placeholders: Vec<String>,
}
impl ShaderSource {
pub fn new(name: impl Into<String>, wgsl: impl Into<String>, workgroup_size: [u32; 3]) -> Self {
let wgsl_str: String = wgsl.into();
let placeholders = collect_placeholders(&wgsl_str);
Self {
name: name.into(),
wgsl: wgsl_str,
workgroup_size,
placeholders,
}
}
pub fn instantiate(&self, defines: &HashMap<String, String>) -> String {
let mut out = self.wgsl.clone();
for (k, v) in defines {
let token = format!("{{{{{}}}}}", k);
out = out.replace(&token, v);
}
out
}
pub fn threads_per_group(&self) -> u32 {
self.workgroup_size[0] * self.workgroup_size[1] * self.workgroup_size[2]
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum SpecConstValue {
Int(i64),
Uint(u64),
Float(f64),
Bool(bool),
}
impl SpecConstValue {
pub fn to_wgsl(&self) -> String {
match self {
SpecConstValue::Int(v) => v.to_string(),
SpecConstValue::Uint(v) => format!("{v}u"),
SpecConstValue::Float(v) => format!("{v}"),
SpecConstValue::Bool(v) => v.to_string(),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct SpecializationConstant {
pub name: String,
pub default_value: SpecConstValue,
pub override_value: Option<SpecConstValue>,
}
impl SpecializationConstant {
pub fn new(name: impl Into<String>, default: SpecConstValue) -> Self {
Self {
name: name.into(),
default_value: default,
override_value: None,
}
}
pub fn with_override(mut self, value: SpecConstValue) -> Self {
self.override_value = Some(value);
self
}
pub fn effective_value(&self) -> &SpecConstValue {
self.override_value.as_ref().unwrap_or(&self.default_value)
}
}
#[derive(Debug, Default)]
pub struct PipelineCache {
pub(super) entries: HashMap<PipelineCacheKey, String>,
pub hits: u64,
pub misses: u64,
}
impl PipelineCache {
pub fn new() -> Self {
Self::default()
}
pub fn insert(&mut self, key: PipelineCacheKey, label: impl Into<String>) {
self.entries.insert(key, label.into());
}
pub fn get(&mut self, key: &PipelineCacheKey) -> Option<&str> {
if let Some(v) = self.entries.get(key) {
self.hits += 1;
Some(v.as_str())
} else {
self.misses += 1;
None
}
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn clear(&mut self) {
self.entries.clear();
}
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
return 0.0;
}
self.hits as f64 / total as f64
}
}
#[derive(Debug, Clone)]
pub struct PipelineDescriptor {
pub key: ShaderKey,
pub bind_group_count: u32,
pub push_constant_bytes: u32,
pub label: String,
}
#[allow(clippy::too_many_arguments)]
impl PipelineDescriptor {
pub fn new(
key: ShaderKey,
bind_group_count: u32,
push_constant_bytes: u32,
label: impl Into<String>,
) -> Self {
Self {
key,
bind_group_count,
push_constant_bytes,
label: label.into(),
}
}
pub fn validate(&self) -> Result<(), RegistryError> {
if self.bind_group_count == 0 {
return Err(RegistryError::UnknownShader(
"pipeline has no bind groups".into(),
));
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ShaderCompileOptions {
pub workgroup_size: [u32; 3],
pub extra_defines: HashMap<String, String>,
pub max_source_bytes: usize,
}
impl ShaderCompileOptions {
pub fn new() -> Self {
Self {
workgroup_size: [64, 1, 1],
extra_defines: HashMap::new(),
max_source_bytes: 0,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct SpecConstSet {
pub constants: HashMap<String, SpecializationConstant>,
}
impl SpecConstSet {
pub fn new() -> Self {
Self::default()
}
pub fn add(&mut self, constant: SpecializationConstant) {
self.constants.insert(constant.name.clone(), constant);
}
pub fn to_defines(&self) -> HashMap<String, String> {
self.constants
.iter()
.map(|(name, c)| (name.clone(), c.effective_value().to_wgsl()))
.collect()
}
pub fn has(&self, name: &str) -> bool {
self.constants.contains_key(name)
}
pub fn get_wgsl(&self, name: &str) -> Option<String> {
self.constants
.get(name)
.map(|c| c.effective_value().to_wgsl())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RegistryError {
UnknownShader(String),
MissingDefine {
shader: String,
define: String,
},
SourceTooLarge {
shader: String,
size: usize,
limit: usize,
},
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct PipelineCacheKey {
pub shader_key: ShaderKey,
pub workgroup_size: [u32; 3],
pub push_constant_bytes: u32,
pub layout_hash: u64,
}
impl PipelineCacheKey {
pub fn new(
shader_key: ShaderKey,
workgroup_size: [u32; 3],
push_constant_bytes: u32,
layout_hash: u64,
) -> Self {
Self {
shader_key,
workgroup_size,
push_constant_bytes,
layout_hash,
}
}
pub fn hash_key(&self) -> u64 {
let repr = format!(
"{}_{:?}_{}_{}",
self.shader_key.fingerprint(),
self.workgroup_size,
self.push_constant_bytes,
self.layout_hash,
);
compute_cache_key(&repr)
}
}