mumu-gpu 0.1.0

GPU/Vulkan matrix and tensor operations for the mumu/lava language
Documentation
#![allow(dead_code)]

use anyhow::Result;
use ash::vk;
use lazy_static::lazy_static;
use std::marker::PhantomData;
use std::sync::{Arc, Mutex};
use log;

pub struct AshVulkanContext {
    pub entry: ash::Entry,
    pub instance: ash::Instance,
    pub device: ash::Device,
    pub queue: ash::vk::Queue,
    pub queue_family_index: u32,
}

impl AshVulkanContext {
    pub fn new() -> Result<Self> {
        env_logger::init();

        let entry = unsafe { ash::Entry::load()? };

        let app_name = std::ffi::CString::new("mumu-gpu")?;
        let app_info = vk::ApplicationInfo {
            s_type: vk::StructureType::APPLICATION_INFO,
            p_next: std::ptr::null(),
            p_application_name: app_name.as_ptr(),
            application_version: 0,
            p_engine_name: app_name.as_ptr(),
            engine_version: 0,
            api_version: vk::make_api_version(0, 1, 2, 0),
            _marker: PhantomData,
        };

        let create_info = vk::InstanceCreateInfo {
            s_type: vk::StructureType::INSTANCE_CREATE_INFO,
            p_next: std::ptr::null(),
            flags: vk::InstanceCreateFlags::empty(),
            p_application_info: &app_info,
            enabled_layer_count: 0,
            pp_enabled_layer_names: std::ptr::null(),
            enabled_extension_count: 0,
            pp_enabled_extension_names: std::ptr::null(),
            _marker: PhantomData,
        };

        let instance = unsafe { entry.create_instance(&create_info, None)? };

        let pdevices = unsafe { instance.enumerate_physical_devices()? };
        if pdevices.is_empty() {
            return Err(anyhow::anyhow!("No Vulkan physical devices found"));
        }
        let physical_device = pdevices[0];

        let queue_props = unsafe {
            instance.get_physical_device_queue_family_properties(physical_device)
        };
        let mut compute_qfam = None;
        for (i, props) in queue_props.iter().enumerate() {
            if props.queue_flags.contains(vk::QueueFlags::COMPUTE) {
                compute_qfam = Some(i as u32);
                break;
            }
        }
        let queue_family_index = compute_qfam.ok_or_else(|| {
            anyhow::anyhow!("No queue family supports COMPUTE")
        })?;

        let queue_priority = 1.0f32;
        let queue_info = vk::DeviceQueueCreateInfo {
            s_type: vk::StructureType::DEVICE_QUEUE_CREATE_INFO,
            p_next: std::ptr::null(),
            flags: vk::DeviceQueueCreateFlags::empty(),
            queue_family_index,
            queue_count: 1,
            p_queue_priorities: &queue_priority,
            _marker: PhantomData,
        };

        #[allow(deprecated)]
        let dev_create_info = vk::DeviceCreateInfo {
            s_type: vk::StructureType::DEVICE_CREATE_INFO,
            p_next: std::ptr::null(),
            flags: vk::DeviceCreateFlags::empty(),
            queue_create_info_count: 1,
            p_queue_create_infos: &queue_info,
            enabled_layer_count: 0,
            pp_enabled_layer_names: std::ptr::null(),
            enabled_extension_count: 0,
            pp_enabled_extension_names: std::ptr::null(),
            p_enabled_features: std::ptr::null(),
            _marker: PhantomData,
        };

        let device = unsafe {
            instance.create_device(physical_device, &dev_create_info, None)?
        };
        let queue = unsafe { device.get_device_queue(queue_family_index, 0) };

        log::info!("AshVulkanContext => created successfully");
        Ok(Self {
            entry,
            instance,
            device,
            queue,
            queue_family_index,
        })
    }
}

impl Drop for AshVulkanContext {
    fn drop(&mut self) {
        unsafe {
            self.device.destroy_device(None);
            self.instance.destroy_instance(None);
        }
        log::info!("AshVulkanContext => destroyed");
    }
}

lazy_static! {
    pub static ref VULKAN_CONTEXT: Arc<Mutex<Option<AshVulkanContext>>> =
        Arc::new(Mutex::new(None));
}