1use super::*;
4use crate::*; use crate::implementation::initialize_kronos;
6use std::ffi::CString;
7use std::ptr;
8use std::sync::{Arc, Mutex};
9
10pub(super) struct ContextInner {
12 pub(super) instance: VkInstance,
13 pub(super) physical_device: VkPhysicalDevice,
14 pub(super) device: VkDevice,
15 pub(super) queue: VkQueue,
16 pub(super) queue_family_index: u32,
17
18 pub(super) descriptor_pool: VkDescriptorPool,
20 pub(super) command_pool: VkCommandPool,
21
22 pub(super) device_properties: VkPhysicalDeviceProperties,
24 pub(super) memory_properties: VkPhysicalDeviceMemoryProperties,
25}
26
27#[derive(Clone)]
33pub struct ComputeContext {
34 pub(super) inner: Arc<Mutex<ContextInner>>,
35}
36
37unsafe impl Send for ComputeContext {}
39unsafe impl Sync for ComputeContext {}
40
41impl ComputeContext {
42 pub(super) fn new_with_config(config: ContextConfig) -> Result<Self> {
43 unsafe {
44 if let Some(ref p) = config.preferred_icd_path {
46 crate::implementation::icd_loader::set_preferred_icd_path(p.clone());
47 } else if let Some(i) = config.preferred_icd_index {
48 crate::implementation::icd_loader::set_preferred_icd_index(i);
49 }
50
51 initialize_kronos()
53 .map_err(|e| KronosError::InitializationFailed(e.to_string()))?;
54
55 let instance = Self::create_instance(&config)?;
57
58 let (physical_device, queue_family_index) = Self::find_compute_device(instance)?;
60
61 let mut device_properties = VkPhysicalDeviceProperties::default();
63 vkGetPhysicalDeviceProperties(physical_device, &mut device_properties);
64
65 let mut memory_properties = VkPhysicalDeviceMemoryProperties::default();
66 vkGetPhysicalDeviceMemoryProperties(physical_device, &mut memory_properties);
67
68 let device_name = std::ffi::CStr::from_ptr(device_properties.deviceName.as_ptr())
70 .to_string_lossy();
71 let device_type_str = match device_properties.deviceType {
72 VkPhysicalDeviceType::DiscreteGpu => "Discrete GPU",
73 VkPhysicalDeviceType::IntegratedGpu => "Integrated GPU",
74 VkPhysicalDeviceType::VirtualGpu => "Virtual GPU",
75 VkPhysicalDeviceType::Cpu => "CPU (Software Renderer)",
76 _ => "Unknown",
77 };
78 log::info!("Selected Vulkan device: {} ({})", device_name, device_type_str);
79
80 let (device, queue) = Self::create_device(physical_device, queue_family_index)?;
82
83 let descriptor_pool = Self::create_descriptor_pool(device)?;
85
86 let command_pool = Self::create_command_pool(device, queue_family_index)?;
88
89 let inner = ContextInner {
90 instance,
91 physical_device,
92 device,
93 queue,
94 queue_family_index,
95 descriptor_pool,
96 command_pool,
97 device_properties,
98 memory_properties,
99 };
100
101 if let Some(info) = crate::implementation::icd_loader::selected_icd_info() {
103 log::info!(
104 "ComputeContext bound to ICD: {} ({}), api=0x{:x}",
105 info.library_path.display(),
106 if info.is_software { "software" } else { "hardware" },
107 info.api_version
108 );
109 }
110
111 Ok(Self {
112 inner: Arc::new(Mutex::new(inner)),
113 })
114 }
115 }
116
117 unsafe fn create_instance(config: &ContextConfig) -> Result<VkInstance> {
127 let app_name = CString::new(config.app_name.clone())
128 .unwrap_or_else(|_| CString::new("Kronos App").unwrap());
129 let engine_name = CString::new("Kronos Compute").unwrap();
130
131 let app_info = VkApplicationInfo {
132 sType: VkStructureType::ApplicationInfo,
133 pNext: ptr::null(),
134 pApplicationName: app_name.as_ptr(),
135 applicationVersion: VK_MAKE_VERSION(1, 0, 0),
136 pEngineName: engine_name.as_ptr(),
137 engineVersion: VK_MAKE_VERSION(1, 0, 0),
138 apiVersion: VK_API_VERSION_1_0,
139 };
140
141 let create_info = VkInstanceCreateInfo {
142 sType: VkStructureType::InstanceCreateInfo,
143 pNext: ptr::null(),
144 flags: 0,
145 pApplicationInfo: &app_info,
146 enabledLayerCount: 0,
147 ppEnabledLayerNames: ptr::null(),
148 enabledExtensionCount: 0,
149 ppEnabledExtensionNames: ptr::null(),
150 };
151
152 let mut instance = VkInstance::NULL;
153 let result = vkCreateInstance(&create_info, ptr::null(), &mut instance);
154
155 if result != VkResult::Success {
156 return Err(KronosError::from(result));
157 }
158
159 Ok(instance)
160 }
161
162 unsafe fn find_compute_device(instance: VkInstance) -> Result<(VkPhysicalDevice, u32)> {
172 let mut device_count = 0;
173 vkEnumeratePhysicalDevices(instance, &mut device_count, ptr::null_mut());
174
175 if device_count == 0 {
176 return Err(KronosError::DeviceNotFound);
177 }
178
179 let mut devices = vec![VkPhysicalDevice::NULL; device_count as usize];
180 vkEnumeratePhysicalDevices(instance, &mut device_count, devices.as_mut_ptr());
181
182 let mut candidates = Vec::new();
184
185 for device in devices {
186 let queue_family = Self::find_compute_queue_family(device)?;
187 if let Some(index) = queue_family {
188 let mut properties = VkPhysicalDeviceProperties::default();
190 vkGetPhysicalDeviceProperties(device, &mut properties);
191
192 candidates.push((device, index, properties.deviceType));
193 }
194 }
195
196 if candidates.is_empty() {
197 return Err(KronosError::DeviceNotFound);
198 }
199
200 candidates.sort_by_key(|(_, _, device_type)| {
202 match *device_type {
203 VkPhysicalDeviceType::DiscreteGpu => 0,
204 VkPhysicalDeviceType::IntegratedGpu => 1,
205 VkPhysicalDeviceType::VirtualGpu => 2,
206 VkPhysicalDeviceType::Cpu => 3,
207 VkPhysicalDeviceType::Other => 4,
208 _ => 5,
209 }
210 });
211
212 let (device, queue_index, _) = candidates[0];
214 Ok((device, queue_index))
215 }
216
217 unsafe fn find_compute_queue_family(device: VkPhysicalDevice) -> Result<Option<u32>> {
227 let mut queue_family_count = 0;
228 vkGetPhysicalDeviceQueueFamilyProperties(device, &mut queue_family_count, ptr::null_mut());
229
230 let mut queue_families = vec![VkQueueFamilyProperties {
231 queueFlags: VkQueueFlags::empty(),
232 queueCount: 0,
233 timestampValidBits: 0,
234 minImageTransferGranularity: VkExtent3D { width: 0, height: 0, depth: 0 },
235 }; queue_family_count as usize];
236 vkGetPhysicalDeviceQueueFamilyProperties(device, &mut queue_family_count, queue_families.as_mut_ptr());
237
238 for (index, family) in queue_families.iter().enumerate() {
239 if family.queueFlags.contains(VkQueueFlags::COMPUTE) {
240 return Ok(Some(index as u32));
241 }
242 }
243
244 Ok(None)
245 }
246
247 unsafe fn create_device(physical_device: VkPhysicalDevice, queue_family_index: u32) -> Result<(VkDevice, VkQueue)> {
258 let queue_priority = 1.0f32;
259
260 let queue_create_info = VkDeviceQueueCreateInfo {
261 sType: VkStructureType::DeviceQueueCreateInfo,
262 pNext: ptr::null(),
263 flags: 0,
264 queueFamilyIndex: queue_family_index,
265 queueCount: 1,
266 pQueuePriorities: &queue_priority,
267 };
268
269 let features = VkPhysicalDeviceFeatures::default();
270
271 let device_create_info = VkDeviceCreateInfo {
272 sType: VkStructureType::DeviceCreateInfo,
273 pNext: ptr::null(),
274 flags: 0,
275 queueCreateInfoCount: 1,
276 pQueueCreateInfos: &queue_create_info,
277 enabledLayerCount: 0,
278 ppEnabledLayerNames: ptr::null(),
279 enabledExtensionCount: 0,
280 ppEnabledExtensionNames: ptr::null(),
281 pEnabledFeatures: &features,
282 };
283
284 let mut device = VkDevice::NULL;
285 let result = vkCreateDevice(physical_device, &device_create_info, ptr::null(), &mut device);
286
287 if result != VkResult::Success {
288 return Err(KronosError::from(result));
289 }
290
291 let mut queue = VkQueue::NULL;
292 vkGetDeviceQueue(device, queue_family_index, 0, &mut queue);
293
294 Ok((device, queue))
295 }
296
297 unsafe fn create_descriptor_pool(device: VkDevice) -> Result<VkDescriptorPool> {
308 let pool_size = VkDescriptorPoolSize {
310 type_: VkDescriptorType::StorageBuffer,
311 descriptorCount: 10000, };
313
314 let pool_info = VkDescriptorPoolCreateInfo {
315 sType: VkStructureType::DescriptorPoolCreateInfo,
316 pNext: ptr::null(),
317 flags: VkDescriptorPoolCreateFlags::FREE_DESCRIPTOR_SET,
318 maxSets: 1000,
319 poolSizeCount: 1,
320 pPoolSizes: &pool_size,
321 };
322
323 let mut pool = VkDescriptorPool::NULL;
324 let result = vkCreateDescriptorPool(device, &pool_info, ptr::null(), &mut pool);
325
326 if result != VkResult::Success {
327 return Err(KronosError::from(result));
328 }
329
330 Ok(pool)
331 }
332
333 unsafe fn create_command_pool(device: VkDevice, queue_family_index: u32) -> Result<VkCommandPool> {
344 let pool_info = VkCommandPoolCreateInfo {
345 sType: VkStructureType::CommandPoolCreateInfo,
346 pNext: ptr::null(),
347 flags: VkCommandPoolCreateFlags::RESET_COMMAND_BUFFER,
348 queueFamilyIndex: queue_family_index,
349 };
350
351 let mut pool = VkCommandPool::NULL;
352 let result = vkCreateCommandPool(device, &pool_info, ptr::null(), &mut pool);
353
354 if result != VkResult::Success {
355 return Err(KronosError::from(result));
356 }
357
358 Ok(pool)
359 }
360
361 pub fn device(&self) -> VkDevice {
363 self.inner.lock().unwrap().device
364 }
365
366 pub fn queue(&self) -> VkQueue {
368 self.inner.lock().unwrap().queue
369 }
370
371 pub fn device_properties(&self) -> VkPhysicalDeviceProperties {
373 self.inner.lock().unwrap().device_properties
374 }
375
376 pub fn icd_info(&self) -> Option<crate::implementation::icd_loader::IcdInfo> {
378 crate::implementation::icd_loader::selected_icd_info()
379 }
380
381 pub(super) fn with_inner<F, R>(&self, f: F) -> R
383 where
384 F: FnOnce(&ContextInner) -> R,
385 {
386 let inner = self.inner.lock().unwrap();
387 f(&*inner)
388 }
389}
390
391impl Drop for ComputeContext {
392 fn drop(&mut self) {
393 if std::sync::Arc::strong_count(&self.inner) != 1 {
395 return;
396 }
397 let inner = self.inner.lock().unwrap();
398 unsafe {
399 if inner.command_pool != VkCommandPool::NULL {
400 vkDestroyCommandPool(inner.device, inner.command_pool, ptr::null());
401 }
402 if inner.descriptor_pool != VkDescriptorPool::NULL {
403 vkDestroyDescriptorPool(inner.device, inner.descriptor_pool, ptr::null());
404 }
405 if inner.device != VkDevice::NULL {
406 vkDestroyDevice(inner.device, ptr::null());
407 }
408 if inner.instance != VkInstance::NULL {
409 vkDestroyInstance(inner.instance, ptr::null());
410 }
411 }
412 }
413}