kronos_compute/api/
context.rs

1//! Main entry point for Kronos Compute
2
3use super::*;
4use crate::*; // Import all functions from the crate root
5use crate::implementation::initialize_kronos;
6use std::ffi::CString;
7use std::ptr;
8use std::sync::{Arc, Mutex};
9
10/// Internal state for ComputeContext
11pub(super) struct ContextInner {
12    pub(super) instance: VkInstance,
13    pub(super) physical_device: VkPhysicalDevice,
14    pub(super) device: VkDevice,
15    pub(super) queue: VkQueue,
16    pub(super) queue_family_index: u32,
17    
18    // Optimization managers
19    pub(super) descriptor_pool: VkDescriptorPool,
20    pub(super) command_pool: VkCommandPool,
21    
22    // Device properties
23    pub(super) device_properties: VkPhysicalDeviceProperties,
24    pub(super) memory_properties: VkPhysicalDeviceMemoryProperties,
25}
26
27/// Main context for compute operations
28/// 
29/// This is the primary entry point for the Kronos Compute API.
30/// It manages the Vulkan instance, device, and queue, and provides
31/// methods to create buffers, pipelines, and execute commands.
32#[derive(Clone)]
33pub struct ComputeContext {
34    pub(super) inner: Arc<Mutex<ContextInner>>,
35}
36
37// Send + Sync for thread safety
38unsafe impl Send for ComputeContext {}
39unsafe impl Sync for ComputeContext {}
40
41impl ComputeContext {
42    pub(super) fn new_with_config(config: ContextConfig) -> Result<Self> {
43        unsafe {
44            // Initialize Kronos ICD loader
45            initialize_kronos()
46                .map_err(|e| KronosError::InitializationFailed(e.to_string()))?;
47            
48            // Create instance
49            let instance = Self::create_instance(&config)?;
50            
51            // Find compute-capable device
52            let (physical_device, queue_family_index) = Self::find_compute_device(instance)?;
53            
54            // Get device properties
55            let mut device_properties = VkPhysicalDeviceProperties::default();
56            vkGetPhysicalDeviceProperties(physical_device, &mut device_properties);
57            
58            let mut memory_properties = VkPhysicalDeviceMemoryProperties::default();
59            vkGetPhysicalDeviceMemoryProperties(physical_device, &mut memory_properties);
60            
61            // Log selected device info
62            let device_name = std::ffi::CStr::from_ptr(device_properties.deviceName.as_ptr())
63                .to_string_lossy();
64            let device_type_str = match device_properties.deviceType {
65                VkPhysicalDeviceType::DiscreteGpu => "Discrete GPU",
66                VkPhysicalDeviceType::IntegratedGpu => "Integrated GPU",
67                VkPhysicalDeviceType::VirtualGpu => "Virtual GPU",
68                VkPhysicalDeviceType::Cpu => "CPU (Software Renderer)",
69                _ => "Unknown",
70            };
71            log::info!("Selected Vulkan device: {} ({})", device_name, device_type_str);
72            
73            // Create logical device
74            let (device, queue) = Self::create_device(physical_device, queue_family_index)?;
75            
76            // Create descriptor pool for persistent descriptors
77            let descriptor_pool = Self::create_descriptor_pool(device)?;
78            
79            // Create command pool
80            let command_pool = Self::create_command_pool(device, queue_family_index)?;
81            
82            let inner = ContextInner {
83                instance,
84                physical_device,
85                device,
86                queue,
87                queue_family_index,
88                descriptor_pool,
89                command_pool,
90                device_properties,
91                memory_properties,
92            };
93            
94            Ok(Self {
95                inner: Arc::new(Mutex::new(inner)),
96            })
97        }
98    }
99    
100    /// Create a Vulkan instance
101    ///
102    /// # Safety
103    ///
104    /// This function is unsafe because:
105    /// - It calls vkCreateInstance which requires the Vulkan loader to be initialized
106    /// - The returned instance must be destroyed with vkDestroyInstance to avoid leaks
107    /// - The config strings must remain valid for the lifetime of the instance creation
108    /// - Null or invalid pointers in the create info will cause undefined behavior
109    unsafe fn create_instance(config: &ContextConfig) -> Result<VkInstance> {
110        let app_name = CString::new(config.app_name.clone())
111            .unwrap_or_else(|_| CString::new("Kronos App").unwrap());
112        let engine_name = CString::new("Kronos Compute").unwrap();
113        
114        let app_info = VkApplicationInfo {
115            sType: VkStructureType::ApplicationInfo,
116            pNext: ptr::null(),
117            pApplicationName: app_name.as_ptr(),
118            applicationVersion: VK_MAKE_VERSION(1, 0, 0),
119            pEngineName: engine_name.as_ptr(),
120            engineVersion: VK_MAKE_VERSION(1, 0, 0),
121            apiVersion: VK_API_VERSION_1_0,
122        };
123        
124        let create_info = VkInstanceCreateInfo {
125            sType: VkStructureType::InstanceCreateInfo,
126            pNext: ptr::null(),
127            flags: 0,
128            pApplicationInfo: &app_info,
129            enabledLayerCount: 0,
130            ppEnabledLayerNames: ptr::null(),
131            enabledExtensionCount: 0,
132            ppEnabledExtensionNames: ptr::null(),
133        };
134        
135        let mut instance = VkInstance::NULL;
136        let result = vkCreateInstance(&create_info, ptr::null(), &mut instance);
137        
138        if result != VkResult::Success {
139            return Err(KronosError::from(result));
140        }
141        
142        Ok(instance)
143    }
144    
145    /// Find a physical device with compute capabilities
146    ///
147    /// # Safety
148    ///
149    /// This function is unsafe because:
150    /// - The instance must be a valid VkInstance handle
151    /// - Calls vkEnumeratePhysicalDevices which may fail with invalid instance
152    /// - The returned physical device is tied to the instance lifetime
153    /// - Accessing the device after instance destruction is undefined behavior
154    unsafe fn find_compute_device(instance: VkInstance) -> Result<(VkPhysicalDevice, u32)> {
155        let mut device_count = 0;
156        vkEnumeratePhysicalDevices(instance, &mut device_count, ptr::null_mut());
157        
158        if device_count == 0 {
159            return Err(KronosError::DeviceNotFound);
160        }
161        
162        let mut devices = vec![VkPhysicalDevice::NULL; device_count as usize];
163        vkEnumeratePhysicalDevices(instance, &mut device_count, devices.as_mut_ptr());
164        
165        // Collect all devices with compute support and their properties
166        let mut candidates = Vec::new();
167        
168        for device in devices {
169            let queue_family = Self::find_compute_queue_family(device)?;
170            if let Some(index) = queue_family {
171                // Get device properties to determine device type
172                let mut properties = VkPhysicalDeviceProperties::default();
173                vkGetPhysicalDeviceProperties(device, &mut properties);
174                
175                candidates.push((device, index, properties.deviceType));
176            }
177        }
178        
179        if candidates.is_empty() {
180            return Err(KronosError::DeviceNotFound);
181        }
182        
183        // Sort by device type preference: DiscreteGpu > IntegratedGpu > VirtualGpu > Cpu
184        candidates.sort_by_key(|(_, _, device_type)| {
185            match *device_type {
186                VkPhysicalDeviceType::DiscreteGpu => 0,
187                VkPhysicalDeviceType::IntegratedGpu => 1,
188                VkPhysicalDeviceType::VirtualGpu => 2,
189                VkPhysicalDeviceType::Cpu => 3,
190                VkPhysicalDeviceType::Other => 4,
191                _ => 5,
192            }
193        });
194        
195        // Return the best device
196        let (device, queue_index, _) = candidates[0];
197        Ok((device, queue_index))
198    }
199    
200    /// Find a queue family with compute support
201    ///
202    /// # Safety
203    ///
204    /// This function is unsafe because:
205    /// - The device must be a valid VkPhysicalDevice handle
206    /// - Calls vkGetPhysicalDeviceQueueFamilyProperties with the device
207    /// - Invalid device handle will cause undefined behavior
208    /// - The device must remain valid during the function execution
209    unsafe fn find_compute_queue_family(device: VkPhysicalDevice) -> Result<Option<u32>> {
210        let mut queue_family_count = 0;
211        vkGetPhysicalDeviceQueueFamilyProperties(device, &mut queue_family_count, ptr::null_mut());
212        
213        let mut queue_families = vec![VkQueueFamilyProperties {
214            queueFlags: VkQueueFlags::empty(),
215            queueCount: 0,
216            timestampValidBits: 0,
217            minImageTransferGranularity: VkExtent3D { width: 0, height: 0, depth: 0 },
218        }; queue_family_count as usize];
219        vkGetPhysicalDeviceQueueFamilyProperties(device, &mut queue_family_count, queue_families.as_mut_ptr());
220        
221        for (index, family) in queue_families.iter().enumerate() {
222            if family.queueFlags.contains(VkQueueFlags::COMPUTE) {
223                return Ok(Some(index as u32));
224            }
225        }
226        
227        Ok(None)
228    }
229    
230    /// Create a logical device and get its compute queue
231    ///
232    /// # Safety
233    ///
234    /// This function is unsafe because:
235    /// - The physical_device must be a valid VkPhysicalDevice handle
236    /// - The queue_family_index must be valid for the physical device
237    /// - Calls vkCreateDevice and vkGetDeviceQueue which require valid handles
238    /// - The returned device and queue must be properly destroyed
239    /// - Queue family index out of bounds will cause undefined behavior
240    unsafe fn create_device(physical_device: VkPhysicalDevice, queue_family_index: u32) -> Result<(VkDevice, VkQueue)> {
241        let queue_priority = 1.0f32;
242        
243        let queue_create_info = VkDeviceQueueCreateInfo {
244            sType: VkStructureType::DeviceQueueCreateInfo,
245            pNext: ptr::null(),
246            flags: 0,
247            queueFamilyIndex: queue_family_index,
248            queueCount: 1,
249            pQueuePriorities: &queue_priority,
250        };
251        
252        let features = VkPhysicalDeviceFeatures::default();
253        
254        let device_create_info = VkDeviceCreateInfo {
255            sType: VkStructureType::DeviceCreateInfo,
256            pNext: ptr::null(),
257            flags: 0,
258            queueCreateInfoCount: 1,
259            pQueueCreateInfos: &queue_create_info,
260            enabledLayerCount: 0,
261            ppEnabledLayerNames: ptr::null(),
262            enabledExtensionCount: 0,
263            ppEnabledExtensionNames: ptr::null(),
264            pEnabledFeatures: &features,
265        };
266        
267        let mut device = VkDevice::NULL;
268        let result = vkCreateDevice(physical_device, &device_create_info, ptr::null(), &mut device);
269        
270        if result != VkResult::Success {
271            return Err(KronosError::from(result));
272        }
273        
274        let mut queue = VkQueue::NULL;
275        vkGetDeviceQueue(device, queue_family_index, 0, &mut queue);
276        
277        Ok((device, queue))
278    }
279    
280    /// Create a descriptor pool for persistent descriptors
281    ///
282    /// # Safety
283    ///
284    /// This function is unsafe because:
285    /// - The device must be a valid VkDevice handle
286    /// - Calls vkCreateDescriptorPool which requires valid device
287    /// - The returned pool must be destroyed with vkDestroyDescriptorPool
288    /// - Invalid device handle will cause undefined behavior
289    /// - Pool creation may fail if device limits are exceeded
290    unsafe fn create_descriptor_pool(device: VkDevice) -> Result<VkDescriptorPool> {
291        // Create a large pool for persistent descriptors
292        let pool_size = VkDescriptorPoolSize {
293            type_: VkDescriptorType::StorageBuffer,
294            descriptorCount: 10000, // Should be enough for most use cases
295        };
296        
297        let pool_info = VkDescriptorPoolCreateInfo {
298            sType: VkStructureType::DescriptorPoolCreateInfo,
299            pNext: ptr::null(),
300            flags: VkDescriptorPoolCreateFlags::FREE_DESCRIPTOR_SET,
301            maxSets: 1000,
302            poolSizeCount: 1,
303            pPoolSizes: &pool_size,
304        };
305        
306        let mut pool = VkDescriptorPool::NULL;
307        let result = vkCreateDescriptorPool(device, &pool_info, ptr::null(), &mut pool);
308        
309        if result != VkResult::Success {
310            return Err(KronosError::from(result));
311        }
312        
313        Ok(pool)
314    }
315    
316    /// Create a command pool for allocating command buffers
317    ///
318    /// # Safety
319    ///
320    /// This function is unsafe because:
321    /// - The device must be a valid VkDevice handle
322    /// - The queue_family_index must be valid for the device
323    /// - Calls vkCreateCommandPool which requires valid parameters
324    /// - The returned pool must be destroyed with vkDestroyCommandPool
325    /// - Invalid queue family index will cause undefined behavior
326    unsafe fn create_command_pool(device: VkDevice, queue_family_index: u32) -> Result<VkCommandPool> {
327        let pool_info = VkCommandPoolCreateInfo {
328            sType: VkStructureType::CommandPoolCreateInfo,
329            pNext: ptr::null(),
330            flags: VkCommandPoolCreateFlags::RESET_COMMAND_BUFFER,
331            queueFamilyIndex: queue_family_index,
332        };
333        
334        let mut pool = VkCommandPool::NULL;
335        let result = vkCreateCommandPool(device, &pool_info, ptr::null(), &mut pool);
336        
337        if result != VkResult::Success {
338            return Err(KronosError::from(result));
339        }
340        
341        Ok(pool)
342    }
343    
344    /// Get the underlying Vulkan device (for advanced usage)
345    pub fn device(&self) -> VkDevice {
346        self.inner.lock().unwrap().device
347    }
348    
349    /// Get the compute queue
350    pub fn queue(&self) -> VkQueue {
351        self.inner.lock().unwrap().queue
352    }
353    
354    /// Get device properties
355    pub fn device_properties(&self) -> VkPhysicalDeviceProperties {
356        self.inner.lock().unwrap().device_properties
357    }
358    
359    // Internal helper for other modules
360    pub(super) fn with_inner<F, R>(&self, f: F) -> R
361    where
362        F: FnOnce(&ContextInner) -> R,
363    {
364        let inner = self.inner.lock().unwrap();
365        f(&*inner)
366    }
367}
368
369impl Drop for ComputeContext {
370    fn drop(&mut self) {
371        let inner = self.inner.lock().unwrap();
372        unsafe {
373            if inner.command_pool != VkCommandPool::NULL {
374                vkDestroyCommandPool(inner.device, inner.command_pool, ptr::null());
375            }
376            if inner.descriptor_pool != VkDescriptorPool::NULL {
377                vkDestroyDescriptorPool(inner.device, inner.descriptor_pool, ptr::null());
378            }
379            if inner.device != VkDevice::NULL {
380                vkDestroyDevice(inner.device, ptr::null());
381            }
382            if inner.instance != VkInstance::NULL {
383                vkDestroyInstance(inner.instance, ptr::null());
384            }
385        }
386    }
387}