kronos_compute/api/
pipeline.rs

1//! Pipeline and shader management
2
3use super::*;
4use crate::*; // Import all functions from the crate root
5use std::ffi::CString;
6use std::fs;
7use std::path::Path;
8use std::ptr;
9
10/// Compiled shader module
11pub struct Shader {
12    context: ComputeContext,
13    module: VkShaderModule,
14}
15
16// Send + Sync for thread safety
17unsafe impl Send for Shader {}
18unsafe impl Sync for Shader {}
19
20/// Compute pipeline with shader and layout
21pub struct Pipeline {
22    pub(super) context: ComputeContext,
23    pub(super) pipeline: VkPipeline,
24    pub(super) layout: VkPipelineLayout,
25    pub(super) descriptor_set_layout: VkDescriptorSetLayout,
26}
27
28// Send + Sync for thread safety  
29unsafe impl Send for Pipeline {}
30unsafe impl Sync for Pipeline {}
31
32/// Information about buffer bindings for a pipeline
33#[derive(Debug, Clone)]
34pub struct BufferBinding {
35    pub binding: u32,
36    pub descriptor_type: VkDescriptorType,
37}
38
39impl Default for BufferBinding {
40    fn default() -> Self {
41        Self {
42            binding: 0,
43            descriptor_type: VkDescriptorType::StorageBuffer,
44        }
45    }
46}
47
48/// Pipeline configuration
49pub struct PipelineConfig {
50    /// Entry point name (default: "main")
51    pub entry_point: String,
52    /// Local workgroup size (x, y, z)
53    pub local_size: (u32, u32, u32),
54    /// Buffer bindings
55    pub bindings: Vec<BufferBinding>,
56    /// Push constant size in bytes (max 128)
57    pub push_constant_size: u32,
58}
59
60impl Default for PipelineConfig {
61    fn default() -> Self {
62        Self {
63            entry_point: "main".to_string(),
64            local_size: (64, 1, 1),
65            bindings: Vec::new(),
66            push_constant_size: 0,
67        }
68    }
69}
70
71impl ComputeContext {
72    /// Load a shader from SPIR-V file
73    pub fn load_shader<P: AsRef<Path>>(&self, path: P) -> Result<Shader> {
74        let spv_data = fs::read(path)
75            .map_err(|e| KronosError::ShaderCompilationFailed(format!("Failed to read shader file: {}", e)))?;
76        
77        self.create_shader_from_spirv(&spv_data)
78    }
79    
80    /// Create a shader from SPIR-V bytes
81    pub fn create_shader_from_spirv(&self, spirv: &[u8]) -> Result<Shader> {
82        if spirv.len() % 4 != 0 {
83            return Err(KronosError::ShaderCompilationFailed(
84                "SPIR-V data must be 4-byte aligned".into()
85            ));
86        }
87        
88        unsafe {
89            self.with_inner(|inner| {
90                let create_info = VkShaderModuleCreateInfo {
91                    sType: VkStructureType::ShaderModuleCreateInfo,
92                    pNext: ptr::null(),
93                    flags: 0,
94                    codeSize: spirv.len(),
95                    pCode: spirv.as_ptr() as *const u32,
96                };
97                
98                let mut module = VkShaderModule::NULL;
99                let result = vkCreateShaderModule(inner.device, &create_info, ptr::null(), &mut module);
100                
101                if result != VkResult::Success {
102                    return Err(KronosError::ShaderCompilationFailed(
103                        format!("vkCreateShaderModule failed: {:?}", result)
104                    ));
105                }
106                
107                Ok(Shader {
108                    context: self.clone(),
109                    module,
110                })
111            })
112        }
113    }
114    
115    /// Create a compute pipeline with default configuration
116    pub fn create_pipeline(&self, shader: &Shader) -> Result<Pipeline> {
117        self.create_pipeline_with_config(shader, PipelineConfig::default())
118    }
119    
120    /// Create a compute pipeline with custom configuration
121    pub fn create_pipeline_with_config(&self, shader: &Shader, config: PipelineConfig) -> Result<Pipeline> {
122        if config.push_constant_size > 128 {
123            return Err(KronosError::ShaderCompilationFailed(
124                format!("Push constant size {} exceeds maximum 128 bytes", config.push_constant_size)
125            ));
126        }
127        
128        unsafe {
129            self.with_inner(|inner| {
130                // Create descriptor set layout for Set0 (persistent descriptors)
131                let bindings: Vec<VkDescriptorSetLayoutBinding> = config.bindings.iter().map(|b| {
132                    VkDescriptorSetLayoutBinding {
133                        binding: b.binding,
134                        descriptorType: b.descriptor_type,
135                        descriptorCount: 1,
136                        stageFlags: VkShaderStageFlags::COMPUTE,
137                        pImmutableSamplers: ptr::null(),
138                    }
139                }).collect();
140                
141                let layout_info = VkDescriptorSetLayoutCreateInfo {
142                    sType: VkStructureType::DescriptorSetLayoutCreateInfo,
143                    pNext: ptr::null(),
144                    flags: 0,
145                    bindingCount: bindings.len() as u32,
146                    pBindings: if bindings.is_empty() { ptr::null() } else { bindings.as_ptr() },
147                };
148                
149                let mut descriptor_set_layout = VkDescriptorSetLayout::NULL;
150                let result = vkCreateDescriptorSetLayout(inner.device, &layout_info, ptr::null(), &mut descriptor_set_layout);
151                
152                if result != VkResult::Success {
153                    return Err(KronosError::from(result));
154                }
155                
156                // Create pipeline layout
157                let push_constant_range = if config.push_constant_size > 0 {
158                    Some(VkPushConstantRange {
159                        stageFlags: VkShaderStageFlags::COMPUTE,
160                        offset: 0,
161                        size: config.push_constant_size,
162                    })
163                } else {
164                    None
165                };
166                
167                let pipeline_layout_info = VkPipelineLayoutCreateInfo {
168                    sType: VkStructureType::PipelineLayoutCreateInfo,
169                    pNext: ptr::null(),
170                    flags: 0,
171                    setLayoutCount: 1,
172                    pSetLayouts: &descriptor_set_layout,
173                    pushConstantRangeCount: if push_constant_range.is_some() { 1 } else { 0 },
174                    pPushConstantRanges: push_constant_range.as_ref().map_or(ptr::null(), |r| r as *const _),
175                };
176                
177                let mut pipeline_layout = VkPipelineLayout::NULL;
178                let result = vkCreatePipelineLayout(inner.device, &pipeline_layout_info, ptr::null(), &mut pipeline_layout);
179                
180                if result != VkResult::Success {
181                    vkDestroyDescriptorSetLayout(inner.device, descriptor_set_layout, ptr::null());
182                    return Err(KronosError::from(result));
183                }
184                
185                // Create compute pipeline
186                let entry_point = CString::new(config.entry_point.clone())
187                    .map_err(|_| KronosError::ShaderCompilationFailed("Invalid entry point name".into()))?;
188                
189                let stage_info = VkPipelineShaderStageCreateInfo {
190                    sType: VkStructureType::PipelineShaderStageCreateInfo,
191                    pNext: ptr::null(),
192                    flags: VkPipelineShaderStageCreateFlags::empty(),
193                    stage: VkShaderStageFlagBits::Compute,
194                    module: shader.module,
195                    pName: entry_point.as_ptr(),
196                    pSpecializationInfo: ptr::null(),
197                };
198                
199                let pipeline_info = VkComputePipelineCreateInfo {
200                    sType: VkStructureType::ComputePipelineCreateInfo,
201                    pNext: ptr::null(),
202                    flags: VkPipelineCreateFlags::empty(),
203                    stage: stage_info,
204                    layout: pipeline_layout,
205                    basePipelineHandle: VkPipeline::NULL,
206                    basePipelineIndex: -1,
207                };
208                
209                let mut pipeline = VkPipeline::NULL;
210                let result = vkCreateComputePipelines(
211                    inner.device,
212                    VkPipelineCache::NULL,
213                    1,
214                    &pipeline_info,
215                    ptr::null(),
216                    &mut pipeline,
217                );
218                
219                if result != VkResult::Success {
220                    vkDestroyPipelineLayout(inner.device, pipeline_layout, ptr::null());
221                    vkDestroyDescriptorSetLayout(inner.device, descriptor_set_layout, ptr::null());
222                    return Err(KronosError::from(result));
223                }
224                
225                Ok(Pipeline {
226                    context: self.clone(),
227                    pipeline,
228                    layout: pipeline_layout,
229                    descriptor_set_layout,
230                })
231            })
232        }
233    }
234}
235
236impl Pipeline {
237    /// Get the raw Vulkan pipeline handle (for advanced usage)
238    pub fn raw(&self) -> VkPipeline {
239        self.pipeline
240    }
241    
242    /// Get the pipeline layout
243    pub fn layout(&self) -> VkPipelineLayout {
244        self.layout
245    }
246    
247    /// Get the descriptor set layout
248    pub fn descriptor_set_layout(&self) -> VkDescriptorSetLayout {
249        self.descriptor_set_layout
250    }
251}
252
253impl Drop for Shader {
254    fn drop(&mut self) {
255        unsafe {
256            self.context.with_inner(|inner| {
257                vkDestroyShaderModule(inner.device, self.module, ptr::null());
258            });
259        }
260    }
261}
262
263impl Drop for Pipeline {
264    fn drop(&mut self) {
265        unsafe {
266            self.context.with_inner(|inner| {
267                vkDestroyPipeline(inner.device, self.pipeline, ptr::null());
268                vkDestroyPipelineLayout(inner.device, self.layout, ptr::null());
269                vkDestroyDescriptorSetLayout(inner.device, self.descriptor_set_layout, ptr::null());
270            });
271        }
272    }
273}