1use super::*;
4use crate::*; use crate::implementation::initialize_kronos;
6
7#[cfg(feature = "implementation")]
9use crate::implementation::{
10 vkCreateInstance, vkDestroyInstance, vkEnumeratePhysicalDevices,
11 vkGetPhysicalDeviceProperties, vkGetPhysicalDeviceMemoryProperties,
12 vkGetPhysicalDeviceQueueFamilyProperties,
13 vkCreateDevice, vkDestroyDevice, vkGetDeviceQueue,
14 vkCreateDescriptorPool, vkDestroyDescriptorPool,
15 vkCreateCommandPool, vkDestroyCommandPool,
16};
17use std::ffi::CString;
18use std::ptr;
19use std::sync::{Arc, Mutex};
20
21pub(super) struct ContextInner {
23 pub(super) instance: VkInstance,
24 pub(super) physical_device: VkPhysicalDevice,
25 pub(super) device: VkDevice,
26 pub(super) queue: VkQueue,
27 pub(super) queue_family_index: u32,
28
29 pub(super) descriptor_pool: VkDescriptorPool,
31 pub(super) command_pool: VkCommandPool,
32
33 pub(super) device_properties: VkPhysicalDeviceProperties,
35 pub(super) memory_properties: VkPhysicalDeviceMemoryProperties,
36}
37
38#[derive(Clone)]
44pub struct ComputeContext {
45 pub(super) inner: Arc<Mutex<ContextInner>>,
46}
47
48unsafe impl Send for ComputeContext {}
50unsafe impl Sync for ComputeContext {}
51
52impl ComputeContext {
53 pub(super) fn new_with_config(config: ContextConfig) -> Result<Self> {
54 log::info!("[SAFE API] ComputeContext::new_with_config() called");
55 unsafe {
56 log::info!("[SAFE API] Applying ICD preferences");
59 if let Some(ref p) = config.preferred_icd_path {
60 log::info!("[SAFE API] Setting preferred ICD path: {:?}", p);
61 crate::implementation::icd_loader::set_preferred_icd_path(p.clone());
62 } else if let Some(i) = config.preferred_icd_index {
63 log::info!("[SAFE API] Setting preferred ICD index: {}", i);
64 crate::implementation::icd_loader::set_preferred_icd_index(i);
65 }
66
67 log::info!("[SAFE API] Initializing Kronos ICD loader");
69 log::info!("[SAFE API] KRONOS_AGGREGATE_ICD = {:?}", std::env::var("KRONOS_AGGREGATE_ICD").ok());
70 initialize_kronos()
71 .map_err(|e| {
72 log::error!("[SAFE API] Failed to initialize Kronos: {:?}", e);
73 KronosError::InitializationFailed(e.to_string())
74 })?;
75 log::info!("[SAFE API] Kronos initialized successfully");
76
77 log::info!("[SAFE API] Creating Vulkan instance");
79 let instance = Self::create_instance(&config)?;
80 log::info!("[SAFE API] Instance created: {:?}", instance);
81
82 log::info!("[SAFE API] Finding compute-capable device");
84 let (physical_device, queue_family_index) = Self::find_compute_device(instance)?;
85 log::info!("[SAFE API] Found device: {:?}, queue family: {}", physical_device, queue_family_index);
86
87 log::info!("[SAFE API] find_compute_device returned successfully");
88
89 log::info!("[SAFE API] Getting device properties");
91 let mut device_properties = VkPhysicalDeviceProperties::default();
92 vkGetPhysicalDeviceProperties(physical_device, &mut device_properties);
93 log::info!("[SAFE API] Got device properties successfully");
94
95 let mut memory_properties = VkPhysicalDeviceMemoryProperties::default();
96 log::info!("[SAFE API] Getting memory properties");
97 vkGetPhysicalDeviceMemoryProperties(physical_device, &mut memory_properties);
98 log::info!("[SAFE API] Got memory properties successfully");
99
100 let device_name_bytes = &device_properties.deviceName;
103 let null_pos = device_name_bytes.iter().position(|&c| c == 0).unwrap_or(device_name_bytes.len());
104 let device_name_u8: Vec<u8> = device_name_bytes[..null_pos]
106 .iter()
107 .map(|&c| c as u8)
108 .collect();
109 let device_name = std::str::from_utf8(&device_name_u8)
110 .unwrap_or("Unknown Device");
111 let device_type_str = match device_properties.deviceType {
112 VkPhysicalDeviceType::DiscreteGpu => "Discrete GPU",
113 VkPhysicalDeviceType::IntegratedGpu => "Integrated GPU",
114 VkPhysicalDeviceType::VirtualGpu => "Virtual GPU",
115 VkPhysicalDeviceType::Cpu => "CPU (Software Renderer)",
116 _ => "Unknown",
117 };
118 log::info!("Selected Vulkan device: {} ({})", device_name, device_type_str);
119
120 log::info!("[SAFE API] Creating logical device");
122 let (device, queue) = Self::create_device(physical_device, queue_family_index)?;
123 log::info!("[SAFE API] Device created: {:?}, queue: {:?}", device, queue);
124
125 log::info!("[SAFE API] Skipping descriptor pool creation temporarily");
128 let descriptor_pool = VkDescriptorPool::NULL;
129 log::info!("[SAFE API] Creating command pool");
134 let command_pool = Self::create_command_pool(device, queue_family_index)?;
135 log::info!("[SAFE API] Command pool created: {:?}", command_pool);
136
137 let inner = ContextInner {
138 instance,
139 physical_device,
140 device,
141 queue,
142 queue_family_index,
143 descriptor_pool,
144 command_pool,
145 device_properties,
146 memory_properties,
147 };
148
149 if let Some(info) = crate::implementation::icd_loader::selected_icd_info() {
151 log::info!(
152 "ComputeContext bound to ICD: {} ({}), api=0x{:x}",
153 info.library_path.display(),
154 if info.is_software { "software" } else { "hardware" },
155 info.api_version
156 );
157 }
158
159 let result = Self {
160 inner: Arc::new(Mutex::new(inner)),
161 };
162 log::info!("[SAFE API] ComputeContext created successfully");
163 Ok(result)
164 }
165 }
166
167 unsafe fn create_instance(config: &ContextConfig) -> Result<VkInstance> {
177 log::info!("[SAFE API] create_instance called with app_name: {}", config.app_name);
178 let app_name = CString::new(config.app_name.clone())
179 .unwrap_or_else(|_| CString::new("Kronos App").unwrap());
180 let engine_name = CString::new("Kronos Compute").unwrap();
181 log::info!("[SAFE API] CStrings created successfully");
182
183 let app_info = VkApplicationInfo {
184 sType: VkStructureType::ApplicationInfo,
185 pNext: ptr::null(),
186 pApplicationName: app_name.as_ptr(),
187 applicationVersion: VK_MAKE_VERSION(1, 0, 0),
188 pEngineName: engine_name.as_ptr(),
189 engineVersion: VK_MAKE_VERSION(1, 0, 0),
190 apiVersion: VK_API_VERSION_1_0,
191 };
192
193 let create_info = VkInstanceCreateInfo {
194 sType: VkStructureType::InstanceCreateInfo,
195 pNext: ptr::null(),
196 flags: 0,
197 pApplicationInfo: &app_info,
198 enabledLayerCount: 0,
199 ppEnabledLayerNames: ptr::null(),
200 enabledExtensionCount: 0,
201 ppEnabledExtensionNames: ptr::null(),
202 };
203
204 let mut instance = VkInstance::NULL;
205 log::info!("[SAFE API] Calling vkCreateInstance");
208 let result = vkCreateInstance(&create_info, ptr::null(), &mut instance);
209 log::info!("[SAFE API] vkCreateInstance returned: {:?}", result);
210
211 if result != VkResult::Success {
212 log::error!("[SAFE API] vkCreateInstance failed with: {:?}", result);
213 return Err(KronosError::from(result));
214 }
215
216 log::info!("[SAFE API] Instance created successfully: {:?}", instance);
217 Ok(instance)
218 }
219
220 unsafe fn find_compute_device(instance: VkInstance) -> Result<(VkPhysicalDevice, u32)> {
230 let mut device_count = 0;
231 log::info!("[SAFE API] Enumerating physical devices...");
232
233 let result = vkEnumeratePhysicalDevices(instance, &mut device_count, ptr::null_mut());
235 if result != VkResult::Success {
236 log::error!("[SAFE API] Failed to get device count: {:?}", result);
237 return Err(KronosError::from(result));
238 }
239 log::info!("[SAFE API] Found {} physical devices", device_count);
240
241 if device_count == 0 {
242 return Err(KronosError::DeviceNotFound);
243 }
244
245 let mut devices = vec![VkPhysicalDevice::NULL; device_count as usize];
246 let result = vkEnumeratePhysicalDevices(instance, &mut device_count, devices.as_mut_ptr());
247 if result != VkResult::Success {
248 log::error!("[SAFE API] Failed to enumerate devices: {:?}", result);
249 return Err(KronosError::from(result));
250 }
251 log::info!("[SAFE API] Successfully enumerated {} devices", device_count);
252
253 let mut candidates = Vec::new();
255
256 for (dev_idx, device) in devices.iter().enumerate() {
257 log::info!("[SAFE API] Checking device {} for compute support", dev_idx);
258 let queue_family = Self::find_compute_queue_family(*device)?;
259 if let Some(index) = queue_family {
260 let mut properties = VkPhysicalDeviceProperties::default();
262 vkGetPhysicalDeviceProperties(*device, &mut properties);
263
264 candidates.push((*device, index, properties.deviceType));
265 }
266 }
267
268 if candidates.is_empty() {
269 return Err(KronosError::DeviceNotFound);
270 }
271
272 candidates.sort_by_key(|(_, _, device_type)| {
274 match *device_type {
275 VkPhysicalDeviceType::DiscreteGpu => 0,
276 VkPhysicalDeviceType::IntegratedGpu => 1,
277 VkPhysicalDeviceType::VirtualGpu => 2,
278 VkPhysicalDeviceType::Cpu => 3,
279 VkPhysicalDeviceType::Other => 4,
280 _ => 5,
281 }
282 });
283
284 let (device, queue_index, device_type) = candidates[0];
286 log::info!("[SAFE API] Selected device with queue index {}, type {:?}", queue_index, device_type);
287 Ok((device, queue_index))
288 }
289
290 unsafe fn find_compute_queue_family(device: VkPhysicalDevice) -> Result<Option<u32>> {
300 log::info!("[SAFE API] Finding queue families for device {:?}", device);
301 let mut queue_family_count = 0;
302 log::info!("[SAFE API] Calling vkGetPhysicalDeviceQueueFamilyProperties (count query)...");
303 vkGetPhysicalDeviceQueueFamilyProperties(device, &mut queue_family_count, ptr::null_mut());
304 log::info!("[SAFE API] Device has {} queue families", queue_family_count);
305
306 let mut queue_families = vec![
307 VkQueueFamilyProperties {
308 queueFlags: VkQueueFlags::empty(),
309 queueCount: 0,
310 timestampValidBits: 0,
311 minImageTransferGranularity: VkExtent3D { width: 0, height: 0, depth: 0 },
312 };
313 queue_family_count as usize
314 ];
315 log::info!("[SAFE API] Getting queue family properties...");
316 vkGetPhysicalDeviceQueueFamilyProperties(device, &mut queue_family_count, queue_families.as_mut_ptr());
317 log::info!("[SAFE API] Got queue family properties, checking for compute support");
318
319 for (index, family) in queue_families.iter().enumerate() {
320 log::info!("[SAFE API] Queue family {}: flags={:?}", index, family.queueFlags);
321 if family.queueFlags.contains(VkQueueFlags::COMPUTE) {
322 log::info!("[SAFE API] Found compute queue at index {}", index);
323 return Ok(Some(index as u32));
324 }
325 }
326
327 log::info!("[SAFE API] No compute queue family found");
328 Ok(None)
329 }
330
331 unsafe fn create_device(physical_device: VkPhysicalDevice, queue_family_index: u32) -> Result<(VkDevice, VkQueue)> {
342 let queue_priority = 1.0f32;
343
344 let queue_create_info = VkDeviceQueueCreateInfo {
345 sType: VkStructureType::DeviceQueueCreateInfo,
346 pNext: ptr::null(),
347 flags: 0,
348 queueFamilyIndex: queue_family_index,
349 queueCount: 1,
350 pQueuePriorities: &queue_priority,
351 };
352
353 log::info!("[SAFE API] Creating device with NULL features pointer (no features requested)");
355
356 let device_create_info = VkDeviceCreateInfo {
357 sType: VkStructureType::DeviceCreateInfo,
358 pNext: ptr::null(),
359 flags: 0,
360 queueCreateInfoCount: 1,
361 pQueueCreateInfos: &queue_create_info,
362 enabledLayerCount: 0,
363 ppEnabledLayerNames: ptr::null(),
364 enabledExtensionCount: 0,
365 ppEnabledExtensionNames: ptr::null(),
366 pEnabledFeatures: ptr::null(), };
368
369 let mut device = VkDevice::NULL;
370 log::info!("[SAFE API] Calling vkCreateDevice with queue family index {}", queue_family_index);
371 let result = vkCreateDevice(physical_device, &device_create_info, ptr::null(), &mut device);
372 log::info!("[SAFE API] vkCreateDevice returned: {:?}", result);
373
374 if result != VkResult::Success {
375 log::error!("[SAFE API] Failed to create device: {:?}", result);
376 return Err(KronosError::from(result));
377 }
378
379 let mut queue = VkQueue::NULL;
380 vkGetDeviceQueue(device, queue_family_index, 0, &mut queue);
381
382 Ok((device, queue))
383 }
384
385 unsafe fn create_descriptor_pool(device: VkDevice) -> Result<VkDescriptorPool> {
396 log::info!("[SAFE API] Creating descriptor pool with device: {:?}", device);
397 let pool_size = VkDescriptorPoolSize {
399 type_: VkDescriptorType::StorageBuffer,
400 descriptorCount: 10000, };
402
403 let pool_info = VkDescriptorPoolCreateInfo {
404 sType: VkStructureType::DescriptorPoolCreateInfo,
405 pNext: ptr::null(),
406 flags: VkDescriptorPoolCreateFlags::FREE_DESCRIPTOR_SET,
407 maxSets: 1000,
408 poolSizeCount: 1,
409 pPoolSizes: &pool_size,
410 };
411
412 let mut pool = VkDescriptorPool::NULL;
413 log::info!("[SAFE API] Calling vkCreateDescriptorPool");
414 let result = vkCreateDescriptorPool(device, &pool_info, ptr::null(), &mut pool);
415 log::info!("[SAFE API] vkCreateDescriptorPool returned: {:?}", result);
416
417 if result != VkResult::Success {
418 log::error!("[SAFE API] Failed to create descriptor pool: {:?}", result);
419 return Err(KronosError::from(result));
420 }
421
422 Ok(pool)
423 }
424
425 unsafe fn create_command_pool(device: VkDevice, queue_family_index: u32) -> Result<VkCommandPool> {
436 let pool_info = VkCommandPoolCreateInfo {
437 sType: VkStructureType::CommandPoolCreateInfo,
438 pNext: ptr::null(),
439 flags: VkCommandPoolCreateFlags::RESET_COMMAND_BUFFER,
440 queueFamilyIndex: queue_family_index,
441 };
442
443 let mut pool = VkCommandPool::NULL;
444 log::info!("[SAFE API] Calling vkCreateCommandPool with device {:?}, queue family {}", device, queue_family_index);
445 let result = vkCreateCommandPool(device, &pool_info, ptr::null(), &mut pool);
446 log::info!("[SAFE API] vkCreateCommandPool returned: {:?}", result);
447
448 if result != VkResult::Success {
449 log::error!("[SAFE API] Failed to create command pool: {:?}", result);
450 return Err(KronosError::from(result));
451 }
452
453 Ok(pool)
454 }
455
456 pub fn device(&self) -> VkDevice {
458 self.inner.lock().unwrap().device
459 }
460
461 pub fn queue(&self) -> VkQueue {
463 self.inner.lock().unwrap().queue
464 }
465
466 pub fn device_properties(&self) -> VkPhysicalDeviceProperties {
468 self.inner.lock().unwrap().device_properties
469 }
470
471 pub fn icd_info(&self) -> Option<crate::implementation::icd_loader::IcdInfo> {
473 crate::implementation::icd_loader::selected_icd_info()
474 }
475
476 pub(super) fn with_inner<F, R>(&self, f: F) -> R
478 where
479 F: FnOnce(&ContextInner) -> R,
480 {
481 let inner = self.inner.lock().unwrap();
482 f(&*inner)
483 }
484}
485
486impl Drop for ComputeContext {
487 fn drop(&mut self) {
488 if std::sync::Arc::strong_count(&self.inner) != 1 {
490 return;
491 }
492 let inner = self.inner.lock().unwrap();
493 unsafe {
494 if inner.command_pool != VkCommandPool::NULL {
495 vkDestroyCommandPool(inner.device, inner.command_pool, ptr::null());
496 }
497 if inner.device != VkDevice::NULL {
502 vkDestroyDevice(inner.device, ptr::null());
503 }
504 if inner.instance != VkInstance::NULL {
505 vkDestroyInstance(inner.instance, ptr::null());
506 }
507 }
508 }
509}