kronos_compute/api/
context.rs

1//! Main entry point for Kronos Compute
2
3use super::*;
4use crate::*; // Import all functions from the crate root
5use crate::implementation::initialize_kronos;
6use std::ffi::CString;
7use std::ptr;
8use std::sync::{Arc, Mutex};
9
10/// Internal state for ComputeContext
11pub(super) struct ContextInner {
12    pub(super) instance: VkInstance,
13    pub(super) physical_device: VkPhysicalDevice,
14    pub(super) device: VkDevice,
15    pub(super) queue: VkQueue,
16    pub(super) queue_family_index: u32,
17    
18    // Optimization managers
19    pub(super) descriptor_pool: VkDescriptorPool,
20    pub(super) command_pool: VkCommandPool,
21    
22    // Device properties
23    pub(super) device_properties: VkPhysicalDeviceProperties,
24    pub(super) memory_properties: VkPhysicalDeviceMemoryProperties,
25}
26
27/// Main context for compute operations
28/// 
29/// This is the primary entry point for the Kronos Compute API.
30/// It manages the Vulkan instance, device, and queue, and provides
31/// methods to create buffers, pipelines, and execute commands.
32#[derive(Clone)]
33pub struct ComputeContext {
34    pub(super) inner: Arc<Mutex<ContextInner>>,
35}
36
37// Send + Sync for thread safety
38unsafe impl Send for ComputeContext {}
39unsafe impl Sync for ComputeContext {}
40
41impl ComputeContext {
42    pub(super) fn new_with_config(config: ContextConfig) -> Result<Self> {
43        unsafe {
44            // Initialize Kronos ICD loader
45            initialize_kronos()
46                .map_err(|e| KronosError::InitializationFailed(e.to_string()))?;
47            
48            // Create instance
49            let instance = Self::create_instance(&config)?;
50            
51            // Find compute-capable device
52            let (physical_device, queue_family_index) = Self::find_compute_device(instance)?;
53            
54            // Get device properties
55            let mut device_properties = VkPhysicalDeviceProperties::default();
56            vkGetPhysicalDeviceProperties(physical_device, &mut device_properties);
57            
58            let mut memory_properties = VkPhysicalDeviceMemoryProperties::default();
59            vkGetPhysicalDeviceMemoryProperties(physical_device, &mut memory_properties);
60            
61            // Create logical device
62            let (device, queue) = Self::create_device(physical_device, queue_family_index)?;
63            
64            // Create descriptor pool for persistent descriptors
65            let descriptor_pool = Self::create_descriptor_pool(device)?;
66            
67            // Create command pool
68            let command_pool = Self::create_command_pool(device, queue_family_index)?;
69            
70            let inner = ContextInner {
71                instance,
72                physical_device,
73                device,
74                queue,
75                queue_family_index,
76                descriptor_pool,
77                command_pool,
78                device_properties,
79                memory_properties,
80            };
81            
82            Ok(Self {
83                inner: Arc::new(Mutex::new(inner)),
84            })
85        }
86    }
87    
88    /// Create a Vulkan instance
89    ///
90    /// # Safety
91    ///
92    /// This function is unsafe because:
93    /// - It calls vkCreateInstance which requires the Vulkan loader to be initialized
94    /// - The returned instance must be destroyed with vkDestroyInstance to avoid leaks
95    /// - The config strings must remain valid for the lifetime of the instance creation
96    /// - Null or invalid pointers in the create info will cause undefined behavior
97    unsafe fn create_instance(config: &ContextConfig) -> Result<VkInstance> {
98        let app_name = CString::new(config.app_name.clone())
99            .unwrap_or_else(|_| CString::new("Kronos App").unwrap());
100        let engine_name = CString::new("Kronos Compute").unwrap();
101        
102        let app_info = VkApplicationInfo {
103            sType: VkStructureType::ApplicationInfo,
104            pNext: ptr::null(),
105            pApplicationName: app_name.as_ptr(),
106            applicationVersion: VK_MAKE_VERSION(1, 0, 0),
107            pEngineName: engine_name.as_ptr(),
108            engineVersion: VK_MAKE_VERSION(1, 0, 0),
109            apiVersion: VK_API_VERSION_1_0,
110        };
111        
112        let create_info = VkInstanceCreateInfo {
113            sType: VkStructureType::InstanceCreateInfo,
114            pNext: ptr::null(),
115            flags: 0,
116            pApplicationInfo: &app_info,
117            enabledLayerCount: 0,
118            ppEnabledLayerNames: ptr::null(),
119            enabledExtensionCount: 0,
120            ppEnabledExtensionNames: ptr::null(),
121        };
122        
123        let mut instance = VkInstance::NULL;
124        let result = vkCreateInstance(&create_info, ptr::null(), &mut instance);
125        
126        if result != VkResult::Success {
127            return Err(KronosError::from(result));
128        }
129        
130        Ok(instance)
131    }
132    
133    /// Find a physical device with compute capabilities
134    ///
135    /// # Safety
136    ///
137    /// This function is unsafe because:
138    /// - The instance must be a valid VkInstance handle
139    /// - Calls vkEnumeratePhysicalDevices which may fail with invalid instance
140    /// - The returned physical device is tied to the instance lifetime
141    /// - Accessing the device after instance destruction is undefined behavior
142    unsafe fn find_compute_device(instance: VkInstance) -> Result<(VkPhysicalDevice, u32)> {
143        let mut device_count = 0;
144        vkEnumeratePhysicalDevices(instance, &mut device_count, ptr::null_mut());
145        
146        if device_count == 0 {
147            return Err(KronosError::DeviceNotFound);
148        }
149        
150        let mut devices = vec![VkPhysicalDevice::NULL; device_count as usize];
151        vkEnumeratePhysicalDevices(instance, &mut device_count, devices.as_mut_ptr());
152        
153        // Find first device with compute queue
154        for device in devices {
155            let queue_family = Self::find_compute_queue_family(device)?;
156            if let Some(index) = queue_family {
157                return Ok((device, index));
158            }
159        }
160        
161        Err(KronosError::DeviceNotFound)
162    }
163    
164    /// Find a queue family with compute support
165    ///
166    /// # Safety
167    ///
168    /// This function is unsafe because:
169    /// - The device must be a valid VkPhysicalDevice handle
170    /// - Calls vkGetPhysicalDeviceQueueFamilyProperties with the device
171    /// - Invalid device handle will cause undefined behavior
172    /// - The device must remain valid during the function execution
173    unsafe fn find_compute_queue_family(device: VkPhysicalDevice) -> Result<Option<u32>> {
174        let mut queue_family_count = 0;
175        vkGetPhysicalDeviceQueueFamilyProperties(device, &mut queue_family_count, ptr::null_mut());
176        
177        let mut queue_families = vec![VkQueueFamilyProperties {
178            queueFlags: VkQueueFlags::empty(),
179            queueCount: 0,
180            timestampValidBits: 0,
181            minImageTransferGranularity: VkExtent3D { width: 0, height: 0, depth: 0 },
182        }; queue_family_count as usize];
183        vkGetPhysicalDeviceQueueFamilyProperties(device, &mut queue_family_count, queue_families.as_mut_ptr());
184        
185        for (index, family) in queue_families.iter().enumerate() {
186            if family.queueFlags.contains(VkQueueFlags::COMPUTE) {
187                return Ok(Some(index as u32));
188            }
189        }
190        
191        Ok(None)
192    }
193    
194    /// Create a logical device and get its compute queue
195    ///
196    /// # Safety
197    ///
198    /// This function is unsafe because:
199    /// - The physical_device must be a valid VkPhysicalDevice handle
200    /// - The queue_family_index must be valid for the physical device
201    /// - Calls vkCreateDevice and vkGetDeviceQueue which require valid handles
202    /// - The returned device and queue must be properly destroyed
203    /// - Queue family index out of bounds will cause undefined behavior
204    unsafe fn create_device(physical_device: VkPhysicalDevice, queue_family_index: u32) -> Result<(VkDevice, VkQueue)> {
205        let queue_priority = 1.0f32;
206        
207        let queue_create_info = VkDeviceQueueCreateInfo {
208            sType: VkStructureType::DeviceQueueCreateInfo,
209            pNext: ptr::null(),
210            flags: 0,
211            queueFamilyIndex: queue_family_index,
212            queueCount: 1,
213            pQueuePriorities: &queue_priority,
214        };
215        
216        let features = VkPhysicalDeviceFeatures::default();
217        
218        let device_create_info = VkDeviceCreateInfo {
219            sType: VkStructureType::DeviceCreateInfo,
220            pNext: ptr::null(),
221            flags: 0,
222            queueCreateInfoCount: 1,
223            pQueueCreateInfos: &queue_create_info,
224            enabledLayerCount: 0,
225            ppEnabledLayerNames: ptr::null(),
226            enabledExtensionCount: 0,
227            ppEnabledExtensionNames: ptr::null(),
228            pEnabledFeatures: &features,
229        };
230        
231        let mut device = VkDevice::NULL;
232        let result = vkCreateDevice(physical_device, &device_create_info, ptr::null(), &mut device);
233        
234        if result != VkResult::Success {
235            return Err(KronosError::from(result));
236        }
237        
238        let mut queue = VkQueue::NULL;
239        vkGetDeviceQueue(device, queue_family_index, 0, &mut queue);
240        
241        Ok((device, queue))
242    }
243    
244    /// Create a descriptor pool for persistent descriptors
245    ///
246    /// # Safety
247    ///
248    /// This function is unsafe because:
249    /// - The device must be a valid VkDevice handle
250    /// - Calls vkCreateDescriptorPool which requires valid device
251    /// - The returned pool must be destroyed with vkDestroyDescriptorPool
252    /// - Invalid device handle will cause undefined behavior
253    /// - Pool creation may fail if device limits are exceeded
254    unsafe fn create_descriptor_pool(device: VkDevice) -> Result<VkDescriptorPool> {
255        // Create a large pool for persistent descriptors
256        let pool_size = VkDescriptorPoolSize {
257            type_: VkDescriptorType::StorageBuffer,
258            descriptorCount: 10000, // Should be enough for most use cases
259        };
260        
261        let pool_info = VkDescriptorPoolCreateInfo {
262            sType: VkStructureType::DescriptorPoolCreateInfo,
263            pNext: ptr::null(),
264            flags: VkDescriptorPoolCreateFlags::FREE_DESCRIPTOR_SET,
265            maxSets: 1000,
266            poolSizeCount: 1,
267            pPoolSizes: &pool_size,
268        };
269        
270        let mut pool = VkDescriptorPool::NULL;
271        let result = vkCreateDescriptorPool(device, &pool_info, ptr::null(), &mut pool);
272        
273        if result != VkResult::Success {
274            return Err(KronosError::from(result));
275        }
276        
277        Ok(pool)
278    }
279    
280    /// Create a command pool for allocating command buffers
281    ///
282    /// # Safety
283    ///
284    /// This function is unsafe because:
285    /// - The device must be a valid VkDevice handle
286    /// - The queue_family_index must be valid for the device
287    /// - Calls vkCreateCommandPool which requires valid parameters
288    /// - The returned pool must be destroyed with vkDestroyCommandPool
289    /// - Invalid queue family index will cause undefined behavior
290    unsafe fn create_command_pool(device: VkDevice, queue_family_index: u32) -> Result<VkCommandPool> {
291        let pool_info = VkCommandPoolCreateInfo {
292            sType: VkStructureType::CommandPoolCreateInfo,
293            pNext: ptr::null(),
294            flags: VkCommandPoolCreateFlags::RESET_COMMAND_BUFFER,
295            queueFamilyIndex: queue_family_index,
296        };
297        
298        let mut pool = VkCommandPool::NULL;
299        let result = vkCreateCommandPool(device, &pool_info, ptr::null(), &mut pool);
300        
301        if result != VkResult::Success {
302            return Err(KronosError::from(result));
303        }
304        
305        Ok(pool)
306    }
307    
308    /// Get the underlying Vulkan device (for advanced usage)
309    pub fn device(&self) -> VkDevice {
310        self.inner.lock().unwrap().device
311    }
312    
313    /// Get the compute queue
314    pub fn queue(&self) -> VkQueue {
315        self.inner.lock().unwrap().queue
316    }
317    
318    /// Get device properties
319    pub fn device_properties(&self) -> VkPhysicalDeviceProperties {
320        self.inner.lock().unwrap().device_properties
321    }
322    
323    // Internal helper for other modules
324    pub(super) fn with_inner<F, R>(&self, f: F) -> R
325    where
326        F: FnOnce(&ContextInner) -> R,
327    {
328        let inner = self.inner.lock().unwrap();
329        f(&*inner)
330    }
331}
332
333impl Drop for ComputeContext {
334    fn drop(&mut self) {
335        let inner = self.inner.lock().unwrap();
336        unsafe {
337            if inner.command_pool != VkCommandPool::NULL {
338                vkDestroyCommandPool(inner.device, inner.command_pool, ptr::null());
339            }
340            if inner.descriptor_pool != VkDescriptorPool::NULL {
341                vkDestroyDescriptorPool(inner.device, inner.descriptor_pool, ptr::null());
342            }
343            if inner.device != VkDevice::NULL {
344                vkDestroyDevice(inner.device, ptr::null());
345            }
346            if inner.instance != VkInstance::NULL {
347                vkDestroyInstance(inner.instance, ptr::null());
348            }
349        }
350    }
351}