trustformers-mobile 0.1.1

Mobile deployment support for TrustformeRS (iOS, Android)
Documentation
//! Android GPU Acceleration Support
//!
//! This module provides GPU compute acceleration for Android devices using
//! both Vulkan and OpenGL ES backends for neural network inference.

use std::os::raw::{c_char, c_float, c_int, c_uint, c_void};
use trustformers_core::Result;

// Android GPU Backend Selection
#[cfg(target_os = "android")]
#[derive(Debug, Clone, Copy)]
pub enum AndroidGPUBackend {
    OpenGLES,
    Vulkan,
}

// OpenGL ES structures
#[cfg(target_os = "android")]
#[repr(C)]
pub struct EGLDisplay(*mut c_void);

#[cfg(target_os = "android")]
#[repr(C)]
pub struct EGLContext(*mut c_void);

#[cfg(target_os = "android")]
#[repr(C)]
pub struct EGLSurface(*mut c_void);

// Vulkan structures
#[cfg(target_os = "android")]
#[repr(C)]
pub struct VkInstance(*mut c_void);

#[cfg(target_os = "android")]
#[repr(C)]
pub struct VkDevice(*mut c_void);

#[cfg(target_os = "android")]
#[repr(C)]
pub struct VkPhysicalDevice(*mut c_void);

#[cfg(target_os = "android")]
#[repr(C)]
pub struct VkQueue(*mut c_void);

#[cfg(target_os = "android")]
#[repr(C)]
pub struct VkCommandBuffer(*mut c_void);

#[cfg(target_os = "android")]
#[repr(C)]
pub struct VkBuffer(*mut c_void);

#[cfg(target_os = "android")]
#[repr(C)]
pub struct VkDeviceMemory(*mut c_void);

#[cfg(target_os = "android")]
#[repr(C)]
pub struct VkDescriptorSet(*mut c_void);

#[cfg(target_os = "android")]
#[repr(C)]
pub struct VkPipeline(*mut c_void);

// Android GPU Compute State
#[cfg(target_os = "android")]
pub struct AndroidGPUComputeState {
    pub backend: AndroidGPUBackend,
    // OpenGL ES state
    pub egl_display: Option<EGLDisplay>,
    pub egl_context: Option<EGLContext>,
    pub egl_surface: Option<EGLSurface>,
    pub compute_program: Option<u32>,
    // Vulkan state
    pub vk_instance: Option<VkInstance>,
    pub vk_device: Option<VkDevice>,
    pub vk_physical_device: Option<VkPhysicalDevice>,
    pub vk_queue: Option<VkQueue>,
    pub vk_command_buffer: Option<VkCommandBuffer>,
    pub vk_conv2d_pipeline: Option<VkPipeline>,
    pub vk_relu_pipeline: Option<VkPipeline>,
    pub vk_matmul_pipeline: Option<VkPipeline>,
}

// OpenGL ES C API bindings
#[cfg(target_os = "android")]
extern "C" {
    pub fn eglGetDisplay(display_id: *mut c_void) -> EGLDisplay;
    pub fn eglInitialize(display: EGLDisplay, major: *mut c_int, minor: *mut c_int) -> c_int;
    pub fn eglCreateContext(
        display: EGLDisplay,
        config: *mut c_void,
        share_context: EGLContext,
        attrib_list: *const c_int,
    ) -> EGLContext;
    pub fn eglMakeCurrent(
        display: EGLDisplay,
        draw: EGLSurface,
        read: EGLSurface,
        context: EGLContext,
    ) -> c_int;
    pub fn eglTerminate(display: EGLDisplay) -> c_int;
    pub fn eglDestroyContext(display: EGLDisplay, context: EGLContext) -> c_int;

    pub fn glCreateProgram() -> c_uint;
    pub fn glCreateShader(shader_type: c_uint) -> c_uint;
    pub fn glShaderSource(
        shader: c_uint,
        count: c_int,
        string: *const *const c_char,
        length: *const c_int,
    );
    pub fn glCompileShader(shader: c_uint);
    pub fn glAttachShader(program: c_uint, shader: c_uint);
    pub fn glLinkProgram(program: c_uint);
    pub fn glUseProgram(program: c_uint);
    pub fn glDeleteProgram(program: c_uint);
    pub fn glDeleteShader(shader: c_uint);
    pub fn glGenBuffers(n: c_int, buffers: *mut c_uint);
    pub fn glBindBuffer(target: c_uint, buffer: c_uint);
    pub fn glBufferData(target: c_uint, size: isize, data: *const c_void, usage: c_uint);
    pub fn glBindBufferBase(target: c_uint, index: c_uint, buffer: c_uint);
    pub fn glDispatchCompute(num_groups_x: c_uint, num_groups_y: c_uint, num_groups_z: c_uint);
    pub fn glMemoryBarrier(barriers: c_uint);
    pub fn glDeleteBuffers(n: c_int, buffers: *const c_uint);
}

