1use crate::{AdaptiveError, Precision, Result};
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct PrecisionProfile {
9 pub name: String,
11 pub description: String,
13 pub zones: Vec<PrecisionZone>,
15 pub quality_target: f32,
17 pub vram_target: f32,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct PrecisionZone {
24 pub start_step: f32,
26 pub end_step: f32,
28 pub precision: Precision,
30 pub rationale: String,
32}
33
34impl PrecisionProfile {
35 pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
37 Self {
38 name: name.into(),
39 description: description.into(),
40 zones: Vec::new(),
41 quality_target: 0.95,
42 vram_target: 0.8,
43 }
44 }
45
46 pub fn add_zone(
48 mut self,
49 start: f32,
50 end: f32,
51 precision: Precision,
52 rationale: impl Into<String>,
53 ) -> Self {
54 self.zones.push(PrecisionZone {
55 start_step: start,
56 end_step: end,
57 precision,
58 rationale: rationale.into(),
59 });
60 self.zones.sort_by(|a, b| {
61 a.start_step
62 .partial_cmp(&b.start_step)
63 .unwrap_or(std::cmp::Ordering::Equal)
64 });
65 self
66 }
67
68 pub fn precision_at(&self, step_fraction: f32) -> Precision {
70 for zone in &self.zones {
71 if step_fraction >= zone.start_step && step_fraction < zone.end_step {
72 return zone.precision;
73 }
74 }
75 Precision::FP16
77 }
78
79 pub fn validate(&self) -> Result<()> {
81 if self.zones.is_empty() {
82 return Err(AdaptiveError::ProfileError(
83 "Profile must have at least one zone".into(),
84 ));
85 }
86
87 let mut expected_start = 0.0f32;
89 for zone in &self.zones {
90 if (zone.start_step - expected_start).abs() > 0.001 {
91 return Err(AdaptiveError::ProfileError(format!(
92 "Gap in zones at step fraction {}",
93 expected_start
94 )));
95 }
96 expected_start = zone.end_step;
97 }
98
99 if (expected_start - 1.0).abs() > 0.001 {
100 return Err(AdaptiveError::ProfileError(
101 "Zones must cover the full range [0.0, 1.0)".into(),
102 ));
103 }
104
105 Ok(())
106 }
107
108 pub fn estimated_vram_ratio(&self) -> f32 {
110 let mut total_weight = 0.0f32;
111 let mut weighted_ratio = 0.0f32;
112
113 for zone in &self.zones {
114 let weight = zone.end_step - zone.start_step;
115 total_weight += weight;
116 weighted_ratio += weight * zone.precision.vram_ratio();
117 }
118
119 if total_weight > 0.0 {
120 weighted_ratio / total_weight
121 } else {
122 1.0
123 }
124 }
125
126 pub fn estimated_quality(&self) -> f32 {
128 let mut total_weight = 0.0f32;
129 let mut weighted_quality = 0.0f32;
130
131 for zone in &self.zones {
132 let weight = zone.end_step - zone.start_step;
133 let quality_weight = weight * (0.5 + zone.start_step * 0.5);
135 total_weight += quality_weight;
136 weighted_quality += quality_weight * zone.precision.quality_factor();
137 }
138
139 if total_weight > 0.0 {
140 weighted_quality / total_weight
141 } else {
142 1.0
143 }
144 }
145
146 pub fn transition_count(&self) -> usize {
148 if self.zones.len() <= 1 {
149 return 0;
150 }
151
152 self.zones
153 .windows(2)
154 .filter(|w| w[0].precision != w[1].precision)
155 .count()
156 }
157}
158
159#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
161pub enum ProfilePreset {
162 Performance,
164 #[default]
166 Balanced,
167 Quality,
169 LowVram,
171 Conservative,
173 NoiseAdaptive,
175}
176
177impl ProfilePreset {
178 pub fn build(self) -> PrecisionProfile {
180 match self {
181 ProfilePreset::Performance => {
182 PrecisionProfile::new("Performance", "Maximum speed with aggressive INT4 usage")
183 .add_zone(
184 0.0,
185 0.4,
186 Precision::INT4,
187 "High noise masks quantization errors",
188 )
189 .add_zone(0.4, 0.7, Precision::INT8, "Medium noise tolerates INT8")
190 .add_zone(0.7, 1.0, Precision::FP16, "Low noise requires precision")
191 }
192
193 ProfilePreset::Balanced => {
194 PrecisionProfile::new("Balanced", "Good balance of speed and quality")
195 .add_zone(0.0, 0.25, Precision::INT4, "Early steps: high noise")
196 .add_zone(
197 0.25,
198 0.5,
199 Precision::INT8,
200 "Mid-early steps: moderate noise",
201 )
202 .add_zone(
203 0.5,
204 1.0,
205 Precision::FP16,
206 "Later steps: detail preservation",
207 )
208 }
209
210 ProfilePreset::Quality => {
211 PrecisionProfile::new("Quality", "Maximum quality with minimal quantization")
212 .add_zone(0.0, 0.15, Precision::INT8, "Only very early INT8")
213 .add_zone(0.15, 1.0, Precision::FP16, "FP16 for most steps")
214 }
215
216 ProfilePreset::LowVram => PrecisionProfile::new("LowVRAM", "Aggressive memory savings")
217 .add_zone(0.0, 0.5, Precision::INT4, "Extended INT4 zone")
218 .add_zone(0.5, 0.85, Precision::INT8, "INT8 for refinement")
219 .add_zone(0.85, 1.0, Precision::FP16, "FP16 only for final details"),
220
221 ProfilePreset::Conservative => {
222 PrecisionProfile::new("Conservative", "Minimal quantization for maximum quality")
223 .add_zone(0.0, 0.1, Precision::INT8, "Brief INT8 at start")
224 .add_zone(0.1, 1.0, Precision::FP16, "FP16 throughout")
225 }
226
227 ProfilePreset::NoiseAdaptive => PrecisionProfile::new(
228 "NoiseAdaptive",
229 "Precision matched to noise level at each step",
230 )
231 .add_zone(0.0, 0.2, Precision::INT4, "Noise sigma > 5.0")
232 .add_zone(0.2, 0.35, Precision::INT8, "Noise sigma 2.0-5.0")
233 .add_zone(0.35, 0.6, Precision::INT8, "Noise sigma 0.5-2.0")
234 .add_zone(0.6, 0.8, Precision::FP16, "Noise sigma 0.1-0.5")
235 .add_zone(0.8, 1.0, Precision::FP16, "Noise sigma < 0.1"),
236 }
237 }
238
239 pub fn description(&self) -> &'static str {
241 match self {
242 ProfilePreset::Performance => "Maximum speed (4x faster, ~8% quality loss)",
243 ProfilePreset::Balanced => "Balanced (2.5x faster, ~3% quality loss)",
244 ProfilePreset::Quality => "Maximum quality (1.5x faster, ~1% quality loss)",
245 ProfilePreset::LowVram => "Low VRAM (30% less memory, ~5% quality loss)",
246 ProfilePreset::Conservative => "Conservative (1.2x faster, minimal quality loss)",
247 ProfilePreset::NoiseAdaptive => "Noise-aware scheduling (optimal quality/speed)",
248 }
249 }
250
251 pub fn all() -> &'static [ProfilePreset] {
253 &[
254 ProfilePreset::Performance,
255 ProfilePreset::Balanced,
256 ProfilePreset::Quality,
257 ProfilePreset::LowVram,
258 ProfilePreset::Conservative,
259 ProfilePreset::NoiseAdaptive,
260 ]
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267
268 #[test]
269 fn test_balanced_profile() {
270 let profile = ProfilePreset::Balanced.build();
271 assert_eq!(profile.zones.len(), 3);
272
273 assert_eq!(profile.precision_at(0.1), Precision::INT4);
275 assert_eq!(profile.precision_at(0.3), Precision::INT8);
277 assert_eq!(profile.precision_at(0.8), Precision::FP16);
279 }
280
281 #[test]
282 fn test_profile_validation() {
283 let valid = ProfilePreset::Balanced.build();
284 assert!(valid.validate().is_ok());
285
286 let invalid = PrecisionProfile::new("Invalid", "Has gaps")
288 .add_zone(0.0, 0.3, Precision::INT4, "")
289 .add_zone(0.5, 1.0, Precision::FP16, ""); assert!(invalid.validate().is_err());
292 }
293
294 #[test]
295 fn test_estimated_metrics() {
296 let performance = ProfilePreset::Performance.build();
297 let quality = ProfilePreset::Quality.build();
298
299 assert!(performance.estimated_vram_ratio() < quality.estimated_vram_ratio());
301
302 assert!(quality.estimated_quality() > performance.estimated_quality());
304 }
305
306 #[test]
307 fn test_transition_count() {
308 let balanced = ProfilePreset::Balanced.build();
309 assert_eq!(balanced.transition_count(), 2); let conservative = ProfilePreset::Conservative.build();
312 assert_eq!(conservative.transition_count(), 1); }
314}