1use crate::{AdaptiveError, Precision, Result};
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
8pub enum TransitionStrategy {
9 #[default]
11 Immediate,
12 Gradual {
14 steps: u32,
16 },
17 StepAware,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct PrecisionTransition {
24 pub from: Precision,
26 pub to: Precision,
28 pub start_step: u32,
30 pub end_step: u32,
32 pub strategy: TransitionStrategy,
34 pub blend_factors: Vec<f32>,
36}
37
38impl PrecisionTransition {
39 pub fn immediate(from: Precision, to: Precision, step: u32) -> Self {
41 Self {
42 from,
43 to,
44 start_step: step,
45 end_step: step,
46 strategy: TransitionStrategy::Immediate,
47 blend_factors: vec![1.0],
48 }
49 }
50
51 pub fn gradual(from: Precision, to: Precision, start_step: u32, duration: u32) -> Self {
53 let blend_factors: Vec<f32> = (0..=duration).map(|i| i as f32 / duration as f32).collect();
54
55 Self {
56 from,
57 to,
58 start_step,
59 end_step: start_step + duration,
60 strategy: TransitionStrategy::Gradual { steps: duration },
61 blend_factors,
62 }
63 }
64
65 pub fn contains_step(&self, step: u32) -> bool {
67 step >= self.start_step && step <= self.end_step
68 }
69
70 pub fn blend_at(&self, step: u32) -> Option<f32> {
72 if !self.contains_step(step) {
73 return None;
74 }
75
76 match self.strategy {
77 TransitionStrategy::Immediate => Some(1.0),
78 TransitionStrategy::Gradual { steps } => {
79 let progress = (step - self.start_step) as f32 / steps as f32;
80 Some(progress.clamp(0.0, 1.0))
81 }
82 TransitionStrategy::StepAware => {
83 let t = (step - self.start_step) as f32 / (self.end_step - self.start_step) as f32;
85 Some(smooth_step(t))
86 }
87 }
88 }
89
90 pub fn effective_precision(&self, step: u32) -> Precision {
92 match self.blend_at(step) {
93 Some(blend) if blend < 0.5 => self.from,
94 Some(_) => self.to,
95 None => self.to,
96 }
97 }
98
99 pub fn validate(&self) -> Result<()> {
101 if self.end_step < self.start_step {
102 return Err(AdaptiveError::InvalidTransition {
103 from: self.from,
104 to: self.to,
105 reason: "End step before start step".into(),
106 });
107 }
108
109 Ok(())
113 }
114
115 pub fn peak_vram_ratio(&self) -> f32 {
117 match self.strategy {
118 TransitionStrategy::Immediate => self.to.vram_ratio(),
119 TransitionStrategy::Gradual { .. } | TransitionStrategy::StepAware => {
120 self.from.vram_ratio().max(self.to.vram_ratio()) * 1.2
122 }
123 }
124 }
125}
126
127fn smooth_step(t: f32) -> f32 {
129 let t = t.clamp(0.0, 1.0);
130 t * t * (3.0 - 2.0 * t)
131}
132
133#[allow(dead_code)]
135fn smoother_step(t: f32) -> f32 {
136 let t = t.clamp(0.0, 1.0);
137 t * t * t * (t * (t * 6.0 - 15.0) + 10.0)
138}
139
140#[derive(Debug, Clone)]
142pub struct TransitionPlanner {
143 vram_mb: u64,
145 preferred_strategy: TransitionStrategy,
147 min_gap: u32,
149}
150
151impl TransitionPlanner {
152 pub fn new(vram_mb: u64) -> Self {
154 Self {
155 vram_mb,
156 preferred_strategy: TransitionStrategy::Immediate,
157 min_gap: 3,
158 }
159 }
160
161 pub fn with_strategy(mut self, strategy: TransitionStrategy) -> Self {
163 self.preferred_strategy = strategy;
164 self
165 }
166
167 pub fn plan_transitions(
169 &self,
170 zones: &[(u32, u32, Precision)], ) -> Vec<PrecisionTransition> {
172 if zones.len() <= 1 {
173 return Vec::new();
174 }
175
176 let mut transitions = Vec::new();
177 let mut last_transition_end = 0u32;
178
179 for window in zones.windows(2) {
180 let (_, end1, prec1) = window[0];
181 let (start2, _, prec2) = window[1];
182
183 if start2 < last_transition_end + self.min_gap {
185 continue;
186 }
187
188 if prec1 == prec2 {
189 continue;
190 }
191
192 let transition = match self.preferred_strategy {
193 TransitionStrategy::Immediate => {
194 PrecisionTransition::immediate(prec1, prec2, start2)
195 }
196 TransitionStrategy::Gradual { steps } => {
197 let safe_steps = steps.min(end1.saturating_sub(1));
199 PrecisionTransition::gradual(
200 prec1,
201 prec2,
202 end1.saturating_sub(safe_steps),
203 safe_steps,
204 )
205 }
206 TransitionStrategy::StepAware => {
207 PrecisionTransition::gradual(prec1, prec2, end1.saturating_sub(1), 2)
209 }
210 };
211
212 last_transition_end = transition.end_step;
213 transitions.push(transition);
214 }
215
216 transitions
217 }
218
219 pub fn vram_mb(&self) -> u64 {
221 self.vram_mb
222 }
223
224 pub fn optimize_for_vram(&self, transitions: &mut [PrecisionTransition]) {
226 let vram_threshold = if self.vram_mb < 8192 {
228 0.85 } else {
230 0.9
231 };
232
233 for transition in transitions {
234 if transition.peak_vram_ratio() > vram_threshold {
236 *transition = PrecisionTransition::immediate(
237 transition.from,
238 transition.to,
239 transition.start_step,
240 );
241 }
242 }
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249
250 #[test]
251 fn test_immediate_transition() {
252 let trans = PrecisionTransition::immediate(Precision::INT4, Precision::FP16, 10);
253
254 assert_eq!(trans.start_step, 10);
255 assert_eq!(trans.end_step, 10);
256 assert!(trans.contains_step(10));
257 assert!(!trans.contains_step(9));
258 assert_eq!(trans.blend_at(10), Some(1.0));
259 }
260
261 #[test]
262 fn test_gradual_transition() {
263 let trans = PrecisionTransition::gradual(Precision::INT8, Precision::FP16, 10, 4);
264
265 assert_eq!(trans.start_step, 10);
266 assert_eq!(trans.end_step, 14);
267
268 assert_eq!(trans.blend_at(10), Some(0.0));
269 assert_eq!(trans.blend_at(12), Some(0.5));
270 assert_eq!(trans.blend_at(14), Some(1.0));
271 }
272
273 #[test]
274 fn test_effective_precision() {
275 let trans = PrecisionTransition::gradual(Precision::INT4, Precision::FP16, 10, 4);
276
277 assert_eq!(trans.effective_precision(10), Precision::INT4);
278 assert_eq!(trans.effective_precision(11), Precision::INT4);
279 assert_eq!(trans.effective_precision(12), Precision::FP16);
280 assert_eq!(trans.effective_precision(14), Precision::FP16);
281 }
282
283 #[test]
284 fn test_smooth_step() {
285 assert_eq!(smooth_step(0.0), 0.0);
286 assert_eq!(smooth_step(1.0), 1.0);
287 assert!((smooth_step(0.5) - 0.5).abs() < 0.01);
288 }
289
290 #[test]
291 fn test_transition_planner() {
292 let planner = TransitionPlanner::new(8192);
293
294 let zones = vec![
295 (0, 10, Precision::INT4),
296 (10, 20, Precision::INT8),
297 (20, 30, Precision::FP16),
298 ];
299
300 let transitions = planner.plan_transitions(&zones);
301 assert_eq!(transitions.len(), 2);
302
303 assert_eq!(transitions[0].from, Precision::INT4);
304 assert_eq!(transitions[0].to, Precision::INT8);
305
306 assert_eq!(transitions[1].from, Precision::INT8);
307 assert_eq!(transitions[1].to, Precision::FP16);
308 }
309}