use std::cell::RefCell;
use std::collections::HashMap;
use awsm_renderer_core::command::compute_pass::ComputePassDescriptor;
use crate::{
error::Result,
pipelines::compute_pipeline::ComputePipelineKey,
render::RenderContext,
render_passes::{
material_classify::{
bind_group::MaterialClassifyBindGroups, pipeline::MaterialClassifyPipelines,
},
RenderPassInitContext,
},
};
pub struct MaterialClassifyRenderPass {
pub bind_groups: MaterialClassifyBindGroups,
pub pipeline_cache: RefCell<HashMap<(u64, Option<u32>), ComputePipelineKey>>,
}
impl MaterialClassifyRenderPass {
pub fn dynamic_cache_len(&self) -> usize {
self.pipeline_cache.borrow().len()
}
pub fn prune_dynamic_pipeline_cache(
&mut self,
current_dispatch_hash: u64,
) -> Vec<ComputePipelineKey> {
let mut dropped = Vec::new();
self.pipeline_cache
.borrow_mut()
.retain(|(hash, _msaa), key| {
if *hash == current_dispatch_hash {
true
} else {
dropped.push(*key);
false
}
});
dropped
}
pub async fn new(ctx: &mut RenderPassInitContext<'_>) -> Result<Self> {
let bind_groups = MaterialClassifyBindGroups::new(ctx).await?;
let first_party_entries = crate::dynamic_materials::first_party_bucket_entries();
MaterialClassifyPipelines::warm_pool(ctx, &bind_groups, &first_party_entries).await?;
Ok(Self {
bind_groups,
pipeline_cache: RefCell::new(HashMap::new()),
})
}
pub fn render(&self, ctx: &RenderContext) -> Result<()> {
let compute_pass = ctx.command_encoder.begin_compute_pass(Some(
&ComputePassDescriptor::new(Some("Material Classify Pass")).into(),
));
let msaa = ctx.anti_aliasing.msaa_sample_count;
let key = (ctx.dynamic_materials.dispatch_hash_cached(), msaa);
let pipeline_key_opt = self.pipeline_cache.borrow().get(&key).copied();
let Some(pipeline_key) = pipeline_key_opt else {
compute_pass.end();
return Ok(());
};
compute_pass.set_pipeline(ctx.pipelines.compute.get(pipeline_key)?);
compute_pass.set_bind_group(0, self.bind_groups.get_bind_group()?, None)?;
let workgroups_x = ctx.render_texture_views.width.div_ceil(8);
let workgroups_y = ctx.render_texture_views.height.div_ceil(8);
compute_pass.dispatch_workgroups(workgroups_x, Some(workgroups_y), Some(1));
compute_pass.end();
Ok(())
}
}