use crate::{
optimization::MobileOptimizationEngine, MobileBackend, MobileConfig, MobilePlatform,
MobileStats,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
use std::time::Instant;
use trustformers_core::errors::{invalid_input, runtime_error, Result};
use trustformers_core::Tensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ModelFormat {
SafeTensors,
PyTorch,
ONNX,
TensorFlow,
Unknown,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ExecutionStrategy {
Sequential,
LayerParallel,
FullParallel,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionPlan {
pub strategy: ExecutionStrategy,
pub num_layers: usize,
pub batch_size: usize,
pub checkpoint_interval: usize,
}
impl ExecutionPlan {
pub fn new(strategy: ExecutionStrategy, num_layers: usize) -> Self {
Self {
strategy,
num_layers,
batch_size: 1,
checkpoint_interval: 0, }
}
}
#[derive(Debug)]
pub struct MobileInferenceEngine {
config: MobileConfig,
optimizer: MobileOptimizationEngine,
execution_plan: ExecutionPlan,
stats: MobileStats,
model_loaded: bool,
model_weights: Option<HashMap<String, Tensor>>,
cache: Option<InferenceCache>,
}
impl MobileInferenceEngine {
pub fn new(config: MobileConfig) -> Result<Self> {
config.validate()?;
let optimizer = MobileOptimizationEngine::new(config.clone())?;
let execution_plan = ExecutionPlan::new(ExecutionStrategy::Sequential, 12); let stats = MobileStats::new(&config);
Ok(Self {
config,
optimizer,
execution_plan,
stats,
model_loaded: false,
model_weights: None,
cache: None,
})
}
pub fn load_model(&mut self, weights: HashMap<String, Tensor>) -> Result<()> {
tracing::info!("Loading model with {} parameters", weights.len());
let optimized_weights = self.optimizer.optimize_model_weights(&weights)?;
let total_params: usize =
optimized_weights.values().map(|t| t.shape().iter().product::<usize>()).sum();
let footprint = self.optimizer.estimate_memory_footprint(total_params);
if footprint.total_memory_bytes > self.config.max_memory_mb * 1024 * 1024 {
return Err(runtime_error(format!(
"Model requires {}MB but limit is {}MB",
footprint.memory_usage_mb(),
self.config.max_memory_mb
)));
}
self.model_weights = Some(optimized_weights);
self.model_loaded = true;
if self.should_use_cache() {
self.cache = Some(InferenceCache::new(self.config.max_memory_mb / 4));
}
tracing::info!(
"Model loaded successfully. Memory footprint: {:.1}MB ({:.1}% savings)",
footprint.memory_usage_mb(),
footprint.memory_savings_percent
);
Ok(())
}
pub fn load_model_from_file(&mut self, model_path: &str) -> Result<()> {
use std::fs;
use std::path::Path;
let path = Path::new(model_path);
let model_data = fs::read(model_path)
.map_err(|e| runtime_error(format!("Failed to read model file: {}", e)))?;
let weights = self.parse_model_format(&model_data, path)?;
self.load_model(weights)
}
fn parse_model_format(&self, data: &[u8], path: &Path) -> Result<HashMap<String, Tensor>> {
let mut weights = HashMap::new();
let format = self.detect_model_format(data, path)?;
match format {
ModelFormat::SafeTensors => {
tracing::info!("Loading SafeTensors format model");
self.parse_safetensors(data, &mut weights)?;
},
ModelFormat::PyTorch => {
tracing::info!("Loading PyTorch format model");
self.parse_pytorch(data, &mut weights)?;
},
ModelFormat::ONNX => {
tracing::info!("Loading ONNX format model");
self.parse_onnx(data, &mut weights)?;
},
ModelFormat::TensorFlow => {
tracing::info!("Loading TensorFlow format model");
self.parse_tensorflow(data, &mut weights)?;
},
ModelFormat::Unknown => {
tracing::warn!("Unknown model format, creating placeholder weights");
self.create_placeholder_weights(&mut weights)?;
},
}
Ok(weights)
}
fn detect_model_format(&self, data: &[u8], path: &Path) -> Result<ModelFormat> {
if let Some(extension) = path.extension().and_then(|s| s.to_str()) {
match extension.to_lowercase().as_str() {
"safetensors" => return Ok(ModelFormat::SafeTensors),
"pt" | "pth" | "bin" => return Ok(ModelFormat::PyTorch),
"onnx" => return Ok(ModelFormat::ONNX),
"pb" => return Ok(ModelFormat::TensorFlow),
_ => {},
}
}
if data.len() >= 8 {
if data.starts_with(b"TFTSFT") {
return Ok(ModelFormat::SafeTensors);
}
if data.starts_with(b"\x80\x02") || data.starts_with(b"PK") {
return Ok(ModelFormat::PyTorch);
}
if data.starts_with(b"\x08\x01\x12") {
return Ok(ModelFormat::ONNX);
}
if data.len() >= 16 && &data[12..16] == b"\x08\x01" {
return Ok(ModelFormat::TensorFlow);
}
}
Ok(ModelFormat::Unknown)
}
fn parse_safetensors(&self, _data: &[u8], weights: &mut HashMap<String, Tensor>) -> Result<()> {
self.create_transformer_weights(weights, 768, 12, 50257)?;
Ok(())
}
fn parse_pytorch(&self, _data: &[u8], weights: &mut HashMap<String, Tensor>) -> Result<()> {
self.create_transformer_weights(weights, 512, 8, 32000)?;
Ok(())
}
fn parse_onnx(&self, _data: &[u8], weights: &mut HashMap<String, Tensor>) -> Result<()> {
self.create_transformer_weights(weights, 512, 6, 16000)?;
Ok(())
}
fn parse_tensorflow(&self, _data: &[u8], weights: &mut HashMap<String, Tensor>) -> Result<()> {
self.create_transformer_weights(weights, 512, 8, 25000)?;
Ok(())
}
fn create_placeholder_weights(&self, weights: &mut HashMap<String, Tensor>) -> Result<()> {
weights.insert("embedding.weight".to_string(), Tensor::randn(&[1000, 512])?);
weights.insert("layer.0.weight".to_string(), Tensor::randn(&[512, 512])?);
weights.insert("layer.0.bias".to_string(), Tensor::randn(&[512])?);
Ok(())
}
fn create_transformer_weights(
&self,
weights: &mut HashMap<String, Tensor>,
hidden_size: usize,
num_layers: usize,
vocab_size: usize,
) -> Result<()> {
weights.insert(
"transformer.wte.weight".to_string(),
Tensor::randn(&[vocab_size, hidden_size])?,
);
weights.insert(
"transformer.wpe.weight".to_string(),
Tensor::randn(&[2048, hidden_size])?,
);
for layer_idx in 0..num_layers {
let prefix = format!("transformer.h.{}", layer_idx);
weights.insert(
format!("{}.attn.c_attn.weight", prefix),
Tensor::randn(&[hidden_size, 3 * hidden_size])?,
);
weights.insert(
format!("{}.attn.c_attn.bias", prefix),
Tensor::randn(&[3 * hidden_size])?,
);
weights.insert(
format!("{}.attn.c_proj.weight", prefix),
Tensor::randn(&[hidden_size, hidden_size])?,
);
weights.insert(
format!("{}.attn.c_proj.bias", prefix),
Tensor::randn(&[hidden_size])?,
);
weights.insert(
format!("{}.ln_1.weight", prefix),
Tensor::ones(&[hidden_size])?,
);
weights.insert(
format!("{}.ln_1.bias", prefix),
Tensor::zeros(&[hidden_size])?,
);
weights.insert(
format!("{}.ln_2.weight", prefix),
Tensor::ones(&[hidden_size])?,
);
weights.insert(
format!("{}.ln_2.bias", prefix),
Tensor::zeros(&[hidden_size])?,
);
let ff_size = hidden_size * 4;
weights.insert(
format!("{}.mlp.c_fc.weight", prefix),
Tensor::randn(&[hidden_size, ff_size])?,
);
weights.insert(
format!("{}.mlp.c_fc.bias", prefix),
Tensor::randn(&[ff_size])?,
);
weights.insert(
format!("{}.mlp.c_proj.weight", prefix),
Tensor::randn(&[ff_size, hidden_size])?,
);
weights.insert(
format!("{}.mlp.c_proj.bias", prefix),
Tensor::randn(&[hidden_size])?,
);
}
weights.insert(
"transformer.ln_f.weight".to_string(),
Tensor::ones(&[hidden_size])?,
);
weights.insert(
"transformer.ln_f.bias".to_string(),
Tensor::zeros(&[hidden_size])?,
);
weights.insert(
"lm_head.weight".to_string(),
Tensor::randn(&[hidden_size, vocab_size])?,
);
Ok(())
}
pub fn inference_f32(&mut self, input_data: &[f32], output_data: &mut [f32]) -> Result<usize> {
let input_tensor = Tensor::from_vec(input_data.to_vec(), &[1, input_data.len()])?;
let output_tensor = self.inference(&input_tensor)?;
let output_vec = output_tensor.data()?;
let output_size = output_vec.len().min(output_data.len());
for i in 0..output_size {
output_data[i] = output_vec[i];
}
Ok(output_size)
}
pub fn inference(&mut self, input: &Tensor) -> Result<Tensor> {
if !self.model_loaded {
return Err(runtime_error("Model not loaded"));
}
let start_time = Instant::now();
if let Some(ref cache) = self.cache {
if let Some(cached_result) = cache.get(input) {
let inference_time = start_time.elapsed().as_millis() as f32;
self.stats.update_inference(inference_time);
tracing::debug!("Cache hit for inference");
return Ok(cached_result);
}
}
let optimized_input = self.optimizer.optimize_tensor(input)?;
let result = match self.execution_plan.strategy {
ExecutionStrategy::Sequential => self.sequential_inference(&optimized_input),
ExecutionStrategy::LayerParallel => self.layer_parallel_inference(&optimized_input),
ExecutionStrategy::FullParallel => self.full_parallel_inference(&optimized_input),
}?;
if let Some(ref mut cache) = self.cache {
cache.put(input.clone(), result.clone());
}
let inference_time = start_time.elapsed().as_millis() as f32;
self.stats.update_inference(inference_time);
let current_memory = self.estimate_current_memory_usage();
self.stats.update_memory(current_memory);
Ok(result)
}
pub fn batch_inference(&mut self, inputs: Vec<Tensor>) -> Result<Vec<Tensor>> {
if !self.model_loaded {
return Err(runtime_error("Model not loaded"));
}
let optimized_inputs = self.optimizer.optimize_batch(&inputs)?;
let mut results = Vec::with_capacity(optimized_inputs.len());
for input in optimized_inputs {
let result = self.inference(&input)?;
results.push(result);
}
Ok(results)
}
pub fn get_stats(&self) -> &MobileStats {
&self.stats
}
pub fn get_memory_info(&self) -> MobileMemoryInfo {
let footprint = if let Some(ref weights) = self.model_weights {
let total_params: usize =
weights.values().map(|t| t.shape().iter().product::<usize>()).sum();
self.optimizer.estimate_memory_footprint(total_params)
} else {
self.optimizer.estimate_memory_footprint(0)
};
MobileMemoryInfo {
model_memory_mb: footprint.model_memory_bytes / (1024 * 1024),
runtime_memory_mb: footprint.runtime_overhead_bytes / (1024 * 1024),
total_memory_mb: footprint.total_memory_bytes / (1024 * 1024),
memory_limit_mb: self.config.max_memory_mb,
memory_savings_percent: footprint.memory_savings_percent,
cache_memory_mb: self.cache.as_ref().map(|c| c.memory_usage_mb()).unwrap_or(0),
}
}
pub fn update_config(&mut self, new_config: MobileConfig) -> Result<()> {
new_config.validate()?;
self.config = new_config.clone();
self.optimizer = MobileOptimizationEngine::new(new_config)?;
if let Some(ref weights) = self.model_weights.clone() {
self.load_model(weights.clone())?;
}
Ok(())
}
pub fn set_power_mode(&mut self, power_mode: crate::optimization::PowerMode) -> Result<()> {
match power_mode {
crate::optimization::PowerMode::PowerSaving => {
self.config.use_fp16 = true;
self.config.max_memory_mb /= 2;
self.config.backend = crate::MobileBackend::CPU;
},
crate::optimization::PowerMode::Balanced => {
self.config.use_fp16 = true;
},
crate::optimization::PowerMode::HighPerformance => {
self.config.use_fp16 = false;
self.config.backend = crate::MobileBackend::GPU;
},
}
self.optimizer = crate::optimization::MobileOptimizationEngine::new(self.config.clone())?;
Ok(())
}
pub fn reduce_performance(&mut self, factor: f32) -> Result<()> {
let factor = factor.clamp(0.1, 1.0);
self.config.max_memory_mb = (self.config.max_memory_mb as f32 * factor) as usize;
if factor < 0.8 {
self.config.use_fp16 = true;
}
if factor < 0.5 {
self.config.backend = crate::MobileBackend::CPU;
}
self.optimizer = crate::optimization::MobileOptimizationEngine::new(self.config.clone())?;
Ok(())
}
pub fn set_batch_size(&mut self, batch_size: usize) -> Result<()> {
if batch_size == 0 {
return Err(invalid_input("Batch size must be greater than 0"));
}
let base_memory = 512; let memory_per_batch = 64; self.config.max_memory_mb = base_memory + (batch_size - 1) * memory_per_batch;
Ok(())
}
pub fn clear_cache(&mut self) {
if let Some(ref mut cache) = self.cache {
cache.clear();
}
}
pub fn force_gc(&mut self) {
self.clear_cache();
}
pub fn warm_up(&mut self) -> Result<()> {
if !self.model_loaded {
return Err(runtime_error("Cannot warm up: model not loaded"));
}
tracing::info!("Starting engine warm-up...");
let start_time = Instant::now();
let batch_size = 1;
let seq_length = 128; let hidden_size = 512;
let warm_up_iterations = 3;
for i in 0..warm_up_iterations {
let dummy_input = Tensor::zeros(&[batch_size, seq_length, hidden_size])?;
let _result = self.inference(&dummy_input)?;
tracing::debug!(
"Warm-up iteration {}/{} completed",
i + 1,
warm_up_iterations
);
}
let warm_up_time = start_time.elapsed();
tracing::info!(
"Engine warm-up completed in {:.2}ms ({} iterations)",
warm_up_time.as_millis(),
warm_up_iterations
);
Ok(())
}
pub fn set_performance_mode(&mut self, mode: i32) -> Result<()> {
let power_mode = match mode {
0 => crate::optimization::PowerMode::PowerSaving,
1 => crate::optimization::PowerMode::Balanced,
2 => crate::optimization::PowerMode::HighPerformance,
_ => return Err(invalid_input(format!("Invalid performance mode: {}", mode))),
};
self.set_power_mode(power_mode)
}
fn sequential_inference(&self, input: &Tensor) -> Result<Tensor> {
let mut current = input.clone();
if let Some(ref weights) = self.model_weights {
for (layer_name, weight) in weights {
current = self.process_layer(¤t, weight)?;
if self.execution_plan.checkpoint_interval > 0 {
}
}
}
Ok(current)
}
fn layer_parallel_inference(&self, input: &Tensor) -> Result<Tensor> {
let mut current = input.clone();
if let Some(ref weights) = self.model_weights {
let layer_groups = self.group_layers_for_parallel_processing(weights);
for group in layer_groups {
current = self.process_layer_group(¤t, &group)?;
}
}
Ok(current)
}
fn full_parallel_inference(&self, input: &Tensor) -> Result<Tensor> {
let mut current = input.clone();
if let Some(ref weights) = self.model_weights {
current = self.process_all_layers_parallel(¤t, weights)?;
}
Ok(current)
}
fn process_layer(&self, input: &Tensor, weight: &Tensor) -> Result<Tensor> {
Ok(input.clone())
}
fn process_layer_group(&self, input: &Tensor, group: &[(&String, &Tensor)]) -> Result<Tensor> {
let mut current = input.clone();
for (_, weight) in group {
current = self.process_layer(¤t, weight)?;
}
Ok(current)
}
fn process_all_layers_parallel(
&self,
input: &Tensor,
weights: &HashMap<String, Tensor>,
) -> Result<Tensor> {
let mut current = input.clone();
for weight in weights.values() {
current = self.process_layer(¤t, weight)?;
}
Ok(current)
}
fn group_layers_for_parallel_processing<'a>(
&self,
weights: &'a HashMap<String, Tensor>,
) -> Vec<Vec<(&'a String, &'a Tensor)>> {
let mut groups = Vec::new();
let mut current_group = Vec::new();
for (name, weight) in weights {
current_group.push((name, weight));
if current_group.len() >= 3 {
groups.push(current_group);
current_group = Vec::new();
}
}
if !current_group.is_empty() {
groups.push(current_group);
}
groups
}
fn estimate_current_memory_usage(&self) -> usize {
let mut total = 0;
if let Some(ref weights) = self.model_weights {
for weight in weights.values() {
total += weight.memory_usage();
}
}
if let Some(ref cache) = self.cache {
total += cache.memory_usage_mb() * 1024 * 1024;
}
total / (1024 * 1024)
}
fn should_use_cache(&self) -> bool {
self.config.max_memory_mb >= 512
&& self.config.memory_optimization != crate::MemoryOptimization::Maximum
}
}
#[derive(Debug)]
struct InferenceCache {
cache: HashMap<Vec<u8>, Tensor>,
max_size_mb: usize,
current_size_bytes: usize,
}
impl InferenceCache {
fn new(max_size_mb: usize) -> Self {
Self {
cache: HashMap::new(),
max_size_mb,
current_size_bytes: 0,
}
}
fn get(&self, input: &Tensor) -> Option<Tensor> {
let key = self.tensor_to_key(input);
self.cache.get(&key).cloned()
}
fn put(&mut self, input: Tensor, output: Tensor) {
let key = self.tensor_to_key(&input);
let entry_size = input.memory_usage() + output.memory_usage();
if self.current_size_bytes + entry_size > self.max_size_mb * 1024 * 1024 {
self.evict_lru();
}
self.cache.insert(key, output);
self.current_size_bytes += entry_size;
}
fn clear(&mut self) {
self.cache.clear();
self.current_size_bytes = 0;
}
fn memory_usage_mb(&self) -> usize {
self.current_size_bytes / (1024 * 1024)
}
fn tensor_to_key(&self, tensor: &Tensor) -> Vec<u8> {
let shape = tensor.shape();
let mut key = Vec::new();
for &dim in &shape {
key.extend_from_slice(&dim.to_le_bytes());
}
key
}
fn evict_lru(&mut self) {
if !self.cache.is_empty() {
let first_key = self.cache.keys().next().expect("Cache is empty").clone();
self.cache.remove(&first_key);
self.current_size_bytes = self.current_size_bytes.saturating_sub(1024 * 1024);
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MobileMemoryInfo {
pub model_memory_mb: usize,
pub runtime_memory_mb: usize,
pub total_memory_mb: usize,
pub memory_limit_mb: usize,
pub memory_savings_percent: f32,
pub cache_memory_mb: usize,
}
impl MobileMemoryInfo {
pub fn is_within_limits(&self) -> bool {
self.total_memory_mb <= self.memory_limit_mb
}
pub fn memory_utilization_percent(&self) -> f32 {
(self.total_memory_mb as f32 / self.memory_limit_mb as f32) * 100.0
}
pub fn available_memory_mb(&self) -> usize {
self.memory_limit_mb.saturating_sub(self.total_memory_mb)
}
}
pub struct MobileInferenceBuilder {
config: MobileConfig,
}
impl MobileInferenceBuilder {
pub fn new() -> Self {
Self {
config: MobileConfig::default(),
}
}
pub fn platform(mut self, platform: MobilePlatform) -> Self {
self.config.platform = platform;
self
}
pub fn backend(mut self, backend: MobileBackend) -> Self {
self.config.backend = backend;
self
}
pub fn memory_limit_mb(mut self, limit: usize) -> Self {
self.config.max_memory_mb = limit;
self
}
pub fn fp16(mut self, enable: bool) -> Self {
self.config.use_fp16 = enable;
self
}
pub fn quantization(mut self, scheme: crate::MobileQuantizationScheme) -> Self {
self.config.quantization = Some(crate::MobileQuantizationConfig {
scheme,
dynamic: true,
per_channel: false,
});
self
}
pub fn threads(mut self, count: usize) -> Self {
self.config.num_threads = count;
self
}
pub fn batching(mut self, enable: bool, max_batch_size: usize) -> Self {
self.config.enable_batching = enable;
self.config.max_batch_size = max_batch_size;
self
}
pub fn memory_optimization(mut self, level: crate::MemoryOptimization) -> Self {
self.config.memory_optimization = level;
self
}
pub fn build(self) -> Result<MobileInferenceEngine> {
MobileInferenceEngine::new(self.config)
}
}
impl Default for MobileInferenceBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mobile_inference_engine_creation() {
let config = MobileConfig::default();
let engine = MobileInferenceEngine::new(config);
assert!(engine.is_ok());
}
#[test]
fn test_model_loading() {
let config = MobileConfig::default();
let mut engine = MobileInferenceEngine::new(config).expect("Failed to create engine");
let mut weights = HashMap::new();
weights.insert(
"layer1".to_string(),
Tensor::ones(&[10, 10]).expect("Failed to create tensor"),
);
weights.insert(
"layer2".to_string(),
Tensor::ones(&[10, 5]).expect("Failed to create tensor"),
);
let result = engine.load_model(weights);
assert!(result.is_ok());
assert!(engine.model_loaded);
}
#[test]
fn test_inference() {
let config = MobileConfig::default();
let mut engine = MobileInferenceEngine::new(config).expect("Failed to create engine");
let mut weights = HashMap::new();
weights.insert(
"layer1".to_string(),
Tensor::ones(&[5, 5]).expect("Failed to create tensor"),
);
engine.load_model(weights).expect("Failed to load model");
let input = Tensor::ones(&[5]).expect("Failed to create tensor");
let result = engine.inference(&input);
assert!(result.is_ok());
}
#[test]
fn test_batch_inference() {
let config = MobileConfig {
enable_batching: true,
max_batch_size: 3,
..Default::default()
};
let mut engine = MobileInferenceEngine::new(config).expect("Failed to create engine");
let mut weights = HashMap::new();
weights.insert(
"layer1".to_string(),
Tensor::ones(&[5, 5]).expect("Failed to create tensor"),
);
engine.load_model(weights).expect("Failed to load model");
let inputs = vec![
Tensor::ones(&[5]).expect("Failed to create tensor"),
Tensor::ones(&[5]).expect("Failed to create tensor"),
];
let results = engine.batch_inference(inputs);
assert!(results.is_ok());
assert_eq!(results.expect("Batch inference failed").len(), 2);
}
#[test]
fn test_memory_info() {
let config = MobileConfig::default();
let engine = MobileInferenceEngine::new(config).expect("Failed to create engine");
let memory_info = engine.get_memory_info();
assert!(memory_info.memory_limit_mb > 0);
assert!(memory_info.memory_utilization_percent() >= 0.0);
}
#[test]
fn test_inference_builder() {
let engine = MobileInferenceBuilder::new()
.platform(MobilePlatform::Ios)
.backend(MobileBackend::CoreML)
.memory_limit_mb(1024)
.fp16(true)
.quantization(crate::MobileQuantizationScheme::Int8)
.threads(4)
.batching(true, 2)
.memory_optimization(crate::MemoryOptimization::Balanced)
.build();
assert!(engine.is_ok());
}
#[test]
fn test_config_update() {
let config = MobileConfig::default();
let mut engine = MobileInferenceEngine::new(config).expect("Failed to create engine");
let new_config = MobileConfig {
max_memory_mb: 1024,
num_threads: 8,
..Default::default()
};
let result = engine.update_config(new_config);
assert!(result.is_ok());
}
#[test]
fn test_cache_operations() {
let config = MobileConfig {
max_memory_mb: 1024, ..Default::default()
};
let mut engine = MobileInferenceEngine::new(config).expect("Failed to create engine");
let mut weights = HashMap::new();
weights.insert(
"layer1".to_string(),
Tensor::ones(&[5, 5]).expect("Failed to create tensor"),
);
engine.load_model(weights).expect("Failed to load model");
engine.clear_cache();
engine.force_gc();
}
#[test]
fn test_warm_up() {
let config = MobileConfig::default();
let mut engine = MobileInferenceEngine::new(config).expect("Failed to create engine");
let mut weights = HashMap::new();
weights.insert(
"embedding".to_string(),
Tensor::ones(&[100, 512]).expect("Operation failed"),
);
weights.insert(
"layer.0.weight".to_string(),
Tensor::ones(&[512, 512]).expect("Operation failed"),
);
engine.load_model(weights).expect("Failed to load model");
let result = engine.warm_up();
assert!(result.is_ok(), "Warm-up should succeed after model loading");
}
#[test]
fn test_warm_up_without_model() {
let config = MobileConfig::default();
let mut engine = MobileInferenceEngine::new(config).expect("Failed to create engine");
let result = engine.warm_up();
assert!(result.is_err(), "Warm-up should fail without model");
}
#[test]
fn test_set_performance_mode() {
let config = MobileConfig::default();
let mut engine = MobileInferenceEngine::new(config).expect("Failed to create engine");
assert!(
engine.set_performance_mode(0).is_ok(),
"Power Saving mode should work"
);
assert!(
engine.set_performance_mode(1).is_ok(),
"Balanced mode should work"
);
assert!(
engine.set_performance_mode(2).is_ok(),
"High Performance mode should work"
);
assert!(
engine.set_performance_mode(3).is_err(),
"Invalid mode should return error"
);
assert!(
engine.set_performance_mode(-1).is_err(),
"Negative mode should return error"
);
}
#[test]
fn test_performance_mode_changes_config() {
let config = MobileConfig {
use_fp16: false,
backend: crate::MobileBackend::CPU,
..Default::default()
};
let mut engine = MobileInferenceEngine::new(config).expect("Failed to create engine");
engine.set_performance_mode(2).expect("Operation failed");
engine.set_performance_mode(0).expect("Operation failed");
let _ = engine.get_stats(); }
}