kronos_compute/api/
context.rs

1//! Main entry point for Kronos Compute
2
3use super::*;
4use crate::*; // Need all the type definitions
5use crate::implementation::initialize_kronos;
6
7// Explicitly import Vulkan functions from implementation when available
8#[cfg(feature = "implementation")]
9use crate::implementation::{
10    vkCreateInstance, vkDestroyInstance, vkEnumeratePhysicalDevices,
11    vkGetPhysicalDeviceProperties, vkGetPhysicalDeviceMemoryProperties,
12    vkGetPhysicalDeviceQueueFamilyProperties,
13    vkCreateDevice, vkDestroyDevice, vkGetDeviceQueue,
14    vkCreateDescriptorPool, vkDestroyDescriptorPool,
15    vkCreateCommandPool, vkDestroyCommandPool,
16};
17use std::ffi::CString;
18use std::ptr;
19use std::sync::{Arc, Mutex};
20
21/// Internal state for ComputeContext
22pub(super) struct ContextInner {
23    pub(super) instance: VkInstance,
24    pub(super) physical_device: VkPhysicalDevice,
25    pub(super) device: VkDevice,
26    pub(super) queue: VkQueue,
27    pub(super) queue_family_index: u32,
28    
29    // Optimization managers
30    pub(super) descriptor_pool: VkDescriptorPool,
31    pub(super) command_pool: VkCommandPool,
32    
33    // Device properties
34    pub(super) device_properties: VkPhysicalDeviceProperties,
35    pub(super) memory_properties: VkPhysicalDeviceMemoryProperties,
36}
37
38/// Main context for compute operations
39/// 
40/// This is the primary entry point for the Kronos Compute API.
41/// It manages the Vulkan instance, device, and queue, and provides
42/// methods to create buffers, pipelines, and execute commands.
43#[derive(Clone)]
44pub struct ComputeContext {
45    pub(super) inner: Arc<Mutex<ContextInner>>,
46}
47
48// Send + Sync for thread safety
49unsafe impl Send for ComputeContext {}
50unsafe impl Sync for ComputeContext {}
51
52impl ComputeContext {
53    pub(super) fn new_with_config(config: ContextConfig) -> Result<Self> {
54        log::info!("[SAFE API] ComputeContext::new_with_config() called");
55        unsafe {
56            // Initialize Kronos pure Rust implementation
57            log::info!("[SAFE API] Initializing Kronos pure Rust implementation");
58            initialize_kronos()
59                .map_err(|e| {
60                    log::error!("[SAFE API] Failed to initialize Kronos: {:?}", e);
61                    KronosError::InitializationFailed(e.to_string())
62                })?;
63            log::info!("[SAFE API] Kronos initialized successfully");
64            
65            // Create instance
66            log::info!("[SAFE API] Creating Vulkan instance");
67            let instance = Self::create_instance(&config)?;
68            log::info!("[SAFE API] Instance created: {:?}", instance);
69            
70            // Find compute-capable device
71            log::info!("[SAFE API] Finding compute-capable device");
72            let (physical_device, queue_family_index) = Self::find_compute_device(instance)?;
73            log::info!("[SAFE API] Found device: {:?}, queue family: {}", physical_device, queue_family_index);
74            
75            log::info!("[SAFE API] find_compute_device returned successfully");
76            
77            // Get device properties
78            log::info!("[SAFE API] Getting device properties");
79            let mut device_properties = VkPhysicalDeviceProperties::default();
80            vkGetPhysicalDeviceProperties(physical_device, &mut device_properties);
81            log::info!("[SAFE API] Got device properties successfully");
82            
83            let mut memory_properties = VkPhysicalDeviceMemoryProperties::default();
84            log::info!("[SAFE API] Getting memory properties");
85            vkGetPhysicalDeviceMemoryProperties(physical_device, &mut memory_properties);
86            log::info!("[SAFE API] Got memory properties successfully");
87            
88            // Log selected device info
89            // deviceName is a fixed-size array, ensure it's null-terminated
90            let device_name_bytes = &device_properties.deviceName;
91            let null_pos = device_name_bytes.iter().position(|&c| c == 0).unwrap_or(device_name_bytes.len());
92            // Convert from &[i8] to &[u8] for from_utf8
93            let device_name_u8: Vec<u8> = device_name_bytes[..null_pos]
94                .iter()
95                .map(|&c| c as u8)
96                .collect();
97            let device_name = std::str::from_utf8(&device_name_u8)
98                .unwrap_or("Unknown Device");
99            let device_type_str = match device_properties.deviceType {
100                VkPhysicalDeviceType::DiscreteGpu => "Discrete GPU",
101                VkPhysicalDeviceType::IntegratedGpu => "Integrated GPU",
102                VkPhysicalDeviceType::VirtualGpu => "Virtual GPU",
103                VkPhysicalDeviceType::Cpu => "CPU (Software Renderer)",
104                _ => "Unknown",
105            };
106            log::info!("Selected Vulkan device: {} ({})", device_name, device_type_str);
107            
108            // Create logical device
109            log::info!("[SAFE API] Creating logical device");
110            let (device, queue) = Self::create_device(physical_device, queue_family_index)?;
111            log::info!("[SAFE API] Device created: {:?}, queue: {:?}", device, queue);
112            
113            // Create descriptor pool for persistent descriptors
114            // TODO: Fix descriptor pool creation - temporarily skip
115            log::info!("[SAFE API] Skipping descriptor pool creation temporarily");
116            let descriptor_pool = VkDescriptorPool::NULL;
117            // let descriptor_pool = Self::create_descriptor_pool(device)?;
118            // log::info!("[SAFE API] Descriptor pool created: {:?}", descriptor_pool);
119            
120            // Create command pool
121            log::info!("[SAFE API] Creating command pool");
122            let command_pool = Self::create_command_pool(device, queue_family_index)?;
123            log::info!("[SAFE API] Command pool created: {:?}", command_pool);
124            
125            let inner = ContextInner {
126                instance,
127                physical_device,
128                device,
129                queue,
130                queue_family_index,
131                descriptor_pool,
132                command_pool,
133                device_properties,
134                memory_properties,
135            };
136            
137            // Log that we're using pure Rust implementation
138            log::info!("ComputeContext using Kronos pure Rust implementation");
139
140            let result = Self {
141                inner: Arc::new(Mutex::new(inner)),
142            };
143            log::info!("[SAFE API] ComputeContext created successfully");
144            Ok(result)
145        }
146    }
147    
148    /// Create a Vulkan instance
149    ///
150    /// # Safety
151    ///
152    /// This function is unsafe because:
153    /// - It calls vkCreateInstance which requires the Vulkan loader to be initialized
154    /// - The returned instance must be destroyed with vkDestroyInstance to avoid leaks
155    /// - The config strings must remain valid for the lifetime of the instance creation
156    /// - Null or invalid pointers in the create info will cause undefined behavior
157    unsafe fn create_instance(config: &ContextConfig) -> Result<VkInstance> {
158        log::info!("[SAFE API] create_instance called with app_name: {}", config.app_name);
159        let app_name = CString::new(config.app_name.clone())
160            .unwrap_or_else(|_| CString::new("Kronos App").unwrap());
161        let engine_name = CString::new("Kronos Compute").unwrap();
162        log::info!("[SAFE API] CStrings created successfully");
163        
164        let app_info = VkApplicationInfo {
165            sType: VkStructureType::ApplicationInfo,
166            pNext: ptr::null(),
167            pApplicationName: app_name.as_ptr(),
168            applicationVersion: VK_MAKE_VERSION(1, 0, 0),
169            pEngineName: engine_name.as_ptr(),
170            engineVersion: VK_MAKE_VERSION(1, 0, 0),
171            apiVersion: VK_API_VERSION_1_0,
172        };
173        
174        let create_info = VkInstanceCreateInfo {
175            sType: VkStructureType::InstanceCreateInfo,
176            pNext: ptr::null(),
177            flags: 0,
178            pApplicationInfo: &app_info,
179            enabledLayerCount: 0,
180            ppEnabledLayerNames: ptr::null(),
181            enabledExtensionCount: 0,
182            ppEnabledExtensionNames: ptr::null(),
183        };
184        
185        let mut instance = VkInstance::NULL;
186        // IMPORTANT: CStrings must remain alive during vkCreateInstance call
187        // They are dropped at the end of this function, which is safe
188        log::info!("[SAFE API] Calling vkCreateInstance");
189        let result = vkCreateInstance(&create_info, ptr::null(), &mut instance);
190        log::info!("[SAFE API] vkCreateInstance returned: {:?}", result);
191        
192        if result != VkResult::Success {
193            log::error!("[SAFE API] vkCreateInstance failed with: {:?}", result);
194            return Err(KronosError::from(result));
195        }
196        
197        log::info!("[SAFE API] Instance created successfully: {:?}", instance);
198        Ok(instance)
199    }
200    
201    /// Find a physical device with compute capabilities
202    ///
203    /// # Safety
204    ///
205    /// This function is unsafe because:
206    /// - The instance must be a valid VkInstance handle
207    /// - Calls vkEnumeratePhysicalDevices which may fail with invalid instance
208    /// - The returned physical device is tied to the instance lifetime
209    /// - Accessing the device after instance destruction is undefined behavior
210    unsafe fn find_compute_device(instance: VkInstance) -> Result<(VkPhysicalDevice, u32)> {
211        let mut device_count = 0;
212        log::info!("[SAFE API] Enumerating physical devices...");
213        
214        // First call to get count
215        let result = vkEnumeratePhysicalDevices(instance, &mut device_count, ptr::null_mut());
216        if result != VkResult::Success {
217            log::error!("[SAFE API] Failed to get device count: {:?}", result);
218            return Err(KronosError::from(result));
219        }
220        log::info!("[SAFE API] Found {} physical devices", device_count);
221        
222        if device_count == 0 {
223            return Err(KronosError::DeviceNotFound);
224        }
225        
226        let mut devices = vec![VkPhysicalDevice::NULL; device_count as usize];
227        let result = vkEnumeratePhysicalDevices(instance, &mut device_count, devices.as_mut_ptr());
228        if result != VkResult::Success {
229            log::error!("[SAFE API] Failed to enumerate devices: {:?}", result);
230            return Err(KronosError::from(result));
231        }
232        log::info!("[SAFE API] Successfully enumerated {} devices", device_count);
233        
234        // Collect all devices with compute support and their properties
235        let mut candidates = Vec::new();
236        
237        for (dev_idx, device) in devices.iter().enumerate() {
238            log::info!("[SAFE API] Checking device {} for compute support", dev_idx);
239            let queue_family = Self::find_compute_queue_family(*device)?;
240            if let Some(index) = queue_family {
241                // Get device properties to determine device type
242                let mut properties = VkPhysicalDeviceProperties::default();
243                vkGetPhysicalDeviceProperties(*device, &mut properties);
244                
245                candidates.push((*device, index, properties.deviceType));
246            }
247        }
248        
249        if candidates.is_empty() {
250            return Err(KronosError::DeviceNotFound);
251        }
252        
253        // Sort by device type preference: DiscreteGpu > IntegratedGpu > VirtualGpu > Cpu
254        candidates.sort_by_key(|(_, _, device_type)| {
255            match *device_type {
256                VkPhysicalDeviceType::DiscreteGpu => 0,
257                VkPhysicalDeviceType::IntegratedGpu => 1,
258                VkPhysicalDeviceType::VirtualGpu => 2,
259                VkPhysicalDeviceType::Cpu => 3,
260                VkPhysicalDeviceType::Other => 4,
261                _ => 5,
262            }
263        });
264        
265        // Return the best device
266        let (device, queue_index, device_type) = candidates[0];
267        log::info!("[SAFE API] Selected device with queue index {}, type {:?}", queue_index, device_type);
268        Ok((device, queue_index))
269    }
270    
271    /// Find a queue family with compute support
272    ///
273    /// # Safety
274    ///
275    /// This function is unsafe because:
276    /// - The device must be a valid VkPhysicalDevice handle
277    /// - Calls vkGetPhysicalDeviceQueueFamilyProperties with the device
278    /// - Invalid device handle will cause undefined behavior
279    /// - The device must remain valid during the function execution
280    unsafe fn find_compute_queue_family(device: VkPhysicalDevice) -> Result<Option<u32>> {
281        log::info!("[SAFE API] Finding queue families for device {:?}", device);
282        let mut queue_family_count = 0;
283        log::info!("[SAFE API] Calling vkGetPhysicalDeviceQueueFamilyProperties (count query)...");
284        vkGetPhysicalDeviceQueueFamilyProperties(device, &mut queue_family_count, ptr::null_mut());
285        log::info!("[SAFE API] Device has {} queue families", queue_family_count);
286        
287        let mut queue_families = vec![
288            VkQueueFamilyProperties {
289                queueFlags: VkQueueFlags::empty(),
290                queueCount: 0,
291                timestampValidBits: 0,
292                minImageTransferGranularity: VkExtent3D { width: 0, height: 0, depth: 0 },
293            };
294            queue_family_count as usize
295        ];
296        log::info!("[SAFE API] Getting queue family properties...");
297        vkGetPhysicalDeviceQueueFamilyProperties(device, &mut queue_family_count, queue_families.as_mut_ptr());
298        log::info!("[SAFE API] Got queue family properties, checking for compute support");
299        
300        for (index, family) in queue_families.iter().enumerate() {
301            log::info!("[SAFE API] Queue family {}: flags={:?}", index, family.queueFlags);
302            if family.queueFlags.contains(VkQueueFlags::COMPUTE) {
303                log::info!("[SAFE API] Found compute queue at index {}", index);
304                return Ok(Some(index as u32));
305            }
306        }
307        
308        log::info!("[SAFE API] No compute queue family found");
309        Ok(None)
310    }
311    
312    /// Create a logical device and get its compute queue
313    ///
314    /// # Safety
315    ///
316    /// This function is unsafe because:
317    /// - The physical_device must be a valid VkPhysicalDevice handle
318    /// - The queue_family_index must be valid for the physical device
319    /// - Calls vkCreateDevice and vkGetDeviceQueue which require valid handles
320    /// - The returned device and queue must be properly destroyed
321    /// - Queue family index out of bounds will cause undefined behavior
322    unsafe fn create_device(physical_device: VkPhysicalDevice, queue_family_index: u32) -> Result<(VkDevice, VkQueue)> {
323        let queue_priority = 1.0f32;
324        
325        let queue_create_info = VkDeviceQueueCreateInfo {
326            sType: VkStructureType::DeviceQueueCreateInfo,
327            pNext: ptr::null(),
328            flags: 0,
329            queueFamilyIndex: queue_family_index,
330            queueCount: 1,
331            pQueuePriorities: &queue_priority,
332        };
333        
334        // Don't request any features - use NULL pointer like the working example
335        log::info!("[SAFE API] Creating device with NULL features pointer (no features requested)");
336        
337        let device_create_info = VkDeviceCreateInfo {
338            sType: VkStructureType::DeviceCreateInfo,
339            pNext: ptr::null(),
340            flags: 0,
341            queueCreateInfoCount: 1,
342            pQueueCreateInfos: &queue_create_info,
343            enabledLayerCount: 0,
344            ppEnabledLayerNames: ptr::null(),
345            enabledExtensionCount: 0,
346            ppEnabledExtensionNames: ptr::null(),
347            pEnabledFeatures: ptr::null(),  // Use NULL like the working example
348        };
349        
350        let mut device = VkDevice::NULL;
351        log::info!("[SAFE API] Calling vkCreateDevice with queue family index {}", queue_family_index);
352        let result = vkCreateDevice(physical_device, &device_create_info, ptr::null(), &mut device);
353        log::info!("[SAFE API] vkCreateDevice returned: {:?}", result);
354        
355        if result != VkResult::Success {
356            log::error!("[SAFE API] Failed to create device: {:?}", result);
357            return Err(KronosError::from(result));
358        }
359        
360        let mut queue = VkQueue::NULL;
361        vkGetDeviceQueue(device, queue_family_index, 0, &mut queue);
362        
363        // No ICD registration needed - we're using pure Rust implementation!
364        log::info!("[SAFE API] Device created with pure Rust implementation");
365        
366        Ok((device, queue))
367    }
368    
369    /// Create a descriptor pool for persistent descriptors
370    ///
371    /// # Safety
372    ///
373    /// This function is unsafe because:
374    /// - The device must be a valid VkDevice handle
375    /// - Calls vkCreateDescriptorPool which requires valid device
376    /// - The returned pool must be destroyed with vkDestroyDescriptorPool
377    /// - Invalid device handle will cause undefined behavior
378    /// - Pool creation may fail if device limits are exceeded
379    unsafe fn create_descriptor_pool(device: VkDevice) -> Result<VkDescriptorPool> {
380        log::info!("[SAFE API] Creating descriptor pool with device: {:?}", device);
381        // Create a large pool for persistent descriptors
382        let pool_size = VkDescriptorPoolSize {
383            type_: VkDescriptorType::StorageBuffer,
384            descriptorCount: 10000, // Should be enough for most use cases
385        };
386        
387        let pool_info = VkDescriptorPoolCreateInfo {
388            sType: VkStructureType::DescriptorPoolCreateInfo,
389            pNext: ptr::null(),
390            flags: VkDescriptorPoolCreateFlags::FREE_DESCRIPTOR_SET,
391            maxSets: 1000,
392            poolSizeCount: 1,
393            pPoolSizes: &pool_size,
394        };
395        
396        let mut pool = VkDescriptorPool::NULL;
397        log::info!("[SAFE API] Calling vkCreateDescriptorPool");
398        let result = vkCreateDescriptorPool(device, &pool_info, ptr::null(), &mut pool);
399        log::info!("[SAFE API] vkCreateDescriptorPool returned: {:?}", result);
400        
401        if result != VkResult::Success {
402            log::error!("[SAFE API] Failed to create descriptor pool: {:?}", result);
403            return Err(KronosError::from(result));
404        }
405        
406        Ok(pool)
407    }
408    
409    /// Create a command pool for allocating command buffers
410    ///
411    /// # Safety
412    ///
413    /// This function is unsafe because:
414    /// - The device must be a valid VkDevice handle
415    /// - The queue_family_index must be valid for the device
416    /// - Calls vkCreateCommandPool which requires valid parameters
417    /// - The returned pool must be destroyed with vkDestroyCommandPool
418    /// - Invalid queue family index will cause undefined behavior
419    unsafe fn create_command_pool(device: VkDevice, queue_family_index: u32) -> Result<VkCommandPool> {
420        let pool_info = VkCommandPoolCreateInfo {
421            sType: VkStructureType::CommandPoolCreateInfo,
422            pNext: ptr::null(),
423            flags: VkCommandPoolCreateFlags::RESET_COMMAND_BUFFER,
424            queueFamilyIndex: queue_family_index,
425        };
426        
427        let mut pool = VkCommandPool::NULL;
428        log::info!("[SAFE API] Calling vkCreateCommandPool with device {:?}, queue family {}", device, queue_family_index);
429        let result = vkCreateCommandPool(device, &pool_info, ptr::null(), &mut pool);
430        log::info!("[SAFE API] vkCreateCommandPool returned: {:?}", result);
431        
432        if result != VkResult::Success {
433            log::error!("[SAFE API] Failed to create command pool: {:?}", result);
434            return Err(KronosError::from(result));
435        }
436        
437        Ok(pool)
438    }
439    
440    /// Get the underlying Vulkan device (for advanced usage)
441    pub fn device(&self) -> VkDevice {
442        self.inner.lock().unwrap().device
443    }
444    
445    /// Get the compute queue
446    pub fn queue(&self) -> VkQueue {
447        self.inner.lock().unwrap().queue
448    }
449    
450    /// Get device properties
451    pub fn device_properties(&self) -> VkPhysicalDeviceProperties {
452        self.inner.lock().unwrap().device_properties
453    }
454    
455    // REMOVED: icd_info() - Kronos is now a pure Rust implementation
456
457    // Internal helper for other modules
458    pub(super) fn with_inner<F, R>(&self, f: F) -> R
459    where
460        F: FnOnce(&ContextInner) -> R,
461    {
462        let inner = self.inner.lock().unwrap();
463        f(&*inner)
464    }
465}
466
467impl Drop for ComputeContext {
468    fn drop(&mut self) {
469        // Only the last Clone should perform destruction to avoid double-free.
470        if std::sync::Arc::strong_count(&self.inner) != 1 {
471            return;
472        }
473        let inner = self.inner.lock().unwrap();
474        unsafe {
475            if inner.command_pool != VkCommandPool::NULL {
476                vkDestroyCommandPool(inner.device, inner.command_pool, ptr::null());
477            }
478            // TODO: Re-enable when descriptor pool creation is fixed
479            // if inner.descriptor_pool != VkDescriptorPool::NULL {
480            //     vkDestroyDescriptorPool(inner.device, inner.descriptor_pool, ptr::null());
481            // }
482            if inner.device != VkDevice::NULL {
483                vkDestroyDevice(inner.device, ptr::null());
484            }
485            if inner.instance != VkInstance::NULL {
486                vkDestroyInstance(inner.instance, ptr::null());
487            }
488        }
489    }
490}