kronos_compute/api/
context.rs

1//! Main entry point for Kronos Compute
2
3use super::*;
4use crate::*; // Need all the type definitions
5use crate::implementation::initialize_kronos;
6
7// Explicitly import Vulkan functions from implementation when available
8#[cfg(feature = "implementation")]
9use crate::implementation::{
10    vkCreateInstance, vkDestroyInstance, vkEnumeratePhysicalDevices,
11    vkGetPhysicalDeviceProperties, vkGetPhysicalDeviceMemoryProperties,
12    vkGetPhysicalDeviceQueueFamilyProperties,
13    vkCreateDevice, vkDestroyDevice, vkGetDeviceQueue,
14    vkCreateDescriptorPool, vkDestroyDescriptorPool,
15    vkCreateCommandPool, vkDestroyCommandPool,
16};
17use std::ffi::CString;
18use std::ptr;
19use std::sync::{Arc, Mutex};
20
21/// Internal state for ComputeContext
22pub(super) struct ContextInner {
23    pub(super) instance: VkInstance,
24    pub(super) physical_device: VkPhysicalDevice,
25    pub(super) device: VkDevice,
26    pub(super) queue: VkQueue,
27    pub(super) queue_family_index: u32,
28    
29    // Optimization managers
30    pub(super) descriptor_pool: VkDescriptorPool,
31    pub(super) command_pool: VkCommandPool,
32    
33    // Device properties
34    pub(super) device_properties: VkPhysicalDeviceProperties,
35    pub(super) memory_properties: VkPhysicalDeviceMemoryProperties,
36}
37
38/// Main context for compute operations
39/// 
40/// This is the primary entry point for the Kronos Compute API.
41/// It manages the Vulkan instance, device, and queue, and provides
42/// methods to create buffers, pipelines, and execute commands.
43#[derive(Clone)]
44pub struct ComputeContext {
45    pub(super) inner: Arc<Mutex<ContextInner>>,
46}
47
48// Send + Sync for thread safety
49unsafe impl Send for ComputeContext {}
50unsafe impl Sync for ComputeContext {}
51
52impl ComputeContext {
53    pub(super) fn new_with_config(config: ContextConfig) -> Result<Self> {
54        log::info!("[SAFE API] ComputeContext::new_with_config() called");
55        unsafe {
56            // Apply preferred ICD selection BEFORE initialization
57            // This is crucial - preferences must be set before initialize_kronos()
58            log::info!("[SAFE API] Applying ICD preferences");
59            if let Some(ref p) = config.preferred_icd_path {
60                log::info!("[SAFE API] Setting preferred ICD path: {:?}", p);
61                crate::implementation::icd_loader::set_preferred_icd_path(p.clone());
62            } else if let Some(i) = config.preferred_icd_index {
63                log::info!("[SAFE API] Setting preferred ICD index: {}", i);
64                crate::implementation::icd_loader::set_preferred_icd_index(i);
65            }
66
67            // Initialize Kronos ICD loader
68            log::info!("[SAFE API] Initializing Kronos ICD loader");
69            log::info!("[SAFE API] KRONOS_AGGREGATE_ICD = {:?}", std::env::var("KRONOS_AGGREGATE_ICD").ok());
70            initialize_kronos()
71                .map_err(|e| {
72                    log::error!("[SAFE API] Failed to initialize Kronos: {:?}", e);
73                    KronosError::InitializationFailed(e.to_string())
74                })?;
75            log::info!("[SAFE API] Kronos initialized successfully");
76            
77            // Create instance
78            log::info!("[SAFE API] Creating Vulkan instance");
79            let instance = Self::create_instance(&config)?;
80            log::info!("[SAFE API] Instance created: {:?}", instance);
81            
82            // Find compute-capable device
83            log::info!("[SAFE API] Finding compute-capable device");
84            let (physical_device, queue_family_index) = Self::find_compute_device(instance)?;
85            log::info!("[SAFE API] Found device: {:?}, queue family: {}", physical_device, queue_family_index);
86            
87            log::info!("[SAFE API] find_compute_device returned successfully");
88            
89            // Get device properties
90            log::info!("[SAFE API] Getting device properties");
91            let mut device_properties = VkPhysicalDeviceProperties::default();
92            vkGetPhysicalDeviceProperties(physical_device, &mut device_properties);
93            log::info!("[SAFE API] Got device properties successfully");
94            
95            let mut memory_properties = VkPhysicalDeviceMemoryProperties::default();
96            log::info!("[SAFE API] Getting memory properties");
97            vkGetPhysicalDeviceMemoryProperties(physical_device, &mut memory_properties);
98            log::info!("[SAFE API] Got memory properties successfully");
99            
100            // Log selected device info
101            // deviceName is a fixed-size array, ensure it's null-terminated
102            let device_name_bytes = &device_properties.deviceName;
103            let null_pos = device_name_bytes.iter().position(|&c| c == 0).unwrap_or(device_name_bytes.len());
104            // Convert from &[i8] to &[u8] for from_utf8
105            let device_name_u8: Vec<u8> = device_name_bytes[..null_pos]
106                .iter()
107                .map(|&c| c as u8)
108                .collect();
109            let device_name = std::str::from_utf8(&device_name_u8)
110                .unwrap_or("Unknown Device");
111            let device_type_str = match device_properties.deviceType {
112                VkPhysicalDeviceType::DiscreteGpu => "Discrete GPU",
113                VkPhysicalDeviceType::IntegratedGpu => "Integrated GPU",
114                VkPhysicalDeviceType::VirtualGpu => "Virtual GPU",
115                VkPhysicalDeviceType::Cpu => "CPU (Software Renderer)",
116                _ => "Unknown",
117            };
118            log::info!("Selected Vulkan device: {} ({})", device_name, device_type_str);
119            
120            // Create logical device
121            log::info!("[SAFE API] Creating logical device");
122            let (device, queue) = Self::create_device(physical_device, queue_family_index)?;
123            log::info!("[SAFE API] Device created: {:?}, queue: {:?}", device, queue);
124            
125            // Create descriptor pool for persistent descriptors
126            // TODO: Fix descriptor pool creation - temporarily skip
127            log::info!("[SAFE API] Skipping descriptor pool creation temporarily");
128            let descriptor_pool = VkDescriptorPool::NULL;
129            // let descriptor_pool = Self::create_descriptor_pool(device)?;
130            // log::info!("[SAFE API] Descriptor pool created: {:?}", descriptor_pool);
131            
132            // Create command pool
133            log::info!("[SAFE API] Creating command pool");
134            let command_pool = Self::create_command_pool(device, queue_family_index)?;
135            log::info!("[SAFE API] Command pool created: {:?}", command_pool);
136            
137            let inner = ContextInner {
138                instance,
139                physical_device,
140                device,
141                queue,
142                queue_family_index,
143                descriptor_pool,
144                command_pool,
145                device_properties,
146                memory_properties,
147            };
148            
149            // Log selected ICD info
150            if let Some(info) = crate::implementation::icd_loader::selected_icd_info() {
151                log::info!(
152                    "ComputeContext bound to ICD: {} ({}), api=0x{:x}",
153                    info.library_path.display(),
154                    if info.is_software { "software" } else { "hardware" },
155                    info.api_version
156                );
157            }
158
159            let result = Self {
160                inner: Arc::new(Mutex::new(inner)),
161            };
162            log::info!("[SAFE API] ComputeContext created successfully");
163            Ok(result)
164        }
165    }
166    
167    /// Create a Vulkan instance
168    ///
169    /// # Safety
170    ///
171    /// This function is unsafe because:
172    /// - It calls vkCreateInstance which requires the Vulkan loader to be initialized
173    /// - The returned instance must be destroyed with vkDestroyInstance to avoid leaks
174    /// - The config strings must remain valid for the lifetime of the instance creation
175    /// - Null or invalid pointers in the create info will cause undefined behavior
176    unsafe fn create_instance(config: &ContextConfig) -> Result<VkInstance> {
177        log::info!("[SAFE API] create_instance called with app_name: {}", config.app_name);
178        let app_name = CString::new(config.app_name.clone())
179            .unwrap_or_else(|_| CString::new("Kronos App").unwrap());
180        let engine_name = CString::new("Kronos Compute").unwrap();
181        log::info!("[SAFE API] CStrings created successfully");
182        
183        let app_info = VkApplicationInfo {
184            sType: VkStructureType::ApplicationInfo,
185            pNext: ptr::null(),
186            pApplicationName: app_name.as_ptr(),
187            applicationVersion: VK_MAKE_VERSION(1, 0, 0),
188            pEngineName: engine_name.as_ptr(),
189            engineVersion: VK_MAKE_VERSION(1, 0, 0),
190            apiVersion: VK_API_VERSION_1_0,
191        };
192        
193        let create_info = VkInstanceCreateInfo {
194            sType: VkStructureType::InstanceCreateInfo,
195            pNext: ptr::null(),
196            flags: 0,
197            pApplicationInfo: &app_info,
198            enabledLayerCount: 0,
199            ppEnabledLayerNames: ptr::null(),
200            enabledExtensionCount: 0,
201            ppEnabledExtensionNames: ptr::null(),
202        };
203        
204        let mut instance = VkInstance::NULL;
205        // IMPORTANT: CStrings must remain alive during vkCreateInstance call
206        // They are dropped at the end of this function, which is safe
207        log::info!("[SAFE API] Calling vkCreateInstance");
208        let result = vkCreateInstance(&create_info, ptr::null(), &mut instance);
209        log::info!("[SAFE API] vkCreateInstance returned: {:?}", result);
210        
211        if result != VkResult::Success {
212            log::error!("[SAFE API] vkCreateInstance failed with: {:?}", result);
213            return Err(KronosError::from(result));
214        }
215        
216        log::info!("[SAFE API] Instance created successfully: {:?}", instance);
217        Ok(instance)
218    }
219    
220    /// Find a physical device with compute capabilities
221    ///
222    /// # Safety
223    ///
224    /// This function is unsafe because:
225    /// - The instance must be a valid VkInstance handle
226    /// - Calls vkEnumeratePhysicalDevices which may fail with invalid instance
227    /// - The returned physical device is tied to the instance lifetime
228    /// - Accessing the device after instance destruction is undefined behavior
229    unsafe fn find_compute_device(instance: VkInstance) -> Result<(VkPhysicalDevice, u32)> {
230        let mut device_count = 0;
231        log::info!("[SAFE API] Enumerating physical devices...");
232        
233        // First call to get count
234        let result = vkEnumeratePhysicalDevices(instance, &mut device_count, ptr::null_mut());
235        if result != VkResult::Success {
236            log::error!("[SAFE API] Failed to get device count: {:?}", result);
237            return Err(KronosError::from(result));
238        }
239        log::info!("[SAFE API] Found {} physical devices", device_count);
240        
241        if device_count == 0 {
242            return Err(KronosError::DeviceNotFound);
243        }
244        
245        let mut devices = vec![VkPhysicalDevice::NULL; device_count as usize];
246        let result = vkEnumeratePhysicalDevices(instance, &mut device_count, devices.as_mut_ptr());
247        if result != VkResult::Success {
248            log::error!("[SAFE API] Failed to enumerate devices: {:?}", result);
249            return Err(KronosError::from(result));
250        }
251        log::info!("[SAFE API] Successfully enumerated {} devices", device_count);
252        
253        // Collect all devices with compute support and their properties
254        let mut candidates = Vec::new();
255        
256        for (dev_idx, device) in devices.iter().enumerate() {
257            log::info!("[SAFE API] Checking device {} for compute support", dev_idx);
258            let queue_family = Self::find_compute_queue_family(*device)?;
259            if let Some(index) = queue_family {
260                // Get device properties to determine device type
261                let mut properties = VkPhysicalDeviceProperties::default();
262                vkGetPhysicalDeviceProperties(*device, &mut properties);
263                
264                candidates.push((*device, index, properties.deviceType));
265            }
266        }
267        
268        if candidates.is_empty() {
269            return Err(KronosError::DeviceNotFound);
270        }
271        
272        // Sort by device type preference: DiscreteGpu > IntegratedGpu > VirtualGpu > Cpu
273        candidates.sort_by_key(|(_, _, device_type)| {
274            match *device_type {
275                VkPhysicalDeviceType::DiscreteGpu => 0,
276                VkPhysicalDeviceType::IntegratedGpu => 1,
277                VkPhysicalDeviceType::VirtualGpu => 2,
278                VkPhysicalDeviceType::Cpu => 3,
279                VkPhysicalDeviceType::Other => 4,
280                _ => 5,
281            }
282        });
283        
284        // Return the best device
285        let (device, queue_index, device_type) = candidates[0];
286        log::info!("[SAFE API] Selected device with queue index {}, type {:?}", queue_index, device_type);
287        Ok((device, queue_index))
288    }
289    
290    /// Find a queue family with compute support
291    ///
292    /// # Safety
293    ///
294    /// This function is unsafe because:
295    /// - The device must be a valid VkPhysicalDevice handle
296    /// - Calls vkGetPhysicalDeviceQueueFamilyProperties with the device
297    /// - Invalid device handle will cause undefined behavior
298    /// - The device must remain valid during the function execution
299    unsafe fn find_compute_queue_family(device: VkPhysicalDevice) -> Result<Option<u32>> {
300        log::info!("[SAFE API] Finding queue families for device {:?}", device);
301        let mut queue_family_count = 0;
302        log::info!("[SAFE API] Calling vkGetPhysicalDeviceQueueFamilyProperties (count query)...");
303        vkGetPhysicalDeviceQueueFamilyProperties(device, &mut queue_family_count, ptr::null_mut());
304        log::info!("[SAFE API] Device has {} queue families", queue_family_count);
305        
306        let mut queue_families = vec![
307            VkQueueFamilyProperties {
308                queueFlags: VkQueueFlags::empty(),
309                queueCount: 0,
310                timestampValidBits: 0,
311                minImageTransferGranularity: VkExtent3D { width: 0, height: 0, depth: 0 },
312            };
313            queue_family_count as usize
314        ];
315        log::info!("[SAFE API] Getting queue family properties...");
316        vkGetPhysicalDeviceQueueFamilyProperties(device, &mut queue_family_count, queue_families.as_mut_ptr());
317        log::info!("[SAFE API] Got queue family properties, checking for compute support");
318        
319        for (index, family) in queue_families.iter().enumerate() {
320            log::info!("[SAFE API] Queue family {}: flags={:?}", index, family.queueFlags);
321            if family.queueFlags.contains(VkQueueFlags::COMPUTE) {
322                log::info!("[SAFE API] Found compute queue at index {}", index);
323                return Ok(Some(index as u32));
324            }
325        }
326        
327        log::info!("[SAFE API] No compute queue family found");
328        Ok(None)
329    }
330    
331    /// Create a logical device and get its compute queue
332    ///
333    /// # Safety
334    ///
335    /// This function is unsafe because:
336    /// - The physical_device must be a valid VkPhysicalDevice handle
337    /// - The queue_family_index must be valid for the physical device
338    /// - Calls vkCreateDevice and vkGetDeviceQueue which require valid handles
339    /// - The returned device and queue must be properly destroyed
340    /// - Queue family index out of bounds will cause undefined behavior
341    unsafe fn create_device(physical_device: VkPhysicalDevice, queue_family_index: u32) -> Result<(VkDevice, VkQueue)> {
342        let queue_priority = 1.0f32;
343        
344        let queue_create_info = VkDeviceQueueCreateInfo {
345            sType: VkStructureType::DeviceQueueCreateInfo,
346            pNext: ptr::null(),
347            flags: 0,
348            queueFamilyIndex: queue_family_index,
349            queueCount: 1,
350            pQueuePriorities: &queue_priority,
351        };
352        
353        // Don't request any features - use NULL pointer like the working example
354        log::info!("[SAFE API] Creating device with NULL features pointer (no features requested)");
355        
356        let device_create_info = VkDeviceCreateInfo {
357            sType: VkStructureType::DeviceCreateInfo,
358            pNext: ptr::null(),
359            flags: 0,
360            queueCreateInfoCount: 1,
361            pQueueCreateInfos: &queue_create_info,
362            enabledLayerCount: 0,
363            ppEnabledLayerNames: ptr::null(),
364            enabledExtensionCount: 0,
365            ppEnabledExtensionNames: ptr::null(),
366            pEnabledFeatures: ptr::null(),  // Use NULL like the working example
367        };
368        
369        let mut device = VkDevice::NULL;
370        log::info!("[SAFE API] Calling vkCreateDevice with queue family index {}", queue_family_index);
371        let result = vkCreateDevice(physical_device, &device_create_info, ptr::null(), &mut device);
372        log::info!("[SAFE API] vkCreateDevice returned: {:?}", result);
373        
374        if result != VkResult::Success {
375            log::error!("[SAFE API] Failed to create device: {:?}", result);
376            return Err(KronosError::from(result));
377        }
378        
379        let mut queue = VkQueue::NULL;
380        vkGetDeviceQueue(device, queue_family_index, 0, &mut queue);
381        
382        Ok((device, queue))
383    }
384    
385    /// Create a descriptor pool for persistent descriptors
386    ///
387    /// # Safety
388    ///
389    /// This function is unsafe because:
390    /// - The device must be a valid VkDevice handle
391    /// - Calls vkCreateDescriptorPool which requires valid device
392    /// - The returned pool must be destroyed with vkDestroyDescriptorPool
393    /// - Invalid device handle will cause undefined behavior
394    /// - Pool creation may fail if device limits are exceeded
395    unsafe fn create_descriptor_pool(device: VkDevice) -> Result<VkDescriptorPool> {
396        log::info!("[SAFE API] Creating descriptor pool with device: {:?}", device);
397        // Create a large pool for persistent descriptors
398        let pool_size = VkDescriptorPoolSize {
399            type_: VkDescriptorType::StorageBuffer,
400            descriptorCount: 10000, // Should be enough for most use cases
401        };
402        
403        let pool_info = VkDescriptorPoolCreateInfo {
404            sType: VkStructureType::DescriptorPoolCreateInfo,
405            pNext: ptr::null(),
406            flags: VkDescriptorPoolCreateFlags::FREE_DESCRIPTOR_SET,
407            maxSets: 1000,
408            poolSizeCount: 1,
409            pPoolSizes: &pool_size,
410        };
411        
412        let mut pool = VkDescriptorPool::NULL;
413        log::info!("[SAFE API] Calling vkCreateDescriptorPool");
414        let result = vkCreateDescriptorPool(device, &pool_info, ptr::null(), &mut pool);
415        log::info!("[SAFE API] vkCreateDescriptorPool returned: {:?}", result);
416        
417        if result != VkResult::Success {
418            log::error!("[SAFE API] Failed to create descriptor pool: {:?}", result);
419            return Err(KronosError::from(result));
420        }
421        
422        Ok(pool)
423    }
424    
425    /// Create a command pool for allocating command buffers
426    ///
427    /// # Safety
428    ///
429    /// This function is unsafe because:
430    /// - The device must be a valid VkDevice handle
431    /// - The queue_family_index must be valid for the device
432    /// - Calls vkCreateCommandPool which requires valid parameters
433    /// - The returned pool must be destroyed with vkDestroyCommandPool
434    /// - Invalid queue family index will cause undefined behavior
435    unsafe fn create_command_pool(device: VkDevice, queue_family_index: u32) -> Result<VkCommandPool> {
436        let pool_info = VkCommandPoolCreateInfo {
437            sType: VkStructureType::CommandPoolCreateInfo,
438            pNext: ptr::null(),
439            flags: VkCommandPoolCreateFlags::RESET_COMMAND_BUFFER,
440            queueFamilyIndex: queue_family_index,
441        };
442        
443        let mut pool = VkCommandPool::NULL;
444        log::info!("[SAFE API] Calling vkCreateCommandPool with device {:?}, queue family {}", device, queue_family_index);
445        let result = vkCreateCommandPool(device, &pool_info, ptr::null(), &mut pool);
446        log::info!("[SAFE API] vkCreateCommandPool returned: {:?}", result);
447        
448        if result != VkResult::Success {
449            log::error!("[SAFE API] Failed to create command pool: {:?}", result);
450            return Err(KronosError::from(result));
451        }
452        
453        Ok(pool)
454    }
455    
456    /// Get the underlying Vulkan device (for advanced usage)
457    pub fn device(&self) -> VkDevice {
458        self.inner.lock().unwrap().device
459    }
460    
461    /// Get the compute queue
462    pub fn queue(&self) -> VkQueue {
463        self.inner.lock().unwrap().queue
464    }
465    
466    /// Get device properties
467    pub fn device_properties(&self) -> VkPhysicalDeviceProperties {
468        self.inner.lock().unwrap().device_properties
469    }
470    
471    /// Get information about the ICD bound to this context (process-wide)
472    pub fn icd_info(&self) -> Option<crate::implementation::icd_loader::IcdInfo> {
473        crate::implementation::icd_loader::selected_icd_info()
474    }
475
476    // Internal helper for other modules
477    pub(super) fn with_inner<F, R>(&self, f: F) -> R
478    where
479        F: FnOnce(&ContextInner) -> R,
480    {
481        let inner = self.inner.lock().unwrap();
482        f(&*inner)
483    }
484}
485
486impl Drop for ComputeContext {
487    fn drop(&mut self) {
488        // Only the last Clone should perform destruction to avoid double-free.
489        if std::sync::Arc::strong_count(&self.inner) != 1 {
490            return;
491        }
492        let inner = self.inner.lock().unwrap();
493        unsafe {
494            if inner.command_pool != VkCommandPool::NULL {
495                vkDestroyCommandPool(inner.device, inner.command_pool, ptr::null());
496            }
497            // TODO: Re-enable when descriptor pool creation is fixed
498            // if inner.descriptor_pool != VkDescriptorPool::NULL {
499            //     vkDestroyDescriptorPool(inner.device, inner.descriptor_pool, ptr::null());
500            // }
501            if inner.device != VkDevice::NULL {
502                vkDestroyDevice(inner.device, ptr::null());
503            }
504            if inner.instance != VkInstance::NULL {
505                vkDestroyInstance(inner.instance, ptr::null());
506            }
507        }
508    }
509}