rafx_api/extra/
indirect.rs

1//! Contents of this file are to help emulate gl_InstanceIndex on DX12. Metal follows the same
2//! convention as vulkan.
3
4use crate::{
5    RafxBuffer, RafxCommandBuffer, RafxDeviceContext, RafxDrawIndexedIndirectCommand,
6    RafxDrawIndirectCommand, RafxResult, RafxRootSignature, RafxShaderStageFlags,
7};
8
9#[cfg(feature = "rafx-dx12")]
10use windows::Win32::Graphics::Direct3D12 as d3d12;
11
12// In order to use indirect commands we need to create a command signature that is compatible
13// with the root signature that will be used
14#[cfg(feature = "rafx-dx12")]
15fn create_indirect_draw_with_push_constant_command_signature(
16    device: &d3d12::ID3D12Device,
17    root_signature: &d3d12::ID3D12RootSignature,
18    indexed: bool,
19) -> RafxResult<d3d12::ID3D12CommandSignature> {
20    let mut sig = d3d12::D3D12_COMMAND_SIGNATURE_DESC::default();
21    let mut draw_arg = d3d12::D3D12_INDIRECT_ARGUMENT_DESC::default();
22    let mut root_constant_arg = d3d12::D3D12_INDIRECT_ARGUMENT_DESC::default();
23
24    if !indexed {
25        draw_arg.Type = d3d12::D3D12_INDIRECT_ARGUMENT_TYPE_DRAW;
26        sig.ByteStride = std::mem::size_of::<RafxDrawIndirectCommand>() as u32 + 4;
27    } else {
28        draw_arg.Type = d3d12::D3D12_INDIRECT_ARGUMENT_TYPE_DRAW_INDEXED;
29        sig.ByteStride = std::mem::size_of::<RafxDrawIndexedIndirectCommand>() as u32 + 4;
30    }
31
32    root_constant_arg.Type = d3d12::D3D12_INDIRECT_ARGUMENT_TYPE_CONSTANT;
33    root_constant_arg.Anonymous.Constant.RootParameterIndex = 0;
34    root_constant_arg.Anonymous.Constant.DestOffsetIn32BitValues = 0;
35    root_constant_arg.Anonymous.Constant.Num32BitValuesToSet = 1;
36
37    sig.NumArgumentDescs = 2;
38    let args = [root_constant_arg, draw_arg];
39    sig.pArgumentDescs = args.as_ptr();
40
41    let mut result: Option<d3d12::ID3D12CommandSignature> = None;
42
43    unsafe {
44        device.CreateCommandSignature(&sig, root_signature, &mut result)?;
45    }
46
47    Ok(result.unwrap())
48}
49
50// Corresponds 1:1 with VkDrawIndirectCommand, MTLDrawPrimitivesIndirectArguments,
51// D3D12_DRAW_ARGUMENTS, but adds a push constant for DX12
52pub struct RafxDrawIndirectCommandWithPushConstant {
53    pub push_constant: u32,
54    pub command: RafxDrawIndirectCommand,
55}
56
57// Corresponds 1:1 with VkDrawIndexedIndirectCommand, MTLDrawIndexedPrimitivesIndirectArguments,
58// D3D12_DRAW_INDEXED_ARGUMENTS, but adds a push constant for DX12
59pub struct RafxDrawIndexedIndirectCommandWithPushConstant {
60    pub push_constant: u32,
61    pub command: RafxDrawIndexedIndirectCommand,
62}
63
64// Size of an indirect draw command compatible with the given device context
65pub fn indirect_command_size(_device_context: &RafxDeviceContext) -> u64 {
66    #[cfg(feature = "rafx-dx12")]
67    if _device_context.is_dx12() {
68        return std::mem::size_of::<RafxDrawIndirectCommand>() as u64 + 4;
69    }
70
71    std::mem::size_of::<RafxDrawIndirectCommand>() as u64
72}
73
74// Size of an indexed indirect draw command compatible with the given device context
75pub fn indexed_indirect_command_size(_device_context: &RafxDeviceContext) -> u64 {
76    #[cfg(feature = "rafx-dx12")]
77    if _device_context.is_dx12() {
78        return std::mem::size_of::<RafxDrawIndexedIndirectCommand>() as u64 + 4;
79    }
80
81    std::mem::size_of::<RafxDrawIndexedIndirectCommand>() as u64
82}
83
84//TODO: Support a non-indexed version of RafxIndexedIndirectCommandSignature and
85// RafxIndexedIndirectCommandEncoder
86
87/// A helper object for doing indirect draw on DX12/Metal/Vulkan in a compatible way. We supply a
88/// push constant on DX12 only to emulate gl_InstanceIndex on DX12. This helper object is mostly
89/// a no-op for vulkan/metal.
90#[derive(Clone)]
91pub struct RafxIndexedIndirectCommandSignature {
92    _root_signature: RafxRootSignature,
93    #[cfg(feature = "rafx-dx12")]
94    dx12_indirect_command_signature: Option<d3d12::ID3D12CommandSignature>,
95}
96
97impl RafxIndexedIndirectCommandSignature {
98    pub fn new(
99        root_signature: &RafxRootSignature,
100        _shader_flags: RafxShaderStageFlags,
101    ) -> RafxResult<Self> {
102        #[cfg(feature = "rafx-dx12")]
103        if let Some(root_signature_dx12) = root_signature.dx12_root_signature() {
104            let descriptor = root_signature_dx12.find_push_constant_descriptor(_shader_flags).ok_or_else(|| crate::RafxError::StringError(format!(
105                "Tried to create a RafxIndexedIndirectCommandSignature for shader flags {:?} but no push constants were found",
106                _shader_flags
107            )))?;
108
109            let command_signature = create_indirect_draw_with_push_constant_command_signature(
110                root_signature_dx12.device_context().d3d12_device(),
111                root_signature_dx12.dx12_root_signature(),
112                true,
113            )?;
114
115            return Ok(RafxIndexedIndirectCommandSignature {
116                _root_signature: root_signature.clone(),
117                dx12_indirect_command_signature: Some(command_signature),
118            });
119        }
120
121        Ok(RafxIndexedIndirectCommandSignature {
122            _root_signature: root_signature.clone(),
123            #[cfg(feature = "rafx-dx12")]
124            dx12_indirect_command_signature: None,
125        })
126    }
127
128    // equivalent to cmd_draw_indexed_indirect
129    pub fn draw_indexed_indirect(
130        &self,
131        command_buffer: &RafxCommandBuffer,
132        indirect_buffer: &RafxBuffer,
133        indirect_buffer_offset_in_bytes: u32,
134        draw_count: u32,
135    ) -> RafxResult<()> {
136        // Special DX12 path
137        #[cfg(feature = "rafx-dx12")]
138        if let Some(dx12_command_buffer) = command_buffer.dx12_command_buffer() {
139            let command_list = dx12_command_buffer.dx12_graphics_command_list();
140            unsafe {
141                let command_signature = self.dx12_indirect_command_signature.as_ref().unwrap();
142                assert!(
143                    indirect_buffer.buffer_def().size as u32 - indirect_buffer_offset_in_bytes
144                        >= 24 * draw_count
145                );
146
147                command_list.ExecuteIndirect(
148                    command_signature,
149                    draw_count,
150                    indirect_buffer.dx12_buffer().unwrap().dx12_resource(),
151                    indirect_buffer_offset_in_bytes as u64,
152                    None,
153                    0,
154                );
155            }
156
157            return Ok(());
158        }
159
160        // Path for non-DX12
161        command_buffer.cmd_draw_indexed_indirect(
162            indirect_buffer,
163            indirect_buffer_offset_in_bytes,
164            draw_count,
165        )
166    }
167}
168
169/// Helper object for writing indirect draws into a buffer. Abstracts over DX12 requiring an
170/// extra 4 bytes to set a push constant
171pub struct RafxIndexedIndirectCommandEncoder<'a> {
172    // We keep a ref to the buffer because we write into mapped memory behind a cached pointer
173    _buffer: &'a RafxBuffer,
174    #[cfg(feature = "rafx-dx12")]
175    is_dx12: bool,
176    mapped_memory: *mut u8,
177    command_count: usize,
178}
179
180impl<'a> RafxIndexedIndirectCommandEncoder<'a> {
181    pub fn new(buffer: &'a RafxBuffer) -> Self {
182        #[cfg(not(feature = "rafx-dx12"))]
183        let is_dx12 = false;
184
185        #[cfg(feature = "rafx-dx12")]
186        let is_dx12 = buffer.dx12_buffer().is_some();
187
188        let command_size = if is_dx12 {
189            std::mem::size_of::<RafxDrawIndexedIndirectCommand>() + 4
190        } else {
191            std::mem::size_of::<RafxDrawIndexedIndirectCommand>()
192        };
193
194        let command_count = (buffer.buffer_def().size as usize / command_size) as usize;
195        RafxIndexedIndirectCommandEncoder {
196            _buffer: buffer,
197            #[cfg(feature = "rafx-dx12")]
198            is_dx12,
199            mapped_memory: buffer.mapped_memory().unwrap(),
200            command_count,
201        }
202    }
203
204    pub fn set_command(
205        &self,
206        index: usize,
207        command: RafxDrawIndexedIndirectCommand,
208    ) {
209        assert!(index < self.command_count);
210        unsafe {
211            #[cfg(feature = "rafx-dx12")]
212            if self.is_dx12 {
213                let mut ptr =
214                    self.mapped_memory as *mut RafxDrawIndexedIndirectCommandWithPushConstant;
215                let push_constant = command.first_instance;
216                *ptr.add(index) = RafxDrawIndexedIndirectCommandWithPushConstant {
217                    command,
218                    push_constant,
219                };
220
221                return;
222            }
223
224            // If we don't have the special dx12 case, use the default command type
225            let ptr = self.mapped_memory as *mut RafxDrawIndexedIndirectCommand;
226            *ptr.add(index) = command;
227        }
228    }
229}