use std::os::raw::c_void;
use trustformers_core::error::{CoreError, Result};
#[cfg(target_os = "android")]
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct VkInstance(pub *mut c_void);
#[cfg(target_os = "android")]
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct VkDevice(pub *mut c_void);
#[cfg(target_os = "android")]
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct VkPhysicalDevice(pub *mut c_void);
#[cfg(target_os = "android")]
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct VkQueue(pub *mut c_void);
#[cfg(target_os = "android")]
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct VkCommandBuffer(pub *mut c_void);
#[cfg(target_os = "android")]
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct VkBuffer(pub *mut c_void);
#[cfg(target_os = "android")]
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct VkDeviceMemory(pub *mut c_void);
#[cfg(target_os = "android")]
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct VkDescriptorSet(pub *mut c_void);
#[cfg(target_os = "android")]
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct VkPipeline(pub *mut c_void);
pub const VK_SUCCESS: i32 = 0;
pub const VK_NOT_READY: i32 = 1;
pub const VK_TIMEOUT: i32 = 2;
pub const VK_EVENT_SET: i32 = 3;
pub const VK_EVENT_RESET: i32 = 4;
pub const VK_INCOMPLETE: i32 = 5;
pub const VK_ERROR_OUT_OF_HOST_MEMORY: i32 = -1;
pub const VK_ERROR_OUT_OF_DEVICE_MEMORY: i32 = -2;
pub const VK_ERROR_INITIALIZATION_FAILED: i32 = -3;
pub const VK_ERROR_DEVICE_LOST: i32 = -4;
pub const VK_ERROR_MEMORY_MAP_FAILED: i32 = -5;
pub const VK_PIPELINE_BIND_POINT_COMPUTE: u32 = 1;
pub const VK_COMMAND_BUFFER_LEVEL_PRIMARY: u32 = 0;
pub const VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT: u32 = 0x00000002;
pub const VK_MEMORY_PROPERTY_HOST_COHERENT_BIT: u32 = 0x00000004;
pub const VK_BUFFER_USAGE_STORAGE_BUFFER_BIT: u32 = 0x00000020;
#[cfg(target_os = "android")]
extern "C" {
pub fn vkCreateInstance(
create_info: *const c_void,
allocator: *const c_void,
instance: *mut VkInstance,
) -> i32;
pub fn vkDestroyInstance(instance: VkInstance, allocator: *const c_void);
pub fn vkEnumeratePhysicalDevices(
instance: VkInstance,
device_count: *mut u32,
physical_devices: *mut VkPhysicalDevice,
) -> i32;
pub fn vkGetPhysicalDeviceProperties(physical_device: VkPhysicalDevice, properties: *mut c_void);
pub fn vkCreateDevice(
physical_device: VkPhysicalDevice,
create_info: *const c_void,
allocator: *const c_void,
device: *mut VkDevice,
) -> i32;
pub fn vkDestroyDevice(device: VkDevice, allocator: *const c_void);
pub fn vkGetDeviceQueue(
device: VkDevice,
queue_family_index: u32,
queue_index: u32,
queue: *mut VkQueue,
);
pub fn vkCreateBuffer(
device: VkDevice,
create_info: *const c_void,
allocator: *const c_void,
buffer: *mut VkBuffer,
) -> i32;
pub fn vkDestroyBuffer(device: VkDevice, buffer: VkBuffer, allocator: *const c_void);
pub fn vkAllocateMemory(
device: VkDevice,
allocate_info: *const c_void,
allocator: *const c_void,
memory: *mut VkDeviceMemory,
) -> i32;
pub fn vkFreeMemory(device: VkDevice, memory: VkDeviceMemory, allocator: *const c_void);
pub fn vkBindBufferMemory(
device: VkDevice,
buffer: VkBuffer,
memory: VkDeviceMemory,
memory_offset: u64,
) -> i32;
pub fn vkMapMemory(
device: VkDevice,
memory: VkDeviceMemory,
offset: u64,
size: u64,
flags: u32,
data: *mut *mut c_void,
) -> i32;
pub fn vkUnmapMemory(device: VkDevice, memory: VkDeviceMemory);
pub fn vkCreateComputePipelines(
device: VkDevice,
pipeline_cache: *mut c_void,
create_info_count: u32,
create_infos: *const c_void,
allocator: *const c_void,
pipelines: *mut VkPipeline,
) -> i32;
pub fn vkDestroyPipeline(device: VkDevice, pipeline: VkPipeline, allocator: *const c_void);
pub fn vkAllocateCommandBuffers(
device: VkDevice,
allocate_info: *const c_void,
command_buffers: *mut VkCommandBuffer,
) -> i32;
pub fn vkFreeCommandBuffers(
device: VkDevice,
command_pool: *mut c_void,
command_buffer_count: u32,
command_buffers: *const VkCommandBuffer,
);
pub fn vkBeginCommandBuffer(command_buffer: VkCommandBuffer, begin_info: *const c_void) -> i32;
pub fn vkEndCommandBuffer(command_buffer: VkCommandBuffer) -> i32;
pub fn vkCmdBindPipeline(
command_buffer: VkCommandBuffer,
pipeline_bind_point: u32,
pipeline: VkPipeline,
);
pub fn vkCmdBindDescriptorSets(
command_buffer: VkCommandBuffer,
pipeline_bind_point: u32,
layout: *mut c_void,
first_set: u32,
descriptor_set_count: u32,
descriptor_sets: *const VkDescriptorSet,
dynamic_offset_count: u32,
dynamic_offsets: *const u32,
);
pub fn vkCmdDispatch(
command_buffer: VkCommandBuffer,
group_count_x: u32,
group_count_y: u32,
group_count_z: u32,
);
pub fn vkQueueSubmit(
queue: VkQueue,
submit_count: u32,
submits: *const c_void,
fence: *mut c_void,
) -> i32;
pub fn vkQueueWaitIdle(queue: VkQueue) -> i32;
pub fn vkDeviceWaitIdle(device: VkDevice) -> i32;
}
#[cfg(target_os = "android")]
pub struct VulkanComputeContext {
instance: VkInstance,
physical_device: VkPhysicalDevice,
device: VkDevice,
queue: VkQueue,
command_buffer: VkCommandBuffer,
}
#[cfg(target_os = "android")]
impl VulkanComputeContext {
pub fn new() -> Result<Self> {
let instance = Self::create_instance()?;
let physical_device = Self::select_physical_device(instance)?;
let device = Self::create_device(physical_device)?;
let queue = Self::get_compute_queue(device);
let command_buffer = Self::create_command_buffer(device)?;
Ok(Self {
instance,
physical_device,
device,
queue,
command_buffer,
})
}
fn create_instance() -> Result<VkInstance> {
let mut instance = VkInstance(std::ptr::null_mut());
let result = unsafe {
vkCreateInstance(
std::ptr::null(), std::ptr::null(), &mut instance,
)
};
if result != VK_SUCCESS {
return Err(TrustformersError::runtime_error(format!(
"Failed to create Vulkan instance: {}",
result
)).into());
}
tracing::info!("Vulkan instance created successfully");
Ok(instance)
}
fn select_physical_device(instance: VkInstance) -> Result<VkPhysicalDevice> {
let mut device_count: u32 = 0;
let mut devices: Vec<VkPhysicalDevice> = Vec::new();
let result = unsafe {
vkEnumeratePhysicalDevices(instance, &mut device_count, std::ptr::null_mut())
};
if result != VK_SUCCESS || device_count == 0 {
return Err(TrustformersError::runtime_error(
"No Vulkan physical devices found".into(),
).into());
}
devices.resize(device_count as usize, VkPhysicalDevice(std::ptr::null_mut()));
let result = unsafe {
vkEnumeratePhysicalDevices(instance, &mut device_count, devices.as_mut_ptr())
};
if result != VK_SUCCESS {
return Err(TrustformersError::runtime_error(
"Failed to enumerate Vulkan physical devices".into(),
).into());
}
tracing::info!("Selected Vulkan physical device 0 of {}", device_count);
Ok(devices[0])
}
fn create_device(physical_device: VkPhysicalDevice) -> Result<VkDevice> {
let mut device = VkDevice(std::ptr::null_mut());
let result = unsafe {
vkCreateDevice(
physical_device,
std::ptr::null(), std::ptr::null(), &mut device,
)
};
if result != VK_SUCCESS {
return Err(TrustformersError::runtime_error(format!(
"Failed to create Vulkan device: {}",
result
)).into());
}
tracing::info!("Vulkan logical device created successfully");
Ok(device)
}
fn get_compute_queue(device: VkDevice) -> VkQueue {
let mut queue = VkQueue(std::ptr::null_mut());
unsafe {
vkGetDeviceQueue(device, 0, 0, &mut queue); }
queue
}
fn create_command_buffer(device: VkDevice) -> Result<VkCommandBuffer> {
let mut command_buffer = VkCommandBuffer(std::ptr::null_mut());
let result = unsafe {
vkAllocateCommandBuffers(
device,
std::ptr::null(), &mut command_buffer,
)
};
if result != VK_SUCCESS {
return Err(TrustformersError::runtime_error(format!(
"Failed to create Vulkan command buffer: {}",
result
)).into());
}
Ok(command_buffer)
}
pub fn create_compute_pipeline(&self, operation: ComputeOperation) -> Result<VkPipeline> {
let mut pipeline = VkPipeline(std::ptr::null_mut());
let result = unsafe {
vkCreateComputePipelines(
self.device,
std::ptr::null_mut(), 1, std::ptr::null(), std::ptr::null(), &mut pipeline,
)
};
if result != VK_SUCCESS {
return Err(TrustformersError::runtime_error(format!(
"Failed to create Vulkan {:?} pipeline: {}",
operation, result
)).into());
}
tracing::info!("Created Vulkan {:?} compute pipeline", operation);
Ok(pipeline)
}
pub fn execute_compute(
&self,
pipeline: VkPipeline,
group_count_x: u32,
group_count_y: u32,
group_count_z: u32,
) -> Result<()> {
let result = unsafe {
vkBeginCommandBuffer(self.command_buffer, std::ptr::null())
};
if result != VK_SUCCESS {
return Err(TrustformersError::runtime_error(format!(
"Failed to begin command buffer: {}",
result
)).into());
}
unsafe {
vkCmdBindPipeline(
self.command_buffer,
VK_PIPELINE_BIND_POINT_COMPUTE,
pipeline,
);
}
unsafe {
vkCmdDispatch(self.command_buffer, group_count_x, group_count_y, group_count_z);
}
let result = unsafe { vkEndCommandBuffer(self.command_buffer) };
if result != VK_SUCCESS {
return Err(TrustformersError::runtime_error(format!(
"Failed to end command buffer: {}",
result
)).into());
}
let result = unsafe {
vkQueueSubmit(
self.queue,
1, std::ptr::null(), std::ptr::null_mut(), )
};
if result != VK_SUCCESS {
return Err(TrustformersError::runtime_error(format!(
"Failed to submit to queue: {}",
result
)).into());
}
let result = unsafe { vkQueueWaitIdle(self.queue) };
if result != VK_SUCCESS {
return Err(TrustformersError::runtime_error(format!(
"Failed to wait for queue idle: {}",
result
)).into());
}
Ok(())
}
pub fn create_buffer(&self, size: u64, usage: u32) -> Result<(VkBuffer, VkDeviceMemory)> {
let mut buffer = VkBuffer(std::ptr::null_mut());
let mut memory = VkDeviceMemory(std::ptr::null_mut());
let result = unsafe {
vkCreateBuffer(
self.device,
std::ptr::null(), std::ptr::null(), &mut buffer,
)
};
if result != VK_SUCCESS {
return Err(TrustformersError::runtime_error(format!(
"Failed to create Vulkan buffer: {}",
result
)).into());
}
let result = unsafe {
vkAllocateMemory(
self.device,
std::ptr::null(), std::ptr::null(), &mut memory,
)
};
if result != VK_SUCCESS {
unsafe {
vkDestroyBuffer(self.device, buffer, std::ptr::null());
}
return Err(TrustformersError::runtime_error(format!(
"Failed to allocate Vulkan memory: {}",
result
)).into());
}
let result = unsafe { vkBindBufferMemory(self.device, buffer, memory, 0) };
if result != VK_SUCCESS {
unsafe {
vkFreeMemory(self.device, memory, std::ptr::null());
vkDestroyBuffer(self.device, buffer, std::ptr::null());
}
return Err(TrustformersError::runtime_error(format!(
"Failed to bind buffer memory: {}",
result
)).into());
}
Ok((buffer, memory))
}
pub fn map_memory(&self, memory: VkDeviceMemory, size: u64) -> Result<*mut c_void> {
let mut data: *mut c_void = std::ptr::null_mut();
let result = unsafe {
vkMapMemory(self.device, memory, 0, size, 0, &mut data)
};
if result != VK_SUCCESS {
return Err(TrustformersError::runtime_error(format!(
"Failed to map Vulkan memory: {}",
result
)).into());
}
Ok(data)
}
pub fn unmap_memory(&self, memory: VkDeviceMemory) {
unsafe {
vkUnmapMemory(self.device, memory);
}
}
pub fn destroy_buffer(&self, buffer: VkBuffer, memory: VkDeviceMemory) {
unsafe {
vkDestroyBuffer(self.device, buffer, std::ptr::null());
vkFreeMemory(self.device, memory, std::ptr::null());
}
}
pub fn destroy_pipeline(&self, pipeline: VkPipeline) {
unsafe {
vkDestroyPipeline(self.device, pipeline, std::ptr::null());
}
}
pub fn get_device(&self) -> VkDevice {
self.device
}
pub fn get_queue(&self) -> VkQueue {
self.queue
}
pub fn get_command_buffer(&self) -> VkCommandBuffer {
self.command_buffer
}
}
#[cfg(target_os = "android")]
impl Drop for VulkanComputeContext {
fn drop(&mut self) {
unsafe {
vkDeviceWaitIdle(self.device);
vkDestroyDevice(self.device, std::ptr::null());
vkDestroyInstance(self.instance, std::ptr::null());
}
tracing::info!("Vulkan context destroyed");
}
}
#[derive(Debug, Clone, Copy)]
pub enum ComputeOperation {
Conv2D,
ReLU,
MatMul,
Add,
Pool2D,
}
pub fn is_vulkan_available() -> bool {
#[cfg(target_os = "android")]
{
true
}
#[cfg(not(target_os = "android"))]
{
false
}
}
pub mod utils {
use super::*;
pub fn vk_result_to_string(result: i32) -> &'static str {
match result {
VK_SUCCESS => "Success",
VK_NOT_READY => "Not ready",
VK_TIMEOUT => "Timeout",
VK_ERROR_OUT_OF_HOST_MEMORY => "Out of host memory",
VK_ERROR_OUT_OF_DEVICE_MEMORY => "Out of device memory",
VK_ERROR_INITIALIZATION_FAILED => "Initialization failed",
VK_ERROR_DEVICE_LOST => "Device lost",
VK_ERROR_MEMORY_MAP_FAILED => "Memory map failed",
_ => "Unknown error",
}
}
pub fn vk_is_success(result: i32) -> bool {
result == VK_SUCCESS
}
}
#[cfg(not(target_os = "android"))]
pub struct VulkanComputeContext;
#[cfg(not(target_os = "android"))]
impl VulkanComputeContext {
pub fn new() -> Result<Self> {
Err(TrustformersError::runtime_error(
"Vulkan is only available on Android".into(),
))
}
}
#[cfg(target_os = "android")]
unsafe impl Send for VkInstance {}
#[cfg(target_os = "android")]
unsafe impl Sync for VkInstance {}
#[cfg(target_os = "android")]
unsafe impl Send for VkDevice {}
#[cfg(target_os = "android")]
unsafe impl Sync for VkDevice {}
#[cfg(target_os = "android")]
unsafe impl Send for VkQueue {}
#[cfg(target_os = "android")]
unsafe impl Sync for VkQueue {}
#[cfg(test)]
mod tests {
use super::*;
use super::utils::*;
#[test]
fn test_vulkan_constants() {
assert_eq!(VK_SUCCESS, 0);
assert_eq!(VK_PIPELINE_BIND_POINT_COMPUTE, 1);
}
#[test]
fn test_result_handling() {
assert!(vk_is_success(VK_SUCCESS).into());
assert!(!vk_is_success(VK_ERROR_OUT_OF_HOST_MEMORY));
assert_eq!(vk_result_to_string(VK_SUCCESS), "Success");
assert_eq!(vk_result_to_string(VK_ERROR_OUT_OF_HOST_MEMORY), "Out of host memory");
}
#[test]
fn test_availability() {
let _available = is_vulkan_available();
}
#[cfg(target_os = "android")]
#[test]
fn test_context_creation() {
let context = VulkanComputeContext::new();
if context.is_err() {
return;
}
let context = context.expect("operation failed in test");
assert!(!context.get_device().0.is_null());
assert!(!context.get_queue().0.is_null());
}
}