#[cfg(all(target_os = "android", feature = "nnapi"))]
use crate::{MemoryOptimization, MobileConfig, MobileStats};
#[cfg(all(target_os = "android", feature = "nnapi"))]
use serde::{Deserialize, Serialize};
#[cfg(all(target_os = "android", feature = "nnapi"))]
use std::collections::HashMap;
#[cfg(all(target_os = "android", feature = "nnapi"))]
use std::time::Instant;
use trustformers_core::error::{CoreError, Result};
#[cfg(all(target_os = "android", feature = "nnapi"))]
use trustformers_core::Tensor;
use trustformers_core::TrustformersError;
#[cfg(all(target_os = "android", feature = "nnapi"))]
use jni::{
objects::{JByteArray, JClass, JObject, JString},
sys::{jbyteArray, jlong, jobject},
JNIEnv, JavaVM,
};
#[cfg(all(target_os = "android", feature = "nnapi"))]
pub struct NNAPIEngine {
config: NNAPIConfig,
model_handle: Option<usize>,
stats: NNAPIStats,
device_info: AndroidDeviceInfo,
jvm: Option<JavaVM>,
compilation_handle: Option<usize>,
}
#[cfg(all(target_os = "android", feature = "nnapi"))]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NNAPIConfig {
pub preferred_devices: Vec<NNAPIDeviceType>,
pub allow_relaxed_computation: bool,
pub enable_compilation_caching: bool,
pub execution_preference: NNAPIExecutionPreference,
pub max_concurrent_executions: usize,
pub use_memory_mapping: bool,
pub operation_timeout_ms: u32,
}
#[cfg(all(target_os = "android", feature = "nnapi"))]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum NNAPIDeviceType {
CPU,
GPU,
NPU,
DSP,
Accelerator,
Any,
}
#[cfg(all(target_os = "android", feature = "nnapi"))]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum NNAPIExecutionPreference {
FastSingleAnswer,
SustainedSpeed,
LowPower,
}
#[cfg(all(target_os = "android", feature = "nnapi"))]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AndroidDeviceInfo {
pub android_api_level: u32,
pub manufacturer: String,
pub device_model: String,
pub available_devices: Vec<NNAPIDeviceInfo>,
pub total_memory_mb: usize,
pub available_memory_mb: usize,
pub has_vulkan: bool,
pub opengl_es_version: String,
}
#[cfg(all(target_os = "android", feature = "nnapi"))]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NNAPIDeviceInfo {
pub name: String,
pub device_type: NNAPIDeviceType,
pub version: String,
pub supported_operations: Vec<String>,
pub performance_info: NNAPIPerformanceInfo,
}
#[cfg(all(target_os = "android", feature = "nnapi"))]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NNAPIPerformanceInfo {
pub exec_time: f32,
pub power_usage: f32,
pub memory_bandwidth_mbps: usize,
pub compute_throughput_ops: usize,
}
#[cfg(all(target_os = "android", feature = "nnapi"))]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NNAPIStats {
pub total_executions: usize,
pub avg_execution_time_ms: f32,
pub compilation_time_ms: f32,
pub memory_usage_mb: usize,
pub device_utilization: HashMap<String, f32>,
pub estimated_power_mw: f32,
pub compilation_cache_hit_rate: f32,
}
#[cfg(all(target_os = "android", feature = "nnapi"))]
impl NNAPIEngine {
pub fn new(config: NNAPIConfig) -> Result<Self> {
config.validate()?;
let device_info = Self::detect_device_info()?;
let stats = NNAPIStats::new();
Ok(Self {
config,
model_handle: None,
stats,
device_info,
jvm: None,
compilation_handle: None,
})
}
pub fn init_with_jvm(&mut self, jvm: JavaVM) -> Result<()> {
self.jvm = Some(jvm);
self.init_nnapi_context()?;
Ok(())
}
pub fn load_model(&mut self, model_data: &[u8]) -> Result<()> {
let start_time = Instant::now();
tracing::info!("Loading NNAPI model ({} bytes)", model_data.len());
let model_handle = self.create_nnapi_model(model_data)?;
let compilation_handle = self.compile_model(model_handle)?;
self.model_handle = Some(model_handle);
self.compilation_handle = Some(compilation_handle);
let compilation_time = start_time.elapsed().as_millis() as f32;
self.stats.compilation_time_ms = compilation_time;
tracing::info!(
"NNAPI model compiled successfully in {:.2}ms on {} devices",
compilation_time,
self.device_info.available_devices.len()
);
Ok(())
}
pub fn execute(&mut self, input: &HashMap<String, Tensor>) -> Result<HashMap<String, Tensor>> {
if self.compilation_handle.is_none() {
return Err(TrustformersError::runtime_error("NNAPI model not compiled".into()).into());
}
let start_time = Instant::now();
let execution_handle = self.create_execution()?;
self.set_input_tensors(execution_handle, input)?;
let output_tensors = self.prepare_output_tensors()?;
self.execute_inference(execution_handle)?;
let results = self.get_output_tensors(execution_handle, output_tensors)?;
self.cleanup_execution(execution_handle)?;
let execution_time = start_time.elapsed().as_millis() as f32;
self.stats.update_execution(execution_time);
Ok(results)
}
pub fn batch_execute(
&mut self,
inputs: &[HashMap<String, Tensor>],
) -> Result<Vec<HashMap<String, Tensor>>> {
let mut results = Vec::with_capacity(inputs.len());
for input in inputs {
let result = self.execute(input)?;
results.push(result);
}
Ok(results)
}
pub fn get_stats(&self) -> &NNAPIStats {
&self.stats
}
pub fn get_device_info(&self) -> &AndroidDeviceInfo {
&self.device_info
}
pub fn optimize_for_device(&mut self) -> Result<()> {
self.config.preferred_devices = self.select_optimal_devices();
self.config.execution_preference = self.select_execution_preference();
if self.device_info.android_api_level >= 30 {
self.config.enable_compilation_caching = true;
self.config.max_concurrent_executions = 4;
} else if self.device_info.android_api_level >= 29 {
self.config.enable_compilation_caching = true;
self.config.max_concurrent_executions = 2;
} else {
self.config.enable_compilation_caching = false;
self.config.max_concurrent_executions = 1;
}
tracing::info!(
"Optimized NNAPI configuration for {} (API {}) with {} devices",
self.device_info.device_model,
self.device_info.android_api_level,
self.device_info.available_devices.len()
);
Ok(())
}
fn detect_device_info() -> Result<AndroidDeviceInfo> {
Ok(AndroidDeviceInfo {
android_api_level: 30,
manufacturer: "Google".to_string(),
device_model: "Pixel".to_string(),
available_devices: vec![NNAPIDeviceInfo {
name: "CPU".to_string(),
device_type: NNAPIDeviceType::CPU,
version: "1.0".to_string(),
supported_operations: vec!["CONV_2D".to_string(), "FULLY_CONNECTED".to_string()],
performance_info: NNAPIPerformanceInfo {
exec_time: 1.0,
power_usage: 1.0,
memory_bandwidth_mbps: 1000,
compute_throughput_ops: 1000000,
},
}],
total_memory_mb: 4096,
available_memory_mb: 2048,
has_vulkan: true,
opengl_es_version: "3.2".to_string(),
})
}
fn init_nnapi_context(&self) -> Result<()> {
Ok(())
}
fn create_nnapi_model(&self, _model_data: &[u8]) -> Result<usize> {
Ok(1) }
fn compile_model(&self, _model_handle: usize) -> Result<usize> {
Ok(1) }
fn create_execution(&self) -> Result<usize> {
Ok(1) }
fn set_input_tensors(
&self,
_execution_handle: usize,
_input: &HashMap<String, Tensor>,
) -> Result<()> {
Ok(())
}
fn prepare_output_tensors(&self) -> Result<Vec<String>> {
Ok(vec!["output".to_string()])
}
fn execute_inference(&self, _execution_handle: usize) -> Result<()> {
Ok(())
}
fn get_output_tensors(
&self,
_execution_handle: usize,
_output_names: Vec<String>,
) -> Result<HashMap<String, Tensor>> {
Ok(HashMap::new())
}
fn cleanup_execution(&self, _execution_handle: usize) -> Result<()> {
Ok(())
}
fn select_optimal_devices(&self) -> Vec<NNAPIDeviceType> {
let mut devices = Vec::new();
for device in &self.device_info.available_devices {
match device.device_type {
NNAPIDeviceType::NPU => devices.push(NNAPIDeviceType::NPU),
NNAPIDeviceType::DSP => devices.push(NNAPIDeviceType::DSP),
NNAPIDeviceType::Accelerator => devices.push(NNAPIDeviceType::Accelerator),
_ => {},
}
}
if self.device_info.has_vulkan {
devices.push(NNAPIDeviceType::GPU);
}
devices.push(NNAPIDeviceType::CPU);
if devices.is_empty() {
devices.push(NNAPIDeviceType::Any);
}
devices
}
fn select_execution_preference(&self) -> NNAPIExecutionPreference {
if self.device_info.available_memory_mb < 1024 {
NNAPIExecutionPreference::LowPower
} else if self.device_info.available_devices.len() > 2 {
NNAPIExecutionPreference::SustainedSpeed
} else {
NNAPIExecutionPreference::FastSingleAnswer
}
}
}
#[cfg(all(target_os = "android", feature = "nnapi"))]
impl Default for NNAPIConfig {
fn default() -> Self {
Self {
preferred_devices: vec![NNAPIDeviceType::Any],
allow_relaxed_computation: true,
enable_compilation_caching: true,
execution_preference: NNAPIExecutionPreference::FastSingleAnswer,
max_concurrent_executions: 1,
use_memory_mapping: true,
operation_timeout_ms: 5000,
}
}
}
#[cfg(all(target_os = "android", feature = "nnapi"))]
impl NNAPIConfig {
pub fn validate(&self) -> Result<()> {
if self.preferred_devices.is_empty() {
return Err(TrustformersError::config_error {
message: "Must specify at least one preferred device".to_string(),
context: trustformers_core::error::ErrorContext::new(
trustformers_core::error::ErrorCode::E4001,
"validate".to_string(),
),
});
}
if self.max_concurrent_executions == 0 {
return Err(TrustformersError::config_error {
message: "Concurrent executions must be > 0".to_string(),
context: trustformers_core::error::ErrorContext::new(
trustformers_core::error::ErrorCode::E4001,
"validate".to_string(),
),
});
}
if self.max_concurrent_executions > 8 {
return Err(TrustformersError::config_error {
message: "Too many concurrent executions for Android".to_string(),
context: trustformers_core::error::ErrorContext::new(
trustformers_core::error::ErrorCode::E4001,
"validate".to_string(),
),
});
}
if self.operation_timeout_ms < 100 {
return Err(TrustformersError::config_error {
message: "Operation timeout too short".to_string(),
context: trustformers_core::error::ErrorContext::new(
trustformers_core::error::ErrorCode::E4001,
"validate".to_string(),
),
});
}
Ok(())
}
pub fn power_optimized() -> Self {
Self {
preferred_devices: vec![
NNAPIDeviceType::NPU,
NNAPIDeviceType::DSP,
NNAPIDeviceType::CPU,
],
allow_relaxed_computation: true,
enable_compilation_caching: true,
execution_preference: NNAPIExecutionPreference::LowPower,
max_concurrent_executions: 1,
use_memory_mapping: false,
operation_timeout_ms: 10000,
}
}
pub fn performance_optimized() -> Self {
Self {
preferred_devices: vec![
NNAPIDeviceType::GPU,
NNAPIDeviceType::NPU,
NNAPIDeviceType::CPU,
],
allow_relaxed_computation: true,
enable_compilation_caching: true,
execution_preference: NNAPIExecutionPreference::SustainedSpeed,
max_concurrent_executions: 4,
use_memory_mapping: true,
operation_timeout_ms: 2000,
}
}
}
#[cfg(all(target_os = "android", feature = "nnapi"))]
impl NNAPIStats {
fn new() -> Self {
Self {
total_executions: 0,
avg_execution_time_ms: 0.0,
compilation_time_ms: 0.0,
memory_usage_mb: 0,
device_utilization: HashMap::new(),
estimated_power_mw: 0.0,
compilation_cache_hit_rate: 0.0,
}
}
fn update_execution(&mut self, execution_time_ms: f32) {
self.total_executions += 1;
let alpha = 0.1;
if self.total_executions == 1 {
self.avg_execution_time_ms = execution_time_ms;
} else {
self.avg_execution_time_ms =
alpha * execution_time_ms + (1.0 - alpha) * self.avg_execution_time_ms;
}
}
}
#[cfg(all(target_os = "android", feature = "nnapi"))]
pub fn mobile_config_to_nnapi(mobile_config: &MobileConfig) -> NNAPIConfig {
let mut nnapi_config = NNAPIConfig::default();
match mobile_config.memory_optimization {
MemoryOptimization::Maximum => {
nnapi_config = NNAPIConfig::power_optimized();
nnapi_config.max_concurrent_executions = 1;
nnapi_config.use_memory_mapping = false;
},
MemoryOptimization::Balanced => {
nnapi_config.execution_preference = NNAPIExecutionPreference::FastSingleAnswer;
nnapi_config.max_concurrent_executions = 2;
nnapi_config.use_memory_mapping = true;
},
MemoryOptimization::Minimal => {
nnapi_config = NNAPIConfig::performance_optimized();
nnapi_config.max_concurrent_executions = mobile_config.num_threads.max(1);
},
}
nnapi_config.allow_relaxed_computation = mobile_config.use_fp16;
nnapi_config
}
#[cfg(all(target_os = "android", feature = "nnapi"))]
#[no_mangle]
pub extern "system" fn Java_com_trustformers_NNAPIEngine_createEngine(
env: JNIEnv,
_class: JClass,
config_json: JString,
) -> jlong {
let config_str: String = match env.get_string(config_json) {
Ok(s) => s.into(),
Err(_) => return 0,
};
match serde_json::from_str::<NNAPIConfig>(&config_str) {
Ok(config) => match NNAPIEngine::new(config) {
Ok(engine) => Box::into_raw(Box::new(engine)) as jlong,
Err(_) => 0,
},
Err(_) => 0,
}
}
#[cfg(all(target_os = "android", feature = "nnapi"))]
#[no_mangle]
pub extern "system" fn Java_com_trustformers_NNAPIEngine_loadModel(
_env: JNIEnv,
_class: JClass,
engine_ptr: jlong,
model_data: jbyteArray,
) -> jlong {
if engine_ptr == 0 {
return 0;
}
let engine = unsafe { &mut *(engine_ptr as *mut NNAPIEngine) };
let model_bytes = vec![0u8; 1024];
match engine.load_model(&model_bytes) {
Ok(_) => 1,
Err(_) => 0,
}
}
#[cfg(all(target_os = "android", feature = "nnapi"))]
#[no_mangle]
pub extern "system" fn Java_com_trustformers_NNAPIEngine_execute(
_env: JNIEnv,
_class: JClass,
engine_ptr: jlong,
input_data: jobject,
) -> jobject {
if engine_ptr == 0 {
return std::ptr::null_mut();
}
let engine = unsafe { &mut *(engine_ptr as *mut NNAPIEngine) };
let input = HashMap::new();
match engine.execute(&input) {
Ok(_output) => {
std::ptr::null_mut() },
Err(_) => std::ptr::null_mut(),
}
}
#[cfg(not(all(target_os = "android", feature = "nnapi")))]
pub struct NNAPIEngine;
#[cfg(not(all(target_os = "android", feature = "nnapi")))]
impl NNAPIEngine {
pub fn new(_config: ()) -> Result<Self> {
Err(TrustformersError::runtime_error("NNAPI only available on Android".into()).into())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(all(target_os = "android", feature = "nnapi"))]
#[test]
fn test_nnapi_config_validation() {
let config = NNAPIConfig::default();
assert!(config.validate().is_ok());
let mut invalid_config = config.clone();
invalid_config.preferred_devices.clear();
assert!(invalid_config.validate().is_err());
invalid_config.preferred_devices.push(NNAPIDeviceType::CPU);
invalid_config.max_concurrent_executions = 0;
assert!(invalid_config.validate().is_err());
}
#[cfg(all(target_os = "android", feature = "nnapi"))]
#[test]
fn test_optimized_configs() {
let power_config = NNAPIConfig::power_optimized();
assert_eq!(
power_config.execution_preference,
NNAPIExecutionPreference::LowPower
);
assert_eq!(power_config.max_concurrent_executions, 1);
assert!(!power_config.use_memory_mapping);
let perf_config = NNAPIConfig::performance_optimized();
assert_eq!(
perf_config.execution_preference,
NNAPIExecutionPreference::SustainedSpeed
);
assert_eq!(perf_config.max_concurrent_executions, 4);
assert!(perf_config.use_memory_mapping);
}
#[cfg(all(target_os = "android", feature = "nnapi"))]
#[test]
fn test_mobile_to_nnapi_config_conversion() {
let mobile_config = crate::MobileConfig {
memory_optimization: MemoryOptimization::Maximum,
num_threads: 1,
use_fp16: true,
..Default::default()
};
let nnapi_config = mobile_config_to_nnapi(&mobile_config);
assert_eq!(
nnapi_config.execution_preference,
NNAPIExecutionPreference::LowPower
);
assert_eq!(nnapi_config.max_concurrent_executions, 1);
assert!(nnapi_config.allow_relaxed_computation);
assert!(!nnapi_config.use_memory_mapping);
}
}