Skip to main content

haagenti_adaptive/
transition.rs

1//! Precision transition strategies
2
3use crate::{AdaptiveError, Precision, Result};
4use serde::{Deserialize, Serialize};
5
6/// Strategy for transitioning between precision levels
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
8pub enum TransitionStrategy {
9    /// Immediate switch (most efficient)
10    #[default]
11    Immediate,
12    /// Gradual blend over multiple steps
13    Gradual {
14        /// Number of steps to blend over
15        steps: u32,
16    },
17    /// Adaptive based on noise level
18    StepAware,
19}
20
21/// A precision transition with blending information
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct PrecisionTransition {
24    /// Source precision
25    pub from: Precision,
26    /// Target precision
27    pub to: Precision,
28    /// Step at which transition starts
29    pub start_step: u32,
30    /// Step at which transition completes
31    pub end_step: u32,
32    /// Strategy used
33    pub strategy: TransitionStrategy,
34    /// Precomputed blend factors per step
35    pub blend_factors: Vec<f32>,
36}
37
38impl PrecisionTransition {
39    /// Create a new immediate transition
40    pub fn immediate(from: Precision, to: Precision, step: u32) -> Self {
41        Self {
42            from,
43            to,
44            start_step: step,
45            end_step: step,
46            strategy: TransitionStrategy::Immediate,
47            blend_factors: vec![1.0],
48        }
49    }
50
51    /// Create a gradual transition
52    pub fn gradual(from: Precision, to: Precision, start_step: u32, duration: u32) -> Self {
53        let blend_factors: Vec<f32> = (0..=duration).map(|i| i as f32 / duration as f32).collect();
54
55        Self {
56            from,
57            to,
58            start_step,
59            end_step: start_step + duration,
60            strategy: TransitionStrategy::Gradual { steps: duration },
61            blend_factors,
62        }
63    }
64
65    /// Check if a step is within this transition
66    pub fn contains_step(&self, step: u32) -> bool {
67        step >= self.start_step && step <= self.end_step
68    }
69
70    /// Get blend factor for a specific step (0.0 = from, 1.0 = to)
71    pub fn blend_at(&self, step: u32) -> Option<f32> {
72        if !self.contains_step(step) {
73            return None;
74        }
75
76        match self.strategy {
77            TransitionStrategy::Immediate => Some(1.0),
78            TransitionStrategy::Gradual { steps } => {
79                let progress = (step - self.start_step) as f32 / steps as f32;
80                Some(progress.clamp(0.0, 1.0))
81            }
82            TransitionStrategy::StepAware => {
83                // Use smooth step for noise-aware transitions
84                let t = (step - self.start_step) as f32 / (self.end_step - self.start_step) as f32;
85                Some(smooth_step(t))
86            }
87        }
88    }
89
90    /// Get the effective precision at a step
91    pub fn effective_precision(&self, step: u32) -> Precision {
92        match self.blend_at(step) {
93            Some(blend) if blend < 0.5 => self.from,
94            Some(_) => self.to,
95            None => self.to,
96        }
97    }
98
99    /// Validate the transition
100    pub fn validate(&self) -> Result<()> {
101        if self.end_step < self.start_step {
102            return Err(AdaptiveError::InvalidTransition {
103                from: self.from,
104                to: self.to,
105                reason: "End step before start step".into(),
106            });
107        }
108
109        // Check for valid precision ordering in most cases
110        // (going from lower to higher is normal, but reverse is allowed for special cases)
111
112        Ok(())
113    }
114
115    /// Compute VRAM requirement during transition
116    pub fn peak_vram_ratio(&self) -> f32 {
117        match self.strategy {
118            TransitionStrategy::Immediate => self.to.vram_ratio(),
119            TransitionStrategy::Gradual { .. } | TransitionStrategy::StepAware => {
120                // During gradual transition, may need both precisions loaded
121                self.from.vram_ratio().max(self.to.vram_ratio()) * 1.2
122            }
123        }
124    }
125}
126
127/// Smooth step function for gradual transitions
128fn smooth_step(t: f32) -> f32 {
129    let t = t.clamp(0.0, 1.0);
130    t * t * (3.0 - 2.0 * t)
131}
132
133/// Smoother step function (Ken Perlin's version)
134#[allow(dead_code)]
135fn smoother_step(t: f32) -> f32 {
136    let t = t.clamp(0.0, 1.0);
137    t * t * t * (t * (t * 6.0 - 15.0) + 10.0)
138}
139
140/// Transition planner for optimizing precision changes
141#[derive(Debug, Clone)]
142pub struct TransitionPlanner {
143    /// Available VRAM
144    vram_mb: u64,
145    /// Preferred strategy
146    preferred_strategy: TransitionStrategy,
147    /// Minimum steps between transitions
148    min_gap: u32,
149}
150
151impl TransitionPlanner {
152    /// Create a new planner
153    pub fn new(vram_mb: u64) -> Self {
154        Self {
155            vram_mb,
156            preferred_strategy: TransitionStrategy::Immediate,
157            min_gap: 3,
158        }
159    }
160
161    /// Set preferred strategy
162    pub fn with_strategy(mut self, strategy: TransitionStrategy) -> Self {
163        self.preferred_strategy = strategy;
164        self
165    }
166
167    /// Plan transitions between precision zones
168    pub fn plan_transitions(
169        &self,
170        zones: &[(u32, u32, Precision)], // (start, end, precision)
171    ) -> Vec<PrecisionTransition> {
172        if zones.len() <= 1 {
173            return Vec::new();
174        }
175
176        let mut transitions = Vec::new();
177        let mut last_transition_end = 0u32;
178
179        for window in zones.windows(2) {
180            let (_, end1, prec1) = window[0];
181            let (start2, _, prec2) = window[1];
182
183            // Enforce minimum gap between transitions
184            if start2 < last_transition_end + self.min_gap {
185                continue;
186            }
187
188            if prec1 == prec2 {
189                continue;
190            }
191
192            let transition = match self.preferred_strategy {
193                TransitionStrategy::Immediate => {
194                    PrecisionTransition::immediate(prec1, prec2, start2)
195                }
196                TransitionStrategy::Gradual { steps } => {
197                    // Ensure transition doesn't exceed zone boundaries
198                    let safe_steps = steps.min(end1.saturating_sub(1));
199                    PrecisionTransition::gradual(
200                        prec1,
201                        prec2,
202                        end1.saturating_sub(safe_steps),
203                        safe_steps,
204                    )
205                }
206                TransitionStrategy::StepAware => {
207                    // Use 2 steps for step-aware transitions
208                    PrecisionTransition::gradual(prec1, prec2, end1.saturating_sub(1), 2)
209                }
210            };
211
212            last_transition_end = transition.end_step;
213            transitions.push(transition);
214        }
215
216        transitions
217    }
218
219    /// Get available VRAM in MB
220    pub fn vram_mb(&self) -> u64 {
221        self.vram_mb
222    }
223
224    /// Optimize transitions to minimize VRAM spikes
225    pub fn optimize_for_vram(&self, transitions: &mut [PrecisionTransition]) {
226        // Calculate VRAM threshold based on available memory
227        let vram_threshold = if self.vram_mb < 8192 {
228            0.85 // More aggressive for low VRAM
229        } else {
230            0.9
231        };
232
233        for transition in transitions {
234            // If gradual transition would cause VRAM spike, make it immediate
235            if transition.peak_vram_ratio() > vram_threshold {
236                *transition = PrecisionTransition::immediate(
237                    transition.from,
238                    transition.to,
239                    transition.start_step,
240                );
241            }
242        }
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    #[test]
251    fn test_immediate_transition() {
252        let trans = PrecisionTransition::immediate(Precision::INT4, Precision::FP16, 10);
253
254        assert_eq!(trans.start_step, 10);
255        assert_eq!(trans.end_step, 10);
256        assert!(trans.contains_step(10));
257        assert!(!trans.contains_step(9));
258        assert_eq!(trans.blend_at(10), Some(1.0));
259    }
260
261    #[test]
262    fn test_gradual_transition() {
263        let trans = PrecisionTransition::gradual(Precision::INT8, Precision::FP16, 10, 4);
264
265        assert_eq!(trans.start_step, 10);
266        assert_eq!(trans.end_step, 14);
267
268        assert_eq!(trans.blend_at(10), Some(0.0));
269        assert_eq!(trans.blend_at(12), Some(0.5));
270        assert_eq!(trans.blend_at(14), Some(1.0));
271    }
272
273    #[test]
274    fn test_effective_precision() {
275        let trans = PrecisionTransition::gradual(Precision::INT4, Precision::FP16, 10, 4);
276
277        assert_eq!(trans.effective_precision(10), Precision::INT4);
278        assert_eq!(trans.effective_precision(11), Precision::INT4);
279        assert_eq!(trans.effective_precision(12), Precision::FP16);
280        assert_eq!(trans.effective_precision(14), Precision::FP16);
281    }
282
283    #[test]
284    fn test_smooth_step() {
285        assert_eq!(smooth_step(0.0), 0.0);
286        assert_eq!(smooth_step(1.0), 1.0);
287        assert!((smooth_step(0.5) - 0.5).abs() < 0.01);
288    }
289
290    #[test]
291    fn test_transition_planner() {
292        let planner = TransitionPlanner::new(8192);
293
294        let zones = vec![
295            (0, 10, Precision::INT4),
296            (10, 20, Precision::INT8),
297            (20, 30, Precision::FP16),
298        ];
299
300        let transitions = planner.plan_transitions(&zones);
301        assert_eq!(transitions.len(), 2);
302
303        assert_eq!(transitions[0].from, Precision::INT4);
304        assert_eq!(transitions[0].to, Precision::INT8);
305
306        assert_eq!(transitions[1].from, Precision::INT8);
307        assert_eq!(transitions[1].to, Precision::FP16);
308    }
309}