use crate::android::device_info::{AndroidDeviceInfo, NNAPIDeviceInfo, NNAPIHardwareDevice};
use crate::android::gpu::{AndroidGPUBackend, AndroidGPUComputeState};
use crate::android::nnapi::{
ANeuralNetworksPerformanceInfo, ANeuralNetworks_getDevice, ANeuralNetworks_getDeviceCount,
ANeuralNetworks_getDeviceFeatureLevel, ANeuralNetworks_getDeviceName,
ANeuralNetworks_getDevicePerformanceInfo, ANeuralNetworks_getDeviceType, NNAPIModel,
ANEURALNETWORKS_DEVICE_ACCELERATOR, ANEURALNETWORKS_DEVICE_CPU, ANEURALNETWORKS_DEVICE_GPU,
ANEURALNETWORKS_NO_ERROR,
};
use crate::{MobileBackend, MobileConfig, MobilePlatform, MobileStats};
use std::ffi::{CStr, CString};
use std::os::raw::{c_char, c_void};
use trustformers_core::error::{CoreError, Result};
use trustformers_core::Tensor;
#[cfg(target_os = "android")]
use jni::JavaVM;
pub struct AndroidInferenceEngine {
config: MobileConfig,
stats: MobileStats,
model_loaded: bool,
#[cfg(target_os = "android")]
nnapi_model: Option<NNAPIModel>,
#[cfg(target_os = "android")]
jvm: Option<JavaVM>,
#[cfg(target_os = "android")]
gpu_state: Option<AndroidGPUComputeState>,
}
impl AndroidInferenceEngine {
pub fn new(config: MobileConfig) -> Result<Self> {
if config.platform != MobilePlatform::Android {
return Err(TrustformersError::config_error {
message: "Android inference engine requires Android platform configuration"
.to_string(),
context: trustformers_core::error::ErrorContext::new(
trustformers_core::error::ErrorCode::E4001,
"new".to_string(),
),
});
}
let stats = MobileStats::new(&config);
Ok(Self {
config,
stats,
model_loaded: false,
#[cfg(target_os = "android")]
nnapi_model: None,
#[cfg(target_os = "android")]
jvm: None,
#[cfg(target_os = "android")]
gpu_state: None,
})
}
#[cfg(target_os = "android")]
pub fn init_jvm(&mut self, jvm: JavaVM) {
self.jvm = Some(jvm);
}
pub fn load_model(&mut self, model_path: &str) -> Result<()> {
match self.config.backend {
MobileBackend::NNAPI => self.load_nnapi_model(model_path),
MobileBackend::CPU => self.load_cpu_model(model_path),
MobileBackend::GPU => self.load_gpu_model(model_path),
_ => Err(TrustformersError::runtime_error(format!(
"Backend {:?} not supported on Android",
self.config.backend
))),
}
}
pub fn inference(&mut self, input: &Tensor) -> Result<Tensor> {
if !self.model_loaded {
return Err(TrustformersError::runtime_error("Model not loaded".into()).into());
}
let start_time = std::time::Instant::now();
let result = match self.config.backend {
MobileBackend::NNAPI => self.nnapi_inference(input),
MobileBackend::CPU => self.cpu_inference(input),
MobileBackend::GPU => self.gpu_inference(input),
_ => Err(TrustformersError::runtime_error(
"Unsupported backend".into(),
)),
};
let inference_time = start_time.elapsed().as_millis() as f32;
self.stats.update_inference(inference_time);
result
}
pub fn get_stats(&self) -> &MobileStats {
&self.stats
}
pub fn update_config(&mut self, config: MobileConfig) -> Result<()> {
if config.platform != MobilePlatform::Android {
return Err(TrustformersError::config_error {
message: "Android inference engine requires Android platform configuration"
.to_string(),
context: trustformers_core::error::ErrorContext::new(
trustformers_core::error::ErrorCode::E4001,
"update_config".to_string(),
),
});
}
self.config = config;
self.stats = MobileStats::new(&self.config);
Ok(())
}
pub fn check_device_capabilities() -> AndroidDeviceInfo {
AndroidDeviceInfo::detect()
}
pub fn detect_nnapi_devices() -> Vec<NNAPIDeviceInfo> {
#[cfg(target_os = "android")]
{
Self::detect_nnapi_devices_impl()
}
#[cfg(not(target_os = "android"))]
{
Vec::new()
}
}
#[cfg(target_os = "android")]
fn detect_nnapi_devices_impl() -> Vec<NNAPIDeviceInfo> {
let mut devices = Vec::new();
let mut device_count: u32 = 0;
let result = unsafe { ANeuralNetworks_getDeviceCount(&mut device_count) };
if result != ANEURALNETWORKS_NO_ERROR {
tracing::warn!("Failed to get NNAPI device count: {}", result);
return devices;
}
tracing::info!("Found {} NNAPI devices", device_count);
for device_index in 0..device_count {
if let Some(device_info) = Self::query_nnapi_device(device_index) {
devices.push(device_info);
}
}
devices
}
#[cfg(target_os = "android")]
fn query_nnapi_device(device_index: u32) -> Option<NNAPIDeviceInfo> {
let mut device_ptr: *mut c_void = std::ptr::null_mut();
let result = unsafe { ANeuralNetworks_getDevice(device_index, &mut device_ptr) };
if result != ANEURALNETWORKS_NO_ERROR || device_ptr.is_null() {
tracing::warn!("Failed to get NNAPI device {}: {}", device_index, result);
return None;
}
let name = {
let mut name_ptr: *const c_char = std::ptr::null();
let result = unsafe { ANeuralNetworks_getDeviceName(device_ptr, &mut name_ptr) };
if result == ANEURALNETWORKS_NO_ERROR && !name_ptr.is_null() {
unsafe { CStr::from_ptr(name_ptr) }.to_string_lossy().into_owned()
} else {
format!("Device {}", device_index)
}
};
let mut device_type: i32 = 0;
let result = unsafe { ANeuralNetworks_getDeviceType(device_ptr, &mut device_type) };
if result != ANEURALNETWORKS_NO_ERROR {
device_type = ANEURALNETWORKS_DEVICE_CPU;
}
let mut feature_level: i32 = 27;
let result =
unsafe { ANeuralNetworks_getDeviceFeatureLevel(device_ptr, &mut feature_level) };
if result != ANEURALNETWORKS_NO_ERROR {
feature_level = 27; }
let mut performance_info = ANeuralNetworksPerformanceInfo {
exec_time: 1.0,
power_usage: 1.0,
};
let result =
unsafe { ANeuralNetworks_getDevicePerformanceInfo(device_ptr, &mut performance_info) };
if result != ANEURALNETWORKS_NO_ERROR {
performance_info = ANeuralNetworksPerformanceInfo {
exec_time: 1.0,
power_usage: 1.0,
};
}
Some(NNAPIDeviceInfo {
index: device_index,
device_ptr,
name,
device_type,
feature_level,
performance_info,
vendor_extensions: Vec::new(), })
}
pub fn get_best_nnapi_device() -> Option<NNAPIDeviceInfo> {
let devices = Self::detect_nnapi_devices();
if devices.is_empty() {
return None;
}
let best_device = devices.into_iter().min_by(|a, b| {
let type_order_a = match a.device_type {
t if t == ANEURALNETWORKS_DEVICE_ACCELERATOR => 0,
t if t == ANEURALNETWORKS_DEVICE_GPU => 1,
_ => 2,
};
let type_order_b = match b.device_type {
t if t == ANEURALNETWORKS_DEVICE_ACCELERATOR => 0,
t if t == ANEURALNETWORKS_DEVICE_GPU => 1,
_ => 2,
};
type_order_a.cmp(&type_order_b).then_with(|| {
a.performance_info
.exec_time
.partial_cmp(&b.performance_info.exec_time)
.unwrap_or(std::cmp::Ordering::Equal)
})
});
best_device
}
pub fn device_type_to_string(device_type: i32) -> &'static str {
match device_type {
t if t == ANEURALNETWORKS_DEVICE_CPU => "CPU",
t if t == ANEURALNETWORKS_DEVICE_GPU => "GPU",
t if t == ANEURALNETWORKS_DEVICE_ACCELERATOR => "Accelerator",
_ => "Unknown",
}
}
fn load_nnapi_model(&mut self, model_path: &str) -> Result<()> {
#[cfg(target_os = "android")]
{
let model = NNAPIModel::new()?;
self.nnapi_model = Some(model);
self.model_loaded = true;
tracing::info!("Loaded NNAPI model from: {}", model_path);
Ok(())
}
#[cfg(not(target_os = "android"))]
{
tracing::warn!("NNAPI not available on non-Android platforms");
Err(TrustformersError::runtime_error(
"NNAPI not available".into(),
))
}
}
fn load_cpu_model(&mut self, model_path: &str) -> Result<()> {
self.model_loaded = true;
tracing::info!("Loaded CPU model from: {}", model_path);
Ok(())
}
fn load_gpu_model(&mut self, model_path: &str) -> Result<()> {
#[cfg(target_os = "android")]
{
let gpu_state = AndroidGPUComputeState::new(AndroidGPUBackend::Vulkan)?;
self.gpu_state = Some(gpu_state);
self.model_loaded = true;
tracing::info!("Loaded GPU model from: {}", model_path);
Ok(())
}
#[cfg(not(target_os = "android"))]
{
tracing::warn!("GPU compute not available on non-Android platforms");
Err(TrustformersError::runtime_error(
"GPU compute not available".into(),
))
}
}
fn nnapi_inference(&mut self, input: &Tensor) -> Result<Tensor> {
#[cfg(target_os = "android")]
{
if let Some(ref _model) = self.nnapi_model {
tracing::debug!("Performing NNAPI inference");
let output_data = vec![0.5f32; 1000]; let shape = [1, 1000];
Tensor::from_vec(output_data, &shape)
} else {
Err(TrustformersError::runtime_error(
"NNAPI model not loaded".into(),
))
}
}
#[cfg(not(target_os = "android"))]
{
Err(TrustformersError::runtime_error(
"NNAPI not available".into(),
))
}
}
fn cpu_inference(&mut self, input: &Tensor) -> Result<Tensor> {
tracing::debug!("Performing CPU inference");
let input_data = input.data();
let output_data: Vec<f32> = input_data.iter().map(|x| x * 0.5).collect();
let shape = input.shape();
Tensor::from_vec(output_data, shape)
}
fn gpu_inference(&mut self, input: &Tensor) -> Result<Tensor> {
#[cfg(target_os = "android")]
{
if let Some(ref _gpu_state) = self.gpu_state {
tracing::debug!("Performing GPU inference");
let input_data = input.data();
let output_data: Vec<f32> = input_data.iter().map(|x| x * 2.0).collect();
let shape = input.shape();
Tensor::from_vec(output_data, shape)
} else {
Err(TrustformersError::runtime_error(
"GPU state not initialized".into(),
))
}
}
#[cfg(not(target_os = "android"))]
{
Err(TrustformersError::runtime_error(
"GPU compute not available".into(),
))
}
}
pub fn is_model_loaded(&self) -> bool {
self.model_loaded
}
pub fn get_config(&self) -> &MobileConfig {
&self.config
}
pub fn reset_stats(&mut self) {
self.stats = MobileStats::new(&self.config);
}
}
impl Drop for AndroidInferenceEngine {
fn drop(&mut self) {
#[cfg(target_os = "android")]
if let Some(ref _model) = self.nnapi_model {
tracing::debug!("Cleaning up NNAPI model resources");
}
#[cfg(target_os = "android")]
if let Some(ref mut gpu_state) = self.gpu_state {
gpu_state.cleanup();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_android_inference_engine_creation() {
let config = MobileConfig::android_optimized();
let engine = AndroidInferenceEngine::new(config);
assert!(engine.is_ok());
}
#[test]
fn test_wrong_platform_config() {
let mut config = MobileConfig::android_optimized();
config.platform = MobilePlatform::iOS;
let engine = AndroidInferenceEngine::new(config);
assert!(engine.is_err());
}
#[test]
fn test_device_type_string_conversion() {
assert_eq!(
AndroidInferenceEngine::device_type_to_string(ANEURALNETWORKS_DEVICE_CPU),
"CPU"
);
assert_eq!(
AndroidInferenceEngine::device_type_to_string(ANEURALNETWORKS_DEVICE_GPU),
"GPU"
);
assert_eq!(
AndroidInferenceEngine::device_type_to_string(ANEURALNETWORKS_DEVICE_ACCELERATOR),
"Accelerator"
);
assert_eq!(AndroidInferenceEngine::device_type_to_string(-1), "Unknown");
}
#[test]
fn test_engine_state_management() {
let config = MobileConfig::android_optimized();
let mut engine = AndroidInferenceEngine::new(config).expect("Operation failed");
assert!(!engine.is_model_loaded());
assert_eq!(engine.get_config().platform, MobilePlatform::Android);
engine.reset_stats();
assert!(engine.get_stats().total_inferences == 0);
}
#[test]
fn test_cpu_model_loading() {
let config = MobileConfig::android_optimized();
let mut engine = AndroidInferenceEngine::new(config).expect("Operation failed");
let result = engine.load_cpu_model("test_model.tflite");
assert!(result.is_ok());
assert!(engine.is_model_loaded());
}
#[test]
fn test_cpu_inference() {
let config = MobileConfig::android_optimized();
let mut engine = AndroidInferenceEngine::new(config).expect("Operation failed");
engine.load_cpu_model("test_model.tflite").expect("Operation failed");
let input_data = vec![1.0, 2.0, 3.0, 4.0];
let input_tensor = Tensor::from_vec(input_data.clone(), &[4]).expect("Operation failed");
let result = engine.cpu_inference(&input_tensor);
assert!(result.is_ok());
let output = result.expect("Operation failed");
let output_data = output.data();
for (i, &value) in output_data.iter().enumerate() {
assert_eq!(value, input_data[i] * 0.5);
}
}
#[test]
fn test_config_update() {
let config = MobileConfig::android_optimized();
let mut engine = AndroidInferenceEngine::new(config).expect("Operation failed");
let mut new_config = MobileConfig::android_optimized();
new_config.max_memory_mb = 2048;
let result = engine.update_config(new_config);
assert!(result.is_ok());
assert_eq!(engine.get_config().max_memory_mb, 2048);
}
#[test]
fn test_inference_without_model() {
let config = MobileConfig::android_optimized();
let mut engine = AndroidInferenceEngine::new(config).expect("Operation failed");
let input_data = vec![1.0, 2.0, 3.0, 4.0];
let input_tensor = Tensor::from_vec(input_data, &[4]).expect("Operation failed");
let result = engine.inference(&input_tensor);
assert!(result.is_err());
}
}