1use super::*;
4use crate::*; use crate::implementation::initialize_kronos;
6use std::ffi::CString;
7use std::ptr;
8use std::sync::{Arc, Mutex};
9
10pub(super) struct ContextInner {
12 pub(super) instance: VkInstance,
13 pub(super) physical_device: VkPhysicalDevice,
14 pub(super) device: VkDevice,
15 pub(super) queue: VkQueue,
16 pub(super) queue_family_index: u32,
17
18 pub(super) descriptor_pool: VkDescriptorPool,
20 pub(super) command_pool: VkCommandPool,
21
22 pub(super) device_properties: VkPhysicalDeviceProperties,
24 pub(super) memory_properties: VkPhysicalDeviceMemoryProperties,
25}
26
27#[derive(Clone)]
33pub struct ComputeContext {
34 pub(super) inner: Arc<Mutex<ContextInner>>,
35}
36
37unsafe impl Send for ComputeContext {}
39unsafe impl Sync for ComputeContext {}
40
41impl ComputeContext {
42 pub(super) fn new_with_config(config: ContextConfig) -> Result<Self> {
43 unsafe {
44 initialize_kronos()
46 .map_err(|e| KronosError::InitializationFailed(e.to_string()))?;
47
48 let instance = Self::create_instance(&config)?;
50
51 let (physical_device, queue_family_index) = Self::find_compute_device(instance)?;
53
54 let mut device_properties = VkPhysicalDeviceProperties::default();
56 vkGetPhysicalDeviceProperties(physical_device, &mut device_properties);
57
58 let mut memory_properties = VkPhysicalDeviceMemoryProperties::default();
59 vkGetPhysicalDeviceMemoryProperties(physical_device, &mut memory_properties);
60
61 let device_name = std::ffi::CStr::from_ptr(device_properties.deviceName.as_ptr())
63 .to_string_lossy();
64 let device_type_str = match device_properties.deviceType {
65 VkPhysicalDeviceType::DiscreteGpu => "Discrete GPU",
66 VkPhysicalDeviceType::IntegratedGpu => "Integrated GPU",
67 VkPhysicalDeviceType::VirtualGpu => "Virtual GPU",
68 VkPhysicalDeviceType::Cpu => "CPU (Software Renderer)",
69 _ => "Unknown",
70 };
71 log::info!("Selected Vulkan device: {} ({})", device_name, device_type_str);
72
73 let (device, queue) = Self::create_device(physical_device, queue_family_index)?;
75
76 let descriptor_pool = Self::create_descriptor_pool(device)?;
78
79 let command_pool = Self::create_command_pool(device, queue_family_index)?;
81
82 let inner = ContextInner {
83 instance,
84 physical_device,
85 device,
86 queue,
87 queue_family_index,
88 descriptor_pool,
89 command_pool,
90 device_properties,
91 memory_properties,
92 };
93
94 Ok(Self {
95 inner: Arc::new(Mutex::new(inner)),
96 })
97 }
98 }
99
100 unsafe fn create_instance(config: &ContextConfig) -> Result<VkInstance> {
110 let app_name = CString::new(config.app_name.clone())
111 .unwrap_or_else(|_| CString::new("Kronos App").unwrap());
112 let engine_name = CString::new("Kronos Compute").unwrap();
113
114 let app_info = VkApplicationInfo {
115 sType: VkStructureType::ApplicationInfo,
116 pNext: ptr::null(),
117 pApplicationName: app_name.as_ptr(),
118 applicationVersion: VK_MAKE_VERSION(1, 0, 0),
119 pEngineName: engine_name.as_ptr(),
120 engineVersion: VK_MAKE_VERSION(1, 0, 0),
121 apiVersion: VK_API_VERSION_1_0,
122 };
123
124 let create_info = VkInstanceCreateInfo {
125 sType: VkStructureType::InstanceCreateInfo,
126 pNext: ptr::null(),
127 flags: 0,
128 pApplicationInfo: &app_info,
129 enabledLayerCount: 0,
130 ppEnabledLayerNames: ptr::null(),
131 enabledExtensionCount: 0,
132 ppEnabledExtensionNames: ptr::null(),
133 };
134
135 let mut instance = VkInstance::NULL;
136 let result = vkCreateInstance(&create_info, ptr::null(), &mut instance);
137
138 if result != VkResult::Success {
139 return Err(KronosError::from(result));
140 }
141
142 Ok(instance)
143 }
144
145 unsafe fn find_compute_device(instance: VkInstance) -> Result<(VkPhysicalDevice, u32)> {
155 let mut device_count = 0;
156 vkEnumeratePhysicalDevices(instance, &mut device_count, ptr::null_mut());
157
158 if device_count == 0 {
159 return Err(KronosError::DeviceNotFound);
160 }
161
162 let mut devices = vec![VkPhysicalDevice::NULL; device_count as usize];
163 vkEnumeratePhysicalDevices(instance, &mut device_count, devices.as_mut_ptr());
164
165 let mut candidates = Vec::new();
167
168 for device in devices {
169 let queue_family = Self::find_compute_queue_family(device)?;
170 if let Some(index) = queue_family {
171 let mut properties = VkPhysicalDeviceProperties::default();
173 vkGetPhysicalDeviceProperties(device, &mut properties);
174
175 candidates.push((device, index, properties.deviceType));
176 }
177 }
178
179 if candidates.is_empty() {
180 return Err(KronosError::DeviceNotFound);
181 }
182
183 candidates.sort_by_key(|(_, _, device_type)| {
185 match *device_type {
186 VkPhysicalDeviceType::DiscreteGpu => 0,
187 VkPhysicalDeviceType::IntegratedGpu => 1,
188 VkPhysicalDeviceType::VirtualGpu => 2,
189 VkPhysicalDeviceType::Cpu => 3,
190 VkPhysicalDeviceType::Other => 4,
191 _ => 5,
192 }
193 });
194
195 let (device, queue_index, _) = candidates[0];
197 Ok((device, queue_index))
198 }
199
200 unsafe fn find_compute_queue_family(device: VkPhysicalDevice) -> Result<Option<u32>> {
210 let mut queue_family_count = 0;
211 vkGetPhysicalDeviceQueueFamilyProperties(device, &mut queue_family_count, ptr::null_mut());
212
213 let mut queue_families = vec![VkQueueFamilyProperties {
214 queueFlags: VkQueueFlags::empty(),
215 queueCount: 0,
216 timestampValidBits: 0,
217 minImageTransferGranularity: VkExtent3D { width: 0, height: 0, depth: 0 },
218 }; queue_family_count as usize];
219 vkGetPhysicalDeviceQueueFamilyProperties(device, &mut queue_family_count, queue_families.as_mut_ptr());
220
221 for (index, family) in queue_families.iter().enumerate() {
222 if family.queueFlags.contains(VkQueueFlags::COMPUTE) {
223 return Ok(Some(index as u32));
224 }
225 }
226
227 Ok(None)
228 }
229
230 unsafe fn create_device(physical_device: VkPhysicalDevice, queue_family_index: u32) -> Result<(VkDevice, VkQueue)> {
241 let queue_priority = 1.0f32;
242
243 let queue_create_info = VkDeviceQueueCreateInfo {
244 sType: VkStructureType::DeviceQueueCreateInfo,
245 pNext: ptr::null(),
246 flags: 0,
247 queueFamilyIndex: queue_family_index,
248 queueCount: 1,
249 pQueuePriorities: &queue_priority,
250 };
251
252 let features = VkPhysicalDeviceFeatures::default();
253
254 let device_create_info = VkDeviceCreateInfo {
255 sType: VkStructureType::DeviceCreateInfo,
256 pNext: ptr::null(),
257 flags: 0,
258 queueCreateInfoCount: 1,
259 pQueueCreateInfos: &queue_create_info,
260 enabledLayerCount: 0,
261 ppEnabledLayerNames: ptr::null(),
262 enabledExtensionCount: 0,
263 ppEnabledExtensionNames: ptr::null(),
264 pEnabledFeatures: &features,
265 };
266
267 let mut device = VkDevice::NULL;
268 let result = vkCreateDevice(physical_device, &device_create_info, ptr::null(), &mut device);
269
270 if result != VkResult::Success {
271 return Err(KronosError::from(result));
272 }
273
274 let mut queue = VkQueue::NULL;
275 vkGetDeviceQueue(device, queue_family_index, 0, &mut queue);
276
277 Ok((device, queue))
278 }
279
280 unsafe fn create_descriptor_pool(device: VkDevice) -> Result<VkDescriptorPool> {
291 let pool_size = VkDescriptorPoolSize {
293 type_: VkDescriptorType::StorageBuffer,
294 descriptorCount: 10000, };
296
297 let pool_info = VkDescriptorPoolCreateInfo {
298 sType: VkStructureType::DescriptorPoolCreateInfo,
299 pNext: ptr::null(),
300 flags: VkDescriptorPoolCreateFlags::FREE_DESCRIPTOR_SET,
301 maxSets: 1000,
302 poolSizeCount: 1,
303 pPoolSizes: &pool_size,
304 };
305
306 let mut pool = VkDescriptorPool::NULL;
307 let result = vkCreateDescriptorPool(device, &pool_info, ptr::null(), &mut pool);
308
309 if result != VkResult::Success {
310 return Err(KronosError::from(result));
311 }
312
313 Ok(pool)
314 }
315
316 unsafe fn create_command_pool(device: VkDevice, queue_family_index: u32) -> Result<VkCommandPool> {
327 let pool_info = VkCommandPoolCreateInfo {
328 sType: VkStructureType::CommandPoolCreateInfo,
329 pNext: ptr::null(),
330 flags: VkCommandPoolCreateFlags::RESET_COMMAND_BUFFER,
331 queueFamilyIndex: queue_family_index,
332 };
333
334 let mut pool = VkCommandPool::NULL;
335 let result = vkCreateCommandPool(device, &pool_info, ptr::null(), &mut pool);
336
337 if result != VkResult::Success {
338 return Err(KronosError::from(result));
339 }
340
341 Ok(pool)
342 }
343
344 pub fn device(&self) -> VkDevice {
346 self.inner.lock().unwrap().device
347 }
348
349 pub fn queue(&self) -> VkQueue {
351 self.inner.lock().unwrap().queue
352 }
353
354 pub fn device_properties(&self) -> VkPhysicalDeviceProperties {
356 self.inner.lock().unwrap().device_properties
357 }
358
359 pub(super) fn with_inner<F, R>(&self, f: F) -> R
361 where
362 F: FnOnce(&ContextInner) -> R,
363 {
364 let inner = self.inner.lock().unwrap();
365 f(&*inner)
366 }
367}
368
369impl Drop for ComputeContext {
370 fn drop(&mut self) {
371 let inner = self.inner.lock().unwrap();
372 unsafe {
373 if inner.command_pool != VkCommandPool::NULL {
374 vkDestroyCommandPool(inner.device, inner.command_pool, ptr::null());
375 }
376 if inner.descriptor_pool != VkDescriptorPool::NULL {
377 vkDestroyDescriptorPool(inner.device, inner.descriptor_pool, ptr::null());
378 }
379 if inner.device != VkDevice::NULL {
380 vkDestroyDevice(inner.device, ptr::null());
381 }
382 if inner.instance != VkInstance::NULL {
383 vkDestroyInstance(inner.instance, ptr::null());
384 }
385 }
386 }
387}