use crate::{AdaptiveError, Precision, Result};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum TransitionStrategy {
#[default]
Immediate,
Gradual {
steps: u32,
},
StepAware,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrecisionTransition {
pub from: Precision,
pub to: Precision,
pub start_step: u32,
pub end_step: u32,
pub strategy: TransitionStrategy,
pub blend_factors: Vec<f32>,
}
impl PrecisionTransition {
pub fn immediate(from: Precision, to: Precision, step: u32) -> Self {
Self {
from,
to,
start_step: step,
end_step: step,
strategy: TransitionStrategy::Immediate,
blend_factors: vec![1.0],
}
}
pub fn gradual(from: Precision, to: Precision, start_step: u32, duration: u32) -> Self {
let blend_factors: Vec<f32> = (0..=duration).map(|i| i as f32 / duration as f32).collect();
Self {
from,
to,
start_step,
end_step: start_step + duration,
strategy: TransitionStrategy::Gradual { steps: duration },
blend_factors,
}
}
pub fn contains_step(&self, step: u32) -> bool {
step >= self.start_step && step <= self.end_step
}
pub fn blend_at(&self, step: u32) -> Option<f32> {
if !self.contains_step(step) {
return None;
}
match self.strategy {
TransitionStrategy::Immediate => Some(1.0),
TransitionStrategy::Gradual { steps } => {
let progress = (step - self.start_step) as f32 / steps as f32;
Some(progress.clamp(0.0, 1.0))
}
TransitionStrategy::StepAware => {
let t = (step - self.start_step) as f32 / (self.end_step - self.start_step) as f32;
Some(smooth_step(t))
}
}
}
pub fn effective_precision(&self, step: u32) -> Precision {
match self.blend_at(step) {
Some(blend) if blend < 0.5 => self.from,
Some(_) => self.to,
None => self.to,
}
}
pub fn validate(&self) -> Result<()> {
if self.end_step < self.start_step {
return Err(AdaptiveError::InvalidTransition {
from: self.from,
to: self.to,
reason: "End step before start step".into(),
});
}
Ok(())
}
pub fn peak_vram_ratio(&self) -> f32 {
match self.strategy {
TransitionStrategy::Immediate => self.to.vram_ratio(),
TransitionStrategy::Gradual { .. } | TransitionStrategy::StepAware => {
self.from.vram_ratio().max(self.to.vram_ratio()) * 1.2
}
}
}
}
fn smooth_step(t: f32) -> f32 {
let t = t.clamp(0.0, 1.0);
t * t * (3.0 - 2.0 * t)
}
#[allow(dead_code)]
fn smoother_step(t: f32) -> f32 {
let t = t.clamp(0.0, 1.0);
t * t * t * (t * (t * 6.0 - 15.0) + 10.0)
}
#[derive(Debug, Clone)]
pub struct TransitionPlanner {
vram_mb: u64,
preferred_strategy: TransitionStrategy,
min_gap: u32,
}
impl TransitionPlanner {
pub fn new(vram_mb: u64) -> Self {
Self {
vram_mb,
preferred_strategy: TransitionStrategy::Immediate,
min_gap: 3,
}
}
pub fn with_strategy(mut self, strategy: TransitionStrategy) -> Self {
self.preferred_strategy = strategy;
self
}
pub fn plan_transitions(
&self,
zones: &[(u32, u32, Precision)], ) -> Vec<PrecisionTransition> {
if zones.len() <= 1 {
return Vec::new();
}
let mut transitions = Vec::new();
let mut last_transition_end = 0u32;
for window in zones.windows(2) {
let (_, end1, prec1) = window[0];
let (start2, _, prec2) = window[1];
if start2 < last_transition_end + self.min_gap {
continue;
}
if prec1 == prec2 {
continue;
}
let transition = match self.preferred_strategy {
TransitionStrategy::Immediate => {
PrecisionTransition::immediate(prec1, prec2, start2)
}
TransitionStrategy::Gradual { steps } => {
let safe_steps = steps.min(end1.saturating_sub(1));
PrecisionTransition::gradual(
prec1,
prec2,
end1.saturating_sub(safe_steps),
safe_steps,
)
}
TransitionStrategy::StepAware => {
PrecisionTransition::gradual(prec1, prec2, end1.saturating_sub(1), 2)
}
};
last_transition_end = transition.end_step;
transitions.push(transition);
}
transitions
}
pub fn vram_mb(&self) -> u64 {
self.vram_mb
}
pub fn optimize_for_vram(&self, transitions: &mut [PrecisionTransition]) {
let vram_threshold = if self.vram_mb < 8192 {
0.85 } else {
0.9
};
for transition in transitions {
if transition.peak_vram_ratio() > vram_threshold {
*transition = PrecisionTransition::immediate(
transition.from,
transition.to,
transition.start_step,
);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_immediate_transition() {
let trans = PrecisionTransition::immediate(Precision::INT4, Precision::FP16, 10);
assert_eq!(trans.start_step, 10);
assert_eq!(trans.end_step, 10);
assert!(trans.contains_step(10));
assert!(!trans.contains_step(9));
assert_eq!(trans.blend_at(10), Some(1.0));
}
#[test]
fn test_gradual_transition() {
let trans = PrecisionTransition::gradual(Precision::INT8, Precision::FP16, 10, 4);
assert_eq!(trans.start_step, 10);
assert_eq!(trans.end_step, 14);
assert_eq!(trans.blend_at(10), Some(0.0));
assert_eq!(trans.blend_at(12), Some(0.5));
assert_eq!(trans.blend_at(14), Some(1.0));
}
#[test]
fn test_effective_precision() {
let trans = PrecisionTransition::gradual(Precision::INT4, Precision::FP16, 10, 4);
assert_eq!(trans.effective_precision(10), Precision::INT4);
assert_eq!(trans.effective_precision(11), Precision::INT4);
assert_eq!(trans.effective_precision(12), Precision::FP16);
assert_eq!(trans.effective_precision(14), Precision::FP16);
}
#[test]
fn test_smooth_step() {
assert_eq!(smooth_step(0.0), 0.0);
assert_eq!(smooth_step(1.0), 1.0);
assert!((smooth_step(0.5) - 0.5).abs() < 0.01);
}
#[test]
fn test_transition_planner() {
let planner = TransitionPlanner::new(8192);
let zones = vec![
(0, 10, Precision::INT4),
(10, 20, Precision::INT8),
(20, 30, Precision::FP16),
];
let transitions = planner.plan_transitions(&zones);
assert_eq!(transitions.len(), 2);
assert_eq!(transitions[0].from, Precision::INT4);
assert_eq!(transitions[0].to, Precision::INT8);
assert_eq!(transitions[1].from, Precision::INT8);
assert_eq!(transitions[1].to, Precision::FP16);
}
}