use crate::scirs2_compat::random::legacy;
use crate::{MobileBackend, MobileConfig, PerformanceTier};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use trustformers_core::errors::Result;
fn random_usize(max: usize) -> usize {
if max == 0 {
return 0;
}
((legacy::f64() * max as f64) as usize).min(max.saturating_sub(1))
}
fn random_f32() -> f32 {
legacy::f32()
}
#[derive(Debug, Clone)]
pub struct MobileNAS {
search_config: NASConfig,
architecture_candidates: Vec<MobileArchitecture>,
performance_history: Vec<PerformanceRecord>,
optimization_agent: ReinforcementLearningAgent,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NASConfig {
pub max_iterations: usize,
pub optimization_targets: Vec<OptimizationTarget>,
pub device_constraints: DeviceConstraints,
pub search_strategy: SearchStrategy,
pub early_stopping: EarlyStoppingConfig,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum OptimizationTarget {
Latency,
Memory,
Power,
Accuracy,
ModelSize,
Energy,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeviceConstraints {
pub max_memory_mb: usize,
pub max_latency_ms: f32,
pub performance_tier: PerformanceTier,
pub available_backends: Vec<MobileBackend>,
pub power_budget_mw: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum SearchStrategy {
Random,
Evolutionary {
population_size: usize,
mutation_rate: f32,
crossover_rate: f32,
},
ReinforcementLearning {
learning_rate: f32,
exploration_rate: f32,
replay_buffer_size: usize,
},
Differentiable {
temperature: f32,
gumbel_softmax: bool,
},
Progressive {
stages: usize,
pruning_threshold: f32,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EarlyStoppingConfig {
pub patience: usize,
pub min_improvement: f32,
pub monitor_metric: OptimizationTarget,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MobileArchitecture {
pub id: String,
pub layers: Vec<LayerConfig>,
pub skip_connections: Vec<SkipConnection>,
pub quantization: QuantizationConfig,
pub estimated_metrics: Option<ArchitectureMetrics>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LayerConfig {
pub layer_type: LayerType,
pub input_dim: Vec<usize>,
pub output_dim: Vec<usize>,
pub parameters: HashMap<String, f32>,
pub activation: ActivationType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LayerType {
DepthwiseSeparableConv {
kernel_size: usize,
stride: usize,
dilation: usize,
},
MobileBottleneck {
expansion_ratio: f32,
kernel_size: usize,
squeeze_excitation: bool,
},
EfficientChannelAttention {
reduction_ratio: usize,
use_gating: bool,
},
MobileMultiHeadAttention {
num_heads: usize,
head_dim: usize,
sparse_attention: bool,
},
GroupNormalization { num_groups: usize },
MobileLinear { use_bias: bool, quantized: bool },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ActivationType {
Swish,
HardSwish,
ReLU6,
GeluApprox,
Mish,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SkipConnection {
pub from_layer: usize,
pub to_layer: usize,
pub connection_type: ConnectionType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ConnectionType {
Residual,
Dense,
Attention { num_heads: usize },
ChannelShuffle,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizationConfig {
pub layer_schemes: HashMap<usize, QuantizationScheme>,
pub mixed_precision: bool,
pub dynamic_quantization: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum QuantizationScheme {
Int4 { symmetric: bool },
Int8 { symmetric: bool },
FP16,
BlockWise { block_size: usize },
FP32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ArchitectureMetrics {
pub latency_ms: f32,
pub memory_mb: f32,
pub power_mw: f32,
pub accuracy: Option<f32>,
pub model_size_mb: f32,
pub energy_per_inference_mj: f32,
pub throughput_fps: f32,
}
#[derive(Debug, Clone)]
pub struct PerformanceRecord {
pub architecture: MobileArchitecture,
pub metrics: ArchitectureMetrics,
pub device_config: MobileConfig,
pub timestamp: std::time::SystemTime,
pub user_context: Option<UserContext>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserContext {
pub usage_patterns: Vec<UsagePattern>,
pub preferences: UserPreferences,
pub environment: DeviceEnvironment,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsagePattern {
pub task_type: String,
pub frequency: f32,
pub input_characteristics: InputCharacteristics,
pub performance_requirements: PerformanceRequirements,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InputCharacteristics {
pub input_sizes: Vec<Vec<usize>>,
pub common_batch_sizes: Vec<usize>,
pub data_types: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerformanceRequirements {
pub max_latency_ms: f32,
pub battery_importance: f32,
pub accuracy_importance: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserPreferences {
pub primary_target: OptimizationTarget,
pub secondary_targets: Vec<OptimizationTarget>,
pub quality_tradeoffs: QualityTradeoffs,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QualityTradeoffs {
pub max_accuracy_loss: f32,
pub max_latency_increase: f32,
pub max_memory_increase: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeviceEnvironment {
pub charging_status: ChargingPattern,
pub network_patterns: NetworkPattern,
pub thermal_environment: ThermalEnvironment,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ChargingPattern {
FrequentCharging,
ModerateCharging,
InfrequentCharging,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum NetworkPattern {
PrimarilyWiFi,
Mixed,
PrimarilyCellular,
FrequentOffline,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ThermalEnvironment {
Cool,
Moderate,
Warm,
Variable,
}
#[derive(Debug, Clone)]
pub struct ReinforcementLearningAgent {
config: RLConfig,
q_network: QNetwork,
replay_buffer: Vec<Experience>,
exploration_rate: f32,
}
#[derive(Debug, Clone)]
pub struct RLConfig {
pub learning_rate: f32,
pub discount_factor: f32,
pub initial_exploration_rate: f32,
pub exploration_decay: f32,
pub min_exploration_rate: f32,
}
#[derive(Debug, Clone)]
pub struct QNetwork {
weights: Vec<Vec<f32>>,
architecture: Vec<usize>,
}
#[derive(Debug, Clone)]
pub struct Experience {
pub state: Vec<f32>,
pub action: ArchitectureAction,
pub reward: f32,
pub next_state: Vec<f32>,
pub done: bool,
}
#[derive(Debug, Clone)]
pub enum ArchitectureAction {
AddLayer {
layer_type: LayerType,
position: usize,
},
RemoveLayer { position: usize },
ModifyLayer {
position: usize,
parameter: String,
value: f32,
},
ChangeQuantization {
layer: usize,
scheme: QuantizationScheme,
},
AddSkipConnection {
from: usize,
to: usize,
connection_type: ConnectionType,
},
RemoveSkipConnection { from: usize, to: usize },
}
impl MobileNAS {
pub fn new(config: NASConfig) -> Self {
let rl_config = RLConfig {
learning_rate: 0.001,
discount_factor: 0.99,
initial_exploration_rate: 1.0,
exploration_decay: 0.995,
min_exploration_rate: 0.1,
};
Self {
search_config: config,
architecture_candidates: Vec::new(),
performance_history: Vec::new(),
optimization_agent: ReinforcementLearningAgent::new(rl_config),
}
}
pub fn search_optimal_architecture(
&mut self,
base_architecture: MobileArchitecture,
user_context: Option<UserContext>,
) -> Result<MobileArchitecture> {
let mut best_architecture = base_architecture.clone();
let mut best_score = f32::NEG_INFINITY;
let mut iterations_without_improvement = 0;
for iteration in 0..self.search_config.max_iterations {
let candidate = match &self.search_config.search_strategy {
SearchStrategy::Random => self.generate_random_architecture(&base_architecture)?,
SearchStrategy::Evolutionary { .. } => {
self.evolve_architecture(&best_architecture)?
},
SearchStrategy::ReinforcementLearning { .. } => {
self.rl_generate_architecture(&best_architecture)?
},
SearchStrategy::Differentiable { .. } => {
self.differentiable_search(&best_architecture)?
},
SearchStrategy::Progressive { .. } => {
self.progressive_search(&best_architecture, iteration)?
},
};
let metrics = self.evaluate_architecture(&candidate)?;
let score = self.calculate_fitness_score(&metrics, &user_context)?;
if score > best_score {
best_score = score;
best_architecture = candidate.clone();
iterations_without_improvement = 0;
let record = PerformanceRecord {
architecture: candidate,
metrics,
device_config: MobileConfig::default(), timestamp: std::time::SystemTime::now(),
user_context: user_context.clone(),
};
self.performance_history.push(record);
} else {
iterations_without_improvement += 1;
}
if iterations_without_improvement >= self.search_config.early_stopping.patience {
println!(
"Early stopping at iteration {} due to no improvement",
iteration
);
break;
}
if matches!(
self.search_config.search_strategy,
SearchStrategy::ReinforcementLearning { .. }
) {
self.optimization_agent.update_from_experience(score)?;
}
}
Ok(best_architecture)
}
fn generate_random_architecture(
&self,
base: &MobileArchitecture,
) -> Result<MobileArchitecture> {
let mut candidate = base.clone();
for _ in 0..3 {
match random_usize(4) {
0 => self.mutate_layer_params(&mut candidate)?,
1 => self.mutate_quantization(&mut candidate)?,
2 => self.mutate_skip_connections(&mut candidate)?,
_ => self.mutate_architecture_structure(&mut candidate)?,
}
}
Ok(candidate)
}
fn evolve_architecture(&self, parent: &MobileArchitecture) -> Result<MobileArchitecture> {
let mut offspring = parent.clone();
if random_f32() < 0.3 {
self.mutate_layer_params(&mut offspring)?;
}
if random_f32() < 0.2 {
self.mutate_quantization(&mut offspring)?;
}
if random_f32() < 0.1 {
self.mutate_skip_connections(&mut offspring)?;
}
Ok(offspring)
}
fn rl_generate_architecture(
&mut self,
current: &MobileArchitecture,
) -> Result<MobileArchitecture> {
let state = self.encode_architecture_state(current)?;
let action = self.optimization_agent.select_action(&state)?;
let mut new_architecture = current.clone();
self.apply_architecture_action(&mut new_architecture, action)?;
Ok(new_architecture)
}
fn differentiable_search(&self, base: &MobileArchitecture) -> Result<MobileArchitecture> {
let mut candidate = base.clone();
for layer in &mut candidate.layers {
if let Some(param) = layer.parameters.get_mut("channels") {
*param *= 1.0 + (random_f32() - 0.5) * 0.1; }
}
Ok(candidate)
}
fn progressive_search(
&self,
base: &MobileArchitecture,
iteration: usize,
) -> Result<MobileArchitecture> {
let mut candidate = base.clone();
let stage = iteration / (self.search_config.max_iterations / 4);
match stage {
0 => self.mutate_layer_params(&mut candidate)?,
1 => self.mutate_quantization(&mut candidate)?,
2 => self.mutate_skip_connections(&mut candidate)?,
_ => self.mutate_architecture_structure(&mut candidate)?,
}
Ok(candidate)
}
fn evaluate_architecture(
&self,
architecture: &MobileArchitecture,
) -> Result<ArchitectureMetrics> {
let mut total_params = 0;
let mut total_flops = 0;
let mut memory_usage = 0;
for layer in &architecture.layers {
let (params, flops, memory) = self.estimate_layer_metrics(layer)?;
total_params += params;
total_flops += flops;
memory_usage += memory;
}
let latency_ms = self.estimate_latency(total_flops, &architecture.quantization)?;
let memory_mb = memory_usage as f32 / (1024.0 * 1024.0);
let power_mw = self.estimate_power_consumption(total_flops, latency_ms)?;
let model_size_mb = (total_params * 4) as f32 / (1024.0 * 1024.0); let energy_per_inference_mj = power_mw * latency_ms;
let throughput_fps = 1000.0 / latency_ms;
Ok(ArchitectureMetrics {
latency_ms,
memory_mb,
power_mw,
accuracy: None, model_size_mb,
energy_per_inference_mj,
throughput_fps,
})
}
fn calculate_fitness_score(
&self,
metrics: &ArchitectureMetrics,
user_context: &Option<UserContext>,
) -> Result<f32> {
let mut score = 0.0;
let mut total_weight = 0.0;
for &target in &self.search_config.optimization_targets {
let (value, weight) = match target {
OptimizationTarget::Latency => {
let normalized = 1.0 / (1.0 + metrics.latency_ms / 100.0);
(normalized, 1.0)
},
OptimizationTarget::Memory => {
let normalized = 1.0 / (1.0 + metrics.memory_mb / 512.0);
(normalized, 1.0)
},
OptimizationTarget::Power => {
let normalized = 1.0 / (1.0 + metrics.power_mw / 1000.0);
(normalized, 1.0)
},
OptimizationTarget::ModelSize => {
let normalized = 1.0 / (1.0 + metrics.model_size_mb / 100.0);
(normalized, 1.0)
},
OptimizationTarget::Energy => {
let normalized = 1.0 / (1.0 + metrics.energy_per_inference_mj / 10.0);
(normalized, 1.0)
},
OptimizationTarget::Accuracy => {
let normalized = metrics.accuracy.unwrap_or(0.8);
(normalized, 2.0) },
};
score += value * weight;
total_weight += weight;
}
if let Some(ref context) = user_context {
score = self.adjust_score_for_user_context(score, metrics, context)?;
}
score = self.apply_constraint_penalties(score, metrics)?;
Ok(score / total_weight)
}
fn adjust_score_for_user_context(
&self,
base_score: f32,
metrics: &ArchitectureMetrics,
context: &UserContext,
) -> Result<f32> {
let mut adjusted_score = base_score;
match context.preferences.primary_target {
OptimizationTarget::Latency if metrics.latency_ms > 50.0 => {
adjusted_score *= 0.8; },
OptimizationTarget::Memory if metrics.memory_mb > 256.0 => {
adjusted_score *= 0.8; },
OptimizationTarget::Power if metrics.power_mw > 500.0 => {
adjusted_score *= 0.8; },
_ => {},
}
for pattern in &context.usage_patterns {
if pattern.frequency > 0.5
&& metrics.latency_ms > pattern.performance_requirements.max_latency_ms
{
adjusted_score *= 0.9; }
}
Ok(adjusted_score)
}
fn apply_constraint_penalties(
&self,
base_score: f32,
metrics: &ArchitectureMetrics,
) -> Result<f32> {
let mut score = base_score;
if metrics.memory_mb > self.search_config.device_constraints.max_memory_mb as f32 {
score *= 0.5; }
if metrics.latency_ms > self.search_config.device_constraints.max_latency_ms {
score *= 0.5; }
if metrics.power_mw > self.search_config.device_constraints.power_budget_mw {
score *= 0.7; }
Ok(score)
}
fn mutate_layer_params(&self, architecture: &mut MobileArchitecture) -> Result<()> {
if !architecture.layers.is_empty() {
let layer_idx = random_usize(architecture.layers.len());
let layer = &mut architecture.layers[layer_idx];
if !layer.parameters.is_empty() {
let keys: Vec<_> = layer.parameters.keys().cloned().collect();
let param_key = &keys[random_usize(keys.len())];
if let Some(value) = layer.parameters.get_mut(param_key) {
*value *= 1.0 + (random_f32() - 0.5) * 0.2; }
}
}
Ok(())
}
fn mutate_quantization(&self, architecture: &mut MobileArchitecture) -> Result<()> {
if !architecture.layers.is_empty() {
let layer_idx = random_usize(architecture.layers.len());
let schemes = [
QuantizationScheme::Int4 { symmetric: true },
QuantizationScheme::Int8 { symmetric: true },
QuantizationScheme::FP16,
QuantizationScheme::FP32,
];
let scheme = schemes[random_usize(schemes.len())].clone();
architecture.quantization.layer_schemes.insert(layer_idx, scheme);
}
Ok(())
}
fn mutate_skip_connections(&self, _architecture: &mut MobileArchitecture) -> Result<()> {
Ok(())
}
fn mutate_architecture_structure(&self, _architecture: &mut MobileArchitecture) -> Result<()> {
Ok(())
}
fn estimate_layer_metrics(&self, layer: &LayerConfig) -> Result<(usize, usize, usize)> {
let params =
layer.input_dim.iter().product::<usize>() * layer.output_dim.iter().product::<usize>();
let flops = params * 2; let memory = params * 4; Ok((params, flops, memory))
}
fn estimate_latency(
&self,
total_flops: usize,
_quantization: &QuantizationConfig,
) -> Result<f32> {
let base_latency = total_flops as f32 / 1_000_000.0; Ok(base_latency)
}
fn estimate_power_consumption(&self, total_flops: usize, latency_ms: f32) -> Result<f32> {
let power = (total_flops as f32 / 1_000_000.0) * 100.0 + latency_ms * 10.0;
Ok(power)
}
fn encode_architecture_state(&self, _architecture: &MobileArchitecture) -> Result<Vec<f32>> {
Ok(vec![0.5; 128]) }
fn apply_architecture_action(
&self,
_architecture: &mut MobileArchitecture,
_action: ArchitectureAction,
) -> Result<()> {
Ok(())
}
}
impl ReinforcementLearningAgent {
fn new(config: RLConfig) -> Self {
Self {
exploration_rate: config.initial_exploration_rate,
config,
q_network: QNetwork {
weights: vec![vec![0.0; 128]; 64], architecture: vec![128, 64, 32, 16],
},
replay_buffer: Vec::new(),
}
}
fn select_action(&mut self, _state: &[f32]) -> Result<ArchitectureAction> {
let actions = vec![
ArchitectureAction::ModifyLayer {
position: 0,
parameter: "channels".to_string(),
value: 64.0,
},
];
let action_idx = if random_f32() < self.exploration_rate {
random_usize(actions.len())
} else {
0 };
Ok(actions[action_idx].clone())
}
fn update_from_experience(&mut self, reward: f32) -> Result<()> {
self.exploration_rate = (self.exploration_rate * self.config.exploration_decay)
.max(self.config.min_exploration_rate);
Ok(())
}
}
impl Default for NASConfig {
fn default() -> Self {
Self {
max_iterations: 100,
optimization_targets: vec![
OptimizationTarget::Latency,
OptimizationTarget::Memory,
OptimizationTarget::Power,
],
device_constraints: DeviceConstraints {
max_memory_mb: 512,
max_latency_ms: 100.0,
performance_tier: PerformanceTier::Mid,
available_backends: vec![MobileBackend::CPU, MobileBackend::GPU],
power_budget_mw: 1000.0,
},
search_strategy: SearchStrategy::Evolutionary {
population_size: 20,
mutation_rate: 0.1,
crossover_rate: 0.7,
},
early_stopping: EarlyStoppingConfig {
patience: 10,
min_improvement: 0.01,
monitor_metric: OptimizationTarget::Latency,
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mobile_nas_creation() {
let config = NASConfig::default();
let nas = MobileNAS::new(config);
assert_eq!(nas.architecture_candidates.len(), 0);
}
#[test]
fn test_architecture_metrics() {
let metrics = ArchitectureMetrics {
latency_ms: 50.0,
memory_mb: 128.0,
power_mw: 500.0,
accuracy: Some(0.9),
model_size_mb: 25.0,
energy_per_inference_mj: 25.0,
throughput_fps: 20.0,
};
assert_eq!(metrics.latency_ms, 50.0);
assert_eq!(metrics.throughput_fps, 20.0);
}
#[test]
fn test_nas_config_default() {
let config = NASConfig::default();
assert_eq!(config.max_iterations, 100);
assert!(config.optimization_targets.contains(&OptimizationTarget::Latency));
}
#[test]
fn test_optimization_target_variants() {
let targets = vec![
OptimizationTarget::Latency,
OptimizationTarget::Memory,
OptimizationTarget::Power,
OptimizationTarget::Accuracy,
OptimizationTarget::ModelSize,
OptimizationTarget::Energy,
];
assert_eq!(targets.len(), 6);
}
#[test]
fn test_search_strategy_random() {
let strategy = SearchStrategy::Random;
assert!(matches!(strategy, SearchStrategy::Random));
}
#[test]
fn test_search_strategy_evolutionary() {
let strategy = SearchStrategy::Evolutionary {
population_size: 50,
mutation_rate: 0.1,
crossover_rate: 0.5,
};
if let SearchStrategy::Evolutionary {
population_size, ..
} = strategy
{
assert_eq!(population_size, 50);
}
}
#[test]
fn test_activation_type_variants() {
let activations = vec![
ActivationType::Swish,
ActivationType::HardSwish,
ActivationType::ReLU6,
ActivationType::GeluApprox,
ActivationType::Mish,
];
assert_eq!(activations.len(), 5);
}
#[test]
fn test_connection_type_variants() {
let connections = vec![
ConnectionType::Residual,
ConnectionType::Dense,
ConnectionType::Attention { num_heads: 4 },
ConnectionType::ChannelShuffle,
];
assert_eq!(connections.len(), 4);
}
#[test]
fn test_quantization_scheme_variants() {
let schemes = vec![
QuantizationScheme::Int4 { symmetric: true },
QuantizationScheme::Int8 { symmetric: false },
QuantizationScheme::FP16,
QuantizationScheme::BlockWise { block_size: 32 },
QuantizationScheme::FP32,
];
assert_eq!(schemes.len(), 5);
}
#[test]
fn test_layer_type_depthwise_conv() {
let layer = LayerType::DepthwiseSeparableConv {
kernel_size: 3,
stride: 1,
dilation: 1,
};
if let LayerType::DepthwiseSeparableConv { kernel_size, .. } = layer {
assert_eq!(kernel_size, 3);
}
}
#[test]
fn test_layer_type_mobile_bottleneck() {
let layer = LayerType::MobileBottleneck {
expansion_ratio: 6.0,
kernel_size: 3,
squeeze_excitation: true,
};
if let LayerType::MobileBottleneck {
expansion_ratio, ..
} = layer
{
assert_eq!(expansion_ratio, 6.0);
}
}
#[test]
fn test_layer_config_creation() {
let config = LayerConfig {
layer_type: LayerType::MobileLinear {
use_bias: true,
quantized: false,
},
input_dim: vec![768],
output_dim: vec![256],
parameters: HashMap::new(),
activation: ActivationType::ReLU6,
};
assert_eq!(config.input_dim, vec![768]);
assert_eq!(config.output_dim, vec![256]);
}
#[test]
fn test_skip_connection_creation() {
let skip = SkipConnection {
from_layer: 0,
to_layer: 2,
connection_type: ConnectionType::Residual,
};
assert_eq!(skip.from_layer, 0);
assert_eq!(skip.to_layer, 2);
}
#[test]
fn test_architecture_metrics_throughput() {
let metrics = ArchitectureMetrics {
latency_ms: 10.0,
memory_mb: 64.0,
power_mw: 200.0,
accuracy: Some(0.95),
model_size_mb: 10.0,
energy_per_inference_mj: 2.0,
throughput_fps: 100.0,
};
assert!((metrics.throughput_fps - 1000.0 / metrics.latency_ms).abs() < 1e-3);
}
#[test]
fn test_early_stopping_config() {
let config = EarlyStoppingConfig {
patience: 10,
min_improvement: 0.001,
monitor_metric: OptimizationTarget::Accuracy,
};
assert_eq!(config.patience, 10);
assert!(config.min_improvement > 0.0);
}
#[test]
fn test_device_constraints_creation() {
let constraints = DeviceConstraints {
max_memory_mb: 256,
max_latency_ms: 50.0,
performance_tier: PerformanceTier::Medium,
available_backends: vec![MobileBackend::CPU],
power_budget_mw: 1000.0,
};
assert_eq!(constraints.max_memory_mb, 256);
assert!(!constraints.available_backends.is_empty());
}
#[test]
fn test_quantization_config_creation() {
let config = QuantizationConfig {
layer_schemes: HashMap::new(),
mixed_precision: true,
dynamic_quantization: false,
};
assert!(config.mixed_precision);
assert!(!config.dynamic_quantization);
assert!(config.layer_schemes.is_empty());
}
#[test]
fn test_mobile_architecture_creation() {
let arch = MobileArchitecture {
id: "arch_001".to_string(),
layers: vec![],
skip_connections: vec![],
quantization: QuantizationConfig {
layer_schemes: HashMap::new(),
mixed_precision: false,
dynamic_quantization: false,
},
estimated_metrics: None,
};
assert_eq!(arch.id, "arch_001");
assert!(arch.layers.is_empty());
assert!(arch.estimated_metrics.is_none());
}
#[test]
fn test_nas_with_added_candidates() {
let config = NASConfig::default();
let nas = MobileNAS::new(config);
assert_eq!(nas.architecture_candidates.len(), 0);
assert_eq!(nas.performance_history.len(), 0);
}
}