kronos_compute/api/
context.rs

1//! Main entry point for Kronos Compute
2
3use super::*;
4use crate::*; // Import all functions from the crate root
5use crate::implementation::initialize_kronos;
6use std::ffi::CString;
7use std::ptr;
8use std::sync::{Arc, Mutex};
9
10/// Internal state for ComputeContext
11pub(super) struct ContextInner {
12    pub(super) instance: VkInstance,
13    pub(super) physical_device: VkPhysicalDevice,
14    pub(super) device: VkDevice,
15    pub(super) queue: VkQueue,
16    pub(super) queue_family_index: u32,
17    
18    // Optimization managers
19    pub(super) descriptor_pool: VkDescriptorPool,
20    pub(super) command_pool: VkCommandPool,
21    
22    // Device properties
23    pub(super) device_properties: VkPhysicalDeviceProperties,
24    pub(super) memory_properties: VkPhysicalDeviceMemoryProperties,
25}
26
27/// Main context for compute operations
28/// 
29/// This is the primary entry point for the Kronos Compute API.
30/// It manages the Vulkan instance, device, and queue, and provides
31/// methods to create buffers, pipelines, and execute commands.
32#[derive(Clone)]
33pub struct ComputeContext {
34    pub(super) inner: Arc<Mutex<ContextInner>>,
35}
36
37// Send + Sync for thread safety
38unsafe impl Send for ComputeContext {}
39unsafe impl Sync for ComputeContext {}
40
41impl ComputeContext {
42    pub(super) fn new_with_config(config: ContextConfig) -> Result<Self> {
43        unsafe {
44            // Apply preferred ICD selection (process-wide for now)
45            if let Some(ref p) = config.preferred_icd_path {
46                crate::implementation::icd_loader::set_preferred_icd_path(p.clone());
47            } else if let Some(i) = config.preferred_icd_index {
48                crate::implementation::icd_loader::set_preferred_icd_index(i);
49            }
50
51            // Initialize Kronos ICD loader
52            initialize_kronos()
53                .map_err(|e| KronosError::InitializationFailed(e.to_string()))?;
54            
55            // Create instance
56            let instance = Self::create_instance(&config)?;
57            
58            // Find compute-capable device
59            let (physical_device, queue_family_index) = Self::find_compute_device(instance)?;
60            
61            // Get device properties
62            let mut device_properties = VkPhysicalDeviceProperties::default();
63            vkGetPhysicalDeviceProperties(physical_device, &mut device_properties);
64            
65            let mut memory_properties = VkPhysicalDeviceMemoryProperties::default();
66            vkGetPhysicalDeviceMemoryProperties(physical_device, &mut memory_properties);
67            
68            // Log selected device info
69            let device_name = std::ffi::CStr::from_ptr(device_properties.deviceName.as_ptr())
70                .to_string_lossy();
71            let device_type_str = match device_properties.deviceType {
72                VkPhysicalDeviceType::DiscreteGpu => "Discrete GPU",
73                VkPhysicalDeviceType::IntegratedGpu => "Integrated GPU",
74                VkPhysicalDeviceType::VirtualGpu => "Virtual GPU",
75                VkPhysicalDeviceType::Cpu => "CPU (Software Renderer)",
76                _ => "Unknown",
77            };
78            log::info!("Selected Vulkan device: {} ({})", device_name, device_type_str);
79            
80            // Create logical device
81            let (device, queue) = Self::create_device(physical_device, queue_family_index)?;
82            
83            // Create descriptor pool for persistent descriptors
84            let descriptor_pool = Self::create_descriptor_pool(device)?;
85            
86            // Create command pool
87            let command_pool = Self::create_command_pool(device, queue_family_index)?;
88            
89            let inner = ContextInner {
90                instance,
91                physical_device,
92                device,
93                queue,
94                queue_family_index,
95                descriptor_pool,
96                command_pool,
97                device_properties,
98                memory_properties,
99            };
100            
101            // Log selected ICD info
102            if let Some(info) = crate::implementation::icd_loader::selected_icd_info() {
103                log::info!(
104                    "ComputeContext bound to ICD: {} ({}), api=0x{:x}",
105                    info.library_path.display(),
106                    if info.is_software { "software" } else { "hardware" },
107                    info.api_version
108                );
109            }
110
111            Ok(Self {
112                inner: Arc::new(Mutex::new(inner)),
113            })
114        }
115    }
116    
117    /// Create a Vulkan instance
118    ///
119    /// # Safety
120    ///
121    /// This function is unsafe because:
122    /// - It calls vkCreateInstance which requires the Vulkan loader to be initialized
123    /// - The returned instance must be destroyed with vkDestroyInstance to avoid leaks
124    /// - The config strings must remain valid for the lifetime of the instance creation
125    /// - Null or invalid pointers in the create info will cause undefined behavior
126    unsafe fn create_instance(config: &ContextConfig) -> Result<VkInstance> {
127        let app_name = CString::new(config.app_name.clone())
128            .unwrap_or_else(|_| CString::new("Kronos App").unwrap());
129        let engine_name = CString::new("Kronos Compute").unwrap();
130        
131        let app_info = VkApplicationInfo {
132            sType: VkStructureType::ApplicationInfo,
133            pNext: ptr::null(),
134            pApplicationName: app_name.as_ptr(),
135            applicationVersion: VK_MAKE_VERSION(1, 0, 0),
136            pEngineName: engine_name.as_ptr(),
137            engineVersion: VK_MAKE_VERSION(1, 0, 0),
138            apiVersion: VK_API_VERSION_1_0,
139        };
140        
141        let create_info = VkInstanceCreateInfo {
142            sType: VkStructureType::InstanceCreateInfo,
143            pNext: ptr::null(),
144            flags: 0,
145            pApplicationInfo: &app_info,
146            enabledLayerCount: 0,
147            ppEnabledLayerNames: ptr::null(),
148            enabledExtensionCount: 0,
149            ppEnabledExtensionNames: ptr::null(),
150        };
151        
152        let mut instance = VkInstance::NULL;
153        let result = vkCreateInstance(&create_info, ptr::null(), &mut instance);
154        
155        if result != VkResult::Success {
156            return Err(KronosError::from(result));
157        }
158        
159        Ok(instance)
160    }
161    
162    /// Find a physical device with compute capabilities
163    ///
164    /// # Safety
165    ///
166    /// This function is unsafe because:
167    /// - The instance must be a valid VkInstance handle
168    /// - Calls vkEnumeratePhysicalDevices which may fail with invalid instance
169    /// - The returned physical device is tied to the instance lifetime
170    /// - Accessing the device after instance destruction is undefined behavior
171    unsafe fn find_compute_device(instance: VkInstance) -> Result<(VkPhysicalDevice, u32)> {
172        let mut device_count = 0;
173        vkEnumeratePhysicalDevices(instance, &mut device_count, ptr::null_mut());
174        
175        if device_count == 0 {
176            return Err(KronosError::DeviceNotFound);
177        }
178        
179        let mut devices = vec![VkPhysicalDevice::NULL; device_count as usize];
180        vkEnumeratePhysicalDevices(instance, &mut device_count, devices.as_mut_ptr());
181        
182        // Collect all devices with compute support and their properties
183        let mut candidates = Vec::new();
184        
185        for device in devices {
186            let queue_family = Self::find_compute_queue_family(device)?;
187            if let Some(index) = queue_family {
188                // Get device properties to determine device type
189                let mut properties = VkPhysicalDeviceProperties::default();
190                vkGetPhysicalDeviceProperties(device, &mut properties);
191                
192                candidates.push((device, index, properties.deviceType));
193            }
194        }
195        
196        if candidates.is_empty() {
197            return Err(KronosError::DeviceNotFound);
198        }
199        
200        // Sort by device type preference: DiscreteGpu > IntegratedGpu > VirtualGpu > Cpu
201        candidates.sort_by_key(|(_, _, device_type)| {
202            match *device_type {
203                VkPhysicalDeviceType::DiscreteGpu => 0,
204                VkPhysicalDeviceType::IntegratedGpu => 1,
205                VkPhysicalDeviceType::VirtualGpu => 2,
206                VkPhysicalDeviceType::Cpu => 3,
207                VkPhysicalDeviceType::Other => 4,
208                _ => 5,
209            }
210        });
211        
212        // Return the best device
213        let (device, queue_index, _) = candidates[0];
214        Ok((device, queue_index))
215    }
216    
217    /// Find a queue family with compute support
218    ///
219    /// # Safety
220    ///
221    /// This function is unsafe because:
222    /// - The device must be a valid VkPhysicalDevice handle
223    /// - Calls vkGetPhysicalDeviceQueueFamilyProperties with the device
224    /// - Invalid device handle will cause undefined behavior
225    /// - The device must remain valid during the function execution
226    unsafe fn find_compute_queue_family(device: VkPhysicalDevice) -> Result<Option<u32>> {
227        let mut queue_family_count = 0;
228        vkGetPhysicalDeviceQueueFamilyProperties(device, &mut queue_family_count, ptr::null_mut());
229        
230        let mut queue_families = vec![VkQueueFamilyProperties {
231            queueFlags: VkQueueFlags::empty(),
232            queueCount: 0,
233            timestampValidBits: 0,
234            minImageTransferGranularity: VkExtent3D { width: 0, height: 0, depth: 0 },
235        }; queue_family_count as usize];
236        vkGetPhysicalDeviceQueueFamilyProperties(device, &mut queue_family_count, queue_families.as_mut_ptr());
237        
238        for (index, family) in queue_families.iter().enumerate() {
239            if family.queueFlags.contains(VkQueueFlags::COMPUTE) {
240                return Ok(Some(index as u32));
241            }
242        }
243        
244        Ok(None)
245    }
246    
247    /// Create a logical device and get its compute queue
248    ///
249    /// # Safety
250    ///
251    /// This function is unsafe because:
252    /// - The physical_device must be a valid VkPhysicalDevice handle
253    /// - The queue_family_index must be valid for the physical device
254    /// - Calls vkCreateDevice and vkGetDeviceQueue which require valid handles
255    /// - The returned device and queue must be properly destroyed
256    /// - Queue family index out of bounds will cause undefined behavior
257    unsafe fn create_device(physical_device: VkPhysicalDevice, queue_family_index: u32) -> Result<(VkDevice, VkQueue)> {
258        let queue_priority = 1.0f32;
259        
260        let queue_create_info = VkDeviceQueueCreateInfo {
261            sType: VkStructureType::DeviceQueueCreateInfo,
262            pNext: ptr::null(),
263            flags: 0,
264            queueFamilyIndex: queue_family_index,
265            queueCount: 1,
266            pQueuePriorities: &queue_priority,
267        };
268        
269        let features = VkPhysicalDeviceFeatures::default();
270        
271        let device_create_info = VkDeviceCreateInfo {
272            sType: VkStructureType::DeviceCreateInfo,
273            pNext: ptr::null(),
274            flags: 0,
275            queueCreateInfoCount: 1,
276            pQueueCreateInfos: &queue_create_info,
277            enabledLayerCount: 0,
278            ppEnabledLayerNames: ptr::null(),
279            enabledExtensionCount: 0,
280            ppEnabledExtensionNames: ptr::null(),
281            pEnabledFeatures: &features,
282        };
283        
284        let mut device = VkDevice::NULL;
285        let result = vkCreateDevice(physical_device, &device_create_info, ptr::null(), &mut device);
286        
287        if result != VkResult::Success {
288            return Err(KronosError::from(result));
289        }
290        
291        let mut queue = VkQueue::NULL;
292        vkGetDeviceQueue(device, queue_family_index, 0, &mut queue);
293        
294        Ok((device, queue))
295    }
296    
297    /// Create a descriptor pool for persistent descriptors
298    ///
299    /// # Safety
300    ///
301    /// This function is unsafe because:
302    /// - The device must be a valid VkDevice handle
303    /// - Calls vkCreateDescriptorPool which requires valid device
304    /// - The returned pool must be destroyed with vkDestroyDescriptorPool
305    /// - Invalid device handle will cause undefined behavior
306    /// - Pool creation may fail if device limits are exceeded
307    unsafe fn create_descriptor_pool(device: VkDevice) -> Result<VkDescriptorPool> {
308        // Create a large pool for persistent descriptors
309        let pool_size = VkDescriptorPoolSize {
310            type_: VkDescriptorType::StorageBuffer,
311            descriptorCount: 10000, // Should be enough for most use cases
312        };
313        
314        let pool_info = VkDescriptorPoolCreateInfo {
315            sType: VkStructureType::DescriptorPoolCreateInfo,
316            pNext: ptr::null(),
317            flags: VkDescriptorPoolCreateFlags::FREE_DESCRIPTOR_SET,
318            maxSets: 1000,
319            poolSizeCount: 1,
320            pPoolSizes: &pool_size,
321        };
322        
323        let mut pool = VkDescriptorPool::NULL;
324        let result = vkCreateDescriptorPool(device, &pool_info, ptr::null(), &mut pool);
325        
326        if result != VkResult::Success {
327            return Err(KronosError::from(result));
328        }
329        
330        Ok(pool)
331    }
332    
333    /// Create a command pool for allocating command buffers
334    ///
335    /// # Safety
336    ///
337    /// This function is unsafe because:
338    /// - The device must be a valid VkDevice handle
339    /// - The queue_family_index must be valid for the device
340    /// - Calls vkCreateCommandPool which requires valid parameters
341    /// - The returned pool must be destroyed with vkDestroyCommandPool
342    /// - Invalid queue family index will cause undefined behavior
343    unsafe fn create_command_pool(device: VkDevice, queue_family_index: u32) -> Result<VkCommandPool> {
344        let pool_info = VkCommandPoolCreateInfo {
345            sType: VkStructureType::CommandPoolCreateInfo,
346            pNext: ptr::null(),
347            flags: VkCommandPoolCreateFlags::RESET_COMMAND_BUFFER,
348            queueFamilyIndex: queue_family_index,
349        };
350        
351        let mut pool = VkCommandPool::NULL;
352        let result = vkCreateCommandPool(device, &pool_info, ptr::null(), &mut pool);
353        
354        if result != VkResult::Success {
355            return Err(KronosError::from(result));
356        }
357        
358        Ok(pool)
359    }
360    
361    /// Get the underlying Vulkan device (for advanced usage)
362    pub fn device(&self) -> VkDevice {
363        self.inner.lock().unwrap().device
364    }
365    
366    /// Get the compute queue
367    pub fn queue(&self) -> VkQueue {
368        self.inner.lock().unwrap().queue
369    }
370    
371    /// Get device properties
372    pub fn device_properties(&self) -> VkPhysicalDeviceProperties {
373        self.inner.lock().unwrap().device_properties
374    }
375    
376    /// Get information about the ICD bound to this context (process-wide)
377    pub fn icd_info(&self) -> Option<crate::implementation::icd_loader::IcdInfo> {
378        crate::implementation::icd_loader::selected_icd_info()
379    }
380
381    // Internal helper for other modules
382    pub(super) fn with_inner<F, R>(&self, f: F) -> R
383    where
384        F: FnOnce(&ContextInner) -> R,
385    {
386        let inner = self.inner.lock().unwrap();
387        f(&*inner)
388    }
389}
390
391impl Drop for ComputeContext {
392    fn drop(&mut self) {
393        // Only the last Clone should perform destruction to avoid double-free.
394        if std::sync::Arc::strong_count(&self.inner) != 1 {
395            return;
396        }
397        let inner = self.inner.lock().unwrap();
398        unsafe {
399            if inner.command_pool != VkCommandPool::NULL {
400                vkDestroyCommandPool(inner.device, inner.command_pool, ptr::null());
401            }
402            if inner.descriptor_pool != VkDescriptorPool::NULL {
403                vkDestroyDescriptorPool(inner.device, inner.descriptor_pool, ptr::null());
404            }
405            if inner.device != VkDevice::NULL {
406                vkDestroyDevice(inner.device, ptr::null());
407            }
408            if inner.instance != VkInstance::NULL {
409                vkDestroyInstance(inner.instance, ptr::null());
410            }
411        }
412    }
413}