1use crate::{
4 AdaptiveError, Precision, PrecisionCapabilities, PrecisionProfile, ProfilePreset, Result,
5 TransitionStrategy,
6};
7use serde::{Deserialize, Serialize};
8use smallvec::SmallVec;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ScheduleConfig {
13 pub total_steps: u32,
15 pub preset: ProfilePreset,
17 pub custom_profile: Option<PrecisionProfile>,
19 pub capabilities: PrecisionCapabilities,
21 pub min_quality: f32,
23 pub transition_strategy: TransitionStrategy,
25 pub strict_capabilities: bool,
27}
28
29impl Default for ScheduleConfig {
30 fn default() -> Self {
31 Self {
32 total_steps: 30,
33 preset: ProfilePreset::Balanced,
34 custom_profile: None,
35 capabilities: PrecisionCapabilities::default(),
36 min_quality: 0.90,
37 transition_strategy: TransitionStrategy::Immediate,
38 strict_capabilities: true,
39 }
40 }
41}
42
43impl ScheduleConfig {
44 pub fn with_preset(preset: ProfilePreset) -> Self {
46 Self {
47 preset,
48 ..Default::default()
49 }
50 }
51
52 pub fn steps(mut self, steps: u32) -> Self {
54 self.total_steps = steps;
55 self
56 }
57
58 pub fn capabilities(mut self, caps: PrecisionCapabilities) -> Self {
60 self.capabilities = caps;
61 self
62 }
63
64 pub fn min_quality(mut self, quality: f32) -> Self {
66 self.min_quality = quality;
67 self
68 }
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct StepPrecision {
74 pub step: u32,
76 pub precision: Precision,
78 pub is_transition: bool,
80 pub blend_factor: Option<f32>,
82 pub vram_ratio: f32,
84 pub quality_factor: f32,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct PrecisionSchedule {
91 pub config: ScheduleConfig,
93 pub profile_name: String,
95 pub steps: Vec<StepPrecision>,
97 pub transitions: Vec<(u32, Precision, Precision)>,
99 pub avg_vram_ratio: f32,
101 pub avg_quality_factor: f32,
103 pub estimated_speedup: f32,
105}
106
107impl PrecisionSchedule {
108 pub fn generate(config: ScheduleConfig) -> Result<Self> {
110 if config.total_steps == 0 {
111 return Err(AdaptiveError::ScheduleError(
112 "Total steps must be > 0".into(),
113 ));
114 }
115
116 let profile = config
117 .custom_profile
118 .clone()
119 .unwrap_or_else(|| config.preset.build());
120
121 profile.validate()?;
123
124 let mut steps = Vec::with_capacity(config.total_steps as usize);
126 let mut transitions = Vec::new();
127 let mut prev_precision: Option<Precision> = None;
128
129 for step in 0..config.total_steps {
130 let fraction = step as f32 / config.total_steps as f32;
131 let mut precision = profile.precision_at(fraction);
132
133 if config.strict_capabilities && !config.capabilities.supports(precision) {
135 precision = config.capabilities.best_supported(precision);
136 }
137
138 let is_transition = prev_precision.is_some_and(|p| p != precision);
139
140 if is_transition {
141 if let Some(prev) = prev_precision {
142 transitions.push((step, prev, precision));
143 }
144 }
145
146 let blend_factor = if is_transition {
147 match config.transition_strategy {
148 TransitionStrategy::Immediate => None,
149 TransitionStrategy::Gradual { steps: blend_steps } => {
150 Some(1.0 / blend_steps as f32)
151 }
152 TransitionStrategy::StepAware => Some(0.5),
153 }
154 } else {
155 None
156 };
157
158 steps.push(StepPrecision {
159 step,
160 precision,
161 is_transition,
162 blend_factor,
163 vram_ratio: precision.vram_ratio(),
164 quality_factor: precision.quality_factor(),
165 });
166
167 prev_precision = Some(precision);
168 }
169
170 let avg_vram_ratio = steps.iter().map(|s| s.vram_ratio).sum::<f32>() / steps.len() as f32;
172 let avg_quality_factor =
173 steps.iter().map(|s| s.quality_factor).sum::<f32>() / steps.len() as f32;
174 let estimated_speedup = steps
175 .iter()
176 .map(|s| s.precision.speedup_factor())
177 .sum::<f32>()
178 / steps.len() as f32;
179
180 if avg_quality_factor < config.min_quality {
182 return Err(AdaptiveError::QualityConstraint {
183 actual: avg_quality_factor,
184 threshold: config.min_quality,
185 });
186 }
187
188 Ok(Self {
189 config,
190 profile_name: profile.name,
191 steps,
192 transitions,
193 avg_vram_ratio,
194 avg_quality_factor,
195 estimated_speedup,
196 })
197 }
198
199 pub fn precision_at(&self, step: u32) -> Result<Precision> {
201 self.steps
202 .get(step as usize)
203 .map(|s| s.precision)
204 .ok_or(AdaptiveError::InvalidStep {
205 step,
206 total_steps: self.config.total_steps,
207 })
208 }
209
210 pub fn step_info(&self, step: u32) -> Result<&StepPrecision> {
212 self.steps
213 .get(step as usize)
214 .ok_or(AdaptiveError::InvalidStep {
215 step,
216 total_steps: self.config.total_steps,
217 })
218 }
219
220 pub fn next_transition(&self, after_step: u32) -> Option<(u32, Precision, Precision)> {
222 self.transitions
223 .iter()
224 .find(|(step, _, _)| *step > after_step)
225 .copied()
226 }
227
228 pub fn steps_at_precision(&self, precision: Precision) -> SmallVec<[u32; 32]> {
230 self.steps
231 .iter()
232 .filter(|s| s.precision == precision)
233 .map(|s| s.step)
234 .collect()
235 }
236
237 pub fn precision_distribution(&self) -> Vec<(Precision, usize, f32)> {
239 let mut counts: std::collections::HashMap<Precision, usize> =
240 std::collections::HashMap::new();
241
242 for step in &self.steps {
243 *counts.entry(step.precision).or_insert(0) += 1;
244 }
245
246 let total = self.steps.len() as f32;
247 let mut result: Vec<_> = counts
248 .into_iter()
249 .map(|(p, count)| (p, count, count as f32 / total))
250 .collect();
251
252 result.sort_by_key(|(p, _, _)| *p);
253 result
254 }
255
256 pub fn format_timeline(&self) -> String {
258 let mut result = String::new();
259
260 result.push_str("Step: ");
261 for step in &self.steps {
262 if step.step % 5 == 0 {
263 result.push_str(&format!("{:2} ", step.step));
264 }
265 }
266 result.push('\n');
267
268 result.push_str("Prec: ");
269 for step in &self.steps {
270 let symbol = match step.precision {
271 Precision::INT4 => '4',
272 Precision::INT8 => '8',
273 Precision::BF16 => 'B',
274 Precision::FP16 => 'H',
275 Precision::FP32 => 'F',
276 };
277 result.push(symbol);
278 }
279 result.push('\n');
280
281 result
282 }
283}
284
285pub fn quick_schedule(preset: ProfilePreset, steps: u32) -> Result<PrecisionSchedule> {
287 PrecisionSchedule::generate(ScheduleConfig::with_preset(preset).steps(steps))
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293
294 #[test]
295 fn test_generate_schedule() {
296 let schedule = quick_schedule(ProfilePreset::Balanced, 20).unwrap();
297
298 assert_eq!(schedule.steps.len(), 20);
299 assert!(!schedule.transitions.is_empty());
300
301 assert!(schedule.steps[0].precision <= Precision::INT8);
303 assert!(schedule.steps[19].precision >= Precision::FP16);
305 }
306
307 #[test]
308 fn test_precision_distribution() {
309 let schedule = quick_schedule(ProfilePreset::Performance, 30).unwrap();
310 let dist = schedule.precision_distribution();
311
312 assert!(dist.len() >= 2);
314
315 let total: f32 = dist.iter().map(|(_, _, pct)| pct).sum();
317 assert!((total - 1.0).abs() < 0.01);
318 }
319
320 #[test]
321 fn test_step_lookup() {
322 let schedule = quick_schedule(ProfilePreset::Balanced, 20).unwrap();
323
324 let info = schedule.step_info(10).unwrap();
325 assert_eq!(info.step, 10);
326
327 assert!(schedule.step_info(100).is_err());
329 }
330
331 #[test]
332 fn test_quality_constraint() {
333 let config = ScheduleConfig::with_preset(ProfilePreset::Performance)
334 .steps(20)
335 .min_quality(0.999); let result = PrecisionSchedule::generate(config);
338 assert!(matches!(
339 result,
340 Err(AdaptiveError::QualityConstraint { .. })
341 ));
342 }
343
344 #[test]
345 fn test_capabilities_adjustment() {
346 let config = ScheduleConfig::with_preset(ProfilePreset::Performance)
348 .steps(20)
349 .capabilities(PrecisionCapabilities::legacy_gpu(4096));
350
351 let schedule = PrecisionSchedule::generate(config).unwrap();
352
353 for step in &schedule.steps {
355 assert!(step.precision >= Precision::FP16);
356 }
357 }
358}