use super::execution_provider::ExecutionProviderKind;
use super::session::{ONNXSession, SessionOptions};
use crate::device::capabilities::{detect_capabilities, ThermalState};
use crate::ir::{Envelope, EnvelopeKind};
use crate::runtime_adapter::tensor_utils::{envelope_to_tensors, tensors_to_envelope};
use crate::runtime_adapter::{
AdapterError, AdapterResult, ModelMetadata, RuntimeAdapter, RuntimeAdapterExt,
};
use std::collections::HashMap;
use std::path::Path;
pub struct ONNXMobileRuntimeAdapter {
models: HashMap<String, ModelMetadata>,
sessions: HashMap<String, ONNXSession>,
current_model: Option<String>,
nnapi_available: bool,
gpu_available: bool,
battery_level: u8,
thermal_state: ThermalState,
}
impl ONNXMobileRuntimeAdapter {
pub fn new() -> Self {
let caps = detect_capabilities();
Self {
models: HashMap::new(),
sessions: HashMap::new(),
current_model: None,
nnapi_available: caps.has_nnapi,
gpu_available: caps.has_gpu,
battery_level: 100, thermal_state: ThermalState::Normal,
}
}
pub fn with_conditions(battery_level: u8, thermal_state: ThermalState) -> Self {
let caps = detect_capabilities();
Self {
models: HashMap::new(),
sessions: HashMap::new(),
current_model: None,
nnapi_available: caps.has_nnapi,
gpu_available: caps.has_gpu,
battery_level,
thermal_state,
}
}
pub fn has_nnapi(&self) -> bool {
self.nnapi_available
}
pub fn has_gpu(&self) -> bool {
self.gpu_available
}
pub fn battery_level(&self) -> u8 {
self.battery_level
}
pub fn set_battery_level(&mut self, level: u8) {
self.battery_level = level.min(100);
}
pub fn thermal_state(&self) -> ThermalState {
self.thermal_state
}
pub fn set_thermal_state(&mut self, state: ThermalState) {
self.thermal_state = state;
}
fn validate_model_file(&self, model_path: &str) -> AdapterResult<()> {
let path = Path::new(model_path);
if !path.exists() {
return Err(AdapterError::ModelNotFound(format!(
"Model file not found: {}",
model_path
)));
}
if !path.is_file() {
return Err(AdapterError::InvalidInput(format!(
"Path is not a file: {}",
model_path
)));
}
if let Some(ext) = path.extension() {
if ext != "onnx" && ext != "ONNX" {
}
}
Ok(())
}
fn extract_model_id(&self, path: &str) -> String {
Path::new(path)
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("unknown")
.to_string()
}
pub fn should_throttle(&self) -> bool {
if self.battery_level < 20 {
return true;
}
matches!(
self.thermal_state,
ThermalState::Hot | ThermalState::Critical
)
}
#[allow(dead_code)]
fn simulate_inference(&self, input: &Envelope) -> Envelope {
let output_text = if self.should_throttle() {
match &input.kind {
EnvelopeKind::Audio(_) => "onnx-mobile-throttled-transcribed text".to_string(),
EnvelopeKind::Text(text) => format!("onnx-mobile-throttled-{}-output", text),
EnvelopeKind::Embedding(_) => "onnx-mobile-throttled-similarity result".to_string(),
}
} else {
match &input.kind {
EnvelopeKind::Audio(_) => "onnx-mobile-transcribed text".to_string(),
EnvelopeKind::Text(text) => format!("onnx-mobile-{}-output", text),
EnvelopeKind::Embedding(_) => "onnx-mobile-similarity result".to_string(),
}
};
Envelope::new(EnvelopeKind::Text(output_text))
}
fn real_inference(&self, session: &ONNXSession, input: &Envelope) -> AdapterResult<Envelope> {
let input_shapes: Vec<Vec<i64>> = session.input_shapes().to_vec();
let input_names: Vec<String> = session.input_names().to_vec();
let input_tensors = envelope_to_tensors(input, &input_shapes, &input_names)?;
let output_tensors = session.run(input_tensors).map_err(|e| {
AdapterError::InferenceFailed(format!("ONNX Runtime inference failed: {}", e))
})?;
let output_names: Vec<String> = session.output_names().to_vec();
let output = tensors_to_envelope(&output_tensors, &output_names)?;
Ok(output)
}
}
impl Default for ONNXMobileRuntimeAdapter {
fn default() -> Self {
Self::new()
}
}
impl RuntimeAdapter for ONNXMobileRuntimeAdapter {
fn name(&self) -> &str {
"onnx-mobile"
}
fn supported_formats(&self) -> Vec<&'static str> {
vec!["onnx", "onnx.gz", "onnx.quantized"]
}
fn load_model(&mut self, path: &str) -> AdapterResult<()> {
self.validate_model_file(path)?;
let model_id = self.extract_model_id(path);
if self.models.contains_key(&model_id) {
log::warn!("Model '{}' is already loaded, skipping reload", model_id);
return Ok(());
}
let session =
ONNXSession::build(path, ExecutionProviderKind::Cpu, SessionOptions::default())?;
let input_shapes = session.input_shapes();
let output_shapes = session.output_shapes();
let input_names = session.input_names();
let output_names = session.output_names();
let mut input_schema = HashMap::new();
for (i, name) in input_names.iter().enumerate() {
if let Some(shape) = input_shapes.get(i) {
input_schema.insert(name.clone(), shape.iter().map(|&s| s as u64).collect());
}
}
let mut output_schema = HashMap::new();
for (i, name) in output_names.iter().enumerate() {
if let Some(shape) = output_shapes.get(i) {
output_schema.insert(name.clone(), shape.iter().map(|&s| s as u64).collect());
}
}
let metadata = ModelMetadata {
model_id: model_id.clone(),
version: "1.0.0".to_string(), runtime_type: "onnx-mobile".to_string(),
model_path: path.to_string(),
input_schema,
output_schema,
};
self.sessions.insert(model_id.clone(), session);
self.models.insert(model_id.clone(), metadata);
self.current_model = Some(model_id);
Ok(())
}
fn execute(&self, input: &Envelope) -> AdapterResult<Envelope> {
let model_id = self.current_model.as_ref().ok_or_else(|| {
AdapterError::ModelNotLoaded("No model loaded. Call load_model() first.".to_string())
})?;
let session = self.sessions.get(model_id).ok_or_else(|| {
AdapterError::ModelNotLoaded(format!("Session for model '{}' not found", model_id))
})?;
self.real_inference(session, input)
}
}
impl RuntimeAdapterExt for ONNXMobileRuntimeAdapter {
fn is_loaded(&self, model_id: &str) -> bool {
self.models.contains_key(model_id)
}
fn get_metadata(&self, model_id: &str) -> AdapterResult<&ModelMetadata> {
self.models.get(model_id).ok_or_else(|| {
AdapterError::ModelNotLoaded(format!("Model '{}' is not loaded", model_id))
})
}
fn infer(&self, model_id: &str, input: &Envelope) -> AdapterResult<Envelope> {
if !self.is_loaded(model_id) {
return Err(AdapterError::ModelNotLoaded(format!(
"Model '{}' is not loaded. Call load_model() first.",
model_id
)));
}
let session = self.sessions.get(model_id).ok_or_else(|| {
AdapterError::ModelNotLoaded(format!("Session for model '{}' not found", model_id))
})?;
self.real_inference(session, input)
}
fn unload_model(&mut self, model_id: &str) -> AdapterResult<()> {
if !self.models.contains_key(model_id) {
return Err(AdapterError::ModelNotLoaded(format!(
"Model '{}' is not loaded",
model_id
)));
}
self.sessions.remove(model_id);
self.models.remove(model_id);
if self.current_model.as_ref() == Some(&model_id.to_string()) {
self.current_model = None;
}
Ok(())
}
fn list_loaded_models(&self) -> Vec<String> {
self.models.keys().cloned().collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_adapter() {
let adapter = ONNXMobileRuntimeAdapter::new();
assert!(adapter.list_loaded_models().is_empty());
}
#[test]
fn test_adapter_name() {
let adapter = ONNXMobileRuntimeAdapter::new();
assert_eq!(adapter.name(), "onnx-mobile");
}
#[test]
fn test_supported_formats() {
let adapter = ONNXMobileRuntimeAdapter::new();
let formats = adapter.supported_formats();
assert!(formats.contains(&"onnx"));
assert!(formats.contains(&"onnx.gz"));
assert!(formats.contains(&"onnx.quantized"));
}
#[test]
fn test_nnapi_detection() {
let adapter = ONNXMobileRuntimeAdapter::new();
assert_eq!(
adapter.has_nnapi(),
crate::device::capabilities::detect_capabilities().has_nnapi
);
}
#[test]
fn test_gpu_detection() {
let adapter = ONNXMobileRuntimeAdapter::new();
assert_eq!(
adapter.has_gpu(),
crate::device::capabilities::detect_capabilities().has_gpu
);
}
#[test]
fn test_battery_level() {
let mut adapter = ONNXMobileRuntimeAdapter::new();
assert_eq!(adapter.battery_level(), 100);
adapter.set_battery_level(50);
assert_eq!(adapter.battery_level(), 50);
adapter.set_battery_level(150); assert_eq!(adapter.battery_level(), 100);
}
#[test]
fn test_thermal_state() {
let mut adapter = ONNXMobileRuntimeAdapter::new();
assert_eq!(adapter.thermal_state(), ThermalState::Normal);
adapter.set_thermal_state(ThermalState::Hot);
assert_eq!(adapter.thermal_state(), ThermalState::Hot);
}
#[test]
fn test_should_throttle_low_battery() {
let mut adapter = ONNXMobileRuntimeAdapter::new();
adapter.set_battery_level(15); assert!(adapter.should_throttle());
}
#[test]
fn test_should_throttle_hot_device() {
let mut adapter = ONNXMobileRuntimeAdapter::new();
adapter.set_thermal_state(ThermalState::Hot);
assert!(adapter.should_throttle());
}
#[test]
fn test_should_throttle_critical_device() {
let mut adapter = ONNXMobileRuntimeAdapter::new();
adapter.set_thermal_state(ThermalState::Critical);
assert!(adapter.should_throttle());
}
#[test]
fn test_should_not_throttle_normal() {
let adapter = ONNXMobileRuntimeAdapter::new();
assert!(!adapter.should_throttle());
}
#[test]
fn test_load_model_not_found() {
let mut adapter = ONNXMobileRuntimeAdapter::new();
let result = adapter.load_model("/nonexistent/model.onnx");
assert!(matches!(result, Err(AdapterError::ModelNotFound(_))));
}
#[test]
fn test_execute_no_model_loaded() {
let adapter = ONNXMobileRuntimeAdapter::new();
let input = Envelope::new(EnvelopeKind::Text("test".to_string()));
let result = adapter.execute(&input);
assert!(matches!(result, Err(AdapterError::ModelNotLoaded(_))));
}
#[test]
fn test_infer_model_not_loaded() {
let adapter = ONNXMobileRuntimeAdapter::new();
let input = Envelope::new(EnvelopeKind::Text("test".to_string()));
let result = adapter.infer("nonexistent-model", &input);
assert!(matches!(result, Err(AdapterError::ModelNotLoaded(_))));
}
}