1use super::*;
4use crate::*; use std::ffi::CString;
6use std::fs;
7use std::path::Path;
8use std::ptr;
9
10pub struct Shader {
12 context: ComputeContext,
13 module: VkShaderModule,
14}
15
16unsafe impl Send for Shader {}
18unsafe impl Sync for Shader {}
19
20pub struct Pipeline {
22 pub(super) context: ComputeContext,
23 pub(super) pipeline: VkPipeline,
24 pub(super) layout: VkPipelineLayout,
25 pub(super) descriptor_set_layout: VkDescriptorSetLayout,
26}
27
28unsafe impl Send for Pipeline {}
30unsafe impl Sync for Pipeline {}
31
32#[derive(Debug, Clone)]
34pub struct BufferBinding {
35 pub binding: u32,
36 pub descriptor_type: VkDescriptorType,
37}
38
39impl Default for BufferBinding {
40 fn default() -> Self {
41 Self {
42 binding: 0,
43 descriptor_type: VkDescriptorType::StorageBuffer,
44 }
45 }
46}
47
48pub struct PipelineConfig {
50 pub entry_point: String,
52 pub local_size: (u32, u32, u32),
54 pub bindings: Vec<BufferBinding>,
56 pub push_constant_size: u32,
58}
59
60impl Default for PipelineConfig {
61 fn default() -> Self {
62 Self {
63 entry_point: "main".to_string(),
64 local_size: (64, 1, 1),
65 bindings: Vec::new(),
66 push_constant_size: 0,
67 }
68 }
69}
70
71impl ComputeContext {
72 pub fn load_shader<P: AsRef<Path>>(&self, path: P) -> Result<Shader> {
74 let spv_data = fs::read(path)
75 .map_err(|e| KronosError::ShaderCompilationFailed(format!("Failed to read shader file: {}", e)))?;
76
77 self.create_shader_from_spirv(&spv_data)
78 }
79
80 pub fn create_shader_from_spirv(&self, spirv: &[u8]) -> Result<Shader> {
82 if spirv.len() % 4 != 0 {
83 return Err(KronosError::ShaderCompilationFailed(
84 "SPIR-V data must be 4-byte aligned".into()
85 ));
86 }
87
88 unsafe {
89 self.with_inner(|inner| {
90 let create_info = VkShaderModuleCreateInfo {
91 sType: VkStructureType::ShaderModuleCreateInfo,
92 pNext: ptr::null(),
93 flags: 0,
94 codeSize: spirv.len(),
95 pCode: spirv.as_ptr() as *const u32,
96 };
97
98 let mut module = VkShaderModule::NULL;
99 let result = vkCreateShaderModule(inner.device, &create_info, ptr::null(), &mut module);
100
101 if result != VkResult::Success {
102 return Err(KronosError::ShaderCompilationFailed(
103 format!("vkCreateShaderModule failed: {:?}", result)
104 ));
105 }
106
107 Ok(Shader {
108 context: self.clone(),
109 module,
110 })
111 })
112 }
113 }
114
115 pub fn create_pipeline(&self, shader: &Shader) -> Result<Pipeline> {
117 self.create_pipeline_with_config(shader, PipelineConfig::default())
118 }
119
120 pub fn create_pipeline_with_config(&self, shader: &Shader, config: PipelineConfig) -> Result<Pipeline> {
122 if config.push_constant_size > 128 {
123 return Err(KronosError::ShaderCompilationFailed(
124 format!("Push constant size {} exceeds maximum 128 bytes", config.push_constant_size)
125 ));
126 }
127
128 unsafe {
129 self.with_inner(|inner| {
130 let bindings: Vec<VkDescriptorSetLayoutBinding> = config.bindings.iter().map(|b| {
132 VkDescriptorSetLayoutBinding {
133 binding: b.binding,
134 descriptorType: b.descriptor_type,
135 descriptorCount: 1,
136 stageFlags: VkShaderStageFlags::COMPUTE,
137 pImmutableSamplers: ptr::null(),
138 }
139 }).collect();
140
141 let layout_info = VkDescriptorSetLayoutCreateInfo {
142 sType: VkStructureType::DescriptorSetLayoutCreateInfo,
143 pNext: ptr::null(),
144 flags: 0,
145 bindingCount: bindings.len() as u32,
146 pBindings: if bindings.is_empty() { ptr::null() } else { bindings.as_ptr() },
147 };
148
149 let mut descriptor_set_layout = VkDescriptorSetLayout::NULL;
150 let result = vkCreateDescriptorSetLayout(inner.device, &layout_info, ptr::null(), &mut descriptor_set_layout);
151
152 if result != VkResult::Success {
153 return Err(KronosError::from(result));
154 }
155
156 let push_constant_range = if config.push_constant_size > 0 {
158 Some(VkPushConstantRange {
159 stageFlags: VkShaderStageFlags::COMPUTE,
160 offset: 0,
161 size: config.push_constant_size,
162 })
163 } else {
164 None
165 };
166
167 let pipeline_layout_info = VkPipelineLayoutCreateInfo {
168 sType: VkStructureType::PipelineLayoutCreateInfo,
169 pNext: ptr::null(),
170 flags: 0,
171 setLayoutCount: 1,
172 pSetLayouts: &descriptor_set_layout,
173 pushConstantRangeCount: if push_constant_range.is_some() { 1 } else { 0 },
174 pPushConstantRanges: push_constant_range.as_ref().map_or(ptr::null(), |r| r as *const _),
175 };
176
177 let mut pipeline_layout = VkPipelineLayout::NULL;
178 let result = vkCreatePipelineLayout(inner.device, &pipeline_layout_info, ptr::null(), &mut pipeline_layout);
179
180 if result != VkResult::Success {
181 vkDestroyDescriptorSetLayout(inner.device, descriptor_set_layout, ptr::null());
182 return Err(KronosError::from(result));
183 }
184
185 let entry_point = CString::new(config.entry_point.clone())
187 .map_err(|_| KronosError::ShaderCompilationFailed("Invalid entry point name".into()))?;
188
189 let stage_info = VkPipelineShaderStageCreateInfo {
190 sType: VkStructureType::PipelineShaderStageCreateInfo,
191 pNext: ptr::null(),
192 flags: VkPipelineShaderStageCreateFlags::empty(),
193 stage: VkShaderStageFlagBits::Compute,
194 module: shader.module,
195 pName: entry_point.as_ptr(),
196 pSpecializationInfo: ptr::null(),
197 };
198
199 let pipeline_info = VkComputePipelineCreateInfo {
200 sType: VkStructureType::ComputePipelineCreateInfo,
201 pNext: ptr::null(),
202 flags: VkPipelineCreateFlags::empty(),
203 stage: stage_info,
204 layout: pipeline_layout,
205 basePipelineHandle: VkPipeline::NULL,
206 basePipelineIndex: -1,
207 };
208
209 let mut pipeline = VkPipeline::NULL;
210 let result = vkCreateComputePipelines(
211 inner.device,
212 VkPipelineCache::NULL,
213 1,
214 &pipeline_info,
215 ptr::null(),
216 &mut pipeline,
217 );
218
219 if result != VkResult::Success {
220 vkDestroyPipelineLayout(inner.device, pipeline_layout, ptr::null());
221 vkDestroyDescriptorSetLayout(inner.device, descriptor_set_layout, ptr::null());
222 return Err(KronosError::from(result));
223 }
224
225 Ok(Pipeline {
226 context: self.clone(),
227 pipeline,
228 layout: pipeline_layout,
229 descriptor_set_layout,
230 })
231 })
232 }
233 }
234}
235
236impl Pipeline {
237 pub fn raw(&self) -> VkPipeline {
239 self.pipeline
240 }
241
242 pub fn layout(&self) -> VkPipelineLayout {
244 self.layout
245 }
246
247 pub fn descriptor_set_layout(&self) -> VkDescriptorSetLayout {
249 self.descriptor_set_layout
250 }
251}
252
253impl Drop for Shader {
254 fn drop(&mut self) {
255 unsafe {
256 self.context.with_inner(|inner| {
257 vkDestroyShaderModule(inner.device, self.module, ptr::null());
258 });
259 }
260 }
261}
262
263impl Drop for Pipeline {
264 fn drop(&mut self) {
265 unsafe {
266 self.context.with_inner(|inner| {
267 vkDestroyPipeline(inner.device, self.pipeline, ptr::null());
268 vkDestroyPipelineLayout(inner.device, self.layout, ptr::null());
269 vkDestroyDescriptorSetLayout(inner.device, self.descriptor_set_layout, ptr::null());
270 });
271 }
272 }
273}