use crate::pipeline::early_exit::{EarlyExitConfig, EarlyExitPipeline, EarlyExitResult};
use crate::pipeline::{Pipeline, PipelineOutput};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::{Duration, Instant};
use trustformers_core::errors::Result;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PrecisionMode {
Full,
Half,
Mixed,
Int8,
Int4,
Dynamic,
Adaptive,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ConditionalStrategy {
AttentionSkipping,
FeedForwardSkipping,
BlockSkipping,
DynamicDepth,
SparseActivation,
TokenConditional,
LayerConditional,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ResourceStrategy {
MinLatency,
MinMemory,
MinEnergy,
Balanced,
MaxThroughput,
Custom(ResourceAllocation),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResourceAllocation {
pub cpu_cores: u32,
pub memory_limit_mb: u64,
pub gpu_memory_limit_mb: u64,
pub energy_budget_watts: f32,
pub latency_budget_ms: u64,
pub quality_threshold: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdaptiveInferenceConfig {
pub precision_mode: PrecisionMode,
pub conditional_strategy: ConditionalStrategy,
pub resource_strategy: ResourceStrategy,
pub early_exit_config: EarlyExitConfig,
pub quality_threshold: f32,
pub latency_budget_ms: u64,
pub memory_budget_mb: u64,
pub energy_budget_watts: f32,
pub adaptive_precision_threshold: f32,
pub skip_probability_threshold: f32,
pub dynamic_batch_size: bool,
pub progressive_inference: bool,
pub uncertainty_estimation: bool,
pub calibration_enabled: bool,
}
impl Default for AdaptiveInferenceConfig {
fn default() -> Self {
Self {
precision_mode: PrecisionMode::Mixed,
conditional_strategy: ConditionalStrategy::DynamicDepth,
resource_strategy: ResourceStrategy::Balanced,
early_exit_config: EarlyExitConfig::default(),
quality_threshold: 0.8,
latency_budget_ms: 100,
memory_budget_mb: 2048,
energy_budget_watts: 50.0,
adaptive_precision_threshold: 0.7,
skip_probability_threshold: 0.3,
dynamic_batch_size: true,
progressive_inference: true,
uncertainty_estimation: true,
calibration_enabled: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdaptiveInferenceResult {
pub prediction: PipelineOutput,
pub early_exit_result: Option<EarlyExitResult>,
pub precision_used: PrecisionMode,
pub layers_computed: usize,
pub layers_skipped: usize,
pub conditional_computations: usize,
pub total_computation_time_ms: u64,
pub memory_peak_mb: f64,
pub energy_consumed_watts: f32,
pub quality_score: f32,
pub uncertainty_score: f32,
pub resource_efficiency: f32,
pub latency_vs_quality_tradeoff: f32,
pub adaptation_decisions: Vec<AdaptationDecision>,
pub performance_metrics: PerformanceMetrics,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdaptationDecision {
pub layer_index: usize,
pub decision_type: String,
pub reason: String,
pub confidence: f32,
pub resource_impact: f32,
pub quality_impact: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerformanceMetrics {
pub throughput_tokens_per_second: f32,
pub latency_percentiles: HashMap<String, f64>,
pub memory_efficiency: f32,
pub energy_efficiency: f32,
pub quality_preservation: f32,
pub speedup_factor: f32,
}
#[derive(Debug, Clone)]
pub struct LayerAnalysis {
pub layer_index: usize,
pub importance_score: f32,
pub complexity_score: f32,
pub skip_probability: f32,
pub precision_requirement: PrecisionMode,
pub memory_footprint: f64,
pub computation_cost: f32,
pub quality_contribution: f32,
}
#[derive(Debug, Clone)]
pub struct InputAnalysis {
pub sequence_length: usize,
pub complexity_score: f32,
pub difficulty_estimate: f32,
pub attention_patterns: Vec<f32>,
pub token_importance: Vec<f32>,
pub estimated_computation_cost: f32,
pub recommended_precision: PrecisionMode,
pub recommended_depth: usize,
}
#[derive(Clone)]
pub struct AdaptiveInferenceEngine<P> {
base_pipeline: P,
early_exit_pipeline: EarlyExitPipeline<P>,
config: AdaptiveInferenceConfig,
layer_analyzers: Vec<LayerAnalyzer>,
input_analyzer: InputAnalyzer,
resource_monitor: ResourceMonitor,
precision_controller: PrecisionController,
conditional_controller: ConditionalController,
performance_tracker: PerformanceTracker,
adaptation_history: Vec<AdaptationDecision>,
}
#[derive(Clone)]
pub struct LayerAnalyzer {
layer_index: usize,
importance_weights: Vec<f32>,
complexity_model: ComplexityModel,
skip_predictor: SkipPredictor,
precision_selector: PrecisionSelector,
}
#[derive(Clone)]
pub struct InputAnalyzer {
complexity_estimator: ComplexityEstimator,
attention_pattern_analyzer: AttentionPatternAnalyzer,
token_importance_ranker: TokenImportanceRanker,
difficulty_predictor: DifficultyPredictor,
}
#[derive(Clone)]
pub struct ResourceMonitor {
cpu_usage: f32,
memory_usage: f64,
gpu_utilization: f32,
energy_consumption: f32,
temperature: f32,
bandwidth_usage: f32,
latency_budget_remaining: u64,
memory_budget_remaining: u64,
energy_budget_remaining: f32,
}
#[derive(Clone)]
pub struct PrecisionController {
current_precision: PrecisionMode,
layer_precisions: HashMap<usize, PrecisionMode>,
precision_history: Vec<(usize, PrecisionMode, f32)>, calibration_data: HashMap<PrecisionMode, f32>,
}
#[derive(Clone)]
pub struct ConditionalController {
skip_decisions: HashMap<usize, bool>,
conditional_probabilities: HashMap<usize, f32>,
activation_patterns: HashMap<usize, Vec<f32>>,
skip_history: Vec<(usize, bool, f32)>, }
#[derive(Clone)]
pub struct PerformanceTracker {
start_time: Instant,
layer_times: Vec<Duration>,
memory_snapshots: Vec<f64>,
energy_snapshots: Vec<f32>,
quality_scores: Vec<f32>,
throughput_history: Vec<f32>,
}
#[derive(Clone)]
pub struct ComplexityModel;
#[derive(Clone)]
pub struct SkipPredictor;
#[derive(Clone)]
pub struct PrecisionSelector;
#[derive(Clone)]
pub struct ComplexityEstimator;
#[derive(Clone)]
pub struct AttentionPatternAnalyzer;
#[derive(Clone)]
pub struct TokenImportanceRanker;
#[derive(Clone)]
pub struct DifficultyPredictor;
impl<P> AdaptiveInferenceEngine<P>
where
P: Pipeline<Output = PipelineOutput> + Clone,
{
pub fn new(base_pipeline: P, config: AdaptiveInferenceConfig) -> Self {
let early_exit_pipeline =
EarlyExitPipeline::new(base_pipeline.clone(), config.early_exit_config.clone());
Self {
base_pipeline,
early_exit_pipeline,
config,
layer_analyzers: Vec::new(),
input_analyzer: InputAnalyzer::new(),
resource_monitor: ResourceMonitor::new(),
precision_controller: PrecisionController::new(),
conditional_controller: ConditionalController::new(),
performance_tracker: PerformanceTracker::new(),
adaptation_history: Vec::new(),
}
}
pub fn adaptive_inference(&mut self, input: P::Input) -> Result<AdaptiveInferenceResult>
where
P::Input: Clone,
{
let start_time = Instant::now();
self.performance_tracker.start_time = start_time;
let input_analysis = self.input_analyzer.analyze_input(&input)?;
self.make_global_adaptations(&input_analysis)?;
let result = self.execute_adaptive_inference(input, &input_analysis)?;
self.performance_tracker.update_final_metrics(&result);
Ok(result)
}
fn make_global_adaptations(&mut self, input_analysis: &InputAnalysis) -> Result<()> {
self.adapt_precision_strategy(input_analysis)?;
self.adapt_conditional_strategy(input_analysis)?;
self.adapt_resource_allocation(input_analysis)?;
self.adapt_early_exit_thresholds(input_analysis)?;
Ok(())
}
fn adapt_precision_strategy(&mut self, input_analysis: &InputAnalysis) -> Result<()> {
let precision = if input_analysis.complexity_score > 0.8 {
PrecisionMode::Full
} else if input_analysis.complexity_score > 0.6 {
PrecisionMode::Mixed
} else if input_analysis.complexity_score > 0.4 {
PrecisionMode::Half
} else {
PrecisionMode::Int8
};
self.precision_controller.current_precision = precision;
for layer_idx in 0..24 {
let layer_precision = self.determine_layer_precision(layer_idx, input_analysis)?;
self.precision_controller.layer_precisions.insert(layer_idx, layer_precision);
}
Ok(())
}
fn determine_layer_precision(
&self,
layer_idx: usize,
input_analysis: &InputAnalysis,
) -> Result<PrecisionMode> {
if layer_idx < 6 {
return Ok(PrecisionMode::Int8);
}
if layer_idx < 18 {
return Ok(if input_analysis.complexity_score > 0.7 {
PrecisionMode::Half
} else {
PrecisionMode::Int8
});
}
Ok(if input_analysis.complexity_score > 0.8 {
PrecisionMode::Full
} else {
PrecisionMode::Mixed
})
}
fn adapt_conditional_strategy(&mut self, input_analysis: &InputAnalysis) -> Result<()> {
for layer_idx in 0..24 {
let skip_prob = self.calculate_skip_probability(layer_idx, input_analysis)?;
self.conditional_controller
.conditional_probabilities
.insert(layer_idx, skip_prob);
}
Ok(())
}
fn calculate_skip_probability(
&self,
layer_idx: usize,
input_analysis: &InputAnalysis,
) -> Result<f32> {
let base_skip_prob = match self.config.conditional_strategy {
ConditionalStrategy::AttentionSkipping
if layer_idx.is_multiple_of(2) && input_analysis.complexity_score < 0.5 => {
0.3
},
ConditionalStrategy::FeedForwardSkipping
if !layer_idx.is_multiple_of(2) && input_analysis.complexity_score < 0.6 => {
0.4
},
ConditionalStrategy::BlockSkipping
if input_analysis.complexity_score < 0.3 => {
0.2
},
ConditionalStrategy::DynamicDepth => {
let target_depth = (input_analysis.difficulty_estimate * 24.0) as usize;
if layer_idx > target_depth {
0.8
} else {
0.0
}
},
_ => 0.0,
};
let resource_factor = if self.resource_monitor.memory_budget_remaining < 512 {
1.5 } else {
1.0
};
Ok((base_skip_prob as f32 * resource_factor as f32).min(0.9f32))
}
fn adapt_resource_allocation(&mut self, input_analysis: &InputAnalysis) -> Result<()> {
let strategy = self.config.resource_strategy.clone();
match strategy {
ResourceStrategy::MinLatency => {
self.precision_controller.current_precision = PrecisionMode::Half;
self.config.quality_threshold = 0.6;
},
ResourceStrategy::MinMemory => {
self.precision_controller.current_precision = PrecisionMode::Int8;
self.config.skip_probability_threshold = 0.5;
},
ResourceStrategy::MinEnergy => {
self.precision_controller.current_precision = PrecisionMode::Int4;
self.config.skip_probability_threshold = 0.4;
},
ResourceStrategy::Balanced => {
self.precision_controller.current_precision = PrecisionMode::Mixed;
self.config.quality_threshold = 0.75;
},
ResourceStrategy::MaxThroughput => {
self.precision_controller.current_precision = PrecisionMode::Half;
self.config.dynamic_batch_size = true;
},
ResourceStrategy::Custom(allocation) => {
self.apply_custom_resource_allocation(&allocation)?;
},
}
Ok(())
}
fn apply_custom_resource_allocation(&mut self, allocation: &ResourceAllocation) -> Result<()> {
self.resource_monitor.memory_budget_remaining = allocation.memory_limit_mb;
self.resource_monitor.energy_budget_remaining = allocation.energy_budget_watts;
self.resource_monitor.latency_budget_remaining = allocation.latency_budget_ms;
self.config.quality_threshold = allocation.quality_threshold;
Ok(())
}
fn adapt_early_exit_thresholds(&mut self, input_analysis: &InputAnalysis) -> Result<()> {
if input_analysis.complexity_score < 0.3 {
self.early_exit_pipeline.exit_predictor_mut().config_mut().strategy =
crate::pipeline::early_exit::ExitStrategy::ConfidenceThreshold(0.7);
} else if input_analysis.complexity_score > 0.8 {
self.early_exit_pipeline.exit_predictor_mut().config_mut().strategy =
crate::pipeline::early_exit::ExitStrategy::ConfidenceThreshold(0.95);
}
Ok(())
}
fn execute_adaptive_inference(
&mut self,
input: P::Input,
input_analysis: &InputAnalysis,
) -> Result<AdaptiveInferenceResult>
where
P::Input: Clone,
{
let start_time = Instant::now();
let mut adaptation_decisions = Vec::new();
let mut layers_computed = 0;
let mut layers_skipped = 0;
let conditional_computations = 0;
let early_exit_result = if self.config.progressive_inference {
self.early_exit_pipeline.__call__(input.clone()).ok()
} else {
None
};
let prediction = if let Some(ref early_result) = early_exit_result {
if early_result.confidence_score >= self.config.quality_threshold {
layers_computed = early_result.total_layers_computed;
layers_skipped = 24 - layers_computed; early_result.prediction.clone()
} else {
self.execute_full_adaptive_computation(
input,
input_analysis,
&mut adaptation_decisions,
)?
}
} else {
self.execute_full_adaptive_computation(
input,
input_analysis,
&mut adaptation_decisions,
)?
};
let total_time = start_time.elapsed().as_millis() as u64;
let memory_peak = self.resource_monitor.memory_usage;
let energy_consumed = self.resource_monitor.energy_consumption;
let quality_score = self.estimate_quality_score(&prediction, &early_exit_result)?;
let uncertainty_score = self.estimate_uncertainty_score(&prediction)?;
let resource_efficiency =
self.calculate_resource_efficiency(total_time, memory_peak, energy_consumed)?;
let latency_vs_quality_tradeoff =
self.calculate_latency_quality_tradeoff(total_time, quality_score)?;
Ok(AdaptiveInferenceResult {
prediction,
early_exit_result,
precision_used: self.precision_controller.current_precision.clone(),
layers_computed,
layers_skipped,
conditional_computations,
total_computation_time_ms: total_time,
memory_peak_mb: memory_peak,
energy_consumed_watts: energy_consumed,
quality_score,
uncertainty_score,
resource_efficiency,
latency_vs_quality_tradeoff,
adaptation_decisions,
performance_metrics: self.calculate_performance_metrics(
total_time,
memory_peak,
energy_consumed,
)?,
})
}
fn execute_full_adaptive_computation(
&mut self,
input: P::Input,
input_analysis: &InputAnalysis,
adaptation_decisions: &mut Vec<AdaptationDecision>,
) -> Result<PipelineOutput>
where
P::Input: Clone,
{
adaptation_decisions.push(AdaptationDecision {
layer_index: 0,
decision_type: "precision_adaptation".to_string(),
reason: format!(
"Adapted to {:?} based on complexity score {:.2}",
self.precision_controller.current_precision, input_analysis.complexity_score
),
confidence: 0.9,
resource_impact: 0.2,
quality_impact: 0.1,
});
for layer_idx in 0..24 {
let skip_prob = self
.conditional_controller
.conditional_probabilities
.get(&layer_idx)
.unwrap_or(&0.0);
if *skip_prob > self.config.skip_probability_threshold {
adaptation_decisions.push(AdaptationDecision {
layer_index: layer_idx,
decision_type: "layer_skip".to_string(),
reason: format!(
"Skipped layer {} with probability {:.2}",
layer_idx, skip_prob
),
confidence: *skip_prob,
resource_impact: 0.3,
quality_impact: 0.05,
});
} else {
let precision = self
.precision_controller
.layer_precisions
.get(&layer_idx)
.unwrap_or(&PrecisionMode::Mixed);
adaptation_decisions.push(AdaptationDecision {
layer_index: layer_idx,
decision_type: "precision_selection".to_string(),
reason: format!("Used {:?} precision for layer {}", precision, layer_idx),
confidence: 0.8,
resource_impact: 0.1,
quality_impact: 0.02,
});
}
}
self.base_pipeline.__call__(input).map_err(Into::into)
}
fn estimate_quality_score(
&self,
prediction: &PipelineOutput,
early_exit_result: &Option<EarlyExitResult>,
) -> Result<f32> {
if let Some(early_result) = early_exit_result {
Ok(early_result.quality_score)
} else {
match prediction {
PipelineOutput::Classification(results) => {
if results.is_empty() {
Ok(0.0)
} else {
Ok(results[0].score)
}
},
PipelineOutput::QuestionAnswering(result) => Ok(result.score),
PipelineOutput::Generation(result) => {
let length_factor = (result.generated_text.len() as f32 / 100.0).min(1.0);
Ok(length_factor * 0.8) },
_ => Ok(0.8), }
}
}
fn estimate_uncertainty_score(&self, prediction: &PipelineOutput) -> Result<f32> {
match prediction {
PipelineOutput::Classification(results) => {
if results.len() < 2 {
return Ok(0.5);
}
let total: f32 = results.iter().map(|r| r.score).sum();
if total == 0.0 {
return Ok(1.0); }
let entropy: f32 = results
.iter()
.map(|r| {
let p = r.score / total;
if p > 0.0 {
-p * p.ln()
} else {
0.0
}
})
.sum();
let max_entropy = (results.len() as f32).ln();
Ok(entropy / max_entropy)
},
_ => Ok(0.3), }
}
fn calculate_resource_efficiency(
&self,
time_ms: u64,
memory_mb: f64,
energy_watts: f32,
) -> Result<f32> {
let time_factor = 1.0 / (time_ms as f32 / 1000.0 + 1.0);
let memory_factor = 1.0 / (memory_mb as f32 / 1024.0 + 1.0);
let energy_factor = 1.0 / (energy_watts + 1.0);
Ok((time_factor + memory_factor + energy_factor) / 3.0)
}
fn calculate_latency_quality_tradeoff(&self, time_ms: u64, quality_score: f32) -> Result<f32> {
let latency_normalized = (time_ms as f32) / (self.config.latency_budget_ms as f32);
let quality_normalized = quality_score;
Ok(quality_normalized / (latency_normalized + 1.0))
}
fn calculate_performance_metrics(
&self,
time_ms: u64,
memory_mb: f64,
energy_watts: f32,
) -> Result<PerformanceMetrics> {
let mut latency_percentiles = HashMap::new();
latency_percentiles.insert("p50".to_string(), time_ms as f64);
latency_percentiles.insert("p90".to_string(), time_ms as f64 * 1.2);
latency_percentiles.insert("p99".to_string(), time_ms as f64 * 1.5);
Ok(PerformanceMetrics {
throughput_tokens_per_second: 1000.0 / (time_ms as f32), latency_percentiles,
memory_efficiency: 1.0 / (memory_mb as f32 / 1024.0 + 1.0),
energy_efficiency: 1.0 / (energy_watts + 1.0),
quality_preservation: 0.9, speedup_factor: 2.0, })
}
}
impl InputAnalyzer {
fn new() -> Self {
Self {
complexity_estimator: ComplexityEstimator,
attention_pattern_analyzer: AttentionPatternAnalyzer,
token_importance_ranker: TokenImportanceRanker,
difficulty_predictor: DifficultyPredictor,
}
}
fn analyze_input<T>(&self, input: &T) -> Result<InputAnalysis> {
Ok(InputAnalysis {
sequence_length: 256, complexity_score: 0.6,
difficulty_estimate: 0.7,
attention_patterns: vec![0.5; 12], token_importance: vec![0.3; 256], estimated_computation_cost: 100.0,
recommended_precision: PrecisionMode::Mixed,
recommended_depth: 18,
})
}
}
impl ResourceMonitor {
fn new() -> Self {
Self {
cpu_usage: 0.0,
memory_usage: 512.0,
gpu_utilization: 0.0,
energy_consumption: 10.0,
temperature: 45.0,
bandwidth_usage: 0.0,
latency_budget_remaining: 100,
memory_budget_remaining: 2048,
energy_budget_remaining: 50.0,
}
}
}
impl PrecisionController {
fn new() -> Self {
Self {
current_precision: PrecisionMode::Mixed,
layer_precisions: HashMap::new(),
precision_history: Vec::new(),
calibration_data: HashMap::new(),
}
}
}
impl ConditionalController {
fn new() -> Self {
Self {
skip_decisions: HashMap::new(),
conditional_probabilities: HashMap::new(),
activation_patterns: HashMap::new(),
skip_history: Vec::new(),
}
}
}
impl PerformanceTracker {
fn new() -> Self {
Self {
start_time: Instant::now(),
layer_times: Vec::new(),
memory_snapshots: Vec::new(),
energy_snapshots: Vec::new(),
quality_scores: Vec::new(),
throughput_history: Vec::new(),
}
}
fn update_final_metrics(&mut self, result: &AdaptiveInferenceResult) {
self.quality_scores.push(result.quality_score);
self.throughput_history
.push(result.performance_metrics.throughput_tokens_per_second);
self.memory_snapshots.push(result.memory_peak_mb);
self.energy_snapshots.push(result.energy_consumed_watts);
}
}
pub fn create_adaptive_inference_pipeline<P>(
base_pipeline: P,
config: AdaptiveInferenceConfig,
) -> AdaptiveInferenceEngine<P>
where
P: Pipeline<Output = PipelineOutput> + Clone,
{
AdaptiveInferenceEngine::new(base_pipeline, config)
}
pub fn create_latency_optimized_pipeline<P>(
base_pipeline: P,
latency_budget_ms: u64,
) -> AdaptiveInferenceEngine<P>
where
P: Pipeline<Output = PipelineOutput> + Clone,
{
let mut config = AdaptiveInferenceConfig::default();
config.resource_strategy = ResourceStrategy::MinLatency;
config.latency_budget_ms = latency_budget_ms;
config.precision_mode = PrecisionMode::Half;
config.conditional_strategy = ConditionalStrategy::DynamicDepth;
AdaptiveInferenceEngine::new(base_pipeline, config)
}
pub fn create_memory_efficient_pipeline<P>(
base_pipeline: P,
memory_budget_mb: u64,
) -> AdaptiveInferenceEngine<P>
where
P: Pipeline<Output = PipelineOutput> + Clone,
{
let mut config = AdaptiveInferenceConfig::default();
config.resource_strategy = ResourceStrategy::MinMemory;
config.memory_budget_mb = memory_budget_mb;
config.precision_mode = PrecisionMode::Int8;
config.conditional_strategy = ConditionalStrategy::BlockSkipping;
AdaptiveInferenceEngine::new(base_pipeline, config)
}
pub fn create_energy_efficient_pipeline<P>(
base_pipeline: P,
energy_budget_watts: f32,
) -> AdaptiveInferenceEngine<P>
where
P: Pipeline<Output = PipelineOutput> + Clone,
{
let mut config = AdaptiveInferenceConfig::default();
config.resource_strategy = ResourceStrategy::MinEnergy;
config.energy_budget_watts = energy_budget_watts;
config.precision_mode = PrecisionMode::Int4;
config.conditional_strategy = ConditionalStrategy::AttentionSkipping;
AdaptiveInferenceEngine::new(base_pipeline, config)
}
pub fn create_balanced_adaptive_pipeline<P>(
base_pipeline: P,
quality_threshold: f32,
) -> AdaptiveInferenceEngine<P>
where
P: Pipeline<Output = PipelineOutput> + Clone,
{
let mut config = AdaptiveInferenceConfig::default();
config.resource_strategy = ResourceStrategy::Balanced;
config.quality_threshold = quality_threshold;
config.precision_mode = PrecisionMode::Adaptive;
config.conditional_strategy = ConditionalStrategy::DynamicDepth;
config.progressive_inference = true;
config.uncertainty_estimation = true;
AdaptiveInferenceEngine::new(base_pipeline, config)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_adaptive_inference_config_default() {
let config = AdaptiveInferenceConfig::default();
assert!(matches!(config.precision_mode, PrecisionMode::Mixed));
assert!(matches!(
config.conditional_strategy,
ConditionalStrategy::DynamicDepth
));
assert!(matches!(
config.resource_strategy,
ResourceStrategy::Balanced
));
assert_eq!(config.quality_threshold, 0.8);
}
#[test]
fn test_precision_mode_selection() {
let config = AdaptiveInferenceConfig::default();
assert!(matches!(config.precision_mode, PrecisionMode::Mixed));
}
#[test]
fn test_resource_allocation() {
let allocation = ResourceAllocation {
cpu_cores: 4,
memory_limit_mb: 1024,
gpu_memory_limit_mb: 2048,
energy_budget_watts: 25.0,
latency_budget_ms: 50,
quality_threshold: 0.9,
};
assert_eq!(allocation.cpu_cores, 4);
assert_eq!(allocation.memory_limit_mb, 1024);
assert_eq!(allocation.quality_threshold, 0.9);
}
#[test]
fn test_input_analysis() {
let analyzer = InputAnalyzer::new();
let input = "test input";
let analysis = analyzer.analyze_input(&input).expect("operation failed in test");
assert_eq!(analysis.sequence_length, 256);
assert!(analysis.complexity_score > 0.0);
assert!(analysis.difficulty_estimate > 0.0);
}
#[test]
fn test_performance_metrics() {
let metrics = PerformanceMetrics {
throughput_tokens_per_second: 100.0,
latency_percentiles: HashMap::new(),
memory_efficiency: 0.8,
energy_efficiency: 0.9,
quality_preservation: 0.95,
speedup_factor: 2.5,
};
assert_eq!(metrics.throughput_tokens_per_second, 100.0);
assert_eq!(metrics.speedup_factor, 2.5);
}
#[test]
fn test_config_default_latency_budget_positive() {
let cfg = AdaptiveInferenceConfig::default();
assert!(
cfg.latency_budget_ms > 0,
"latency_budget_ms must be positive"
);
}
#[test]
fn test_config_default_memory_budget_positive() {
let cfg = AdaptiveInferenceConfig::default();
assert!(
cfg.memory_budget_mb > 0,
"memory_budget_mb must be positive"
);
}
#[test]
fn test_config_default_energy_budget_positive() {
let cfg = AdaptiveInferenceConfig::default();
assert!(
cfg.energy_budget_watts > 0.0,
"energy_budget_watts must be positive"
);
}
#[test]
fn test_config_default_quality_threshold_in_range() {
let cfg = AdaptiveInferenceConfig::default();
assert!(
cfg.quality_threshold > 0.0 && cfg.quality_threshold <= 1.0,
"quality_threshold must be in (0, 1]"
);
}
#[test]
fn test_config_default_skip_probability_threshold_in_range() {
let cfg = AdaptiveInferenceConfig::default();
assert!(
cfg.skip_probability_threshold >= 0.0 && cfg.skip_probability_threshold <= 1.0,
"skip_probability_threshold must be in [0, 1]"
);
}
#[test]
fn test_config_default_adaptive_precision_threshold_in_range() {
let cfg = AdaptiveInferenceConfig::default();
assert!(cfg.adaptive_precision_threshold >= 0.0 && cfg.adaptive_precision_threshold <= 1.0);
}
#[test]
fn test_config_flags_enabled_by_default() {
let cfg = AdaptiveInferenceConfig::default();
assert!(cfg.dynamic_batch_size);
assert!(cfg.progressive_inference);
assert!(cfg.uncertainty_estimation);
assert!(cfg.calibration_enabled);
}
#[test]
fn test_precision_mode_all_variants_constructable() {
let modes = [
PrecisionMode::Full,
PrecisionMode::Half,
PrecisionMode::Mixed,
PrecisionMode::Int8,
PrecisionMode::Int4,
PrecisionMode::Dynamic,
PrecisionMode::Adaptive,
];
assert_eq!(modes.len(), 7);
}
#[test]
fn test_conditional_strategy_all_variants_constructable() {
let strategies = [
ConditionalStrategy::AttentionSkipping,
ConditionalStrategy::FeedForwardSkipping,
ConditionalStrategy::BlockSkipping,
ConditionalStrategy::DynamicDepth,
ConditionalStrategy::SparseActivation,
ConditionalStrategy::TokenConditional,
ConditionalStrategy::LayerConditional,
];
assert_eq!(strategies.len(), 7);
}
#[test]
fn test_resource_strategy_balanced_is_default() {
let cfg = AdaptiveInferenceConfig::default();
assert!(matches!(cfg.resource_strategy, ResourceStrategy::Balanced));
}
#[test]
fn test_adaptation_decision_struct_fields() {
let decision = AdaptationDecision {
layer_index: 3,
decision_type: "layer_skip".to_string(),
reason: "low complexity".to_string(),
confidence: 0.95,
resource_impact: 0.1,
quality_impact: 0.02,
};
assert_eq!(decision.layer_index, 3);
assert!(decision.confidence > 0.0 && decision.confidence <= 1.0);
assert!(decision.resource_impact >= 0.0);
assert!(decision.quality_impact >= 0.0);
}
#[test]
fn test_resource_allocation_custom_fields() {
let alloc = ResourceAllocation {
cpu_cores: 8,
memory_limit_mb: 4096,
gpu_memory_limit_mb: 8192,
energy_budget_watts: 100.0,
latency_budget_ms: 200,
quality_threshold: 0.95,
};
assert_eq!(alloc.cpu_cores, 8);
assert_eq!(alloc.memory_limit_mb, 4096);
assert!(alloc.quality_threshold > 0.0 && alloc.quality_threshold <= 1.0);
}
#[test]
fn test_input_analysis_recommended_precision_valid() {
let analyzer = InputAnalyzer::new();
let analysis = analyzer.analyze_input(&"sample").expect("analyze_input should succeed");
let _mode = &analysis.recommended_precision;
assert!(analysis.estimated_computation_cost > 0.0);
}
#[test]
fn test_input_analysis_recommended_depth_within_model_bounds() {
let analyzer = InputAnalyzer::new();
let analysis = analyzer.analyze_input(&"test input").expect("analyze_input should succeed");
assert!(
analysis.recommended_depth <= 24,
"recommended depth must not exceed model depth"
);
assert!(
analysis.recommended_depth > 0,
"recommended depth must be positive"
);
}
#[test]
fn test_performance_metrics_latency_throughput_relationship() {
let latency_ms = 10.0_f32;
let expected_throughput = 1000.0 / latency_ms;
let metrics = PerformanceMetrics {
throughput_tokens_per_second: expected_throughput,
latency_percentiles: HashMap::new(),
memory_efficiency: 0.9,
energy_efficiency: 0.8,
quality_preservation: 0.95,
speedup_factor: 2.0,
};
assert!(
(metrics.throughput_tokens_per_second - 100.0).abs() < 1e-4,
"throughput should be 1000/latency_ms"
);
}
#[test]
fn test_resource_efficiency_higher_speed_is_better() {
let time_fast = 10_u64;
let time_slow = 1000_u64;
let mem = 512.0_f64;
let energy = 50.0_f32;
let ef_fast = {
let t = 1.0 / (time_fast as f32 / 1000.0 + 1.0);
let m = 1.0 / (mem as f32 / 1024.0 + 1.0);
let e = 1.0 / (energy + 1.0);
(t + m + e) / 3.0
};
let ef_slow = {
let t = 1.0 / (time_slow as f32 / 1000.0 + 1.0);
let m = 1.0 / (mem as f32 / 1024.0 + 1.0);
let e = 1.0 / (energy + 1.0);
(t + m + e) / 3.0
};
assert!(
ef_fast > ef_slow,
"faster execution should yield higher resource efficiency"
);
}
#[test]
fn test_latency_quality_tradeoff_higher_quality_better() {
let latency_ms = 50_u64;
let budget_ms = 100_u64;
let quality_high = 0.95_f32;
let quality_low = 0.5_f32;
let tradeoff = |q: f32| {
let l_norm = latency_ms as f32 / budget_ms as f32;
q / (l_norm + 1.0)
};
assert!(
tradeoff(quality_high) > tradeoff(quality_low),
"higher quality should produce better tradeoff score"
);
}
}