use std::os::raw::{c_char, c_float, c_int, c_uint, c_void};
use trustformers_core::Result;
#[cfg(target_os = "android")]
#[derive(Debug, Clone, Copy)]
pub enum AndroidGPUBackend {
OpenGLES,
Vulkan,
}
#[cfg(target_os = "android")]
#[repr(C)]
pub struct EGLDisplay(*mut c_void);
#[cfg(target_os = "android")]
#[repr(C)]
pub struct EGLContext(*mut c_void);
#[cfg(target_os = "android")]
#[repr(C)]
pub struct EGLSurface(*mut c_void);
#[cfg(target_os = "android")]
#[repr(C)]
pub struct VkInstance(*mut c_void);
#[cfg(target_os = "android")]
#[repr(C)]
pub struct VkDevice(*mut c_void);
#[cfg(target_os = "android")]
#[repr(C)]
pub struct VkPhysicalDevice(*mut c_void);
#[cfg(target_os = "android")]
#[repr(C)]
pub struct VkQueue(*mut c_void);
#[cfg(target_os = "android")]
#[repr(C)]
pub struct VkCommandBuffer(*mut c_void);
#[cfg(target_os = "android")]
#[repr(C)]
pub struct VkBuffer(*mut c_void);
#[cfg(target_os = "android")]
#[repr(C)]
pub struct VkDeviceMemory(*mut c_void);
#[cfg(target_os = "android")]
#[repr(C)]
pub struct VkDescriptorSet(*mut c_void);
#[cfg(target_os = "android")]
#[repr(C)]
pub struct VkPipeline(*mut c_void);
#[cfg(target_os = "android")]
pub struct AndroidGPUComputeState {
pub backend: AndroidGPUBackend,
pub egl_display: Option<EGLDisplay>,
pub egl_context: Option<EGLContext>,
pub egl_surface: Option<EGLSurface>,
pub compute_program: Option<u32>,
pub vk_instance: Option<VkInstance>,
pub vk_device: Option<VkDevice>,
pub vk_physical_device: Option<VkPhysicalDevice>,
pub vk_queue: Option<VkQueue>,
pub vk_command_buffer: Option<VkCommandBuffer>,
pub vk_conv2d_pipeline: Option<VkPipeline>,
pub vk_relu_pipeline: Option<VkPipeline>,
pub vk_matmul_pipeline: Option<VkPipeline>,
}
#[cfg(target_os = "android")]
extern "C" {
pub fn eglGetDisplay(display_id: *mut c_void) -> EGLDisplay;
pub fn eglInitialize(display: EGLDisplay, major: *mut c_int, minor: *mut c_int) -> c_int;
pub fn eglCreateContext(
display: EGLDisplay,
config: *mut c_void,
share_context: EGLContext,
attrib_list: *const c_int,
) -> EGLContext;
pub fn eglMakeCurrent(
display: EGLDisplay,
draw: EGLSurface,
read: EGLSurface,
context: EGLContext,
) -> c_int;
pub fn eglTerminate(display: EGLDisplay) -> c_int;
pub fn eglDestroyContext(display: EGLDisplay, context: EGLContext) -> c_int;
pub fn glCreateProgram() -> c_uint;
pub fn glCreateShader(shader_type: c_uint) -> c_uint;
pub fn glShaderSource(
shader: c_uint,
count: c_int,
string: *const *const c_char,
length: *const c_int,
);
pub fn glCompileShader(shader: c_uint);
pub fn glAttachShader(program: c_uint, shader: c_uint);
pub fn glLinkProgram(program: c_uint);
pub fn glUseProgram(program: c_uint);
pub fn glDeleteProgram(program: c_uint);
pub fn glDeleteShader(shader: c_uint);
pub fn glGenBuffers(n: c_int, buffers: *mut c_uint);
pub fn glBindBuffer(target: c_uint, buffer: c_uint);
pub fn glBufferData(target: c_uint, size: isize, data: *const c_void, usage: c_uint);
pub fn glBindBufferBase(target: c_uint, index: c_uint, buffer: c_uint);
pub fn glDispatchCompute(num_groups_x: c_uint, num_groups_y: c_uint, num_groups_z: c_uint);
pub fn glMemoryBarrier(barriers: c_uint);
pub fn glDeleteBuffers(n: c_int, buffers: *const c_uint);
}
#[cfg(target_os = "android")]
extern "C" {
pub fn vkCreateInstance(
create_info: *const c_void,
allocator: *const c_void,
instance: *mut VkInstance,
) -> i32;
pub fn vkEnumeratePhysicalDevices(
instance: VkInstance,
device_count: *mut u32,
physical_devices: *mut VkPhysicalDevice,
) -> i32;
pub fn vkGetPhysicalDeviceProperties(
physical_device: VkPhysicalDevice,
properties: *mut c_void,
);
pub fn vkCreateDevice(
physical_device: VkPhysicalDevice,
create_info: *const c_void,
allocator: *const c_void,
device: *mut VkDevice,
) -> i32;
pub fn vkGetDeviceQueue(
device: VkDevice,
queue_family_index: u32,
queue_index: u32,
queue: *mut VkQueue,
);
pub fn vkCreateBuffer(
device: VkDevice,
create_info: *const c_void,
allocator: *const c_void,
buffer: *mut VkBuffer,
) -> i32;
pub fn vkAllocateMemory(
device: VkDevice,
allocate_info: *const c_void,
allocator: *const c_void,
memory: *mut VkDeviceMemory,
) -> i32;
pub fn vkBindBufferMemory(
device: VkDevice,
buffer: VkBuffer,
memory: VkDeviceMemory,
memory_offset: u64,
) -> i32;
pub fn vkMapMemory(
device: VkDevice,
memory: VkDeviceMemory,
offset: u64,
size: u64,
flags: u32,
data: *mut *mut c_void,
) -> i32;
pub fn vkUnmapMemory(device: VkDevice, memory: VkDeviceMemory);
pub fn vkCreateComputePipelines(
device: VkDevice,
pipeline_cache: *mut c_void,
create_info_count: u32,
create_infos: *const c_void,
allocator: *const c_void,
pipelines: *mut VkPipeline,
) -> i32;
pub fn vkAllocateCommandBuffers(
device: VkDevice,
allocate_info: *const c_void,
command_buffers: *mut VkCommandBuffer,
) -> i32;
pub fn vkBeginCommandBuffer(command_buffer: VkCommandBuffer, begin_info: *const c_void) -> i32;
pub fn vkCmdBindPipeline(
command_buffer: VkCommandBuffer,
pipeline_bind_point: u32,
pipeline: VkPipeline,
);
pub fn vkCmdBindDescriptorSets(
command_buffer: VkCommandBuffer,
pipeline_bind_point: u32,
layout: *mut c_void,
first_set: u32,
descriptor_set_count: u32,
descriptor_sets: *const VkDescriptorSet,
dynamic_offset_count: u32,
dynamic_offsets: *const u32,
);
pub fn vkCmdDispatch(
command_buffer: VkCommandBuffer,
group_count_x: u32,
group_count_y: u32,
group_count_z: u32,
);
pub fn vkEndCommandBuffer(command_buffer: VkCommandBuffer) -> i32;
pub fn vkQueueSubmit(
queue: VkQueue,
submit_count: u32,
submits: *const c_void,
fence: *mut c_void,
) -> i32;
pub fn vkQueueWaitIdle(queue: VkQueue) -> i32;
pub fn vkDestroyDevice(device: VkDevice, allocator: *const c_void);
pub fn vkDestroyInstance(instance: VkInstance, allocator: *const c_void);
}
#[cfg(target_os = "android")]
impl AndroidGPUComputeState {
pub fn new(backend: AndroidGPUBackend) -> Result<Self> {
Ok(Self {
backend,
egl_display: None,
egl_context: None,
egl_surface: None,
compute_program: None,
vk_instance: None,
vk_device: None,
vk_physical_device: None,
vk_queue: None,
vk_command_buffer: None,
vk_conv2d_pipeline: None,
vk_relu_pipeline: None,
vk_matmul_pipeline: None,
})
}
pub fn is_initialized(&self) -> bool {
match self.backend {
AndroidGPUBackend::OpenGLES => self.egl_display.is_some() && self.egl_context.is_some(),
AndroidGPUBackend::Vulkan => self.vk_instance.is_some() && self.vk_device.is_some(),
}
}
pub fn get_backend(&self) -> AndroidGPUBackend {
self.backend
}
pub fn supports_compute(&self) -> bool {
match self.backend {
AndroidGPUBackend::OpenGLES => self.compute_program.is_some(),
AndroidGPUBackend::Vulkan => {
self.vk_conv2d_pipeline.is_some()
|| self.vk_relu_pipeline.is_some()
|| self.vk_matmul_pipeline.is_some()
},
}
}
pub fn cleanup(&mut self) {
match self.backend {
AndroidGPUBackend::OpenGLES => {
if let Some(program) = self.compute_program {
unsafe {
glDeleteProgram(program);
}
}
if let (Some(display), Some(context)) = (self.egl_display, self.egl_context) {
unsafe {
eglDestroyContext(display, context);
eglTerminate(display);
}
}
},
AndroidGPUBackend::Vulkan => {
if let Some(device) = self.vk_device {
unsafe {
vkDestroyDevice(device, std::ptr::null());
}
}
if let Some(instance) = self.vk_instance {
unsafe {
vkDestroyInstance(instance, std::ptr::null());
}
}
},
}
}
}
#[cfg(target_os = "android")]
impl Drop for AndroidGPUComputeState {
fn drop(&mut self) {
self.cleanup();
}
}
#[cfg(not(target_os = "android"))]
#[derive(Debug, Clone, Copy)]
pub enum AndroidGPUBackend {
OpenGLES,
Vulkan,
}
#[cfg(not(target_os = "android"))]
pub struct AndroidGPUComputeState {
backend: AndroidGPUBackend,
}
#[cfg(not(target_os = "android"))]
impl AndroidGPUComputeState {
pub fn new(backend: AndroidGPUBackend) -> Result<Self> {
Ok(Self { backend })
}
pub fn is_initialized(&self) -> bool {
false
}
pub fn get_backend(&self) -> AndroidGPUBackend {
self.backend
}
pub fn supports_compute(&self) -> bool {
false
}
pub fn cleanup(&mut self) {
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_android_gpu_backend_selection() {
let opengl_state = AndroidGPUComputeState::new(AndroidGPUBackend::OpenGLES);
assert!(opengl_state.is_ok());
let vulkan_state = AndroidGPUComputeState::new(AndroidGPUBackend::Vulkan);
assert!(vulkan_state.is_ok());
}
#[test]
fn test_gpu_state_default_state() {
let state = AndroidGPUComputeState::new(AndroidGPUBackend::OpenGLES)
.expect("operation failed in test");
assert!(!state.is_initialized());
assert!(!state.supports_compute());
match state.get_backend() {
AndroidGPUBackend::OpenGLES => assert!(true),
AndroidGPUBackend::Vulkan => panic!("Expected OpenGLES backend"),
}
}
#[test]
fn test_vulkan_backend_properties() {
let state = AndroidGPUComputeState::new(AndroidGPUBackend::Vulkan)
.expect("operation failed in test");
assert!(!state.is_initialized());
assert!(!state.supports_compute());
match state.get_backend() {
AndroidGPUBackend::Vulkan => assert!(true),
AndroidGPUBackend::OpenGLES => panic!("Expected Vulkan backend"),
}
}
#[test]
fn test_cleanup_safety() {
let mut state = AndroidGPUComputeState::new(AndroidGPUBackend::OpenGLES)
.expect("operation failed in test");
state.cleanup();
state.cleanup();
}
}