use crate::{
AdaptiveError, Precision, PrecisionCapabilities, PrecisionProfile, ProfilePreset, Result,
TransitionStrategy,
};
use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScheduleConfig {
pub total_steps: u32,
pub preset: ProfilePreset,
pub custom_profile: Option<PrecisionProfile>,
pub capabilities: PrecisionCapabilities,
pub min_quality: f32,
pub transition_strategy: TransitionStrategy,
pub strict_capabilities: bool,
}
impl Default for ScheduleConfig {
fn default() -> Self {
Self {
total_steps: 30,
preset: ProfilePreset::Balanced,
custom_profile: None,
capabilities: PrecisionCapabilities::default(),
min_quality: 0.90,
transition_strategy: TransitionStrategy::Immediate,
strict_capabilities: true,
}
}
}
impl ScheduleConfig {
pub fn with_preset(preset: ProfilePreset) -> Self {
Self {
preset,
..Default::default()
}
}
pub fn steps(mut self, steps: u32) -> Self {
self.total_steps = steps;
self
}
pub fn capabilities(mut self, caps: PrecisionCapabilities) -> Self {
self.capabilities = caps;
self
}
pub fn min_quality(mut self, quality: f32) -> Self {
self.min_quality = quality;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StepPrecision {
pub step: u32,
pub precision: Precision,
pub is_transition: bool,
pub blend_factor: Option<f32>,
pub vram_ratio: f32,
pub quality_factor: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrecisionSchedule {
pub config: ScheduleConfig,
pub profile_name: String,
pub steps: Vec<StepPrecision>,
pub transitions: Vec<(u32, Precision, Precision)>,
pub avg_vram_ratio: f32,
pub avg_quality_factor: f32,
pub estimated_speedup: f32,
}
impl PrecisionSchedule {
pub fn generate(config: ScheduleConfig) -> Result<Self> {
if config.total_steps == 0 {
return Err(AdaptiveError::ScheduleError(
"Total steps must be > 0".into(),
));
}
let profile = config
.custom_profile
.clone()
.unwrap_or_else(|| config.preset.build());
profile.validate()?;
let mut steps = Vec::with_capacity(config.total_steps as usize);
let mut transitions = Vec::new();
let mut prev_precision: Option<Precision> = None;
for step in 0..config.total_steps {
let fraction = step as f32 / config.total_steps as f32;
let mut precision = profile.precision_at(fraction);
if config.strict_capabilities && !config.capabilities.supports(precision) {
precision = config.capabilities.best_supported(precision);
}
let is_transition = prev_precision.is_some_and(|p| p != precision);
if is_transition {
if let Some(prev) = prev_precision {
transitions.push((step, prev, precision));
}
}
let blend_factor = if is_transition {
match config.transition_strategy {
TransitionStrategy::Immediate => None,
TransitionStrategy::Gradual { steps: blend_steps } => {
Some(1.0 / blend_steps as f32)
}
TransitionStrategy::StepAware => Some(0.5),
}
} else {
None
};
steps.push(StepPrecision {
step,
precision,
is_transition,
blend_factor,
vram_ratio: precision.vram_ratio(),
quality_factor: precision.quality_factor(),
});
prev_precision = Some(precision);
}
let avg_vram_ratio = steps.iter().map(|s| s.vram_ratio).sum::<f32>() / steps.len() as f32;
let avg_quality_factor =
steps.iter().map(|s| s.quality_factor).sum::<f32>() / steps.len() as f32;
let estimated_speedup = steps
.iter()
.map(|s| s.precision.speedup_factor())
.sum::<f32>()
/ steps.len() as f32;
if avg_quality_factor < config.min_quality {
return Err(AdaptiveError::QualityConstraint {
actual: avg_quality_factor,
threshold: config.min_quality,
});
}
Ok(Self {
config,
profile_name: profile.name,
steps,
transitions,
avg_vram_ratio,
avg_quality_factor,
estimated_speedup,
})
}
pub fn precision_at(&self, step: u32) -> Result<Precision> {
self.steps
.get(step as usize)
.map(|s| s.precision)
.ok_or(AdaptiveError::InvalidStep {
step,
total_steps: self.config.total_steps,
})
}
pub fn step_info(&self, step: u32) -> Result<&StepPrecision> {
self.steps
.get(step as usize)
.ok_or(AdaptiveError::InvalidStep {
step,
total_steps: self.config.total_steps,
})
}
pub fn next_transition(&self, after_step: u32) -> Option<(u32, Precision, Precision)> {
self.transitions
.iter()
.find(|(step, _, _)| *step > after_step)
.copied()
}
pub fn steps_at_precision(&self, precision: Precision) -> SmallVec<[u32; 32]> {
self.steps
.iter()
.filter(|s| s.precision == precision)
.map(|s| s.step)
.collect()
}
pub fn precision_distribution(&self) -> Vec<(Precision, usize, f32)> {
let mut counts: std::collections::HashMap<Precision, usize> =
std::collections::HashMap::new();
for step in &self.steps {
*counts.entry(step.precision).or_insert(0) += 1;
}
let total = self.steps.len() as f32;
let mut result: Vec<_> = counts
.into_iter()
.map(|(p, count)| (p, count, count as f32 / total))
.collect();
result.sort_by_key(|(p, _, _)| *p);
result
}
pub fn format_timeline(&self) -> String {
let mut result = String::new();
result.push_str("Step: ");
for step in &self.steps {
if step.step % 5 == 0 {
result.push_str(&format!("{:2} ", step.step));
}
}
result.push('\n');
result.push_str("Prec: ");
for step in &self.steps {
let symbol = match step.precision {
Precision::INT4 => '4',
Precision::INT8 => '8',
Precision::BF16 => 'B',
Precision::FP16 => 'H',
Precision::FP32 => 'F',
};
result.push(symbol);
}
result.push('\n');
result
}
}
pub fn quick_schedule(preset: ProfilePreset, steps: u32) -> Result<PrecisionSchedule> {
PrecisionSchedule::generate(ScheduleConfig::with_preset(preset).steps(steps))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_schedule() {
let schedule = quick_schedule(ProfilePreset::Balanced, 20).unwrap();
assert_eq!(schedule.steps.len(), 20);
assert!(!schedule.transitions.is_empty());
assert!(schedule.steps[0].precision <= Precision::INT8);
assert!(schedule.steps[19].precision >= Precision::FP16);
}
#[test]
fn test_precision_distribution() {
let schedule = quick_schedule(ProfilePreset::Performance, 30).unwrap();
let dist = schedule.precision_distribution();
assert!(dist.len() >= 2);
let total: f32 = dist.iter().map(|(_, _, pct)| pct).sum();
assert!((total - 1.0).abs() < 0.01);
}
#[test]
fn test_step_lookup() {
let schedule = quick_schedule(ProfilePreset::Balanced, 20).unwrap();
let info = schedule.step_info(10).unwrap();
assert_eq!(info.step, 10);
assert!(schedule.step_info(100).is_err());
}
#[test]
fn test_quality_constraint() {
let config = ScheduleConfig::with_preset(ProfilePreset::Performance)
.steps(20)
.min_quality(0.999);
let result = PrecisionSchedule::generate(config);
assert!(matches!(
result,
Err(AdaptiveError::QualityConstraint { .. })
));
}
#[test]
fn test_capabilities_adjustment() {
let config = ScheduleConfig::with_preset(ProfilePreset::Performance)
.steps(20)
.capabilities(PrecisionCapabilities::legacy_gpu(4096));
let schedule = PrecisionSchedule::generate(config).unwrap();
for step in &schedule.steps {
assert!(step.precision >= Precision::FP16);
}
}
}