1use super::*;
4use crate::*; use crate::implementation::initialize_kronos;
6
7#[cfg(feature = "implementation")]
9use crate::implementation::{
10 vkCreateInstance, vkDestroyInstance, vkEnumeratePhysicalDevices,
11 vkGetPhysicalDeviceProperties, vkGetPhysicalDeviceMemoryProperties,
12 vkGetPhysicalDeviceQueueFamilyProperties,
13 vkCreateDevice, vkDestroyDevice, vkGetDeviceQueue,
14};
15use std::ffi::CString;
16use std::ptr;
17use std::sync::{Arc, Mutex};
18
19pub(super) struct ContextInner {
21 pub(super) instance: VkInstance,
22 pub(super) physical_device: VkPhysicalDevice,
23 pub(super) device: VkDevice,
24 pub(super) queue: VkQueue,
25 pub(super) queue_family_index: u32,
26
27 pub(super) descriptor_pool: VkDescriptorPool,
29 pub(super) command_pool: VkCommandPool,
30
31 pub(super) device_properties: VkPhysicalDeviceProperties,
33 pub(super) memory_properties: VkPhysicalDeviceMemoryProperties,
34}
35
36#[derive(Clone)]
42pub struct ComputeContext {
43 pub(super) inner: Arc<Mutex<ContextInner>>,
44}
45
46unsafe impl Send for ComputeContext {}
48unsafe impl Sync for ComputeContext {}
49
50impl ComputeContext {
51 pub(super) fn new_with_config(config: ContextConfig) -> Result<Self> {
52 unsafe {
53 if let Some(ref p) = config.preferred_icd_path {
55 crate::implementation::icd_loader::set_preferred_icd_path(p.clone());
56 } else if let Some(i) = config.preferred_icd_index {
57 crate::implementation::icd_loader::set_preferred_icd_index(i);
58 }
59
60 initialize_kronos()
62 .map_err(|e| KronosError::InitializationFailed(e.to_string()))?;
63
64 let instance = Self::create_instance(&config)?;
66
67 let (physical_device, queue_family_index) = Self::find_compute_device(instance)?;
69
70 let mut device_properties = VkPhysicalDeviceProperties::default();
72 vkGetPhysicalDeviceProperties(physical_device, &mut device_properties);
73
74 let mut memory_properties = VkPhysicalDeviceMemoryProperties::default();
75 vkGetPhysicalDeviceMemoryProperties(physical_device, &mut memory_properties);
76
77 let device_name = std::ffi::CStr::from_ptr(device_properties.deviceName.as_ptr())
79 .to_string_lossy();
80 let device_type_str = match device_properties.deviceType {
81 VkPhysicalDeviceType::DiscreteGpu => "Discrete GPU",
82 VkPhysicalDeviceType::IntegratedGpu => "Integrated GPU",
83 VkPhysicalDeviceType::VirtualGpu => "Virtual GPU",
84 VkPhysicalDeviceType::Cpu => "CPU (Software Renderer)",
85 _ => "Unknown",
86 };
87 log::info!("Selected Vulkan device: {} ({})", device_name, device_type_str);
88
89 let (device, queue) = Self::create_device(physical_device, queue_family_index)?;
91
92 let descriptor_pool = Self::create_descriptor_pool(device)?;
94
95 let command_pool = Self::create_command_pool(device, queue_family_index)?;
97
98 let inner = ContextInner {
99 instance,
100 physical_device,
101 device,
102 queue,
103 queue_family_index,
104 descriptor_pool,
105 command_pool,
106 device_properties,
107 memory_properties,
108 };
109
110 if let Some(info) = crate::implementation::icd_loader::selected_icd_info() {
112 log::info!(
113 "ComputeContext bound to ICD: {} ({}), api=0x{:x}",
114 info.library_path.display(),
115 if info.is_software { "software" } else { "hardware" },
116 info.api_version
117 );
118 }
119
120 Ok(Self {
121 inner: Arc::new(Mutex::new(inner)),
122 })
123 }
124 }
125
126 unsafe fn create_instance(config: &ContextConfig) -> Result<VkInstance> {
136 let app_name = CString::new(config.app_name.clone())
137 .unwrap_or_else(|_| CString::new("Kronos App").unwrap());
138 let engine_name = CString::new("Kronos Compute").unwrap();
139
140 let app_info = VkApplicationInfo {
141 sType: VkStructureType::ApplicationInfo,
142 pNext: ptr::null(),
143 pApplicationName: app_name.as_ptr(),
144 applicationVersion: VK_MAKE_VERSION(1, 0, 0),
145 pEngineName: engine_name.as_ptr(),
146 engineVersion: VK_MAKE_VERSION(1, 0, 0),
147 apiVersion: VK_API_VERSION_1_0,
148 };
149
150 let create_info = VkInstanceCreateInfo {
151 sType: VkStructureType::InstanceCreateInfo,
152 pNext: ptr::null(),
153 flags: 0,
154 pApplicationInfo: &app_info,
155 enabledLayerCount: 0,
156 ppEnabledLayerNames: ptr::null(),
157 enabledExtensionCount: 0,
158 ppEnabledExtensionNames: ptr::null(),
159 };
160
161 let mut instance = VkInstance::NULL;
162 let result = vkCreateInstance(&create_info, ptr::null(), &mut instance);
163
164 if result != VkResult::Success {
165 return Err(KronosError::from(result));
166 }
167
168 Ok(instance)
169 }
170
171 unsafe fn find_compute_device(instance: VkInstance) -> Result<(VkPhysicalDevice, u32)> {
181 let mut device_count = 0;
182 vkEnumeratePhysicalDevices(instance, &mut device_count, ptr::null_mut());
183
184 if device_count == 0 {
185 return Err(KronosError::DeviceNotFound);
186 }
187
188 let mut devices = vec![VkPhysicalDevice::NULL; device_count as usize];
189 vkEnumeratePhysicalDevices(instance, &mut device_count, devices.as_mut_ptr());
190
191 let mut candidates = Vec::new();
193
194 for device in devices {
195 let queue_family = Self::find_compute_queue_family(device)?;
196 if let Some(index) = queue_family {
197 let mut properties = VkPhysicalDeviceProperties::default();
199 vkGetPhysicalDeviceProperties(device, &mut properties);
200
201 candidates.push((device, index, properties.deviceType));
202 }
203 }
204
205 if candidates.is_empty() {
206 return Err(KronosError::DeviceNotFound);
207 }
208
209 candidates.sort_by_key(|(_, _, device_type)| {
211 match *device_type {
212 VkPhysicalDeviceType::DiscreteGpu => 0,
213 VkPhysicalDeviceType::IntegratedGpu => 1,
214 VkPhysicalDeviceType::VirtualGpu => 2,
215 VkPhysicalDeviceType::Cpu => 3,
216 VkPhysicalDeviceType::Other => 4,
217 _ => 5,
218 }
219 });
220
221 let (device, queue_index, _) = candidates[0];
223 Ok((device, queue_index))
224 }
225
226 unsafe fn find_compute_queue_family(device: VkPhysicalDevice) -> Result<Option<u32>> {
236 let mut queue_family_count = 0;
237 vkGetPhysicalDeviceQueueFamilyProperties(device, &mut queue_family_count, ptr::null_mut());
238
239 let mut queue_families = vec![VkQueueFamilyProperties {
240 queueFlags: VkQueueFlags::empty(),
241 queueCount: 0,
242 timestampValidBits: 0,
243 minImageTransferGranularity: VkExtent3D { width: 0, height: 0, depth: 0 },
244 }; queue_family_count as usize];
245 vkGetPhysicalDeviceQueueFamilyProperties(device, &mut queue_family_count, queue_families.as_mut_ptr());
246
247 for (index, family) in queue_families.iter().enumerate() {
248 if family.queueFlags.contains(VkQueueFlags::COMPUTE) {
249 return Ok(Some(index as u32));
250 }
251 }
252
253 Ok(None)
254 }
255
256 unsafe fn create_device(physical_device: VkPhysicalDevice, queue_family_index: u32) -> Result<(VkDevice, VkQueue)> {
267 let queue_priority = 1.0f32;
268
269 let queue_create_info = VkDeviceQueueCreateInfo {
270 sType: VkStructureType::DeviceQueueCreateInfo,
271 pNext: ptr::null(),
272 flags: 0,
273 queueFamilyIndex: queue_family_index,
274 queueCount: 1,
275 pQueuePriorities: &queue_priority,
276 };
277
278 let features = VkPhysicalDeviceFeatures::default();
279
280 let device_create_info = VkDeviceCreateInfo {
281 sType: VkStructureType::DeviceCreateInfo,
282 pNext: ptr::null(),
283 flags: 0,
284 queueCreateInfoCount: 1,
285 pQueueCreateInfos: &queue_create_info,
286 enabledLayerCount: 0,
287 ppEnabledLayerNames: ptr::null(),
288 enabledExtensionCount: 0,
289 ppEnabledExtensionNames: ptr::null(),
290 pEnabledFeatures: &features,
291 };
292
293 let mut device = VkDevice::NULL;
294 let result = vkCreateDevice(physical_device, &device_create_info, ptr::null(), &mut device);
295
296 if result != VkResult::Success {
297 return Err(KronosError::from(result));
298 }
299
300 let mut queue = VkQueue::NULL;
301 vkGetDeviceQueue(device, queue_family_index, 0, &mut queue);
302
303 Ok((device, queue))
304 }
305
306 unsafe fn create_descriptor_pool(device: VkDevice) -> Result<VkDescriptorPool> {
317 let pool_size = VkDescriptorPoolSize {
319 type_: VkDescriptorType::StorageBuffer,
320 descriptorCount: 10000, };
322
323 let pool_info = VkDescriptorPoolCreateInfo {
324 sType: VkStructureType::DescriptorPoolCreateInfo,
325 pNext: ptr::null(),
326 flags: VkDescriptorPoolCreateFlags::FREE_DESCRIPTOR_SET,
327 maxSets: 1000,
328 poolSizeCount: 1,
329 pPoolSizes: &pool_size,
330 };
331
332 let mut pool = VkDescriptorPool::NULL;
333 let result = vkCreateDescriptorPool(device, &pool_info, ptr::null(), &mut pool);
334
335 if result != VkResult::Success {
336 return Err(KronosError::from(result));
337 }
338
339 Ok(pool)
340 }
341
342 unsafe fn create_command_pool(device: VkDevice, queue_family_index: u32) -> Result<VkCommandPool> {
353 let pool_info = VkCommandPoolCreateInfo {
354 sType: VkStructureType::CommandPoolCreateInfo,
355 pNext: ptr::null(),
356 flags: VkCommandPoolCreateFlags::RESET_COMMAND_BUFFER,
357 queueFamilyIndex: queue_family_index,
358 };
359
360 let mut pool = VkCommandPool::NULL;
361 let result = vkCreateCommandPool(device, &pool_info, ptr::null(), &mut pool);
362
363 if result != VkResult::Success {
364 return Err(KronosError::from(result));
365 }
366
367 Ok(pool)
368 }
369
370 pub fn device(&self) -> VkDevice {
372 self.inner.lock().unwrap().device
373 }
374
375 pub fn queue(&self) -> VkQueue {
377 self.inner.lock().unwrap().queue
378 }
379
380 pub fn device_properties(&self) -> VkPhysicalDeviceProperties {
382 self.inner.lock().unwrap().device_properties
383 }
384
385 pub fn icd_info(&self) -> Option<crate::implementation::icd_loader::IcdInfo> {
387 crate::implementation::icd_loader::selected_icd_info()
388 }
389
390 pub(super) fn with_inner<F, R>(&self, f: F) -> R
392 where
393 F: FnOnce(&ContextInner) -> R,
394 {
395 let inner = self.inner.lock().unwrap();
396 f(&*inner)
397 }
398}
399
400impl Drop for ComputeContext {
401 fn drop(&mut self) {
402 if std::sync::Arc::strong_count(&self.inner) != 1 {
404 return;
405 }
406 let inner = self.inner.lock().unwrap();
407 unsafe {
408 if inner.command_pool != VkCommandPool::NULL {
409 vkDestroyCommandPool(inner.device, inner.command_pool, ptr::null());
410 }
411 if inner.descriptor_pool != VkDescriptorPool::NULL {
412 vkDestroyDescriptorPool(inner.device, inner.descriptor_pool, ptr::null());
413 }
414 if inner.device != VkDevice::NULL {
415 vkDestroyDevice(inner.device, ptr::null());
416 }
417 if inner.instance != VkInstance::NULL {
418 vkDestroyInstance(inner.instance, ptr::null());
419 }
420 }
421 }
422}