Skip to main content

haagenti_adaptive/
schedule.rs

1//! Precision schedule for a specific generation
2
3use crate::{
4    AdaptiveError, Precision, PrecisionCapabilities, PrecisionProfile, ProfilePreset, Result,
5    TransitionStrategy,
6};
7use serde::{Deserialize, Serialize};
8use smallvec::SmallVec;
9
10/// Configuration for schedule generation
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ScheduleConfig {
13    /// Total number of denoising steps
14    pub total_steps: u32,
15    /// Profile preset to use
16    pub preset: ProfilePreset,
17    /// Custom profile (overrides preset if set)
18    pub custom_profile: Option<PrecisionProfile>,
19    /// Hardware capabilities
20    pub capabilities: PrecisionCapabilities,
21    /// Minimum quality threshold (0.0 - 1.0)
22    pub min_quality: f32,
23    /// Transition strategy between precisions
24    pub transition_strategy: TransitionStrategy,
25    /// Whether to force capabilities check
26    pub strict_capabilities: bool,
27}
28
29impl Default for ScheduleConfig {
30    fn default() -> Self {
31        Self {
32            total_steps: 30,
33            preset: ProfilePreset::Balanced,
34            custom_profile: None,
35            capabilities: PrecisionCapabilities::default(),
36            min_quality: 0.90,
37            transition_strategy: TransitionStrategy::Immediate,
38            strict_capabilities: true,
39        }
40    }
41}
42
43impl ScheduleConfig {
44    /// Create config for a specific preset
45    pub fn with_preset(preset: ProfilePreset) -> Self {
46        Self {
47            preset,
48            ..Default::default()
49        }
50    }
51
52    /// Set total steps
53    pub fn steps(mut self, steps: u32) -> Self {
54        self.total_steps = steps;
55        self
56    }
57
58    /// Set hardware capabilities
59    pub fn capabilities(mut self, caps: PrecisionCapabilities) -> Self {
60        self.capabilities = caps;
61        self
62    }
63
64    /// Set minimum quality
65    pub fn min_quality(mut self, quality: f32) -> Self {
66        self.min_quality = quality;
67        self
68    }
69}
70
71/// Precision assignment for a single step
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct StepPrecision {
74    /// Step number
75    pub step: u32,
76    /// Assigned precision
77    pub precision: Precision,
78    /// Whether this is a transition step
79    pub is_transition: bool,
80    /// Blending factor (for gradual transitions)
81    pub blend_factor: Option<f32>,
82    /// Expected VRAM usage ratio
83    pub vram_ratio: f32,
84    /// Expected quality impact
85    pub quality_factor: f32,
86}
87
88/// Complete precision schedule for a generation
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct PrecisionSchedule {
91    /// Configuration used to generate this schedule
92    pub config: ScheduleConfig,
93    /// Profile used
94    pub profile_name: String,
95    /// Per-step precision assignments
96    pub steps: Vec<StepPrecision>,
97    /// Precision transition points
98    pub transitions: Vec<(u32, Precision, Precision)>,
99    /// Estimated average VRAM ratio
100    pub avg_vram_ratio: f32,
101    /// Estimated quality factor
102    pub avg_quality_factor: f32,
103    /// Estimated speedup factor
104    pub estimated_speedup: f32,
105}
106
107impl PrecisionSchedule {
108    /// Generate a schedule from configuration
109    pub fn generate(config: ScheduleConfig) -> Result<Self> {
110        if config.total_steps == 0 {
111            return Err(AdaptiveError::ScheduleError(
112                "Total steps must be > 0".into(),
113            ));
114        }
115
116        let profile = config
117            .custom_profile
118            .clone()
119            .unwrap_or_else(|| config.preset.build());
120
121        // Validate profile
122        profile.validate()?;
123
124        // Generate step assignments
125        let mut steps = Vec::with_capacity(config.total_steps as usize);
126        let mut transitions = Vec::new();
127        let mut prev_precision: Option<Precision> = None;
128
129        for step in 0..config.total_steps {
130            let fraction = step as f32 / config.total_steps as f32;
131            let mut precision = profile.precision_at(fraction);
132
133            // Adjust for hardware capabilities
134            if config.strict_capabilities && !config.capabilities.supports(precision) {
135                precision = config.capabilities.best_supported(precision);
136            }
137
138            let is_transition = prev_precision.is_some_and(|p| p != precision);
139
140            if is_transition {
141                if let Some(prev) = prev_precision {
142                    transitions.push((step, prev, precision));
143                }
144            }
145
146            let blend_factor = if is_transition {
147                match config.transition_strategy {
148                    TransitionStrategy::Immediate => None,
149                    TransitionStrategy::Gradual { steps: blend_steps } => {
150                        Some(1.0 / blend_steps as f32)
151                    }
152                    TransitionStrategy::StepAware => Some(0.5),
153                }
154            } else {
155                None
156            };
157
158            steps.push(StepPrecision {
159                step,
160                precision,
161                is_transition,
162                blend_factor,
163                vram_ratio: precision.vram_ratio(),
164                quality_factor: precision.quality_factor(),
165            });
166
167            prev_precision = Some(precision);
168        }
169
170        // Calculate averages
171        let avg_vram_ratio = steps.iter().map(|s| s.vram_ratio).sum::<f32>() / steps.len() as f32;
172        let avg_quality_factor =
173            steps.iter().map(|s| s.quality_factor).sum::<f32>() / steps.len() as f32;
174        let estimated_speedup = steps
175            .iter()
176            .map(|s| s.precision.speedup_factor())
177            .sum::<f32>()
178            / steps.len() as f32;
179
180        // Check quality constraint
181        if avg_quality_factor < config.min_quality {
182            return Err(AdaptiveError::QualityConstraint {
183                actual: avg_quality_factor,
184                threshold: config.min_quality,
185            });
186        }
187
188        Ok(Self {
189            config,
190            profile_name: profile.name,
191            steps,
192            transitions,
193            avg_vram_ratio,
194            avg_quality_factor,
195            estimated_speedup,
196        })
197    }
198
199    /// Get precision for a specific step
200    pub fn precision_at(&self, step: u32) -> Result<Precision> {
201        self.steps
202            .get(step as usize)
203            .map(|s| s.precision)
204            .ok_or(AdaptiveError::InvalidStep {
205                step,
206                total_steps: self.config.total_steps,
207            })
208    }
209
210    /// Get step precision info
211    pub fn step_info(&self, step: u32) -> Result<&StepPrecision> {
212        self.steps
213            .get(step as usize)
214            .ok_or(AdaptiveError::InvalidStep {
215                step,
216                total_steps: self.config.total_steps,
217            })
218    }
219
220    /// Get next transition after a given step
221    pub fn next_transition(&self, after_step: u32) -> Option<(u32, Precision, Precision)> {
222        self.transitions
223            .iter()
224            .find(|(step, _, _)| *step > after_step)
225            .copied()
226    }
227
228    /// Get all steps using a specific precision
229    pub fn steps_at_precision(&self, precision: Precision) -> SmallVec<[u32; 32]> {
230        self.steps
231            .iter()
232            .filter(|s| s.precision == precision)
233            .map(|s| s.step)
234            .collect()
235    }
236
237    /// Total time at each precision
238    pub fn precision_distribution(&self) -> Vec<(Precision, usize, f32)> {
239        let mut counts: std::collections::HashMap<Precision, usize> =
240            std::collections::HashMap::new();
241
242        for step in &self.steps {
243            *counts.entry(step.precision).or_insert(0) += 1;
244        }
245
246        let total = self.steps.len() as f32;
247        let mut result: Vec<_> = counts
248            .into_iter()
249            .map(|(p, count)| (p, count, count as f32 / total))
250            .collect();
251
252        result.sort_by_key(|(p, _, _)| *p);
253        result
254    }
255
256    /// Format as a visual timeline
257    pub fn format_timeline(&self) -> String {
258        let mut result = String::new();
259
260        result.push_str("Step: ");
261        for step in &self.steps {
262            if step.step % 5 == 0 {
263                result.push_str(&format!("{:2} ", step.step));
264            }
265        }
266        result.push('\n');
267
268        result.push_str("Prec: ");
269        for step in &self.steps {
270            let symbol = match step.precision {
271                Precision::INT4 => '4',
272                Precision::INT8 => '8',
273                Precision::BF16 => 'B',
274                Precision::FP16 => 'H',
275                Precision::FP32 => 'F',
276            };
277            result.push(symbol);
278        }
279        result.push('\n');
280
281        result
282    }
283}
284
285/// Quick schedule generation for common cases
286pub fn quick_schedule(preset: ProfilePreset, steps: u32) -> Result<PrecisionSchedule> {
287    PrecisionSchedule::generate(ScheduleConfig::with_preset(preset).steps(steps))
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293
294    #[test]
295    fn test_generate_schedule() {
296        let schedule = quick_schedule(ProfilePreset::Balanced, 20).unwrap();
297
298        assert_eq!(schedule.steps.len(), 20);
299        assert!(!schedule.transitions.is_empty());
300
301        // Early steps should be lower precision
302        assert!(schedule.steps[0].precision <= Precision::INT8);
303        // Late steps should be higher precision
304        assert!(schedule.steps[19].precision >= Precision::FP16);
305    }
306
307    #[test]
308    fn test_precision_distribution() {
309        let schedule = quick_schedule(ProfilePreset::Performance, 30).unwrap();
310        let dist = schedule.precision_distribution();
311
312        // Should have multiple precisions
313        assert!(dist.len() >= 2);
314
315        // Percentages should sum to ~1.0
316        let total: f32 = dist.iter().map(|(_, _, pct)| pct).sum();
317        assert!((total - 1.0).abs() < 0.01);
318    }
319
320    #[test]
321    fn test_step_lookup() {
322        let schedule = quick_schedule(ProfilePreset::Balanced, 20).unwrap();
323
324        let info = schedule.step_info(10).unwrap();
325        assert_eq!(info.step, 10);
326
327        // Invalid step should error
328        assert!(schedule.step_info(100).is_err());
329    }
330
331    #[test]
332    fn test_quality_constraint() {
333        let config = ScheduleConfig::with_preset(ProfilePreset::Performance)
334            .steps(20)
335            .min_quality(0.999); // Impossible with Performance preset
336
337        let result = PrecisionSchedule::generate(config);
338        assert!(matches!(
339            result,
340            Err(AdaptiveError::QualityConstraint { .. })
341        ));
342    }
343
344    #[test]
345    fn test_capabilities_adjustment() {
346        // Legacy GPU doesn't support INT4
347        let config = ScheduleConfig::with_preset(ProfilePreset::Performance)
348            .steps(20)
349            .capabilities(PrecisionCapabilities::legacy_gpu(4096));
350
351        let schedule = PrecisionSchedule::generate(config).unwrap();
352
353        // Should have adjusted to supported precisions
354        for step in &schedule.steps {
355            assert!(step.precision >= Precision::FP16);
356        }
357    }
358}