use crate::error::Result;
use crate::pipeline::{Pipeline, PipelineOutput};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Instant;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ExitStrategy {
ConfidenceThreshold(f32),
EntropyThreshold(f32),
VarianceThreshold(f32),
ConsistencyThreshold(f32),
ComputationalBudget(u64), EnergyBudget(f32),
AdaptiveThreshold,
Patience(u32),
Combined(Vec<ExitStrategy>),
LearnedExit,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EarlyExitConfig {
pub strategy: ExitStrategy,
pub min_layers: usize,
pub max_layers: usize,
pub patience_threshold: u32,
pub confidence_calibration: bool,
pub dynamic_threshold_adjustment: bool,
pub performance_tracking: bool,
pub energy_aware: bool,
pub memory_aware: bool,
pub context_aware: bool,
pub fallback_to_full: bool,
pub exit_point_optimization: bool,
}
impl Default for EarlyExitConfig {
fn default() -> Self {
Self {
strategy: ExitStrategy::ConfidenceThreshold(0.9),
min_layers: 6,
max_layers: 12,
patience_threshold: 3,
confidence_calibration: true,
dynamic_threshold_adjustment: true,
performance_tracking: true,
energy_aware: false,
memory_aware: true,
context_aware: true,
fallback_to_full: true,
exit_point_optimization: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExitPoint {
pub layer_index: usize,
pub confidence_score: f32,
pub entropy_score: f32,
pub variance_score: f32,
pub consistency_score: f32,
pub computation_time_ms: u64,
pub energy_consumed: f32,
pub memory_used_mb: f64,
pub should_exit: bool,
pub exit_reason: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EarlyExitResult {
pub prediction: PipelineOutput,
pub exit_point: ExitPoint,
pub total_layers_computed: usize,
pub computation_saved_percent: f32,
pub energy_saved_percent: f32,
pub confidence_score: f32,
pub quality_score: f32,
pub exit_path: Vec<ExitPoint>,
pub final_decision_reason: String,
}
#[derive(Debug, Clone)]
pub struct LayerOutput {
pub layer_index: usize,
pub hidden_states: Vec<f32>, pub attention_weights: Option<Vec<f32>>,
pub logits: Option<Vec<f32>>,
pub intermediate_prediction: Option<PipelineOutput>,
pub computation_time_ms: u64,
pub memory_usage_mb: f64,
}
#[derive(Clone)]
pub struct EarlyExitPredictor {
config: EarlyExitConfig,
exit_history: Vec<ExitPoint>,
performance_stats: HashMap<usize, PerformanceStats>,
adaptive_thresholds: HashMap<String, f32>,
energy_tracker: EnergyTracker,
memory_tracker: MemoryTracker,
context_analyzer: ContextAnalyzer,
}
#[derive(Debug, Clone)]
struct PerformanceStats {
total_exits: u64,
successful_exits: u64,
average_confidence: f32,
average_computation_time: f64,
accuracy_loss: f32,
}
#[derive(Debug, Clone)]
struct EnergyTracker {
baseline_energy_per_layer: f32,
current_energy_consumption: f32,
energy_budget_remaining: f32,
}
#[derive(Debug, Clone)]
struct MemoryTracker {
peak_memory_usage: f64,
current_memory_usage: f64,
memory_pressure_level: f32,
}
#[derive(Debug, Clone)]
struct ContextAnalyzer {
input_complexity_score: f32,
task_difficulty_estimate: f32,
domain_specific_threshold: f32,
}
impl EarlyExitPredictor {
pub fn new(config: EarlyExitConfig) -> Self {
Self {
config,
exit_history: Vec::new(),
performance_stats: HashMap::new(),
adaptive_thresholds: HashMap::new(),
energy_tracker: EnergyTracker {
baseline_energy_per_layer: 1.0,
current_energy_consumption: 0.0,
energy_budget_remaining: 100.0,
},
memory_tracker: MemoryTracker {
peak_memory_usage: 0.0,
current_memory_usage: 0.0,
memory_pressure_level: 0.0,
},
context_analyzer: ContextAnalyzer {
input_complexity_score: 0.5,
task_difficulty_estimate: 0.5,
domain_specific_threshold: 0.8,
},
}
}
pub fn config_mut(&mut self) -> &mut EarlyExitConfig {
&mut self.config
}
pub fn config(&self) -> &EarlyExitConfig {
&self.config
}
pub fn should_exit(&mut self, layer_output: &LayerOutput) -> Result<ExitPoint> {
let mut exit_point = self.create_base_exit_point(layer_output)?;
self.update_energy_tracking(layer_output);
self.update_memory_tracking(layer_output);
self.update_context_analysis(layer_output);
exit_point.should_exit = self.evaluate_exit_strategy(&exit_point, layer_output)?;
if layer_output.layer_index < self.config.min_layers {
exit_point.should_exit = false;
exit_point.exit_reason = format!("Below minimum layers ({})", self.config.min_layers);
}
if layer_output.layer_index >= self.config.max_layers {
exit_point.should_exit = true;
exit_point.exit_reason = "Reached maximum layers".to_string();
}
self.exit_history.push(exit_point.clone());
if self.exit_history.len() > 1000 {
self.exit_history.remove(0);
}
self.update_performance_stats(&exit_point);
Ok(exit_point)
}
fn create_base_exit_point(&self, layer_output: &LayerOutput) -> Result<ExitPoint> {
let confidence_score = self.calculate_confidence_score(layer_output)?;
let entropy_score = self.calculate_entropy_score(layer_output)?;
let variance_score = self.calculate_variance_score(layer_output)?;
let consistency_score = self.calculate_consistency_score(layer_output)?;
Ok(ExitPoint {
layer_index: layer_output.layer_index,
confidence_score,
entropy_score,
variance_score,
consistency_score,
computation_time_ms: layer_output.computation_time_ms,
energy_consumed: self.energy_tracker.current_energy_consumption,
memory_used_mb: layer_output.memory_usage_mb,
should_exit: false,
exit_reason: String::new(),
})
}
fn calculate_confidence_score(&self, layer_output: &LayerOutput) -> Result<f32> {
if let Some(ref logits) = layer_output.logits {
let max_logit = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let exp_sum: f32 = logits.iter().map(|&x| (x - max_logit).exp()).sum();
let max_prob = 1.0 / exp_sum; Ok(max_prob)
} else if let Some(ref prediction) = layer_output.intermediate_prediction {
match prediction {
PipelineOutput::Classification(results) => {
Ok(results.iter().map(|r| r.score).fold(0.0f32, f32::max))
},
PipelineOutput::QuestionAnswering(result) => Ok(result.score),
_ => Ok(0.8), }
} else {
let depth_factor = layer_output.layer_index as f32 / self.config.max_layers as f32;
Ok(0.5 + 0.3 * depth_factor) }
}
fn calculate_entropy_score(&self, layer_output: &LayerOutput) -> Result<f32> {
if let Some(ref logits) = layer_output.logits {
let max_logit = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let exp_sum: f32 = logits.iter().map(|&x| (x - max_logit).exp()).sum();
let entropy: f32 = logits
.iter()
.map(|&x| {
let prob = (x - max_logit).exp() / exp_sum;
if prob > 0.0 {
-prob * prob.ln()
} else {
0.0
}
})
.sum();
let max_entropy = (logits.len() as f32).ln();
Ok(1.0 - entropy / max_entropy)
} else {
let variance = self.calculate_hidden_state_variance(&layer_output.hidden_states);
Ok(1.0 / (1.0 + variance)) }
}
fn calculate_variance_score(&self, layer_output: &LayerOutput) -> Result<f32> {
let variance = self.calculate_hidden_state_variance(&layer_output.hidden_states);
Ok(1.0 / (1.0 + variance))
}
fn calculate_consistency_score(&self, layer_output: &LayerOutput) -> Result<f32> {
if self.exit_history.len() < 2 {
return Ok(0.5); }
let recent_confidences: Vec<f32> =
self.exit_history.iter().rev().take(3).map(|ep| ep.confidence_score).collect();
if recent_confidences.len() < 2 {
return Ok(0.5);
}
let mean = recent_confidences.iter().sum::<f32>() / recent_confidences.len() as f32;
let variance = recent_confidences.iter().map(|&x| (x - mean).powi(2)).sum::<f32>()
/ recent_confidences.len() as f32;
Ok(1.0 / (1.0 + variance))
}
fn calculate_hidden_state_variance(&self, hidden_states: &[f32]) -> f32 {
if hidden_states.is_empty() {
return 0.0;
}
let mean = hidden_states.iter().sum::<f32>() / hidden_states.len() as f32;
let variance = hidden_states.iter().map(|&x| (x - mean).powi(2)).sum::<f32>()
/ hidden_states.len() as f32;
variance
}
fn evaluate_exit_strategy(
&mut self,
exit_point: &ExitPoint,
layer_output: &LayerOutput,
) -> Result<bool> {
match &self.config.strategy {
ExitStrategy::ConfidenceThreshold(threshold) => {
let adjusted_threshold = self.get_adjusted_threshold("confidence", *threshold);
if exit_point.confidence_score >= adjusted_threshold {
Ok(true)
} else {
Ok(false)
}
},
ExitStrategy::EntropyThreshold(threshold) => {
let adjusted_threshold = self.get_adjusted_threshold("entropy", *threshold);
Ok(exit_point.entropy_score >= adjusted_threshold)
},
ExitStrategy::VarianceThreshold(threshold) => {
let adjusted_threshold = self.get_adjusted_threshold("variance", *threshold);
Ok(exit_point.variance_score >= adjusted_threshold)
},
ExitStrategy::ConsistencyThreshold(threshold) => {
let adjusted_threshold = self.get_adjusted_threshold("consistency", *threshold);
Ok(exit_point.consistency_score >= adjusted_threshold)
},
ExitStrategy::ComputationalBudget(budget_ms) => {
Ok(exit_point.computation_time_ms >= *budget_ms)
},
ExitStrategy::EnergyBudget(budget) => {
Ok(self.energy_tracker.energy_budget_remaining <= *budget)
},
ExitStrategy::AdaptiveThreshold => {
self.evaluate_adaptive_threshold(exit_point, layer_output)
},
ExitStrategy::Patience(max_patience) => {
self.evaluate_patience_strategy(exit_point, *max_patience)
},
ExitStrategy::Combined(strategies) => {
self.evaluate_combined_strategies(strategies, exit_point, layer_output)
},
ExitStrategy::LearnedExit => self.evaluate_learned_exit(exit_point, layer_output),
}
}
fn get_adjusted_threshold(&self, strategy_type: &str, base_threshold: f32) -> f32 {
if !self.config.dynamic_threshold_adjustment {
return base_threshold;
}
let mut adjusted = base_threshold;
if self.context_analyzer.input_complexity_score > 0.7 {
adjusted *= 0.9; }
if self.context_analyzer.task_difficulty_estimate > 0.8 {
adjusted *= 0.85; }
if self.memory_tracker.memory_pressure_level > 0.8 {
adjusted *= 1.1; }
if let Some(&adaptive_threshold) = self.adaptive_thresholds.get(strategy_type) {
adjusted = (adjusted + adaptive_threshold) / 2.0;
}
adjusted.clamp(0.1, 0.99)
}
fn evaluate_adaptive_threshold(
&mut self,
exit_point: &ExitPoint,
_layer_output: &LayerOutput,
) -> Result<bool> {
let confidence_weight = 0.4;
let entropy_weight = 0.2;
let consistency_weight = 0.2;
let context_weight = 0.2;
let composite_score = confidence_weight * exit_point.confidence_score
+ entropy_weight * exit_point.entropy_score
+ consistency_weight * exit_point.consistency_score
+ context_weight * (1.0 - self.context_analyzer.input_complexity_score);
let historical_threshold = self.calculate_historical_threshold();
let adaptive_threshold = (0.8 + historical_threshold) / 2.0;
Ok(composite_score >= adaptive_threshold)
}
fn evaluate_patience_strategy(
&self,
exit_point: &ExitPoint,
max_patience: u32,
) -> Result<bool> {
let mut patience_counter = 0;
let confidence_threshold = 0.8;
for previous_exit in self.exit_history.iter().rev() {
if previous_exit.confidence_score < confidence_threshold {
patience_counter += 1;
} else {
break;
}
}
Ok(patience_counter >= max_patience || exit_point.confidence_score >= 0.95)
}
fn evaluate_combined_strategies(
&self,
strategies: &[ExitStrategy],
exit_point: &ExitPoint,
layer_output: &LayerOutput,
) -> Result<bool> {
let mut exit_votes = 0;
let mut total_strategies = 0;
for strategy in strategies {
total_strategies += 1;
let mut temp_config = self.config.clone();
temp_config.strategy = strategy.clone();
let mut temp_predictor = EarlyExitPredictor::new(temp_config);
if temp_predictor.evaluate_exit_strategy(exit_point, layer_output)? {
exit_votes += 1;
}
}
Ok(exit_votes > total_strategies / 2)
}
fn evaluate_learned_exit(
&self,
exit_point: &ExitPoint,
_layer_output: &LayerOutput,
) -> Result<bool> {
let features = [
exit_point.confidence_score,
exit_point.entropy_score,
exit_point.consistency_score,
exit_point.layer_index as f32 / self.config.max_layers as f32,
self.context_analyzer.input_complexity_score,
self.memory_tracker.memory_pressure_level,
];
let weights = [0.3, 0.2, 0.2, 0.1, 0.1, 0.1];
let score: f32 = features.iter().zip(weights.iter()).map(|(f, w)| f * w).sum();
Ok(score >= 0.7)
}
fn calculate_historical_threshold(&self) -> f32 {
if self.exit_history.is_empty() {
return 0.8;
}
let successful_exits: Vec<&ExitPoint> =
self.exit_history.iter().filter(|ep| ep.should_exit).collect();
if successful_exits.is_empty() {
return 0.8;
}
let avg_confidence = successful_exits.iter().map(|ep| ep.confidence_score).sum::<f32>()
/ successful_exits.len() as f32;
avg_confidence * 0.9 }
fn update_energy_tracking(&mut self, layer_output: &LayerOutput) {
self.energy_tracker.current_energy_consumption +=
self.energy_tracker.baseline_energy_per_layer;
let complexity_factor = layer_output.hidden_states.len() as f32 / 1000.0;
self.energy_tracker.current_energy_consumption += complexity_factor;
self.energy_tracker.energy_budget_remaining -=
self.energy_tracker.baseline_energy_per_layer;
}
fn update_memory_tracking(&mut self, layer_output: &LayerOutput) {
self.memory_tracker.current_memory_usage = layer_output.memory_usage_mb;
if layer_output.memory_usage_mb > self.memory_tracker.peak_memory_usage {
self.memory_tracker.peak_memory_usage = layer_output.memory_usage_mb;
}
let memory_limit = 2048.0; self.memory_tracker.memory_pressure_level =
(self.memory_tracker.current_memory_usage / memory_limit).min(1.0) as f32;
}
fn update_context_analysis(&mut self, layer_output: &LayerOutput) {
let variance = self.calculate_hidden_state_variance(&layer_output.hidden_states);
self.context_analyzer.input_complexity_score =
(self.context_analyzer.input_complexity_score * 0.9 + variance * 0.1).clamp(0.0, 1.0);
if layer_output.layer_index > 0 {
let convergence_rate = self.calculate_convergence_rate();
self.context_analyzer.task_difficulty_estimate =
(1.0 - convergence_rate).clamp(0.0, 1.0);
}
}
fn calculate_convergence_rate(&self) -> f32 {
if self.exit_history.len() < 3 {
return 0.5;
}
let recent_confidences: Vec<f32> =
self.exit_history.iter().rev().take(3).map(|ep| ep.confidence_score).collect();
let improvement = recent_confidences[0] - recent_confidences[2];
(improvement + 1.0) / 2.0 }
fn update_performance_stats(&mut self, exit_point: &ExitPoint) {
let layer_index = exit_point.layer_index;
let stats = self.performance_stats.entry(layer_index).or_insert(PerformanceStats {
total_exits: 0,
successful_exits: 0,
average_confidence: 0.0,
average_computation_time: 0.0,
accuracy_loss: 0.0,
});
stats.total_exits += 1;
if exit_point.should_exit {
stats.successful_exits += 1;
}
let alpha = 0.1f32; stats.average_confidence =
stats.average_confidence * (1.0 - alpha) + exit_point.confidence_score * alpha;
stats.average_computation_time = stats.average_computation_time * (1.0 - alpha as f64)
+ exit_point.computation_time_ms as f64 * alpha as f64;
}
pub fn get_performance_stats(&self) -> &HashMap<usize, PerformanceStats> {
&self.performance_stats
}
pub fn reset(&mut self) {
self.exit_history.clear();
self.energy_tracker.current_energy_consumption = 0.0;
self.energy_tracker.energy_budget_remaining = 100.0;
self.memory_tracker.current_memory_usage = 0.0;
self.memory_tracker.peak_memory_usage = 0.0;
self.context_analyzer.input_complexity_score = 0.5;
self.context_analyzer.task_difficulty_estimate = 0.5;
}
}
#[derive(Clone)]
pub struct EarlyExitPipeline<P> {
base_pipeline: P,
exit_predictor: EarlyExitPredictor,
}
impl<P> EarlyExitPipeline<P>
where
P: Pipeline,
{
pub fn new(base_pipeline: P, config: EarlyExitConfig) -> Self {
Self {
base_pipeline,
exit_predictor: EarlyExitPredictor::new(config),
}
}
pub fn exit_predictor_mut(&mut self) -> &mut EarlyExitPredictor {
&mut self.exit_predictor
}
pub fn exit_predictor(&self) -> &EarlyExitPredictor {
&self.exit_predictor
}
fn simulate_layer_by_layer_processing(&self, input: &P::Input) -> Result<EarlyExitResult> {
let start_time = Instant::now();
let mut exit_path = Vec::new();
let mut current_layer = 0;
let max_layers = self.exit_predictor.config.max_layers;
while current_layer < max_layers {
let layer_start = Instant::now();
let hidden_states = self.simulate_layer_computation(current_layer);
let layer_output = LayerOutput {
layer_index: current_layer,
hidden_states: hidden_states.clone(),
attention_weights: Some(vec![0.5; 64]), logits: if current_layer >= self.exit_predictor.config.min_layers {
Some(self.simulate_logits(&hidden_states))
} else {
None
},
intermediate_prediction: if current_layer >= self.exit_predictor.config.min_layers {
Some(self.simulate_intermediate_prediction(&hidden_states)?)
} else {
None
},
computation_time_ms: layer_start.elapsed().as_millis() as u64,
memory_usage_mb: 100.0 + current_layer as f64 * 10.0, };
let mut temp_predictor = self.exit_predictor.clone();
let exit_point = temp_predictor.should_exit(&layer_output)?;
exit_path.push(exit_point.clone());
if exit_point.should_exit {
let total_time = start_time.elapsed().as_millis() as u64;
let computation_saved =
((max_layers - current_layer - 1) as f32 / max_layers as f32) * 100.0;
let energy_saved = computation_saved * 0.8;
let confidence_score = exit_point.confidence_score;
let exit_reason = exit_point.exit_reason.clone();
let quality_score = self.estimate_quality_score(&exit_point);
return Ok(EarlyExitResult {
prediction: layer_output.intermediate_prediction.unwrap_or_else(|| {
PipelineOutput::Summarization(
"Fallback prediction due to early exit".to_string(),
)
}),
exit_point,
total_layers_computed: current_layer + 1,
computation_saved_percent: computation_saved,
energy_saved_percent: energy_saved,
confidence_score,
quality_score,
exit_path,
final_decision_reason: exit_reason,
});
}
current_layer += 1;
}
let fallback_output =
PipelineOutput::Summarization("Full computation completed".to_string());
let final_prediction = fallback_output;
let total_time = start_time.elapsed().as_millis() as u64;
Ok(EarlyExitResult {
prediction: match final_prediction {
p => {
PipelineOutput::Summarization("Full pipeline prediction completed".to_string())
},
},
exit_point: ExitPoint {
layer_index: max_layers - 1,
confidence_score: 1.0,
entropy_score: 1.0,
variance_score: 1.0,
consistency_score: 1.0,
computation_time_ms: total_time,
energy_consumed: 100.0,
memory_used_mb: 100.0 + max_layers as f64 * 10.0,
should_exit: true,
exit_reason: "Completed all layers".to_string(),
},
total_layers_computed: max_layers,
computation_saved_percent: 0.0,
energy_saved_percent: 0.0,
confidence_score: 1.0,
quality_score: 1.0,
exit_path,
final_decision_reason: "Full computation completed".to_string(),
})
}
fn simulate_layer_computation(&self, layer_index: usize) -> Vec<f32> {
let size = 768; let mut hidden_states = Vec::with_capacity(size);
for i in 0..size {
let value = (layer_index as f32 * 0.1 + i as f32 * 0.001).sin() * 0.5;
hidden_states.push(value);
}
hidden_states
}
fn simulate_logits(&self, hidden_states: &[f32]) -> Vec<f32> {
let num_classes = 10;
let mut logits = Vec::with_capacity(num_classes);
for i in 0..num_classes {
let logit = hidden_states[i % hidden_states.len()] * 2.0 + (i as f32 * 0.1);
logits.push(logit);
}
logits
}
fn simulate_intermediate_prediction(&self, hidden_states: &[f32]) -> Result<PipelineOutput> {
let num_classes = 3;
let mut class_scores = Vec::new();
for i in 0..num_classes {
let score = hidden_states[i % hidden_states.len()].abs().min(1.0);
class_scores.push(crate::pipeline::ClassificationOutput {
label: format!("Class_{}", i),
score,
});
}
class_scores
.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
Ok(PipelineOutput::Classification(class_scores))
}
fn estimate_quality_score(&self, exit_point: &ExitPoint) -> f32 {
let depth_factor =
exit_point.layer_index as f32 / self.exit_predictor.config.max_layers as f32;
let confidence_factor = exit_point.confidence_score;
let consistency_factor = exit_point.consistency_score;
(depth_factor * 0.3 + confidence_factor * 0.5 + consistency_factor * 0.2).min(1.0)
}
}
impl<P> Pipeline for EarlyExitPipeline<P>
where
P: Pipeline,
P::Input: Clone,
{
type Input = P::Input;
type Output = EarlyExitResult;
fn __call__(&self, input: Self::Input) -> Result<Self::Output> {
self.simulate_layer_by_layer_processing(&input)
}
fn batch(&self, inputs: Vec<Self::Input>) -> Result<Vec<Self::Output>> {
inputs.into_iter().map(|input| self.__call__(input)).collect()
}
}
pub fn create_early_exit_pipeline<P>(
base_pipeline: P,
config: EarlyExitConfig,
) -> EarlyExitPipeline<P>
where
P: Pipeline,
{
EarlyExitPipeline::new(base_pipeline, config)
}
pub fn create_confidence_based_early_exit<P>(
base_pipeline: P,
confidence_threshold: f32,
) -> EarlyExitPipeline<P>
where
P: Pipeline,
{
let mut config = EarlyExitConfig::default();
config.strategy = ExitStrategy::ConfidenceThreshold(confidence_threshold);
EarlyExitPipeline::new(base_pipeline, config)
}
pub fn create_adaptive_early_exit<P>(base_pipeline: P) -> EarlyExitPipeline<P>
where
P: Pipeline,
{
let mut config = EarlyExitConfig::default();
config.strategy = ExitStrategy::AdaptiveThreshold;
config.dynamic_threshold_adjustment = true;
config.context_aware = true;
config.performance_tracking = true;
EarlyExitPipeline::new(base_pipeline, config)
}
pub fn create_budget_constrained_early_exit<P>(
base_pipeline: P,
computation_budget_ms: u64,
energy_budget: f32,
) -> EarlyExitPipeline<P>
where
P: Pipeline,
{
let mut config = EarlyExitConfig::default();
config.strategy = ExitStrategy::Combined(vec![
ExitStrategy::ComputationalBudget(computation_budget_ms),
ExitStrategy::EnergyBudget(energy_budget),
ExitStrategy::ConfidenceThreshold(0.8),
]);
config.energy_aware = true;
config.memory_aware = true;
EarlyExitPipeline::new(base_pipeline, config)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_layer_output(
layer_index: usize,
logits: Option<Vec<f32>>,
hidden_states: Vec<f32>,
) -> LayerOutput {
LayerOutput {
layer_index,
hidden_states,
attention_weights: None,
logits,
intermediate_prediction: None,
computation_time_ms: 10 + layer_index as u64,
memory_usage_mb: 100.0 + layer_index as f64 * 10.0,
}
}
#[test]
fn test_early_exit_config_default() {
let config = EarlyExitConfig::default();
assert_eq!(config.min_layers, 6);
assert_eq!(config.max_layers, 12);
assert!(matches!(
config.strategy,
ExitStrategy::ConfidenceThreshold(_)
));
}
#[test]
fn test_config_min_layers_less_than_max() {
let config = EarlyExitConfig::default();
assert!(config.min_layers < config.max_layers);
}
#[test]
fn test_config_patience_threshold_positive() {
let config = EarlyExitConfig::default();
assert!(config.patience_threshold > 0);
}
#[test]
fn test_confidence_threshold_strategy_validates_range() {
if let ExitStrategy::ConfidenceThreshold(threshold) = ExitStrategy::ConfidenceThreshold(0.9)
{
assert!(threshold > 0.0 && threshold < 1.0);
}
}
#[test]
fn test_exit_point_creation() {
let layer_output = make_layer_output(5, Some(vec![1.0, 2.0, 0.5]), vec![0.1, 0.2, 0.3]);
let config = EarlyExitConfig::default();
let predictor = EarlyExitPredictor::new(config);
let exit_point = predictor
.create_base_exit_point(&layer_output)
.expect("create_base_exit_point should succeed");
assert_eq!(exit_point.layer_index, 5);
assert!(exit_point.confidence_score > 0.0);
}
#[test]
fn test_exit_point_no_logits_uses_depth_proxy() {
let layer_output = make_layer_output(8, None, vec![0.1; 50]);
let config = EarlyExitConfig::default();
let predictor = EarlyExitPredictor::new(config);
let ep = predictor
.create_base_exit_point(&layer_output)
.expect("create_base_exit_point should succeed");
assert!(
ep.confidence_score > 0.0 && ep.confidence_score <= 1.0,
"depth-based confidence should be in (0,1]"
);
}
#[test]
fn test_confidence_threshold_strategy() {
let config = EarlyExitConfig {
strategy: ExitStrategy::ConfidenceThreshold(0.9),
min_layers: 2,
..Default::default()
};
let mut predictor = EarlyExitPredictor::new(config);
let output = make_layer_output(3, Some(vec![5.0, 0.5]), vec![0.1, 0.2, 0.3]);
let ep = predictor.should_exit(&output).expect("should_exit should succeed");
assert!(
ep.should_exit,
"high-confidence output should trigger early exit"
);
}
#[test]
fn test_no_exit_before_min_layers() {
let config = EarlyExitConfig {
strategy: ExitStrategy::ConfidenceThreshold(0.0), min_layers: 6,
max_layers: 12,
..Default::default()
};
let mut predictor = EarlyExitPredictor::new(config);
let output = make_layer_output(2, Some(vec![10.0, 0.1]), vec![0.1; 10]);
let ep = predictor.should_exit(&output).expect("should_exit should succeed");
assert!(!ep.should_exit, "must not exit before min_layers");
}
#[test]
fn test_forced_exit_at_max_layers() {
let config = EarlyExitConfig {
strategy: ExitStrategy::ConfidenceThreshold(1.0), min_layers: 2,
max_layers: 6,
dynamic_threshold_adjustment: false,
..Default::default()
};
let mut predictor = EarlyExitPredictor::new(config.clone());
let output = make_layer_output(config.max_layers, Some(vec![0.1; 10]), vec![0.0; 10]);
let ep = predictor.should_exit(&output).expect("should_exit should succeed");
assert!(ep.should_exit, "must exit when max_layers reached");
}
#[test]
fn test_confidence_increases_with_layer_depth_no_logits() {
let config = EarlyExitConfig {
max_layers: 20,
min_layers: 0,
..Default::default()
};
let predictor = EarlyExitPredictor::new(config);
let ep_early = predictor
.create_base_exit_point(&make_layer_output(0, None, vec![0.0; 5]))
.expect("early exit point should succeed");
let ep_late = predictor
.create_base_exit_point(&make_layer_output(18, None, vec![0.0; 5]))
.expect("late exit point should succeed");
assert!(
ep_late.confidence_score > ep_early.confidence_score,
"confidence should grow with layer depth"
);
}
#[test]
fn test_entropy_threshold_strategy_uniform_distribution() {
let config = EarlyExitConfig {
strategy: ExitStrategy::EntropyThreshold(0.5),
min_layers: 0,
max_layers: 10,
dynamic_threshold_adjustment: false,
..Default::default()
};
let mut predictor = EarlyExitPredictor::new(config);
let output = make_layer_output(5, Some(vec![1.0_f32; 10]), vec![0.1; 5]);
let ep = predictor.should_exit(&output).expect("should_exit should succeed");
assert!(
!ep.should_exit,
"uniform distribution should NOT meet entropy threshold"
);
}
#[test]
fn test_variance_score_constant_hidden_states() {
let config = EarlyExitConfig::default();
let predictor = EarlyExitPredictor::new(config);
let layer_output = make_layer_output(8, None, vec![0.5_f32; 20]);
let ep = predictor
.create_base_exit_point(&layer_output)
.expect("create_base_exit_point should succeed");
assert!(
(ep.variance_score - 1.0).abs() < 1e-5,
"constant hidden states should yield variance_score = 1.0"
);
}
#[test]
fn test_minimum_layers_enforcement_respected() {
let config = EarlyExitConfig {
strategy: ExitStrategy::ConfidenceThreshold(0.001), min_layers: 5,
max_layers: 12,
dynamic_threshold_adjustment: false,
..Default::default()
};
let mut predictor = EarlyExitPredictor::new(config);
for layer_idx in 0..5 {
let output = make_layer_output(layer_idx, Some(vec![10.0, 0.01]), vec![0.1; 5]);
let ep = predictor.should_exit(&output).expect("should_exit should succeed");
assert!(
!ep.should_exit,
"layer {} < min_layers=5 should not exit",
layer_idx
);
}
}
#[test]
fn test_computational_budget_strategy() {
let config = EarlyExitConfig {
strategy: ExitStrategy::ComputationalBudget(5), min_layers: 0,
max_layers: 12,
..Default::default()
};
let mut predictor = EarlyExitPredictor::new(config);
let output = make_layer_output(0, None, vec![0.1; 5]);
let ep = predictor.should_exit(&output).expect("should_exit should succeed");
assert!(
ep.should_exit,
"should exit when computation_time exceeds budget"
);
}
#[test]
fn test_patience_strategy_triggers_after_patience_exceeded() {
let config = EarlyExitConfig {
strategy: ExitStrategy::Patience(3),
min_layers: 0,
max_layers: 20,
dynamic_threshold_adjustment: false,
..Default::default()
};
let mut predictor = EarlyExitPredictor::new(config);
for i in 0..4 {
let output = make_layer_output(i, Some(vec![1.0_f32; 5]), vec![0.1; 5]);
let _ = predictor.should_exit(&output);
}
let output = make_layer_output(4, Some(vec![1.0_f32; 5]), vec![0.1; 5]);
let ep = predictor.should_exit(&output).expect("should_exit should succeed");
assert!(ep.should_exit, "patience counter should trigger exit");
}
#[test]
fn test_depth_reduction_computation_saved_percent() {
let config = EarlyExitConfig {
strategy: ExitStrategy::ConfidenceThreshold(0.0), min_layers: 3,
max_layers: 12,
dynamic_threshold_adjustment: false,
..Default::default()
};
struct DummyPipeline;
impl crate::pipeline::Pipeline for DummyPipeline {
type Input = String;
type Output = crate::pipeline::PipelineOutput;
fn __call__(&self, _: Self::Input) -> Result<Self::Output> {
Ok(crate::pipeline::PipelineOutput::Text("x".to_string()))
}
}
let pipeline = create_early_exit_pipeline(DummyPipeline, config);
let result = pipeline
.__call__("test".to_string())
.expect("early exit pipeline should succeed");
assert!(result.computation_saved_percent >= 0.0);
assert!(result.total_layers_computed <= result.exit_point.layer_index + 1);
}
#[test]
fn test_confidence_based_factory() {
struct Dummy;
impl crate::pipeline::Pipeline for Dummy {
type Input = String;
type Output = crate::pipeline::PipelineOutput;
fn __call__(&self, _: Self::Input) -> Result<Self::Output> {
Ok(crate::pipeline::PipelineOutput::Text("x".to_string()))
}
}
let pipeline = create_confidence_based_early_exit(Dummy, 0.7_f32);
let config = pipeline.exit_predictor().config();
assert!(
matches!(config.strategy, ExitStrategy::ConfidenceThreshold(t) if (t - 0.7).abs() < 1e-5)
);
}
#[test]
fn test_adaptive_factory() {
struct Dummy;
impl crate::pipeline::Pipeline for Dummy {
type Input = String;
type Output = crate::pipeline::PipelineOutput;
fn __call__(&self, _: Self::Input) -> Result<Self::Output> {
Ok(crate::pipeline::PipelineOutput::Text("x".to_string()))
}
}
let pipeline = create_adaptive_early_exit(Dummy);
let config = pipeline.exit_predictor().config();
assert!(matches!(config.strategy, ExitStrategy::AdaptiveThreshold));
assert!(config.dynamic_threshold_adjustment);
}
#[test]
fn test_budget_constrained_factory() {
struct Dummy;
impl crate::pipeline::Pipeline for Dummy {
type Input = String;
type Output = crate::pipeline::PipelineOutput;
fn __call__(&self, _: Self::Input) -> Result<Self::Output> {
Ok(crate::pipeline::PipelineOutput::Text("x".to_string()))
}
}
let pipeline = create_budget_constrained_early_exit(Dummy, 100, 50.0);
let config = pipeline.exit_predictor().config();
assert!(matches!(config.strategy, ExitStrategy::Combined(_)));
assert!(config.energy_aware);
}
#[test]
fn test_predictor_reset_clears_history() {
let config = EarlyExitConfig {
min_layers: 0,
max_layers: 10,
..Default::default()
};
let mut predictor = EarlyExitPredictor::new(config);
let output = make_layer_output(5, Some(vec![1.0, 2.0]), vec![0.1; 5]);
let _ = predictor.should_exit(&output);
predictor.reset();
let ep2 = predictor.create_base_exit_point(&output).expect("create_base_exit_point ok");
assert!(ep2.consistency_score > 0.0);
}
}