use super::d3d;
use super::d3d12;
use super::dxgi;
use crate::dx12::RafxDeviceContextDx12;
use crate::{
RafxComputePipelineDef, RafxGraphicsPipelineDef, RafxPipelineType, RafxResult,
RafxRootSignature, RafxShaderStageFlags, RafxVertexAttributeRate,
MAX_RENDER_TARGET_ATTACHMENTS,
};
use std::ffi::CString;
#[derive(Debug)]
pub struct RafxPipelineDx12 {
pipeline_type: RafxPipelineType,
root_signature: RafxRootSignature,
pipeline: d3d12::ID3D12PipelineState,
topology: d3d::D3D_PRIMITIVE_TOPOLOGY,
vertex_buffer_strides: [u32; crate::MAX_VERTEX_INPUT_BINDINGS],
}
impl RafxPipelineDx12 {
pub fn pipeline_type(&self) -> RafxPipelineType {
self.pipeline_type
}
pub fn root_signature(&self) -> &RafxRootSignature {
&self.root_signature
}
pub fn pipeline(&self) -> &d3d12::ID3D12PipelineState {
&self.pipeline
}
pub fn topology(&self) -> d3d::D3D_PRIMITIVE_TOPOLOGY {
self.topology
}
pub fn vertex_buffer_strides(&self) -> &[u32; crate::MAX_VERTEX_INPUT_BINDINGS] {
&self.vertex_buffer_strides
}
pub fn set_debug_name(
&self,
name: impl AsRef<str>,
) {
if self
.root_signature
.dx12_root_signature()
.unwrap()
.inner
.device_context
.device_info()
.debug_names_enabled
{
unsafe {
let name: &str = name.as_ref();
let utf16: Vec<_> = name.encode_utf16().chain(std::iter::once(0)).collect();
self.pipeline
.SetName(windows::core::PCWSTR::from_raw(utf16.as_ptr()))
.unwrap();
}
}
}
pub fn new_graphics_pipeline(
device_context: &RafxDeviceContextDx12,
pipeline_def: &RafxGraphicsPipelineDef,
) -> RafxResult<Self> {
let mut vs_bytecode = None;
let mut ps_bytecode = None;
let mut ds_bytecode = None;
let mut hs_bytecode = None;
let mut gs_bytecode = None;
for stage in pipeline_def.shader.dx12_shader().unwrap().stages() {
let module = stage.shader_module.dx12_shader_module().unwrap();
if stage
.reflection
.shader_stage
.intersects(RafxShaderStageFlags::VERTEX)
{
vs_bytecode = Some(
module.get_or_compile_bytecode(&stage.reflection.entry_point_name, "vs_6_0")?,
);
}
if stage
.reflection
.shader_stage
.intersects(RafxShaderStageFlags::FRAGMENT)
{
ps_bytecode = Some(
module.get_or_compile_bytecode(&stage.reflection.entry_point_name, "ps_6_0")?,
);
}
if stage
.reflection
.shader_stage
.intersects(RafxShaderStageFlags::TESSELLATION_EVALUATION)
{
ds_bytecode = Some(
module.get_or_compile_bytecode(&stage.reflection.entry_point_name, "ds_6_0")?,
);
}
if stage
.reflection
.shader_stage
.intersects(RafxShaderStageFlags::TESSELLATION_CONTROL)
{
hs_bytecode = Some(
module.get_or_compile_bytecode(&stage.reflection.entry_point_name, "hs_6_0")?,
);
}
if stage
.reflection
.shader_stage
.intersects(RafxShaderStageFlags::GEOMETRY)
{
gs_bytecode = Some(
module.get_or_compile_bytecode(&stage.reflection.entry_point_name, "gs_6_0")?,
);
}
}
let stream_out_desc = d3d12::D3D12_STREAM_OUTPUT_DESC::default();
let depth_stencil_desc = if pipeline_def.depth_stencil_format.is_some() {
super::internal::conversions::depth_state_depth_stencil_desc(pipeline_def.depth_state)
} else {
d3d12::D3D12_DEPTH_STENCIL_DESC::default()
};
let mut input_elements = Vec::<d3d12::D3D12_INPUT_ELEMENT_DESC>::with_capacity(16);
let mut semantic_name_cstrings = Vec::default();
let mut vertex_buffer_strides = [0; crate::MAX_VERTEX_INPUT_BINDINGS];
for (i, vertex_buffer_binding) in pipeline_def.vertex_layout.buffers.iter().enumerate() {
vertex_buffer_strides[i] = vertex_buffer_binding.stride;
}
for vertex_attribute in &pipeline_def.vertex_layout.attributes {
let mut ending_digit_count = 0;
for char in vertex_attribute.hlsl_semantic.chars().rev() {
if char.is_ascii_digit() {
ending_digit_count += 1;
}
}
let (semantic_name, semantic_index) = if ending_digit_count > 0 {
let name_chars = &vertex_attribute.hlsl_semantic
[0..(vertex_attribute.hlsl_semantic.len() - ending_digit_count)];
let number_chars =
&vertex_attribute.hlsl_semantic[(vertex_attribute.hlsl_semantic.len()
- ending_digit_count)
..vertex_attribute.hlsl_semantic.len()];
let ending_number = number_chars.parse::<u32>().unwrap();
(CString::new(name_chars).unwrap(), ending_number)
} else {
(CString::new(&*vertex_attribute.hlsl_semantic).unwrap(), 0)
};
let input_format = vertex_attribute.format.into();
log::trace!(
"pushing input element {:?}/{} format {:?}={:?}",
semantic_name,
semantic_index,
vertex_attribute.format,
input_format
);
let semantic_name_ptr = semantic_name.as_ptr();
semantic_name_cstrings.push(semantic_name);
let mut input_element = d3d12::D3D12_INPUT_ELEMENT_DESC {
SemanticName: windows::core::PCSTR::from_raw(semantic_name_ptr as *const u8),
SemanticIndex: semantic_index,
Format: input_format,
InputSlot: vertex_attribute.buffer_index,
AlignedByteOffset: vertex_attribute.byte_offset,
InputSlotClass: d3d12::D3D12_INPUT_CLASSIFICATION_PER_VERTEX_DATA,
InstanceDataStepRate: 0,
};
match pipeline_def.vertex_layout.buffers[vertex_attribute.buffer_index as usize].rate {
RafxVertexAttributeRate::Vertex => {
input_element.InputSlotClass =
d3d12::D3D12_INPUT_CLASSIFICATION_PER_VERTEX_DATA;
input_element.InstanceDataStepRate = 0;
}
RafxVertexAttributeRate::Instance => {
input_element.InputSlotClass =
d3d12::D3D12_INPUT_CLASSIFICATION_PER_INSTANCE_DATA;
input_element.InstanceDataStepRate = 1;
}
}
input_elements.push(input_element);
}
let input_layout_desc = if input_elements.is_empty() {
d3d12::D3D12_INPUT_LAYOUT_DESC::default()
} else {
d3d12::D3D12_INPUT_LAYOUT_DESC {
pInputElementDescs: input_elements.as_ptr(),
NumElements: input_elements.len() as u32,
}
};
let sample_desc = dxgi::Common::DXGI_SAMPLE_DESC {
Count: pipeline_def.sample_count.as_u32(),
Quality: 0, };
let cached_pipeline_state = d3d12::D3D12_CACHED_PIPELINE_STATE::default();
let render_target_count = pipeline_def
.color_formats
.len()
.min(MAX_RENDER_TARGET_ATTACHMENTS);
let mut rtv_formats = [dxgi::Common::DXGI_FORMAT_UNKNOWN; MAX_RENDER_TARGET_ATTACHMENTS];
for i in 0..render_target_count {
rtv_formats[i] = pipeline_def.color_formats[i].into();
}
let pipeline_state_desc = d3d12::D3D12_GRAPHICS_PIPELINE_STATE_DESC {
pRootSignature: ::windows::core::ManuallyDrop::new(
&pipeline_def
.root_signature
.dx12_root_signature()
.unwrap()
.dx12_root_signature()
.clone(),
),
VS: vs_bytecode.map(|x| *x.bytecode()).unwrap_or_default(),
PS: ps_bytecode.map(|x| *x.bytecode()).unwrap_or_default(),
DS: ds_bytecode.map(|x| *x.bytecode()).unwrap_or_default(),
GS: gs_bytecode.map(|x| *x.bytecode()).unwrap_or_default(),
HS: hs_bytecode.map(|x| *x.bytecode()).unwrap_or_default(),
StreamOutput: stream_out_desc,
BlendState: super::internal::conversions::blend_state_blend_state_desc(
pipeline_def.blend_state,
render_target_count,
),
SampleMask: u32::MAX,
RasterizerState: super::internal::conversions::rasterizer_state_rasterizer_desc(
pipeline_def.rasterizer_state,
),
DepthStencilState: depth_stencil_desc, InputLayout: input_layout_desc,
IBStripCutValue: d3d12::D3D12_INDEX_BUFFER_STRIP_CUT_VALUE_DISABLED,
PrimitiveTopologyType: pipeline_def.primitive_topology.into(),
NumRenderTargets: render_target_count as u32,
RTVFormats: rtv_formats,
DSVFormat: pipeline_def
.depth_stencil_format
.map(|x| x.into())
.unwrap_or(dxgi::Common::DXGI_FORMAT_UNKNOWN),
SampleDesc: sample_desc,
CachedPSO: cached_pipeline_state,
Flags: d3d12::D3D12_PIPELINE_STATE_FLAG_NONE,
NodeMask: 0,
};
let pipeline: d3d12::ID3D12PipelineState = unsafe {
device_context
.d3d12_device()
.CreateGraphicsPipelineState(&pipeline_state_desc)?
};
let topology = pipeline_def.primitive_topology.into();
let pipeline = RafxPipelineDx12 {
root_signature: pipeline_def.root_signature.clone(),
pipeline_type: pipeline_def.root_signature.pipeline_type(),
pipeline,
topology,
vertex_buffer_strides,
};
if let Some(debug_name) = pipeline_def.debug_name {
pipeline.set_debug_name(debug_name)
}
Ok(pipeline)
}
pub fn new_compute_pipeline(
device_context: &RafxDeviceContextDx12,
pipeline_def: &RafxComputePipelineDef,
) -> RafxResult<Self> {
let mut cs_bytecode = None;
for stage in pipeline_def.shader.dx12_shader().unwrap().stages() {
let module = stage.shader_module.dx12_shader_module().unwrap();
if stage
.reflection
.shader_stage
.intersects(RafxShaderStageFlags::COMPUTE)
{
assert!(cs_bytecode.is_none());
cs_bytecode = Some(
module.get_or_compile_bytecode(&stage.reflection.entry_point_name, "cs_6_0")?,
);
} else {
Err("Tried to create compute pipeline with a non-compute shader stage specified")?;
}
}
let cached_pipeline_state = d3d12::D3D12_CACHED_PIPELINE_STATE::default();
log::info!(
"creating pipeline with root descriptor {:?}",
pipeline_def
.root_signature
.dx12_root_signature()
.unwrap()
.dx12_root_signature()
);
let pipeline_state_desc = d3d12::D3D12_COMPUTE_PIPELINE_STATE_DESC {
pRootSignature: ::windows::core::ManuallyDrop::new(
&pipeline_def
.root_signature
.dx12_root_signature()
.unwrap()
.dx12_root_signature()
.clone(),
),
CS: cs_bytecode.map(|x| *x.bytecode()).unwrap_or_default(),
CachedPSO: cached_pipeline_state,
Flags: d3d12::D3D12_PIPELINE_STATE_FLAG_NONE,
NodeMask: 0,
};
let pipeline: d3d12::ID3D12PipelineState = unsafe {
device_context
.d3d12_device()
.CreateComputePipelineState(&pipeline_state_desc)?
};
let pipeline = RafxPipelineDx12 {
root_signature: pipeline_def.root_signature.clone(),
pipeline_type: pipeline_def.root_signature.pipeline_type(),
pipeline,
topology: d3d::D3D_PRIMITIVE_TOPOLOGY_UNDEFINED,
vertex_buffer_strides: [0; crate::MAX_VERTEX_INPUT_BINDINGS],
};
if let Some(debug_name) = pipeline_def.debug_name {
pipeline.set_debug_name(debug_name)
}
Ok(pipeline)
}
}