1use super::*;
4use crate::*; use crate::implementation::initialize_kronos;
6use std::ffi::CString;
7use std::ptr;
8use std::sync::{Arc, Mutex};
9
10pub(super) struct ContextInner {
12 pub(super) instance: VkInstance,
13 pub(super) physical_device: VkPhysicalDevice,
14 pub(super) device: VkDevice,
15 pub(super) queue: VkQueue,
16 pub(super) queue_family_index: u32,
17
18 pub(super) descriptor_pool: VkDescriptorPool,
20 pub(super) command_pool: VkCommandPool,
21
22 pub(super) device_properties: VkPhysicalDeviceProperties,
24 pub(super) memory_properties: VkPhysicalDeviceMemoryProperties,
25}
26
27#[derive(Clone)]
33pub struct ComputeContext {
34 pub(super) inner: Arc<Mutex<ContextInner>>,
35}
36
37unsafe impl Send for ComputeContext {}
39unsafe impl Sync for ComputeContext {}
40
41impl ComputeContext {
42 pub(super) fn new_with_config(config: ContextConfig) -> Result<Self> {
43 unsafe {
44 initialize_kronos()
46 .map_err(|e| KronosError::InitializationFailed(e.to_string()))?;
47
48 let instance = Self::create_instance(&config)?;
50
51 let (physical_device, queue_family_index) = Self::find_compute_device(instance)?;
53
54 let mut device_properties = VkPhysicalDeviceProperties::default();
56 vkGetPhysicalDeviceProperties(physical_device, &mut device_properties);
57
58 let mut memory_properties = VkPhysicalDeviceMemoryProperties::default();
59 vkGetPhysicalDeviceMemoryProperties(physical_device, &mut memory_properties);
60
61 let (device, queue) = Self::create_device(physical_device, queue_family_index)?;
63
64 let descriptor_pool = Self::create_descriptor_pool(device)?;
66
67 let command_pool = Self::create_command_pool(device, queue_family_index)?;
69
70 let inner = ContextInner {
71 instance,
72 physical_device,
73 device,
74 queue,
75 queue_family_index,
76 descriptor_pool,
77 command_pool,
78 device_properties,
79 memory_properties,
80 };
81
82 Ok(Self {
83 inner: Arc::new(Mutex::new(inner)),
84 })
85 }
86 }
87
88 unsafe fn create_instance(config: &ContextConfig) -> Result<VkInstance> {
98 let app_name = CString::new(config.app_name.clone())
99 .unwrap_or_else(|_| CString::new("Kronos App").unwrap());
100 let engine_name = CString::new("Kronos Compute").unwrap();
101
102 let app_info = VkApplicationInfo {
103 sType: VkStructureType::ApplicationInfo,
104 pNext: ptr::null(),
105 pApplicationName: app_name.as_ptr(),
106 applicationVersion: VK_MAKE_VERSION(1, 0, 0),
107 pEngineName: engine_name.as_ptr(),
108 engineVersion: VK_MAKE_VERSION(1, 0, 0),
109 apiVersion: VK_API_VERSION_1_0,
110 };
111
112 let create_info = VkInstanceCreateInfo {
113 sType: VkStructureType::InstanceCreateInfo,
114 pNext: ptr::null(),
115 flags: 0,
116 pApplicationInfo: &app_info,
117 enabledLayerCount: 0,
118 ppEnabledLayerNames: ptr::null(),
119 enabledExtensionCount: 0,
120 ppEnabledExtensionNames: ptr::null(),
121 };
122
123 let mut instance = VkInstance::NULL;
124 let result = vkCreateInstance(&create_info, ptr::null(), &mut instance);
125
126 if result != VkResult::Success {
127 return Err(KronosError::from(result));
128 }
129
130 Ok(instance)
131 }
132
133 unsafe fn find_compute_device(instance: VkInstance) -> Result<(VkPhysicalDevice, u32)> {
143 let mut device_count = 0;
144 vkEnumeratePhysicalDevices(instance, &mut device_count, ptr::null_mut());
145
146 if device_count == 0 {
147 return Err(KronosError::DeviceNotFound);
148 }
149
150 let mut devices = vec![VkPhysicalDevice::NULL; device_count as usize];
151 vkEnumeratePhysicalDevices(instance, &mut device_count, devices.as_mut_ptr());
152
153 for device in devices {
155 let queue_family = Self::find_compute_queue_family(device)?;
156 if let Some(index) = queue_family {
157 return Ok((device, index));
158 }
159 }
160
161 Err(KronosError::DeviceNotFound)
162 }
163
164 unsafe fn find_compute_queue_family(device: VkPhysicalDevice) -> Result<Option<u32>> {
174 let mut queue_family_count = 0;
175 vkGetPhysicalDeviceQueueFamilyProperties(device, &mut queue_family_count, ptr::null_mut());
176
177 let mut queue_families = vec![VkQueueFamilyProperties {
178 queueFlags: VkQueueFlags::empty(),
179 queueCount: 0,
180 timestampValidBits: 0,
181 minImageTransferGranularity: VkExtent3D { width: 0, height: 0, depth: 0 },
182 }; queue_family_count as usize];
183 vkGetPhysicalDeviceQueueFamilyProperties(device, &mut queue_family_count, queue_families.as_mut_ptr());
184
185 for (index, family) in queue_families.iter().enumerate() {
186 if family.queueFlags.contains(VkQueueFlags::COMPUTE) {
187 return Ok(Some(index as u32));
188 }
189 }
190
191 Ok(None)
192 }
193
194 unsafe fn create_device(physical_device: VkPhysicalDevice, queue_family_index: u32) -> Result<(VkDevice, VkQueue)> {
205 let queue_priority = 1.0f32;
206
207 let queue_create_info = VkDeviceQueueCreateInfo {
208 sType: VkStructureType::DeviceQueueCreateInfo,
209 pNext: ptr::null(),
210 flags: 0,
211 queueFamilyIndex: queue_family_index,
212 queueCount: 1,
213 pQueuePriorities: &queue_priority,
214 };
215
216 let features = VkPhysicalDeviceFeatures::default();
217
218 let device_create_info = VkDeviceCreateInfo {
219 sType: VkStructureType::DeviceCreateInfo,
220 pNext: ptr::null(),
221 flags: 0,
222 queueCreateInfoCount: 1,
223 pQueueCreateInfos: &queue_create_info,
224 enabledLayerCount: 0,
225 ppEnabledLayerNames: ptr::null(),
226 enabledExtensionCount: 0,
227 ppEnabledExtensionNames: ptr::null(),
228 pEnabledFeatures: &features,
229 };
230
231 let mut device = VkDevice::NULL;
232 let result = vkCreateDevice(physical_device, &device_create_info, ptr::null(), &mut device);
233
234 if result != VkResult::Success {
235 return Err(KronosError::from(result));
236 }
237
238 let mut queue = VkQueue::NULL;
239 vkGetDeviceQueue(device, queue_family_index, 0, &mut queue);
240
241 Ok((device, queue))
242 }
243
244 unsafe fn create_descriptor_pool(device: VkDevice) -> Result<VkDescriptorPool> {
255 let pool_size = VkDescriptorPoolSize {
257 type_: VkDescriptorType::StorageBuffer,
258 descriptorCount: 10000, };
260
261 let pool_info = VkDescriptorPoolCreateInfo {
262 sType: VkStructureType::DescriptorPoolCreateInfo,
263 pNext: ptr::null(),
264 flags: VkDescriptorPoolCreateFlags::FREE_DESCRIPTOR_SET,
265 maxSets: 1000,
266 poolSizeCount: 1,
267 pPoolSizes: &pool_size,
268 };
269
270 let mut pool = VkDescriptorPool::NULL;
271 let result = vkCreateDescriptorPool(device, &pool_info, ptr::null(), &mut pool);
272
273 if result != VkResult::Success {
274 return Err(KronosError::from(result));
275 }
276
277 Ok(pool)
278 }
279
280 unsafe fn create_command_pool(device: VkDevice, queue_family_index: u32) -> Result<VkCommandPool> {
291 let pool_info = VkCommandPoolCreateInfo {
292 sType: VkStructureType::CommandPoolCreateInfo,
293 pNext: ptr::null(),
294 flags: VkCommandPoolCreateFlags::RESET_COMMAND_BUFFER,
295 queueFamilyIndex: queue_family_index,
296 };
297
298 let mut pool = VkCommandPool::NULL;
299 let result = vkCreateCommandPool(device, &pool_info, ptr::null(), &mut pool);
300
301 if result != VkResult::Success {
302 return Err(KronosError::from(result));
303 }
304
305 Ok(pool)
306 }
307
308 pub fn device(&self) -> VkDevice {
310 self.inner.lock().unwrap().device
311 }
312
313 pub fn queue(&self) -> VkQueue {
315 self.inner.lock().unwrap().queue
316 }
317
318 pub fn device_properties(&self) -> VkPhysicalDeviceProperties {
320 self.inner.lock().unwrap().device_properties
321 }
322
323 pub(super) fn with_inner<F, R>(&self, f: F) -> R
325 where
326 F: FnOnce(&ContextInner) -> R,
327 {
328 let inner = self.inner.lock().unwrap();
329 f(&*inner)
330 }
331}
332
333impl Drop for ComputeContext {
334 fn drop(&mut self) {
335 let inner = self.inner.lock().unwrap();
336 unsafe {
337 if inner.command_pool != VkCommandPool::NULL {
338 vkDestroyCommandPool(inner.device, inner.command_pool, ptr::null());
339 }
340 if inner.descriptor_pool != VkDescriptorPool::NULL {
341 vkDestroyDescriptorPool(inner.device, inner.descriptor_pool, ptr::null());
342 }
343 if inner.device != VkDevice::NULL {
344 vkDestroyDevice(inner.device, ptr::null());
345 }
346 if inner.instance != VkInstance::NULL {
347 vkDestroyInstance(inner.instance, ptr::null());
348 }
349 }
350 }
351}