kronos_compute/api/
context.rs

1//! Main entry point for Kronos Compute
2
3use super::*;
4use crate::*; // Need all the type definitions
5use crate::implementation::initialize_kronos;
6
7// Explicitly import Vulkan functions from implementation when available
8#[cfg(feature = "implementation")]
9use crate::implementation::{
10    vkCreateInstance, vkDestroyInstance, vkEnumeratePhysicalDevices,
11    vkGetPhysicalDeviceProperties, vkGetPhysicalDeviceMemoryProperties,
12    vkGetPhysicalDeviceQueueFamilyProperties,
13    vkCreateDevice, vkDestroyDevice, vkGetDeviceQueue,
14};
15use std::ffi::CString;
16use std::ptr;
17use std::sync::{Arc, Mutex};
18
19/// Internal state for ComputeContext
20pub(super) struct ContextInner {
21    pub(super) instance: VkInstance,
22    pub(super) physical_device: VkPhysicalDevice,
23    pub(super) device: VkDevice,
24    pub(super) queue: VkQueue,
25    pub(super) queue_family_index: u32,
26    
27    // Optimization managers
28    pub(super) descriptor_pool: VkDescriptorPool,
29    pub(super) command_pool: VkCommandPool,
30    
31    // Device properties
32    pub(super) device_properties: VkPhysicalDeviceProperties,
33    pub(super) memory_properties: VkPhysicalDeviceMemoryProperties,
34}
35
36/// Main context for compute operations
37/// 
38/// This is the primary entry point for the Kronos Compute API.
39/// It manages the Vulkan instance, device, and queue, and provides
40/// methods to create buffers, pipelines, and execute commands.
41#[derive(Clone)]
42pub struct ComputeContext {
43    pub(super) inner: Arc<Mutex<ContextInner>>,
44}
45
46// Send + Sync for thread safety
47unsafe impl Send for ComputeContext {}
48unsafe impl Sync for ComputeContext {}
49
50impl ComputeContext {
51    pub(super) fn new_with_config(config: ContextConfig) -> Result<Self> {
52        unsafe {
53            // Apply preferred ICD selection (process-wide for now)
54            if let Some(ref p) = config.preferred_icd_path {
55                crate::implementation::icd_loader::set_preferred_icd_path(p.clone());
56            } else if let Some(i) = config.preferred_icd_index {
57                crate::implementation::icd_loader::set_preferred_icd_index(i);
58            }
59
60            // Initialize Kronos ICD loader
61            initialize_kronos()
62                .map_err(|e| KronosError::InitializationFailed(e.to_string()))?;
63            
64            // Create instance
65            let instance = Self::create_instance(&config)?;
66            
67            // Find compute-capable device
68            let (physical_device, queue_family_index) = Self::find_compute_device(instance)?;
69            
70            // Get device properties
71            let mut device_properties = VkPhysicalDeviceProperties::default();
72            vkGetPhysicalDeviceProperties(physical_device, &mut device_properties);
73            
74            let mut memory_properties = VkPhysicalDeviceMemoryProperties::default();
75            vkGetPhysicalDeviceMemoryProperties(physical_device, &mut memory_properties);
76            
77            // Log selected device info
78            let device_name = std::ffi::CStr::from_ptr(device_properties.deviceName.as_ptr())
79                .to_string_lossy();
80            let device_type_str = match device_properties.deviceType {
81                VkPhysicalDeviceType::DiscreteGpu => "Discrete GPU",
82                VkPhysicalDeviceType::IntegratedGpu => "Integrated GPU",
83                VkPhysicalDeviceType::VirtualGpu => "Virtual GPU",
84                VkPhysicalDeviceType::Cpu => "CPU (Software Renderer)",
85                _ => "Unknown",
86            };
87            log::info!("Selected Vulkan device: {} ({})", device_name, device_type_str);
88            
89            // Create logical device
90            let (device, queue) = Self::create_device(physical_device, queue_family_index)?;
91            
92            // Create descriptor pool for persistent descriptors
93            let descriptor_pool = Self::create_descriptor_pool(device)?;
94            
95            // Create command pool
96            let command_pool = Self::create_command_pool(device, queue_family_index)?;
97            
98            let inner = ContextInner {
99                instance,
100                physical_device,
101                device,
102                queue,
103                queue_family_index,
104                descriptor_pool,
105                command_pool,
106                device_properties,
107                memory_properties,
108            };
109            
110            // Log selected ICD info
111            if let Some(info) = crate::implementation::icd_loader::selected_icd_info() {
112                log::info!(
113                    "ComputeContext bound to ICD: {} ({}), api=0x{:x}",
114                    info.library_path.display(),
115                    if info.is_software { "software" } else { "hardware" },
116                    info.api_version
117                );
118            }
119
120            Ok(Self {
121                inner: Arc::new(Mutex::new(inner)),
122            })
123        }
124    }
125    
126    /// Create a Vulkan instance
127    ///
128    /// # Safety
129    ///
130    /// This function is unsafe because:
131    /// - It calls vkCreateInstance which requires the Vulkan loader to be initialized
132    /// - The returned instance must be destroyed with vkDestroyInstance to avoid leaks
133    /// - The config strings must remain valid for the lifetime of the instance creation
134    /// - Null or invalid pointers in the create info will cause undefined behavior
135    unsafe fn create_instance(config: &ContextConfig) -> Result<VkInstance> {
136        let app_name = CString::new(config.app_name.clone())
137            .unwrap_or_else(|_| CString::new("Kronos App").unwrap());
138        let engine_name = CString::new("Kronos Compute").unwrap();
139        
140        let app_info = VkApplicationInfo {
141            sType: VkStructureType::ApplicationInfo,
142            pNext: ptr::null(),
143            pApplicationName: app_name.as_ptr(),
144            applicationVersion: VK_MAKE_VERSION(1, 0, 0),
145            pEngineName: engine_name.as_ptr(),
146            engineVersion: VK_MAKE_VERSION(1, 0, 0),
147            apiVersion: VK_API_VERSION_1_0,
148        };
149        
150        let create_info = VkInstanceCreateInfo {
151            sType: VkStructureType::InstanceCreateInfo,
152            pNext: ptr::null(),
153            flags: 0,
154            pApplicationInfo: &app_info,
155            enabledLayerCount: 0,
156            ppEnabledLayerNames: ptr::null(),
157            enabledExtensionCount: 0,
158            ppEnabledExtensionNames: ptr::null(),
159        };
160        
161        let mut instance = VkInstance::NULL;
162        let result = vkCreateInstance(&create_info, ptr::null(), &mut instance);
163        
164        if result != VkResult::Success {
165            return Err(KronosError::from(result));
166        }
167        
168        Ok(instance)
169    }
170    
171    /// Find a physical device with compute capabilities
172    ///
173    /// # Safety
174    ///
175    /// This function is unsafe because:
176    /// - The instance must be a valid VkInstance handle
177    /// - Calls vkEnumeratePhysicalDevices which may fail with invalid instance
178    /// - The returned physical device is tied to the instance lifetime
179    /// - Accessing the device after instance destruction is undefined behavior
180    unsafe fn find_compute_device(instance: VkInstance) -> Result<(VkPhysicalDevice, u32)> {
181        let mut device_count = 0;
182        vkEnumeratePhysicalDevices(instance, &mut device_count, ptr::null_mut());
183        
184        if device_count == 0 {
185            return Err(KronosError::DeviceNotFound);
186        }
187        
188        let mut devices = vec![VkPhysicalDevice::NULL; device_count as usize];
189        vkEnumeratePhysicalDevices(instance, &mut device_count, devices.as_mut_ptr());
190        
191        // Collect all devices with compute support and their properties
192        let mut candidates = Vec::new();
193        
194        for device in devices {
195            let queue_family = Self::find_compute_queue_family(device)?;
196            if let Some(index) = queue_family {
197                // Get device properties to determine device type
198                let mut properties = VkPhysicalDeviceProperties::default();
199                vkGetPhysicalDeviceProperties(device, &mut properties);
200                
201                candidates.push((device, index, properties.deviceType));
202            }
203        }
204        
205        if candidates.is_empty() {
206            return Err(KronosError::DeviceNotFound);
207        }
208        
209        // Sort by device type preference: DiscreteGpu > IntegratedGpu > VirtualGpu > Cpu
210        candidates.sort_by_key(|(_, _, device_type)| {
211            match *device_type {
212                VkPhysicalDeviceType::DiscreteGpu => 0,
213                VkPhysicalDeviceType::IntegratedGpu => 1,
214                VkPhysicalDeviceType::VirtualGpu => 2,
215                VkPhysicalDeviceType::Cpu => 3,
216                VkPhysicalDeviceType::Other => 4,
217                _ => 5,
218            }
219        });
220        
221        // Return the best device
222        let (device, queue_index, _) = candidates[0];
223        Ok((device, queue_index))
224    }
225    
226    /// Find a queue family with compute support
227    ///
228    /// # Safety
229    ///
230    /// This function is unsafe because:
231    /// - The device must be a valid VkPhysicalDevice handle
232    /// - Calls vkGetPhysicalDeviceQueueFamilyProperties with the device
233    /// - Invalid device handle will cause undefined behavior
234    /// - The device must remain valid during the function execution
235    unsafe fn find_compute_queue_family(device: VkPhysicalDevice) -> Result<Option<u32>> {
236        let mut queue_family_count = 0;
237        vkGetPhysicalDeviceQueueFamilyProperties(device, &mut queue_family_count, ptr::null_mut());
238        
239        let mut queue_families = vec![VkQueueFamilyProperties {
240            queueFlags: VkQueueFlags::empty(),
241            queueCount: 0,
242            timestampValidBits: 0,
243            minImageTransferGranularity: VkExtent3D { width: 0, height: 0, depth: 0 },
244        }; queue_family_count as usize];
245        vkGetPhysicalDeviceQueueFamilyProperties(device, &mut queue_family_count, queue_families.as_mut_ptr());
246        
247        for (index, family) in queue_families.iter().enumerate() {
248            if family.queueFlags.contains(VkQueueFlags::COMPUTE) {
249                return Ok(Some(index as u32));
250            }
251        }
252        
253        Ok(None)
254    }
255    
256    /// Create a logical device and get its compute queue
257    ///
258    /// # Safety
259    ///
260    /// This function is unsafe because:
261    /// - The physical_device must be a valid VkPhysicalDevice handle
262    /// - The queue_family_index must be valid for the physical device
263    /// - Calls vkCreateDevice and vkGetDeviceQueue which require valid handles
264    /// - The returned device and queue must be properly destroyed
265    /// - Queue family index out of bounds will cause undefined behavior
266    unsafe fn create_device(physical_device: VkPhysicalDevice, queue_family_index: u32) -> Result<(VkDevice, VkQueue)> {
267        let queue_priority = 1.0f32;
268        
269        let queue_create_info = VkDeviceQueueCreateInfo {
270            sType: VkStructureType::DeviceQueueCreateInfo,
271            pNext: ptr::null(),
272            flags: 0,
273            queueFamilyIndex: queue_family_index,
274            queueCount: 1,
275            pQueuePriorities: &queue_priority,
276        };
277        
278        let features = VkPhysicalDeviceFeatures::default();
279        
280        let device_create_info = VkDeviceCreateInfo {
281            sType: VkStructureType::DeviceCreateInfo,
282            pNext: ptr::null(),
283            flags: 0,
284            queueCreateInfoCount: 1,
285            pQueueCreateInfos: &queue_create_info,
286            enabledLayerCount: 0,
287            ppEnabledLayerNames: ptr::null(),
288            enabledExtensionCount: 0,
289            ppEnabledExtensionNames: ptr::null(),
290            pEnabledFeatures: &features,
291        };
292        
293        let mut device = VkDevice::NULL;
294        let result = vkCreateDevice(physical_device, &device_create_info, ptr::null(), &mut device);
295        
296        if result != VkResult::Success {
297            return Err(KronosError::from(result));
298        }
299        
300        let mut queue = VkQueue::NULL;
301        vkGetDeviceQueue(device, queue_family_index, 0, &mut queue);
302        
303        Ok((device, queue))
304    }
305    
306    /// Create a descriptor pool for persistent descriptors
307    ///
308    /// # Safety
309    ///
310    /// This function is unsafe because:
311    /// - The device must be a valid VkDevice handle
312    /// - Calls vkCreateDescriptorPool which requires valid device
313    /// - The returned pool must be destroyed with vkDestroyDescriptorPool
314    /// - Invalid device handle will cause undefined behavior
315    /// - Pool creation may fail if device limits are exceeded
316    unsafe fn create_descriptor_pool(device: VkDevice) -> Result<VkDescriptorPool> {
317        // Create a large pool for persistent descriptors
318        let pool_size = VkDescriptorPoolSize {
319            type_: VkDescriptorType::StorageBuffer,
320            descriptorCount: 10000, // Should be enough for most use cases
321        };
322        
323        let pool_info = VkDescriptorPoolCreateInfo {
324            sType: VkStructureType::DescriptorPoolCreateInfo,
325            pNext: ptr::null(),
326            flags: VkDescriptorPoolCreateFlags::FREE_DESCRIPTOR_SET,
327            maxSets: 1000,
328            poolSizeCount: 1,
329            pPoolSizes: &pool_size,
330        };
331        
332        let mut pool = VkDescriptorPool::NULL;
333        let result = vkCreateDescriptorPool(device, &pool_info, ptr::null(), &mut pool);
334        
335        if result != VkResult::Success {
336            return Err(KronosError::from(result));
337        }
338        
339        Ok(pool)
340    }
341    
342    /// Create a command pool for allocating command buffers
343    ///
344    /// # Safety
345    ///
346    /// This function is unsafe because:
347    /// - The device must be a valid VkDevice handle
348    /// - The queue_family_index must be valid for the device
349    /// - Calls vkCreateCommandPool which requires valid parameters
350    /// - The returned pool must be destroyed with vkDestroyCommandPool
351    /// - Invalid queue family index will cause undefined behavior
352    unsafe fn create_command_pool(device: VkDevice, queue_family_index: u32) -> Result<VkCommandPool> {
353        let pool_info = VkCommandPoolCreateInfo {
354            sType: VkStructureType::CommandPoolCreateInfo,
355            pNext: ptr::null(),
356            flags: VkCommandPoolCreateFlags::RESET_COMMAND_BUFFER,
357            queueFamilyIndex: queue_family_index,
358        };
359        
360        let mut pool = VkCommandPool::NULL;
361        let result = vkCreateCommandPool(device, &pool_info, ptr::null(), &mut pool);
362        
363        if result != VkResult::Success {
364            return Err(KronosError::from(result));
365        }
366        
367        Ok(pool)
368    }
369    
370    /// Get the underlying Vulkan device (for advanced usage)
371    pub fn device(&self) -> VkDevice {
372        self.inner.lock().unwrap().device
373    }
374    
375    /// Get the compute queue
376    pub fn queue(&self) -> VkQueue {
377        self.inner.lock().unwrap().queue
378    }
379    
380    /// Get device properties
381    pub fn device_properties(&self) -> VkPhysicalDeviceProperties {
382        self.inner.lock().unwrap().device_properties
383    }
384    
385    /// Get information about the ICD bound to this context (process-wide)
386    pub fn icd_info(&self) -> Option<crate::implementation::icd_loader::IcdInfo> {
387        crate::implementation::icd_loader::selected_icd_info()
388    }
389
390    // Internal helper for other modules
391    pub(super) fn with_inner<F, R>(&self, f: F) -> R
392    where
393        F: FnOnce(&ContextInner) -> R,
394    {
395        let inner = self.inner.lock().unwrap();
396        f(&*inner)
397    }
398}
399
400impl Drop for ComputeContext {
401    fn drop(&mut self) {
402        // Only the last Clone should perform destruction to avoid double-free.
403        if std::sync::Arc::strong_count(&self.inner) != 1 {
404            return;
405        }
406        let inner = self.inner.lock().unwrap();
407        unsafe {
408            if inner.command_pool != VkCommandPool::NULL {
409                vkDestroyCommandPool(inner.device, inner.command_pool, ptr::null());
410            }
411            if inner.descriptor_pool != VkDescriptorPool::NULL {
412                vkDestroyDescriptorPool(inner.device, inner.descriptor_pool, ptr::null());
413            }
414            if inner.device != VkDevice::NULL {
415                vkDestroyDevice(inner.device, ptr::null());
416            }
417            if inner.instance != VkInstance::NULL {
418                vkDestroyInstance(inner.instance, ptr::null());
419            }
420        }
421    }
422}