// Vulkan C API bindings
#[cfg(target_os = "android")]
extern "C" {
    pub fn vkCreateInstance(
        create_info: *const c_void,
        allocator: *const c_void,
        instance: *mut VkInstance,
    ) -> i32;
    pub fn vkEnumeratePhysicalDevices(
        instance: VkInstance,
        device_count: *mut u32,
        physical_devices: *mut VkPhysicalDevice,
    ) -> i32;
    pub fn vkGetPhysicalDeviceProperties(
        physical_device: VkPhysicalDevice,
        properties: *mut c_void,
    );
    pub fn vkCreateDevice(
        physical_device: VkPhysicalDevice,
        create_info: *const c_void,
        allocator: *const c_void,
        device: *mut VkDevice,
    ) -> i32;
    pub fn vkGetDeviceQueue(
        device: VkDevice,
        queue_family_index: u32,
        queue_index: u32,
        queue: *mut VkQueue,
    );
    pub fn vkCreateBuffer(
        device: VkDevice,
        create_info: *const c_void,
        allocator: *const c_void,
        buffer: *mut VkBuffer,
    ) -> i32;
    pub fn vkAllocateMemory(
        device: VkDevice,
        allocate_info: *const c_void,
        allocator: *const c_void,
        memory: *mut VkDeviceMemory,
    ) -> i32;
    pub fn vkBindBufferMemory(
        device: VkDevice,
        buffer: VkBuffer,
        memory: VkDeviceMemory,
        memory_offset: u64,
    ) -> i32;
    pub fn vkMapMemory(
        device: VkDevice,
        memory: VkDeviceMemory,
        offset: u64,
        size: u64,
        flags: u32,
        data: *mut *mut c_void,
    ) -> i32;
    pub fn vkUnmapMemory(device: VkDevice, memory: VkDeviceMemory);
    pub fn vkCreateComputePipelines(
        device: VkDevice,
        pipeline_cache: *mut c_void,
        create_info_count: u32,
        create_infos: *const c_void,
        allocator: *const c_void,
        pipelines: *mut VkPipeline,
    ) -> i32;
    pub fn vkAllocateCommandBuffers(
        device: VkDevice,
        allocate_info: *const c_void,
        command_buffers: *mut VkCommandBuffer,
    ) -> i32;
    pub fn vkBeginCommandBuffer(command_buffer: VkCommandBuffer, begin_info: *const c_void) -> i32;
    pub fn vkCmdBindPipeline(
        command_buffer: VkCommandBuffer,
        pipeline_bind_point: u32,
        pipeline: VkPipeline,
    );
    pub fn vkCmdBindDescriptorSets(
        command_buffer: VkCommandBuffer,
        pipeline_bind_point: u32,
        layout: *mut c_void,
        first_set: u32,
        descriptor_set_count: u32,
        descriptor_sets: *const VkDescriptorSet,
        dynamic_offset_count: u32,
        dynamic_offsets: *const u32,
    );
    pub fn vkCmdDispatch(
        command_buffer: VkCommandBuffer,
        group_count_x: u32,
        group_count_y: u32,
        group_count_z: u32,
    );
    pub fn vkEndCommandBuffer(command_buffer: VkCommandBuffer) -> i32;
    pub fn vkQueueSubmit(
        queue: VkQueue,
        submit_count: u32,
        submits: *const c_void,
        fence: *mut c_void,
    ) -> i32;
    pub fn vkQueueWaitIdle(queue: VkQueue) -> i32;
    pub fn vkDestroyDevice(device: VkDevice, allocator: *const c_void);
    pub fn vkDestroyInstance(instance: VkInstance, allocator: *const c_void);
}

#[cfg(target_os = "android")]
impl AndroidGPUComputeState {
    pub fn new(backend: AndroidGPUBackend) -> Result<Self> {
        Ok(Self {
            backend,
            egl_display: None,
            egl_context: None,
            egl_surface: None,
            compute_program: None,
            vk_instance: None,
            vk_device: None,
            vk_physical_device: None,
            vk_queue: None,
            vk_command_buffer: None,
            vk_conv2d_pipeline: None,
            vk_relu_pipeline: None,
            vk_matmul_pipeline: None,
        })
    }

