1use super::*;
4use crate::*; use std::ptr;
6
7pub struct CommandBuilder {
13 context: ComputeContext,
14 pipeline: Pipeline,
15 command_buffer: VkCommandBuffer,
16 descriptor_set: Option<VkDescriptorSet>,
17 bindings: Vec<(u32, Buffer)>,
18 push_constants: Vec<u8>,
19 workgroups: (u32, u32, u32),
20}
21
22impl ComputeContext {
23 pub fn dispatch(&self, pipeline: &Pipeline) -> CommandBuilder {
25 CommandBuilder {
26 context: self.clone(),
27 pipeline: Pipeline {
28 context: pipeline.context.clone(),
29 pipeline: pipeline.pipeline,
30 layout: pipeline.layout,
31 descriptor_set_layout: pipeline.descriptor_set_layout,
32 },
33 command_buffer: VkCommandBuffer::NULL,
34 descriptor_set: None,
35 bindings: Vec::new(),
36 push_constants: Vec::new(),
37 workgroups: (1, 1, 1),
38 }
39 }
40}
41
42impl CommandBuilder {
43 pub fn bind_buffer(mut self, binding: u32, buffer: &Buffer) -> Self {
45 self.bindings.push((binding, Buffer {
46 context: buffer.context.clone(),
47 buffer: buffer.buffer,
48 memory: buffer.memory,
49 size: buffer.size,
50 usage: buffer.usage,
51 _marker: std::marker::PhantomData,
52 }));
53 self
54 }
55
56 pub fn push_constants<T: Copy>(mut self, data: &T) -> Self {
58 let bytes = unsafe {
59 std::slice::from_raw_parts(
60 data as *const T as *const u8,
61 std::mem::size_of::<T>(),
62 )
63 };
64 self.push_constants = bytes.to_vec();
65 self
66 }
67
68 pub fn workgroups(mut self, x: u32, y: u32, z: u32) -> Self {
70 self.workgroups = (x, y, z);
71 self
72 }
73
74 pub fn execute(mut self) -> Result<()> {
76 unsafe {
77 self.context.with_inner(|inner| {
78 let alloc_info = VkCommandBufferAllocateInfo {
80 sType: VkStructureType::CommandBufferAllocateInfo,
81 pNext: ptr::null(),
82 commandPool: inner.command_pool,
83 level: VkCommandBufferLevel::Primary,
84 commandBufferCount: 1,
85 };
86
87 vkAllocateCommandBuffers(inner.device, &alloc_info, &mut self.command_buffer);
88
89 let begin_info = VkCommandBufferBeginInfo {
91 sType: VkStructureType::CommandBufferBeginInfo,
92 pNext: ptr::null(),
93 flags: VkCommandBufferUsageFlags::ONE_TIME_SUBMIT,
94 pInheritanceInfo: ptr::null(),
95 };
96
97 let result = vkBeginCommandBuffer(self.command_buffer, &begin_info);
98 if result != VkResult::Success {
99 return Err(KronosError::from(result));
100 }
101
102 if !self.bindings.is_empty() {
104 let alloc_info = VkDescriptorSetAllocateInfo {
106 sType: VkStructureType::DescriptorSetAllocateInfo,
107 pNext: ptr::null(),
108 descriptorPool: inner.descriptor_pool,
109 descriptorSetCount: 1,
110 pSetLayouts: &self.pipeline.descriptor_set_layout,
111 };
112
113 let mut descriptor_set = VkDescriptorSet::NULL;
114 let result = vkAllocateDescriptorSets(inner.device, &alloc_info, &mut descriptor_set);
115 if result != VkResult::Success {
116 return Err(KronosError::from(result));
117 }
118
119 self.descriptor_set = Some(descriptor_set);
120
121 let buffer_infos: Vec<VkDescriptorBufferInfo> = self.bindings.iter().map(|(_, buffer)| {
123 VkDescriptorBufferInfo {
124 buffer: buffer.buffer,
125 offset: 0,
126 range: buffer.size as VkDeviceSize,
127 }
128 }).collect();
129
130 let writes: Vec<VkWriteDescriptorSet> = self.bindings.iter().enumerate().map(|(i, (binding, _))| {
131 VkWriteDescriptorSet {
132 sType: VkStructureType::WriteDescriptorSet,
133 pNext: ptr::null(),
134 dstSet: descriptor_set,
135 dstBinding: *binding,
136 dstArrayElement: 0,
137 descriptorCount: 1,
138 descriptorType: VkDescriptorType::StorageBuffer,
139 pImageInfo: ptr::null(),
140 pBufferInfo: &buffer_infos[i],
141 pTexelBufferView: ptr::null(),
142 }
143 }).collect();
144
145 vkUpdateDescriptorSets(inner.device, writes.len() as u32, writes.as_ptr(), 0, ptr::null());
146 }
147
148 let barriers: Vec<VkBufferMemoryBarrier> = self.bindings.iter().map(|(_, buffer)| {
151 VkBufferMemoryBarrier {
152 sType: VkStructureType::BufferMemoryBarrier,
153 pNext: ptr::null(),
154 srcAccessMask: VkAccessFlags::TRANSFER_WRITE,
155 dstAccessMask: VkAccessFlags::SHADER_READ | VkAccessFlags::SHADER_WRITE,
156 srcQueueFamilyIndex: VK_QUEUE_FAMILY_IGNORED,
157 dstQueueFamilyIndex: VK_QUEUE_FAMILY_IGNORED,
158 buffer: buffer.buffer,
159 offset: 0,
160 size: buffer.size as VkDeviceSize,
161 }
162 }).collect();
163
164 if !barriers.is_empty() {
165 vkCmdPipelineBarrier(
166 self.command_buffer,
167 VkPipelineStageFlags::TOP_OF_PIPE,
168 VkPipelineStageFlags::COMPUTE_SHADER,
169 VkDependencyFlags::empty(),
170 0,
171 ptr::null(),
172 barriers.len() as u32,
173 barriers.as_ptr(),
174 0,
175 ptr::null(),
176 );
177 }
178
179 vkCmdBindPipeline(self.command_buffer, VkPipelineBindPoint::Compute, self.pipeline.pipeline);
181
182 if let Some(descriptor_set) = self.descriptor_set {
184 vkCmdBindDescriptorSets(
185 self.command_buffer,
186 VkPipelineBindPoint::Compute,
187 self.pipeline.layout,
188 0,
189 1,
190 &descriptor_set,
191 0,
192 ptr::null(),
193 );
194 }
195
196 if !self.push_constants.is_empty() {
198 vkCmdPushConstants(
199 self.command_buffer,
200 self.pipeline.layout,
201 VkShaderStageFlags::COMPUTE,
202 0,
203 self.push_constants.len() as u32,
204 self.push_constants.as_ptr() as *const _,
205 );
206 }
207
208 vkCmdDispatch(self.command_buffer, self.workgroups.0, self.workgroups.1, self.workgroups.2);
210
211 let result = vkEndCommandBuffer(self.command_buffer);
213 if result != VkResult::Success {
214 return Err(KronosError::from(result));
215 }
216
217 let submit_info = VkSubmitInfo {
219 sType: VkStructureType::SubmitInfo,
220 pNext: ptr::null(),
221 waitSemaphoreCount: 0,
222 pWaitSemaphores: ptr::null(),
223 pWaitDstStageMask: ptr::null(),
224 commandBufferCount: 1,
225 pCommandBuffers: &self.command_buffer,
226 signalSemaphoreCount: 0,
227 pSignalSemaphores: ptr::null(),
228 };
229
230 let result = vkQueueSubmit(inner.queue, 1, &submit_info, VkFence::NULL);
231 if result != VkResult::Success {
232 return Err(KronosError::CommandExecutionFailed(
233 format!("vkQueueSubmit failed: {:?}", result)
234 ));
235 }
236
237 vkQueueWaitIdle(inner.queue);
239
240 vkFreeCommandBuffers(inner.device, inner.command_pool, 1, &self.command_buffer);
242
243 if let Some(descriptor_set) = self.descriptor_set {
245 vkFreeDescriptorSets(inner.device, inner.descriptor_pool, 1, &descriptor_set);
246 }
247
248 Ok(())
249 })
250 }
251 }
252}