use crate::{AdaptiveError, Precision, Result};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrecisionProfile {
pub name: String,
pub description: String,
pub zones: Vec<PrecisionZone>,
pub quality_target: f32,
pub vram_target: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrecisionZone {
pub start_step: f32,
pub end_step: f32,
pub precision: Precision,
pub rationale: String,
}
impl PrecisionProfile {
pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
Self {
name: name.into(),
description: description.into(),
zones: Vec::new(),
quality_target: 0.95,
vram_target: 0.8,
}
}
pub fn add_zone(
mut self,
start: f32,
end: f32,
precision: Precision,
rationale: impl Into<String>,
) -> Self {
self.zones.push(PrecisionZone {
start_step: start,
end_step: end,
precision,
rationale: rationale.into(),
});
self.zones.sort_by(|a, b| {
a.start_step
.partial_cmp(&b.start_step)
.unwrap_or(std::cmp::Ordering::Equal)
});
self
}
pub fn precision_at(&self, step_fraction: f32) -> Precision {
for zone in &self.zones {
if step_fraction >= zone.start_step && step_fraction < zone.end_step {
return zone.precision;
}
}
Precision::FP16
}
pub fn validate(&self) -> Result<()> {
if self.zones.is_empty() {
return Err(AdaptiveError::ProfileError(
"Profile must have at least one zone".into(),
));
}
let mut expected_start = 0.0f32;
for zone in &self.zones {
if (zone.start_step - expected_start).abs() > 0.001 {
return Err(AdaptiveError::ProfileError(format!(
"Gap in zones at step fraction {}",
expected_start
)));
}
expected_start = zone.end_step;
}
if (expected_start - 1.0).abs() > 0.001 {
return Err(AdaptiveError::ProfileError(
"Zones must cover the full range [0.0, 1.0)".into(),
));
}
Ok(())
}
pub fn estimated_vram_ratio(&self) -> f32 {
let mut total_weight = 0.0f32;
let mut weighted_ratio = 0.0f32;
for zone in &self.zones {
let weight = zone.end_step - zone.start_step;
total_weight += weight;
weighted_ratio += weight * zone.precision.vram_ratio();
}
if total_weight > 0.0 {
weighted_ratio / total_weight
} else {
1.0
}
}
pub fn estimated_quality(&self) -> f32 {
let mut total_weight = 0.0f32;
let mut weighted_quality = 0.0f32;
for zone in &self.zones {
let weight = zone.end_step - zone.start_step;
let quality_weight = weight * (0.5 + zone.start_step * 0.5);
total_weight += quality_weight;
weighted_quality += quality_weight * zone.precision.quality_factor();
}
if total_weight > 0.0 {
weighted_quality / total_weight
} else {
1.0
}
}
pub fn transition_count(&self) -> usize {
if self.zones.len() <= 1 {
return 0;
}
self.zones
.windows(2)
.filter(|w| w[0].precision != w[1].precision)
.count()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
pub enum ProfilePreset {
Performance,
#[default]
Balanced,
Quality,
LowVram,
Conservative,
NoiseAdaptive,
}
impl ProfilePreset {
pub fn build(self) -> PrecisionProfile {
match self {
ProfilePreset::Performance => {
PrecisionProfile::new("Performance", "Maximum speed with aggressive INT4 usage")
.add_zone(
0.0,
0.4,
Precision::INT4,
"High noise masks quantization errors",
)
.add_zone(0.4, 0.7, Precision::INT8, "Medium noise tolerates INT8")
.add_zone(0.7, 1.0, Precision::FP16, "Low noise requires precision")
}
ProfilePreset::Balanced => {
PrecisionProfile::new("Balanced", "Good balance of speed and quality")
.add_zone(0.0, 0.25, Precision::INT4, "Early steps: high noise")
.add_zone(
0.25,
0.5,
Precision::INT8,
"Mid-early steps: moderate noise",
)
.add_zone(
0.5,
1.0,
Precision::FP16,
"Later steps: detail preservation",
)
}
ProfilePreset::Quality => {
PrecisionProfile::new("Quality", "Maximum quality with minimal quantization")
.add_zone(0.0, 0.15, Precision::INT8, "Only very early INT8")
.add_zone(0.15, 1.0, Precision::FP16, "FP16 for most steps")
}
ProfilePreset::LowVram => PrecisionProfile::new("LowVRAM", "Aggressive memory savings")
.add_zone(0.0, 0.5, Precision::INT4, "Extended INT4 zone")
.add_zone(0.5, 0.85, Precision::INT8, "INT8 for refinement")
.add_zone(0.85, 1.0, Precision::FP16, "FP16 only for final details"),
ProfilePreset::Conservative => {
PrecisionProfile::new("Conservative", "Minimal quantization for maximum quality")
.add_zone(0.0, 0.1, Precision::INT8, "Brief INT8 at start")
.add_zone(0.1, 1.0, Precision::FP16, "FP16 throughout")
}
ProfilePreset::NoiseAdaptive => PrecisionProfile::new(
"NoiseAdaptive",
"Precision matched to noise level at each step",
)
.add_zone(0.0, 0.2, Precision::INT4, "Noise sigma > 5.0")
.add_zone(0.2, 0.35, Precision::INT8, "Noise sigma 2.0-5.0")
.add_zone(0.35, 0.6, Precision::INT8, "Noise sigma 0.5-2.0")
.add_zone(0.6, 0.8, Precision::FP16, "Noise sigma 0.1-0.5")
.add_zone(0.8, 1.0, Precision::FP16, "Noise sigma < 0.1"),
}
}
pub fn description(&self) -> &'static str {
match self {
ProfilePreset::Performance => "Maximum speed (4x faster, ~8% quality loss)",
ProfilePreset::Balanced => "Balanced (2.5x faster, ~3% quality loss)",
ProfilePreset::Quality => "Maximum quality (1.5x faster, ~1% quality loss)",
ProfilePreset::LowVram => "Low VRAM (30% less memory, ~5% quality loss)",
ProfilePreset::Conservative => "Conservative (1.2x faster, minimal quality loss)",
ProfilePreset::NoiseAdaptive => "Noise-aware scheduling (optimal quality/speed)",
}
}
pub fn all() -> &'static [ProfilePreset] {
&[
ProfilePreset::Performance,
ProfilePreset::Balanced,
ProfilePreset::Quality,
ProfilePreset::LowVram,
ProfilePreset::Conservative,
ProfilePreset::NoiseAdaptive,
]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_balanced_profile() {
let profile = ProfilePreset::Balanced.build();
assert_eq!(profile.zones.len(), 3);
assert_eq!(profile.precision_at(0.1), Precision::INT4);
assert_eq!(profile.precision_at(0.3), Precision::INT8);
assert_eq!(profile.precision_at(0.8), Precision::FP16);
}
#[test]
fn test_profile_validation() {
let valid = ProfilePreset::Balanced.build();
assert!(valid.validate().is_ok());
let invalid = PrecisionProfile::new("Invalid", "Has gaps")
.add_zone(0.0, 0.3, Precision::INT4, "")
.add_zone(0.5, 1.0, Precision::FP16, "");
assert!(invalid.validate().is_err());
}
#[test]
fn test_estimated_metrics() {
let performance = ProfilePreset::Performance.build();
let quality = ProfilePreset::Quality.build();
assert!(performance.estimated_vram_ratio() < quality.estimated_vram_ratio());
assert!(quality.estimated_quality() > performance.estimated_quality());
}
#[test]
fn test_transition_count() {
let balanced = ProfilePreset::Balanced.build();
assert_eq!(balanced.transition_count(), 2);
let conservative = ProfilePreset::Conservative.build();
assert_eq!(conservative.transition_count(), 1); }
}