    pub fn is_initialized(&self) -> bool {
        match self.backend {
            AndroidGPUBackend::OpenGLES => self.egl_display.is_some() && self.egl_context.is_some(),
            AndroidGPUBackend::Vulkan => self.vk_instance.is_some() && self.vk_device.is_some(),
        }
    }

    pub fn get_backend(&self) -> AndroidGPUBackend {
        self.backend
    }

    pub fn supports_compute(&self) -> bool {
        match self.backend {
            AndroidGPUBackend::OpenGLES => self.compute_program.is_some(),
            AndroidGPUBackend::Vulkan => {
                self.vk_conv2d_pipeline.is_some()
                    || self.vk_relu_pipeline.is_some()
                    || self.vk_matmul_pipeline.is_some()
            },
        }
    }

    pub fn cleanup(&mut self) {
        match self.backend {
            AndroidGPUBackend::OpenGLES => {
                // Cleanup OpenGL ES resources
                if let Some(program) = self.compute_program {
                    unsafe {
                        glDeleteProgram(program);
                    }
                }
                if let (Some(display), Some(context)) = (self.egl_display, self.egl_context) {
                    unsafe {
                        eglDestroyContext(display, context);
                        eglTerminate(display);
                    }
                }
            },
            AndroidGPUBackend::Vulkan => {
                // Cleanup Vulkan resources
                if let Some(device) = self.vk_device {
                    unsafe {
                        vkDestroyDevice(device, std::ptr::null());
                    }
                }
                if let Some(instance) = self.vk_instance {
                    unsafe {
                        vkDestroyInstance(instance, std::ptr::null());
                    }
                }
            },
        }
    }
}

#[cfg(target_os = "android")]
impl Drop for AndroidGPUComputeState {
    fn drop(&mut self) {
        self.cleanup();
    }
}

// Non-Android stub implementations
#[cfg(not(target_os = "android"))]
#[derive(Debug, Clone, Copy)]
pub enum AndroidGPUBackend {
    OpenGLES,
    Vulkan,
}

#[cfg(not(target_os = "android"))]
pub struct AndroidGPUComputeState {
    backend: AndroidGPUBackend,
}

#[cfg(not(target_os = "android"))]
impl AndroidGPUComputeState {
    pub fn new(backend: AndroidGPUBackend) -> Result<Self> {
        Ok(Self { backend })
    }

    pub fn is_initialized(&self) -> bool {
        false
    }

    pub fn get_backend(&self) -> AndroidGPUBackend {
        self.backend
    }

    pub fn supports_compute(&self) -> bool {
        false
    }

    pub fn cleanup(&mut self) {
        // No-op for non-Android
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_android_gpu_backend_selection() {
        let opengl_state = AndroidGPUComputeState::new(AndroidGPUBackend::OpenGLES);
        assert!(opengl_state.is_ok());

        let vulkan_state = AndroidGPUComputeState::new(AndroidGPUBackend::Vulkan);
        assert!(vulkan_state.is_ok());
    }

    #[test]
    fn test_gpu_state_default_state() {
        let state = AndroidGPUComputeState::new(AndroidGPUBackend::OpenGLES)
            .expect("operation failed in test");
        assert!(!state.is_initialized());
        assert!(!state.supports_compute());

        match state.get_backend() {
            AndroidGPUBackend::OpenGLES => assert!(true),
            AndroidGPUBackend::Vulkan => panic!("Expected OpenGLES backend"),
        }
    }

    #[test]
    fn test_vulkan_backend_properties() {
        let state = AndroidGPUComputeState::new(AndroidGPUBackend::Vulkan)
            .expect("operation failed in test");
        assert!(!state.is_initialized());
        assert!(!state.supports_compute());

        match state.get_backend() {
            AndroidGPUBackend::Vulkan => assert!(true),
            AndroidGPUBackend::OpenGLES => panic!("Expected Vulkan backend"),
        }
    }

    #[test]
    fn test_cleanup_safety() {
        let mut state = AndroidGPUComputeState::new(AndroidGPUBackend::OpenGLES)
            .expect("operation failed in test");
        // Should not panic even without initialization
        state.cleanup();
        // Should be safe to call multiple times
        state.cleanup();
    }
}