use crate::core::Device;
use ash::vk;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum PipelineError {
#[error("Pipeline creation failed: {0}")]
CreationFailed(vk::Result),
#[error("Pipeline layout creation failed: {0}")]
LayoutCreationFailed(vk::Result),
#[error("Missing required field: {0}")]
MissingField(&'static str),
#[error("Invalid pipeline configuration: {0}")]
InvalidConfiguration(String),
}
pub struct PipelineLayout {
layout: vk::PipelineLayout,
}
impl PipelineLayout {
pub fn new(
device: &Device,
descriptor_set_layouts: &[vk::DescriptorSetLayout],
push_constant_ranges: &[vk::PushConstantRange],
) -> Result<Self, PipelineError> {
let layout_info = vk::PipelineLayoutCreateInfo {
set_layout_count: descriptor_set_layouts.len() as u32,
p_set_layouts: descriptor_set_layouts.as_ptr(),
push_constant_range_count: push_constant_ranges.len() as u32,
p_push_constant_ranges: push_constant_ranges.as_ptr(),
..Default::default()
};
let layout = unsafe {
device
.handle()
.create_pipeline_layout(&layout_info, None)
.map_err(PipelineError::LayoutCreationFailed)?
};
Ok(Self { layout })
}
#[inline]
pub fn handle(&self) -> vk::PipelineLayout {
self.layout
}
pub fn destroy(&self, device: &Device) {
unsafe {
device.handle().destroy_pipeline_layout(self.layout, None);
}
}
}
impl Drop for PipelineLayout {
fn drop(&mut self) {
if self.layout != vk::PipelineLayout::null() {
eprintln!(
"WARNING: PipelineLayout dropped without calling .destroy() - potential memory leak"
);
}
}
}
pub struct Pipeline {
pipeline: vk::Pipeline,
layout: PipelineLayout,
bind_point: vk::PipelineBindPoint,
}
impl Pipeline {
pub(crate) fn from_handle(
pipeline: vk::Pipeline,
layout: PipelineLayout,
bind_point: vk::PipelineBindPoint,
) -> Self {
Self {
pipeline,
layout,
bind_point,
}
}
#[inline]
pub fn handle(&self) -> vk::Pipeline {
self.pipeline
}
#[inline]
pub fn layout(&self) -> &PipelineLayout {
&self.layout
}
#[inline]
pub fn bind_point(&self) -> vk::PipelineBindPoint {
self.bind_point
}
pub fn destroy(&self, device: &Device) {
unsafe {
device.handle().destroy_pipeline(self.pipeline, None);
}
self.layout.destroy(device);
}
}
impl Drop for Pipeline {
fn drop(&mut self) {
if self.pipeline != vk::Pipeline::null() {
eprintln!(
"WARNING: Pipeline dropped without calling .destroy() - potential memory leak"
);
}
}
}
#[derive(Clone)]
pub struct ShaderStageInfo {
pub stage: vk::ShaderStageFlags,
pub module: vk::ShaderModule,
pub entry_point: String,
}
pub struct PipelineBuilder {
vertex_shader: Option<ShaderStageInfo>,
fragment_shader: Option<ShaderStageInfo>,
geometry_shader: Option<ShaderStageInfo>,
tessellation_control_shader: Option<ShaderStageInfo>,
tessellation_evaluation_shader: Option<ShaderStageInfo>,
vertex_binding_descriptions: Vec<vk::VertexInputBindingDescription>,
vertex_attribute_descriptions: Vec<vk::VertexInputAttributeDescription>,
topology: vk::PrimitiveTopology,
primitive_restart_enable: bool,
viewports: Vec<vk::Viewport>,
scissors: Vec<vk::Rect2D>,
polygon_mode: vk::PolygonMode,
cull_mode: vk::CullModeFlags,
front_face: vk::FrontFace,
depth_bias_enable: bool,
line_width: f32,
sample_count: vk::SampleCountFlags,
depth_test_enable: bool,
depth_write_enable: bool,
depth_compare_op: vk::CompareOp,
stencil_test_enable: bool,
color_blend_attachments: Vec<vk::PipelineColorBlendAttachmentState>,
blend_constants: [f32; 4],
dynamic_states: Vec<vk::DynamicState>,
render_pass: Option<vk::RenderPass>,
subpass: u32,
color_attachment_formats: Vec<vk::Format>,
depth_attachment_format: Option<vk::Format>,
stencil_attachment_format: Option<vk::Format>,
}
impl Default for PipelineBuilder {
fn default() -> Self {
Self::new()
}
}
impl PipelineBuilder {
pub fn new() -> Self {
Self {
vertex_shader: None,
fragment_shader: None,
geometry_shader: None,
tessellation_control_shader: None,
tessellation_evaluation_shader: None,
vertex_binding_descriptions: Vec::new(),
vertex_attribute_descriptions: Vec::new(),
topology: vk::PrimitiveTopology::TRIANGLE_LIST,
primitive_restart_enable: false,
viewports: Vec::new(),
scissors: Vec::new(),
polygon_mode: vk::PolygonMode::FILL,
cull_mode: vk::CullModeFlags::BACK,
front_face: vk::FrontFace::COUNTER_CLOCKWISE,
depth_bias_enable: false,
line_width: 1.0,
sample_count: vk::SampleCountFlags::TYPE_1,
depth_test_enable: true,
depth_write_enable: true,
depth_compare_op: vk::CompareOp::LESS,
stencil_test_enable: false,
color_blend_attachments: vec![vk::PipelineColorBlendAttachmentState {
blend_enable: vk::FALSE,
color_write_mask: vk::ColorComponentFlags::RGBA,
..Default::default()
}],
blend_constants: [0.0; 4],
dynamic_states: Vec::new(),
render_pass: None,
subpass: 0,
color_attachment_formats: Vec::new(),
depth_attachment_format: None,
stencil_attachment_format: None,
}
}
pub fn vertex_shader(mut self, module: vk::ShaderModule, entry_point: &str) -> Self {
self.vertex_shader = Some(ShaderStageInfo {
stage: vk::ShaderStageFlags::VERTEX,
module,
entry_point: entry_point.to_string(),
});
self
}
pub fn fragment_shader(mut self, module: vk::ShaderModule, entry_point: &str) -> Self {
self.fragment_shader = Some(ShaderStageInfo {
stage: vk::ShaderStageFlags::FRAGMENT,
module,
entry_point: entry_point.to_string(),
});
self
}
pub fn geometry_shader(mut self, module: vk::ShaderModule, entry_point: &str) -> Self {
self.geometry_shader = Some(ShaderStageInfo {
stage: vk::ShaderStageFlags::GEOMETRY,
module,
entry_point: entry_point.to_string(),
});
self
}
pub fn vertex_bindings(mut self, bindings: Vec<vk::VertexInputBindingDescription>) -> Self {
self.vertex_binding_descriptions = bindings;
self
}
pub fn vertex_attributes(
mut self,
attributes: Vec<vk::VertexInputAttributeDescription>,
) -> Self {
self.vertex_attribute_descriptions = attributes;
self
}
pub fn topology(mut self, topology: vk::PrimitiveTopology) -> Self {
self.topology = topology;
self
}
pub fn viewport(mut self, viewport: vk::Viewport) -> Self {
self.viewports = vec![viewport];
self
}
pub fn viewports(mut self, viewports: Vec<vk::Viewport>) -> Self {
self.viewports = viewports;
self
}
pub fn scissor(mut self, scissor: vk::Rect2D) -> Self {
self.scissors = vec![scissor];
self
}
pub fn scissors(mut self, scissors: Vec<vk::Rect2D>) -> Self {
self.scissors = scissors;
self
}
pub fn polygon_mode(mut self, mode: vk::PolygonMode) -> Self {
self.polygon_mode = mode;
self
}
pub fn cull_mode(mut self, mode: vk::CullModeFlags) -> Self {
self.cull_mode = mode;
self
}
pub fn front_face(mut self, front_face: vk::FrontFace) -> Self {
self.front_face = front_face;
self
}
pub fn line_width(mut self, width: f32) -> Self {
self.line_width = width;
self
}
pub fn sample_count(mut self, count: vk::SampleCountFlags) -> Self {
self.sample_count = count;
self
}
pub fn depth_test(mut self, enable: bool) -> Self {
self.depth_test_enable = enable;
self
}
pub fn depth_write(mut self, enable: bool) -> Self {
self.depth_write_enable = enable;
self
}
pub fn depth_compare_op(mut self, op: vk::CompareOp) -> Self {
self.depth_compare_op = op;
self
}
pub fn color_blend_attachments(
mut self,
attachments: Vec<vk::PipelineColorBlendAttachmentState>,
) -> Self {
self.color_blend_attachments = attachments;
self
}
pub fn dynamic_states(mut self, states: Vec<vk::DynamicState>) -> Self {
self.dynamic_states = states;
self
}
pub fn render_pass(mut self, render_pass: vk::RenderPass, subpass: u32) -> Self {
self.render_pass = Some(render_pass);
self.subpass = subpass;
self
}
pub fn color_attachment_formats(mut self, formats: Vec<vk::Format>) -> Self {
self.color_attachment_formats = formats;
self
}
pub fn depth_attachment_format(mut self, format: vk::Format) -> Self {
self.depth_attachment_format = Some(format);
self
}
pub fn build_graphics(
self,
device: &Device,
descriptor_set_layouts: &[vk::DescriptorSetLayout],
push_constant_ranges: &[vk::PushConstantRange],
) -> Result<Pipeline, PipelineError> {
let vertex_shader = self
.vertex_shader
.ok_or(PipelineError::MissingField("vertex_shader"))?;
let mut shader_stages = Vec::new();
let shader_infos: Vec<&ShaderStageInfo> = vec![
Some(&vertex_shader),
self.fragment_shader.as_ref(),
self.geometry_shader.as_ref(),
self.tessellation_control_shader.as_ref(),
self.tessellation_evaluation_shader.as_ref(),
]
.into_iter()
.flatten()
.collect();
let entry_points: Vec<_> = shader_infos
.iter()
.map(|s| std::ffi::CString::new(s.entry_point.clone()).unwrap())
.collect();
for (idx, info) in shader_infos.iter().enumerate() {
shader_stages.push(vk::PipelineShaderStageCreateInfo {
stage: info.stage,
module: info.module,
p_name: entry_points[idx].as_ptr(),
..Default::default()
});
}
let vertex_input_state = vk::PipelineVertexInputStateCreateInfo {
vertex_binding_description_count: self.vertex_binding_descriptions.len() as u32,
p_vertex_binding_descriptions: self.vertex_binding_descriptions.as_ptr(),
vertex_attribute_description_count: self.vertex_attribute_descriptions.len() as u32,
p_vertex_attribute_descriptions: self.vertex_attribute_descriptions.as_ptr(),
..Default::default()
};
let input_assembly_state = vk::PipelineInputAssemblyStateCreateInfo {
topology: self.topology,
primitive_restart_enable: if self.primitive_restart_enable {
vk::TRUE
} else {
vk::FALSE
},
..Default::default()
};
let viewport_state = vk::PipelineViewportStateCreateInfo {
viewport_count: if self.viewports.is_empty() {
1
} else {
self.viewports.len() as u32
},
p_viewports: if self.viewports.is_empty() {
std::ptr::null()
} else {
self.viewports.as_ptr()
},
scissor_count: if self.scissors.is_empty() {
1
} else {
self.scissors.len() as u32
},
p_scissors: if self.scissors.is_empty() {
std::ptr::null()
} else {
self.scissors.as_ptr()
},
..Default::default()
};
let rasterization_state = vk::PipelineRasterizationStateCreateInfo {
polygon_mode: self.polygon_mode,
cull_mode: self.cull_mode,
front_face: self.front_face,
depth_bias_enable: if self.depth_bias_enable {
vk::TRUE
} else {
vk::FALSE
},
line_width: self.line_width,
..Default::default()
};
let multisample_state = vk::PipelineMultisampleStateCreateInfo {
rasterization_samples: self.sample_count,
..Default::default()
};
let depth_stencil_state = vk::PipelineDepthStencilStateCreateInfo {
depth_test_enable: if self.depth_test_enable {
vk::TRUE
} else {
vk::FALSE
},
depth_write_enable: if self.depth_write_enable {
vk::TRUE
} else {
vk::FALSE
},
depth_compare_op: self.depth_compare_op,
stencil_test_enable: if self.stencil_test_enable {
vk::TRUE
} else {
vk::FALSE
},
..Default::default()
};
let color_blend_state = vk::PipelineColorBlendStateCreateInfo {
attachment_count: self.color_blend_attachments.len() as u32,
p_attachments: self.color_blend_attachments.as_ptr(),
blend_constants: self.blend_constants,
..Default::default()
};
let dynamic_state = if !self.dynamic_states.is_empty() {
Some(vk::PipelineDynamicStateCreateInfo {
dynamic_state_count: self.dynamic_states.len() as u32,
p_dynamic_states: self.dynamic_states.as_ptr(),
..Default::default()
})
} else {
None
};
let layout = PipelineLayout::new(device, descriptor_set_layouts, push_constant_ranges)?;
let mut pipeline_info = vk::GraphicsPipelineCreateInfo {
stage_count: shader_stages.len() as u32,
p_stages: shader_stages.as_ptr(),
p_vertex_input_state: &vertex_input_state,
p_input_assembly_state: &input_assembly_state,
p_viewport_state: &viewport_state,
p_rasterization_state: &rasterization_state,
p_multisample_state: &multisample_state,
p_depth_stencil_state: &depth_stencil_state,
p_color_blend_state: &color_blend_state,
p_dynamic_state: dynamic_state
.as_ref()
.map(|s| s as *const _)
.unwrap_or(std::ptr::null()),
layout: layout.handle(),
render_pass: self.render_pass.unwrap_or(vk::RenderPass::null()),
subpass: self.subpass,
..Default::default()
};
let mut rendering_info = if !self.color_attachment_formats.is_empty() {
Some(vk::PipelineRenderingCreateInfo {
color_attachment_count: self.color_attachment_formats.len() as u32,
p_color_attachment_formats: self.color_attachment_formats.as_ptr(),
depth_attachment_format: self
.depth_attachment_format
.unwrap_or(vk::Format::UNDEFINED),
stencil_attachment_format: self
.stencil_attachment_format
.unwrap_or(vk::Format::UNDEFINED),
..Default::default()
})
} else {
None
};
if let Some(ref mut info) = rendering_info {
pipeline_info.p_next = info as *const _ as *const std::ffi::c_void;
}
let pipelines = unsafe {
device
.handle()
.create_graphics_pipelines(vk::PipelineCache::null(), &[pipeline_info], None)
.map_err(|e| PipelineError::CreationFailed(e.1))?
};
Ok(Pipeline {
pipeline: pipelines[0],
layout,
bind_point: vk::PipelineBindPoint::GRAPHICS,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{Device, DeviceCreateInfo, Instance, InstanceCreateInfo, QueueCreateInfo};
fn create_test_device() -> (Instance, Device) {
let instance = Instance::new(InstanceCreateInfo {
enable_validation: false,
..Default::default()
})
.unwrap();
let physical_devices = instance.enumerate_physical_devices().unwrap();
let physical_device = physical_devices[0];
let graphics_family = unsafe {
instance
.get_physical_device_queue_family_properties(physical_device)
.iter()
.enumerate()
.find(|(_, qf)| qf.queue_flags.contains(vk::QueueFlags::GRAPHICS))
.map(|(i, _)| i as u32)
.unwrap()
};
let device = Device::new(
&instance,
physical_device,
DeviceCreateInfo {
queue_create_infos: vec![QueueCreateInfo {
queue_family_index: graphics_family,
queue_count: 1,
queue_priorities: vec![1.0],
}],
..Default::default()
},
)
.unwrap();
(instance, device)
}
#[test]
fn test_pipeline_layout_creation() {
let (_instance, device) = create_test_device();
let layout = PipelineLayout::new(&device, &[], &[]).unwrap();
assert_ne!(layout.handle(), vk::PipelineLayout::null());
layout.destroy(&device);
}
#[test]
fn test_pipeline_builder_defaults() {
let builder = PipelineBuilder::new();
assert_eq!(builder.topology, vk::PrimitiveTopology::TRIANGLE_LIST);
assert_eq!(builder.polygon_mode, vk::PolygonMode::FILL);
assert_eq!(builder.cull_mode, vk::CullModeFlags::BACK);
assert_eq!(builder.front_face, vk::FrontFace::COUNTER_CLOCKWISE);
assert!(builder.depth_test_enable);
assert!(builder.depth_write_enable);
}
#[test]
fn test_pipeline_builder_fluent_api() {
let builder = PipelineBuilder::new()
.topology(vk::PrimitiveTopology::LINE_LIST)
.cull_mode(vk::CullModeFlags::NONE)
.depth_test(false);
assert_eq!(builder.topology, vk::PrimitiveTopology::LINE_LIST);
assert_eq!(builder.cull_mode, vk::CullModeFlags::NONE);
assert!(!builder.depth_test_enable);
}
}