use crate::error::{GpuAdvancedError, Result};
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use wgpu::{BindGroupLayout, ComputePipeline, Device};
pub struct PipelineBuilder {
device: Arc<Device>,
stages: Vec<PipelineStage>,
bind_group_layouts: Vec<BindGroupLayout>,
cache: Arc<RwLock<PipelineCache>>,
config: PipelineConfig,
}
impl PipelineBuilder {
pub fn new(device: Arc<Device>) -> Self {
Self {
device,
stages: Vec::new(),
bind_group_layouts: Vec::new(),
cache: Arc::new(RwLock::new(PipelineCache::new())),
config: PipelineConfig::default(),
}
}
pub fn add_stage(mut self, stage: PipelineStage) -> Self {
self.stages.push(stage);
self
}
pub fn add_bind_group_layout(mut self, layout: BindGroupLayout) -> Self {
self.bind_group_layouts.push(layout);
self
}
pub fn with_config(mut self, config: PipelineConfig) -> Self {
self.config = config;
self
}
pub fn with_cache(mut self, cache: Arc<RwLock<PipelineCache>>) -> Self {
self.cache = cache;
self
}
pub fn build(mut self) -> Result<Pipeline> {
if self.stages.is_empty() {
return Err(GpuAdvancedError::ConfigError(
"Pipeline must have at least one stage".to_string(),
));
}
self.validate()?;
if self.config.optimize {
let optimized_stages = self.optimize_stages_immutable(&self.stages)?;
self.stages = optimized_stages;
}
let device = self.device;
let cache = self.cache;
let config = self.config;
let bind_group_layouts = self.bind_group_layouts;
let stages = self.stages;
let mut compute_pipelines = Vec::new();
for stage in &stages {
let layout_refs: Vec<Option<&BindGroupLayout>> = bind_group_layouts
.iter()
.take(stage.bind_group_count)
.map(Some)
.collect();
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some(&format!("shader_{}", stage.label)),
source: wgpu::ShaderSource::Wgsl(stage.shader_source.as_str().into()),
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some(&format!("layout_{}", stage.label)),
bind_group_layouts: &layout_refs,
immediate_size: 0,
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some(&stage.label),
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some(&stage.entry_point),
compilation_options: Default::default(),
cache: None,
});
compute_pipelines.push(Arc::new(pipeline));
}
Ok(Pipeline {
device,
stages,
compute_pipelines,
cache,
config,
})
}
fn validate(&self) -> Result<()> {
for (i, stage) in self.stages.iter().enumerate() {
for dep in &stage.dependencies {
if *dep >= i {
return Err(GpuAdvancedError::ConfigError(format!(
"Stage {} has forward or circular dependency on stage {}",
i, dep
)));
}
}
}
for stage in &self.stages {
if stage.bind_group_count > self.bind_group_layouts.len() {
return Err(GpuAdvancedError::ConfigError(format!(
"Stage '{}' requires {} bind groups but only {} layouts provided",
stage.label,
stage.bind_group_count,
self.bind_group_layouts.len()
)));
}
}
Ok(())
}
fn optimize_stages_immutable(&self, stages: &[PipelineStage]) -> Result<Vec<PipelineStage>> {
let mut optimized = Vec::new();
let mut skip_next = false;
for i in 0..stages.len() {
if skip_next {
skip_next = false;
continue;
}
let stage = &stages[i];
if i + 1 < stages.len() && self.can_merge_stages(stage, &stages[i + 1]) {
let merged = self.merge_stages(stage, &stages[i + 1])?;
optimized.push(merged);
skip_next = true;
} else {
optimized.push(stage.clone());
}
}
Ok(optimized)
}
fn can_merge_stages(&self, _stage1: &PipelineStage, _stage2: &PipelineStage) -> bool {
false }
fn merge_stages(
&self,
stage1: &PipelineStage,
stage2: &PipelineStage,
) -> Result<PipelineStage> {
Ok(PipelineStage {
label: format!("{}_merged_{}", stage1.label, stage2.label),
shader_source: stage1.shader_source.clone(),
entry_point: stage1.entry_point.clone(),
workgroup_size: stage1.workgroup_size,
bind_group_count: stage1.bind_group_count.max(stage2.bind_group_count),
dependencies: stage1.dependencies.clone(),
})
}
#[allow(dead_code)]
fn build_compute_pipeline(&self, stage: &PipelineStage) -> Result<Arc<ComputePipeline>> {
let cache_key = self.compute_cache_key(stage);
{
let cache = self.cache.read();
if let Some(pipeline) = cache.get(&cache_key) {
return Ok(pipeline);
}
}
let shader = self
.device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some(&format!("shader_{}", stage.label)),
source: wgpu::ShaderSource::Wgsl(stage.shader_source.as_str().into()),
});
let layout_refs: Vec<Option<&BindGroupLayout>> = self
.bind_group_layouts
.iter()
.take(stage.bind_group_count)
.map(Some)
.collect();
let pipeline_layout = self
.device
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some(&format!("layout_{}", stage.label)),
bind_group_layouts: &layout_refs,
immediate_size: 0,
});
let pipeline = self
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some(&stage.label),
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some(&stage.entry_point),
compilation_options: Default::default(),
cache: None,
});
let pipeline = Arc::new(pipeline);
{
let mut cache = self.cache.write();
cache.insert(cache_key, pipeline.clone());
}
Ok(pipeline)
}
#[allow(dead_code)]
fn compute_cache_key(&self, stage: &PipelineStage) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
stage.shader_source.hash(&mut hasher);
stage.entry_point.hash(&mut hasher);
stage.bind_group_count.hash(&mut hasher);
format!("{}_{}", stage.label, hasher.finish())
}
}
#[derive(Debug, Clone)]
pub struct PipelineStage {
pub label: String,
pub shader_source: String,
pub entry_point: String,
pub workgroup_size: (u32, u32, u32),
pub bind_group_count: usize,
pub dependencies: Vec<usize>,
}
impl PipelineStage {
pub fn new(
label: impl Into<String>,
shader_source: impl Into<String>,
entry_point: impl Into<String>,
) -> Self {
Self {
label: label.into(),
shader_source: shader_source.into(),
entry_point: entry_point.into(),
workgroup_size: (8, 8, 1),
bind_group_count: 1,
dependencies: Vec::new(),
}
}
pub fn with_workgroup_size(mut self, x: u32, y: u32, z: u32) -> Self {
self.workgroup_size = (x, y, z);
self
}
pub fn with_bind_groups(mut self, count: usize) -> Self {
self.bind_group_count = count;
self
}
pub fn depends_on(mut self, stage_index: usize) -> Self {
self.dependencies.push(stage_index);
self
}
}
#[derive(Debug, Clone)]
pub struct PipelineConfig {
pub optimize: bool,
pub cache: bool,
pub max_stages: usize,
}
impl Default for PipelineConfig {
fn default() -> Self {
Self {
optimize: true,
cache: true,
max_stages: 16,
}
}
}
pub struct Pipeline {
device: Arc<Device>,
stages: Vec<PipelineStage>,
compute_pipelines: Vec<Arc<ComputePipeline>>,
#[allow(dead_code)]
cache: Arc<RwLock<PipelineCache>>,
config: PipelineConfig,
}
impl Pipeline {
pub fn device(&self) -> &Device {
&self.device
}
pub fn stage_count(&self) -> usize {
self.stages.len()
}
pub fn get_stage(&self, index: usize) -> Option<&PipelineStage> {
self.stages.get(index)
}
pub fn get_compute_pipeline(&self, index: usize) -> Option<&ComputePipeline> {
self.compute_pipelines.get(index).map(|p| p.as_ref())
}
pub fn compute_pipelines(&self) -> &[Arc<ComputePipeline>] {
&self.compute_pipelines
}
pub fn info(&self) -> PipelineInfo {
PipelineInfo {
stage_count: self.stages.len(),
total_bind_groups: self.stages.iter().map(|s| s.bind_group_count).sum(),
optimized: self.config.optimize,
cached: self.config.cache,
}
}
pub fn visualize(&self) -> String {
let mut output = String::from("Pipeline Structure:\n");
for (i, stage) in self.stages.iter().enumerate() {
output.push_str(&format!(" Stage {}: {}\n", i, stage.label));
output.push_str(&format!(" Entry: {}\n", stage.entry_point));
output.push_str(&format!(
" Workgroup: {}x{}x{}\n",
stage.workgroup_size.0, stage.workgroup_size.1, stage.workgroup_size.2
));
output.push_str(&format!(" Bind groups: {}\n", stage.bind_group_count));
if !stage.dependencies.is_empty() {
output.push_str(&format!(" Dependencies: {:?}\n", stage.dependencies));
}
}
output
}
}
#[derive(Debug, Clone)]
pub struct PipelineInfo {
pub stage_count: usize,
pub total_bind_groups: usize,
pub optimized: bool,
pub cached: bool,
}
pub struct PipelineCache {
pipelines: HashMap<String, Arc<ComputePipeline>>,
max_size: usize,
}
impl PipelineCache {
pub fn new() -> Self {
Self {
pipelines: HashMap::new(),
max_size: 128,
}
}
pub fn get(&self, key: &str) -> Option<Arc<ComputePipeline>> {
self.pipelines.get(key).cloned()
}
pub fn insert(&mut self, key: String, pipeline: Arc<ComputePipeline>) {
if self.pipelines.len() >= self.max_size {
if let Some(first_key) = self.pipelines.keys().next().cloned() {
self.pipelines.remove(&first_key);
}
}
self.pipelines.insert(key, pipeline);
}
pub fn clear(&mut self) {
self.pipelines.clear();
}
pub fn size(&self) -> usize {
self.pipelines.len()
}
}
impl Default for PipelineCache {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pipeline_stage_creation() {
let stage = PipelineStage::new("test", "shader code", "main")
.with_workgroup_size(16, 16, 1)
.with_bind_groups(2);
assert_eq!(stage.label, "test");
assert_eq!(stage.workgroup_size, (16, 16, 1));
assert_eq!(stage.bind_group_count, 2);
}
#[test]
fn test_pipeline_cache() {
let cache = PipelineCache::new();
assert_eq!(cache.size(), 0);
assert!(cache.get("nonexistent").is_none());
}
#[test]
fn test_pipeline_config() {
let config = PipelineConfig::default();
assert!(config.optimize);
assert!(config.cache);
assert_eq!(config.max_stages, 16);
}
}