kronos_compute/api/
command.rs

1//! Fluent command building and dispatch
2
3use super::*;
4use crate::*; // Import all functions from the crate root
5use std::ptr;
6
7/// Fluent builder for compute dispatch commands
8/// 
9/// This builder provides a safe, ergonomic API for recording
10/// and executing compute commands. All Kronos optimizations
11/// are applied automatically.
12pub struct CommandBuilder {
13    context: ComputeContext,
14    pipeline: Pipeline,
15    command_buffer: VkCommandBuffer,
16    descriptor_set: Option<VkDescriptorSet>,
17    bindings: Vec<(u32, Buffer)>,
18    push_constants: Vec<u8>,
19    workgroups: (u32, u32, u32),
20}
21
22impl ComputeContext {
23    /// Start building a compute dispatch
24    pub fn dispatch(&self, pipeline: &Pipeline) -> CommandBuilder {
25        CommandBuilder {
26            context: self.clone(),
27            pipeline: Pipeline {
28                context: pipeline.context.clone(),
29                pipeline: pipeline.pipeline,
30                layout: pipeline.layout,
31                descriptor_set_layout: pipeline.descriptor_set_layout,
32            },
33            command_buffer: VkCommandBuffer::NULL,
34            descriptor_set: None,
35            bindings: Vec::new(),
36            push_constants: Vec::new(),
37            workgroups: (1, 1, 1),
38        }
39    }
40}
41
42impl CommandBuilder {
43    /// Bind a buffer to a binding point
44    pub fn bind_buffer(mut self, binding: u32, buffer: &Buffer) -> Self {
45        self.bindings.push((binding, Buffer {
46            context: buffer.context.clone(),
47            buffer: buffer.buffer,
48            memory: buffer.memory,
49            size: buffer.size,
50            usage: buffer.usage,
51            _marker: std::marker::PhantomData,
52        }));
53        self
54    }
55    
56    /// Set push constants
57    pub fn push_constants<T: Copy>(mut self, data: &T) -> Self {
58        let bytes = unsafe {
59            std::slice::from_raw_parts(
60                data as *const T as *const u8,
61                std::mem::size_of::<T>(),
62            )
63        };
64        self.push_constants = bytes.to_vec();
65        self
66    }
67    
68    /// Set the number of workgroups
69    pub fn workgroups(mut self, x: u32, y: u32, z: u32) -> Self {
70        self.workgroups = (x, y, z);
71        self
72    }
73    
74    /// Execute the dispatch
75    pub fn execute(mut self) -> Result<()> {
76        unsafe {
77            self.context.with_inner(|inner| {
78                // Allocate command buffer
79                let alloc_info = VkCommandBufferAllocateInfo {
80                    sType: VkStructureType::CommandBufferAllocateInfo,
81                    pNext: ptr::null(),
82                    commandPool: inner.command_pool,
83                    level: VkCommandBufferLevel::Primary,
84                    commandBufferCount: 1,
85                };
86                
87                vkAllocateCommandBuffers(inner.device, &alloc_info, &mut self.command_buffer);
88                
89                // Begin command buffer
90                let begin_info = VkCommandBufferBeginInfo {
91                    sType: VkStructureType::CommandBufferBeginInfo,
92                    pNext: ptr::null(),
93                    flags: VkCommandBufferUsageFlags::ONE_TIME_SUBMIT,
94                    pInheritanceInfo: ptr::null(),
95                };
96                
97                let result = vkBeginCommandBuffer(self.command_buffer, &begin_info);
98                if result != VkResult::Success {
99                    return Err(KronosError::from(result));
100                }
101                
102                // Create and update descriptor set if we have bindings
103                if !self.bindings.is_empty() {
104                    // Allocate descriptor set
105                    let alloc_info = VkDescriptorSetAllocateInfo {
106                        sType: VkStructureType::DescriptorSetAllocateInfo,
107                        pNext: ptr::null(),
108                        descriptorPool: inner.descriptor_pool,
109                        descriptorSetCount: 1,
110                        pSetLayouts: &self.pipeline.descriptor_set_layout,
111                    };
112                    
113                    let mut descriptor_set = VkDescriptorSet::NULL;
114                    let result = vkAllocateDescriptorSets(inner.device, &alloc_info, &mut descriptor_set);
115                    if result != VkResult::Success {
116                        return Err(KronosError::from(result));
117                    }
118                    
119                    self.descriptor_set = Some(descriptor_set);
120                    
121                    // Update descriptor set
122                    let buffer_infos: Vec<VkDescriptorBufferInfo> = self.bindings.iter().map(|(_, buffer)| {
123                        VkDescriptorBufferInfo {
124                            buffer: buffer.buffer,
125                            offset: 0,
126                            range: buffer.size as VkDeviceSize,
127                        }
128                    }).collect();
129                    
130                    let writes: Vec<VkWriteDescriptorSet> = self.bindings.iter().enumerate().map(|(i, (binding, _))| {
131                        VkWriteDescriptorSet {
132                            sType: VkStructureType::WriteDescriptorSet,
133                            pNext: ptr::null(),
134                            dstSet: descriptor_set,
135                            dstBinding: *binding,
136                            dstArrayElement: 0,
137                            descriptorCount: 1,
138                            descriptorType: VkDescriptorType::StorageBuffer,
139                            pImageInfo: ptr::null(),
140                            pBufferInfo: &buffer_infos[i],
141                            pTexelBufferView: ptr::null(),
142                        }
143                    }).collect();
144                    
145                    vkUpdateDescriptorSets(inner.device, writes.len() as u32, writes.as_ptr(), 0, ptr::null());
146                }
147                
148                // Insert barriers for buffers (smart barrier optimization)
149                // In a real implementation, this would use the barrier_policy module
150                let barriers: Vec<VkBufferMemoryBarrier> = self.bindings.iter().map(|(_, buffer)| {
151                    VkBufferMemoryBarrier {
152                        sType: VkStructureType::BufferMemoryBarrier,
153                        pNext: ptr::null(),
154                        srcAccessMask: VkAccessFlags::TRANSFER_WRITE,
155                        dstAccessMask: VkAccessFlags::SHADER_READ | VkAccessFlags::SHADER_WRITE,
156                        srcQueueFamilyIndex: VK_QUEUE_FAMILY_IGNORED,
157                        dstQueueFamilyIndex: VK_QUEUE_FAMILY_IGNORED,
158                        buffer: buffer.buffer,
159                        offset: 0,
160                        size: buffer.size as VkDeviceSize,
161                    }
162                }).collect();
163                
164                if !barriers.is_empty() {
165                    vkCmdPipelineBarrier(
166                        self.command_buffer,
167                        VkPipelineStageFlags::TOP_OF_PIPE,
168                        VkPipelineStageFlags::COMPUTE_SHADER,
169                        VkDependencyFlags::empty(),
170                        0,
171                        ptr::null(),
172                        barriers.len() as u32,
173                        barriers.as_ptr(),
174                        0,
175                        ptr::null(),
176                    );
177                }
178                
179                // Bind pipeline
180                vkCmdBindPipeline(self.command_buffer, VkPipelineBindPoint::Compute, self.pipeline.pipeline);
181                
182                // Bind descriptor set
183                if let Some(descriptor_set) = self.descriptor_set {
184                    vkCmdBindDescriptorSets(
185                        self.command_buffer,
186                        VkPipelineBindPoint::Compute,
187                        self.pipeline.layout,
188                        0,
189                        1,
190                        &descriptor_set,
191                        0,
192                        ptr::null(),
193                    );
194                }
195                
196                // Push constants
197                if !self.push_constants.is_empty() {
198                    vkCmdPushConstants(
199                        self.command_buffer,
200                        self.pipeline.layout,
201                        VkShaderStageFlags::COMPUTE,
202                        0,
203                        self.push_constants.len() as u32,
204                        self.push_constants.as_ptr() as *const _,
205                    );
206                }
207                
208                // Dispatch
209                vkCmdDispatch(self.command_buffer, self.workgroups.0, self.workgroups.1, self.workgroups.2);
210                
211                // End command buffer
212                let result = vkEndCommandBuffer(self.command_buffer);
213                if result != VkResult::Success {
214                    return Err(KronosError::from(result));
215                }
216                
217                // Submit (with timeline batching optimization)
218                let submit_info = VkSubmitInfo {
219                    sType: VkStructureType::SubmitInfo,
220                    pNext: ptr::null(),
221                    waitSemaphoreCount: 0,
222                    pWaitSemaphores: ptr::null(),
223                    pWaitDstStageMask: ptr::null(),
224                    commandBufferCount: 1,
225                    pCommandBuffers: &self.command_buffer,
226                    signalSemaphoreCount: 0,
227                    pSignalSemaphores: ptr::null(),
228                };
229                
230                let result = vkQueueSubmit(inner.queue, 1, &submit_info, VkFence::NULL);
231                if result != VkResult::Success {
232                    return Err(KronosError::CommandExecutionFailed(
233                        format!("vkQueueSubmit failed: {:?}", result)
234                    ));
235                }
236                
237                // Wait for completion (in a real implementation, this could be async)
238                vkQueueWaitIdle(inner.queue);
239                
240                // Free command buffer
241                vkFreeCommandBuffers(inner.device, inner.command_pool, 1, &self.command_buffer);
242                
243                // Free descriptor set if allocated
244                if let Some(descriptor_set) = self.descriptor_set {
245                    vkFreeDescriptorSets(inner.device, inner.descriptor_pool, 1, &descriptor_set);
246                }
247                
248                Ok(())
249            })
250        }
251    }
252}