Skip to main content

haagenti_adaptive/
profile.rs

1//! Precision profiles for different use cases
2
3use crate::{AdaptiveError, Precision, Result};
4use serde::{Deserialize, Serialize};
5
6/// A precision profile defining step-to-precision mapping rules
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct PrecisionProfile {
9    /// Profile name
10    pub name: String,
11    /// Description
12    pub description: String,
13    /// Precision zones (sorted by start_step)
14    pub zones: Vec<PrecisionZone>,
15    /// Quality target (0.0 - 1.0)
16    pub quality_target: f32,
17    /// VRAM target percentage (0.0 - 1.0)
18    pub vram_target: f32,
19}
20
21/// A zone where a specific precision is used
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct PrecisionZone {
24    /// Start step (inclusive)
25    pub start_step: f32,
26    /// End step (exclusive), as fraction of total steps
27    pub end_step: f32,
28    /// Precision to use in this zone
29    pub precision: Precision,
30    /// Reason for this choice
31    pub rationale: String,
32}
33
34impl PrecisionProfile {
35    /// Create a new profile
36    pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
37        Self {
38            name: name.into(),
39            description: description.into(),
40            zones: Vec::new(),
41            quality_target: 0.95,
42            vram_target: 0.8,
43        }
44    }
45
46    /// Add a precision zone
47    pub fn add_zone(
48        mut self,
49        start: f32,
50        end: f32,
51        precision: Precision,
52        rationale: impl Into<String>,
53    ) -> Self {
54        self.zones.push(PrecisionZone {
55            start_step: start,
56            end_step: end,
57            precision,
58            rationale: rationale.into(),
59        });
60        self.zones.sort_by(|a, b| {
61            a.start_step
62                .partial_cmp(&b.start_step)
63                .unwrap_or(std::cmp::Ordering::Equal)
64        });
65        self
66    }
67
68    /// Get precision for a given step fraction (0.0 - 1.0)
69    pub fn precision_at(&self, step_fraction: f32) -> Precision {
70        for zone in &self.zones {
71            if step_fraction >= zone.start_step && step_fraction < zone.end_step {
72                return zone.precision;
73            }
74        }
75        // Default to FP16 if no zone matches
76        Precision::FP16
77    }
78
79    /// Validate the profile
80    pub fn validate(&self) -> Result<()> {
81        if self.zones.is_empty() {
82            return Err(AdaptiveError::ProfileError(
83                "Profile must have at least one zone".into(),
84            ));
85        }
86
87        // Check for gaps
88        let mut expected_start = 0.0f32;
89        for zone in &self.zones {
90            if (zone.start_step - expected_start).abs() > 0.001 {
91                return Err(AdaptiveError::ProfileError(format!(
92                    "Gap in zones at step fraction {}",
93                    expected_start
94                )));
95            }
96            expected_start = zone.end_step;
97        }
98
99        if (expected_start - 1.0).abs() > 0.001 {
100            return Err(AdaptiveError::ProfileError(
101                "Zones must cover the full range [0.0, 1.0)".into(),
102            ));
103        }
104
105        Ok(())
106    }
107
108    /// Estimate average VRAM usage ratio
109    pub fn estimated_vram_ratio(&self) -> f32 {
110        let mut total_weight = 0.0f32;
111        let mut weighted_ratio = 0.0f32;
112
113        for zone in &self.zones {
114            let weight = zone.end_step - zone.start_step;
115            total_weight += weight;
116            weighted_ratio += weight * zone.precision.vram_ratio();
117        }
118
119        if total_weight > 0.0 {
120            weighted_ratio / total_weight
121        } else {
122            1.0
123        }
124    }
125
126    /// Estimate average quality factor
127    pub fn estimated_quality(&self) -> f32 {
128        let mut total_weight = 0.0f32;
129        let mut weighted_quality = 0.0f32;
130
131        for zone in &self.zones {
132            let weight = zone.end_step - zone.start_step;
133            // Later steps matter more for quality
134            let quality_weight = weight * (0.5 + zone.start_step * 0.5);
135            total_weight += quality_weight;
136            weighted_quality += quality_weight * zone.precision.quality_factor();
137        }
138
139        if total_weight > 0.0 {
140            weighted_quality / total_weight
141        } else {
142            1.0
143        }
144    }
145
146    /// Count precision transitions
147    pub fn transition_count(&self) -> usize {
148        if self.zones.len() <= 1 {
149            return 0;
150        }
151
152        self.zones
153            .windows(2)
154            .filter(|w| w[0].precision != w[1].precision)
155            .count()
156    }
157}
158
159/// Preset profiles for common use cases
160#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
161pub enum ProfilePreset {
162    /// Maximum performance, lower quality
163    Performance,
164    /// Balanced performance and quality
165    #[default]
166    Balanced,
167    /// Maximum quality, slower
168    Quality,
169    /// Aggressive VRAM savings
170    LowVram,
171    /// Conservative, mostly FP16
172    Conservative,
173    /// Custom noise-aware schedule
174    NoiseAdaptive,
175}
176
177impl ProfilePreset {
178    /// Build the profile for this preset
179    pub fn build(self) -> PrecisionProfile {
180        match self {
181            ProfilePreset::Performance => {
182                PrecisionProfile::new("Performance", "Maximum speed with aggressive INT4 usage")
183                    .add_zone(
184                        0.0,
185                        0.4,
186                        Precision::INT4,
187                        "High noise masks quantization errors",
188                    )
189                    .add_zone(0.4, 0.7, Precision::INT8, "Medium noise tolerates INT8")
190                    .add_zone(0.7, 1.0, Precision::FP16, "Low noise requires precision")
191            }
192
193            ProfilePreset::Balanced => {
194                PrecisionProfile::new("Balanced", "Good balance of speed and quality")
195                    .add_zone(0.0, 0.25, Precision::INT4, "Early steps: high noise")
196                    .add_zone(
197                        0.25,
198                        0.5,
199                        Precision::INT8,
200                        "Mid-early steps: moderate noise",
201                    )
202                    .add_zone(
203                        0.5,
204                        1.0,
205                        Precision::FP16,
206                        "Later steps: detail preservation",
207                    )
208            }
209
210            ProfilePreset::Quality => {
211                PrecisionProfile::new("Quality", "Maximum quality with minimal quantization")
212                    .add_zone(0.0, 0.15, Precision::INT8, "Only very early INT8")
213                    .add_zone(0.15, 1.0, Precision::FP16, "FP16 for most steps")
214            }
215
216            ProfilePreset::LowVram => PrecisionProfile::new("LowVRAM", "Aggressive memory savings")
217                .add_zone(0.0, 0.5, Precision::INT4, "Extended INT4 zone")
218                .add_zone(0.5, 0.85, Precision::INT8, "INT8 for refinement")
219                .add_zone(0.85, 1.0, Precision::FP16, "FP16 only for final details"),
220
221            ProfilePreset::Conservative => {
222                PrecisionProfile::new("Conservative", "Minimal quantization for maximum quality")
223                    .add_zone(0.0, 0.1, Precision::INT8, "Brief INT8 at start")
224                    .add_zone(0.1, 1.0, Precision::FP16, "FP16 throughout")
225            }
226
227            ProfilePreset::NoiseAdaptive => PrecisionProfile::new(
228                "NoiseAdaptive",
229                "Precision matched to noise level at each step",
230            )
231            .add_zone(0.0, 0.2, Precision::INT4, "Noise sigma > 5.0")
232            .add_zone(0.2, 0.35, Precision::INT8, "Noise sigma 2.0-5.0")
233            .add_zone(0.35, 0.6, Precision::INT8, "Noise sigma 0.5-2.0")
234            .add_zone(0.6, 0.8, Precision::FP16, "Noise sigma 0.1-0.5")
235            .add_zone(0.8, 1.0, Precision::FP16, "Noise sigma < 0.1"),
236        }
237    }
238
239    /// Get description
240    pub fn description(&self) -> &'static str {
241        match self {
242            ProfilePreset::Performance => "Maximum speed (4x faster, ~8% quality loss)",
243            ProfilePreset::Balanced => "Balanced (2.5x faster, ~3% quality loss)",
244            ProfilePreset::Quality => "Maximum quality (1.5x faster, ~1% quality loss)",
245            ProfilePreset::LowVram => "Low VRAM (30% less memory, ~5% quality loss)",
246            ProfilePreset::Conservative => "Conservative (1.2x faster, minimal quality loss)",
247            ProfilePreset::NoiseAdaptive => "Noise-aware scheduling (optimal quality/speed)",
248        }
249    }
250
251    /// List all presets
252    pub fn all() -> &'static [ProfilePreset] {
253        &[
254            ProfilePreset::Performance,
255            ProfilePreset::Balanced,
256            ProfilePreset::Quality,
257            ProfilePreset::LowVram,
258            ProfilePreset::Conservative,
259            ProfilePreset::NoiseAdaptive,
260        ]
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267
268    #[test]
269    fn test_balanced_profile() {
270        let profile = ProfilePreset::Balanced.build();
271        assert_eq!(profile.zones.len(), 3);
272
273        // Early step should be INT4
274        assert_eq!(profile.precision_at(0.1), Precision::INT4);
275        // Mid step should be INT8
276        assert_eq!(profile.precision_at(0.3), Precision::INT8);
277        // Late step should be FP16
278        assert_eq!(profile.precision_at(0.8), Precision::FP16);
279    }
280
281    #[test]
282    fn test_profile_validation() {
283        let valid = ProfilePreset::Balanced.build();
284        assert!(valid.validate().is_ok());
285
286        // Invalid: gap in zones
287        let invalid = PrecisionProfile::new("Invalid", "Has gaps")
288            .add_zone(0.0, 0.3, Precision::INT4, "")
289            .add_zone(0.5, 1.0, Precision::FP16, ""); // Gap at 0.3-0.5
290
291        assert!(invalid.validate().is_err());
292    }
293
294    #[test]
295    fn test_estimated_metrics() {
296        let performance = ProfilePreset::Performance.build();
297        let quality = ProfilePreset::Quality.build();
298
299        // Performance should use less VRAM
300        assert!(performance.estimated_vram_ratio() < quality.estimated_vram_ratio());
301
302        // Quality should have better quality factor
303        assert!(quality.estimated_quality() > performance.estimated_quality());
304    }
305
306    #[test]
307    fn test_transition_count() {
308        let balanced = ProfilePreset::Balanced.build();
309        assert_eq!(balanced.transition_count(), 2); // INT4->INT8, INT8->FP16
310
311        let conservative = ProfilePreset::Conservative.build();
312        assert_eq!(conservative.transition_count(), 1); // INT8->FP16
313    }
314}