Skip to main content

elara_visual/
prediction.rs

1//! Visual Prediction Engine - Continuity under packet loss
2//!
3//! This is the ELARA difference: when packets are lost, we PREDICT visual state
4//! instead of freezing or showing artifacts. Reality continues.
5
6use elara_core::StateTime;
7
8use crate::{VisualState, VisualStateId};
9
10/// Prediction configuration
11#[derive(Debug, Clone)]
12pub struct PredictionConfig {
13    /// Maximum prediction horizon in milliseconds
14    pub max_horizon_ms: u32,
15
16    /// Confidence decay rate per 100ms
17    pub confidence_decay: f32,
18
19    /// Minimum confidence before prediction stops
20    pub min_confidence: f32,
21
22    /// Enable motion prediction
23    pub predict_motion: bool,
24
25    /// Enable expression prediction
26    pub predict_expression: bool,
27}
28
29impl Default for PredictionConfig {
30    fn default() -> Self {
31        Self {
32            max_horizon_ms: 500,
33            confidence_decay: 0.1,
34            min_confidence: 0.3,
35            predict_motion: true,
36            predict_expression: true,
37        }
38    }
39}
40
41/// Visual state predictor
42#[derive(Debug)]
43pub struct VisualPredictor {
44    /// Configuration
45    config: PredictionConfig,
46
47    /// Last known good state
48    last_state: Option<VisualState>,
49
50    /// Previous state (for velocity estimation)
51    prev_state: Option<VisualState>,
52
53    /// Current prediction (if any)
54    current_prediction: Option<VisualState>,
55
56    /// How many consecutive predictions we've made
57    prediction_count: u32,
58}
59
60impl VisualPredictor {
61    /// Create a new predictor
62    pub fn new(config: PredictionConfig) -> Self {
63        Self {
64            config,
65            last_state: None,
66            prev_state: None,
67            current_prediction: None,
68            prediction_count: 0,
69        }
70    }
71
72    /// Update with a new received state
73    pub fn update(&mut self, state: VisualState) {
74        self.prev_state = self.last_state.take();
75        self.last_state = Some(state);
76        self.current_prediction = None;
77        self.prediction_count = 0;
78    }
79
80    /// Get the current best state (received or predicted)
81    pub fn current_state(&self) -> Option<&VisualState> {
82        self.current_prediction
83            .as_ref()
84            .or(self.last_state.as_ref())
85    }
86
87    /// Predict state at a future time
88    /// Returns None if prediction is not possible or confidence is too low
89    pub fn predict(&mut self, target_time: StateTime) -> Option<VisualState> {
90        let last = self.last_state.as_ref()?;
91
92        let delta_ms = target_time.as_millis() - last.timestamp.as_millis();
93
94        // Don't predict backwards
95        if delta_ms <= 0 {
96            return Some(last.clone());
97        }
98
99        // Don't predict beyond horizon
100        if delta_ms > self.config.max_horizon_ms as i64 {
101            return None;
102        }
103
104        // Calculate confidence decay
105        let decay_steps = delta_ms as f32 / 100.0;
106        let confidence = 1.0 - (decay_steps * self.config.confidence_decay);
107
108        if confidence < self.config.min_confidence {
109            return None;
110        }
111
112        // Create predicted state
113        let mut predicted = last.clone();
114        predicted.timestamp = target_time;
115        predicted.sequence = last.sequence + 1;
116        predicted.id = VisualStateId::new(predicted.sequence);
117
118        // Predict face
119        if self.config.predict_expression {
120            if let (Some(ref prev), Some(ref mut face)) = (&self.prev_state, &mut predicted.face) {
121                if let Some(ref prev_face) = prev.face {
122                    // Predict mouth movement (for speech)
123                    if face.speaking {
124                        // Oscillate mouth openness for natural speech
125                        let phase = (delta_ms as f32 / 150.0).sin();
126                        face.mouth.openness = 0.3 + 0.2 * phase.abs();
127                    }
128
129                    // Predict head movement (smooth continuation)
130                    let dt = delta_ms as f32 / 1000.0;
131                    let head_vel = (
132                        (last.face.as_ref().map(|f| f.head_rotation.0).unwrap_or(0.0)
133                            - prev_face.head_rotation.0)
134                            / 0.1,
135                        (last.face.as_ref().map(|f| f.head_rotation.1).unwrap_or(0.0)
136                            - prev_face.head_rotation.1)
137                            / 0.1,
138                        (last.face.as_ref().map(|f| f.head_rotation.2).unwrap_or(0.0)
139                            - prev_face.head_rotation.2)
140                            / 0.1,
141                    );
142
143                    face.head_rotation.0 += head_vel.0 * dt * 0.5; // Damped
144                    face.head_rotation.1 += head_vel.1 * dt * 0.5;
145                    face.head_rotation.2 += head_vel.2 * dt * 0.5;
146                }
147
148                // Reduce confidence
149                face.confidence *= confidence;
150            }
151        }
152
153        // Predict pose
154        if self.config.predict_motion {
155            if let Some(ref mut pose) = predicted.pose {
156                // Use stored velocity for prediction
157                let dt = delta_ms as f32 / 1000.0;
158
159                for joint in &mut pose.joints {
160                    joint.position.x += pose.velocity.x * dt;
161                    joint.position.y += pose.velocity.y * dt;
162                    joint.position.z += pose.velocity.z * dt;
163                }
164
165                pose.confidence *= confidence;
166            }
167        }
168
169        // Scene doesn't need much prediction (relatively static)
170
171        self.current_prediction = Some(predicted.clone());
172        self.prediction_count += 1;
173
174        Some(predicted)
175    }
176
177    /// Check if we're currently in prediction mode
178    pub fn is_predicting(&self) -> bool {
179        self.prediction_count > 0
180    }
181
182    /// Get prediction count
183    pub fn prediction_count(&self) -> u32 {
184        self.prediction_count
185    }
186
187    /// Get estimated confidence of current state
188    pub fn confidence(&self) -> f32 {
189        if let Some(ref pred) = self.current_prediction {
190            // Use face confidence as proxy
191            pred.face.as_ref().map(|f| f.confidence).unwrap_or(0.5)
192        } else if self.last_state.is_some() {
193            1.0
194        } else {
195            0.0
196        }
197    }
198}
199
200/// Interpolation between two visual states
201pub struct VisualInterpolator;
202
203impl VisualInterpolator {
204    /// Interpolate between two visual states
205    pub fn interpolate(from: &VisualState, to: &VisualState, t: f32) -> VisualState {
206        let t = t.clamp(0.0, 1.0);
207
208        let mut result = to.clone();
209
210        // Interpolate face
211        result.face = match (&from.face, &to.face) {
212            (Some(f1), Some(f2)) => Some(f1.lerp(f2, t)),
213            (None, Some(f)) => Some(f.clone()),
214            (Some(f), None) => {
215                if t < 0.5 {
216                    Some(f.clone())
217                } else {
218                    None
219                }
220            }
221            (None, None) => None,
222        };
223
224        // Interpolate pose
225        result.pose = match (&from.pose, &to.pose) {
226            (Some(p1), Some(p2)) => Some(p1.lerp(p2, t)),
227            (None, Some(p)) => Some(p.clone()),
228            (Some(p), None) => {
229                if t < 0.5 {
230                    Some(p.clone())
231                } else {
232                    None
233                }
234            }
235            (None, None) => None,
236        };
237
238        // Interpolate scene
239        result.scene = match (&from.scene, &to.scene) {
240            (Some(s1), Some(s2)) => Some(s1.lerp(s2, t)),
241            (None, Some(s)) => Some(s.clone()),
242            (Some(s), None) => {
243                if t < 0.5 {
244                    Some(s.clone())
245                } else {
246                    None
247                }
248            }
249            (None, None) => None,
250        };
251
252        result
253    }
254}
255
256/// Jitter buffer for visual states (NOT traditional jitter buffer)
257/// This is a state buffer that enables smooth interpolation
258#[derive(Debug)]
259pub struct VisualStateBuffer {
260    /// Buffered states
261    states: Vec<VisualState>,
262
263    /// Maximum buffer size
264    max_size: usize,
265
266    /// Target delay in milliseconds (for smoothing)
267    target_delay_ms: u32,
268}
269
270impl VisualStateBuffer {
271    /// Create a new buffer
272    pub fn new(max_size: usize, target_delay_ms: u32) -> Self {
273        Self {
274            states: Vec::with_capacity(max_size),
275            max_size,
276            target_delay_ms,
277        }
278    }
279
280    /// Add a state to the buffer
281    pub fn push(&mut self, state: VisualState) {
282        // Insert in order by timestamp
283        let pos = self
284            .states
285            .iter()
286            .position(|s| s.timestamp.as_millis() > state.timestamp.as_millis())
287            .unwrap_or(self.states.len());
288
289        self.states.insert(pos, state);
290
291        // Remove old states if buffer is full
292        while self.states.len() > self.max_size {
293            self.states.remove(0);
294        }
295    }
296
297    /// Get interpolated state at a specific time
298    pub fn get_at(&self, time: StateTime) -> Option<VisualState> {
299        if self.states.is_empty() {
300            return None;
301        }
302
303        // Find surrounding states
304        let target_ms = time.as_millis() - self.target_delay_ms as i64;
305
306        let mut before: Option<&VisualState> = None;
307        let mut after: Option<&VisualState> = None;
308
309        for state in &self.states {
310            if state.timestamp.as_millis() <= target_ms {
311                before = Some(state);
312            } else {
313                after = Some(state);
314                break;
315            }
316        }
317
318        match (before, after) {
319            (Some(b), Some(a)) => {
320                // Interpolate
321                let range = a.timestamp.as_millis() - b.timestamp.as_millis();
322                if range <= 0 {
323                    return Some(b.clone());
324                }
325                let t = (target_ms - b.timestamp.as_millis()) as f32 / range as f32;
326                Some(VisualInterpolator::interpolate(b, a, t))
327            }
328            (Some(b), None) => Some(b.clone()),
329            (None, Some(a)) => Some(a.clone()),
330            (None, None) => None,
331        }
332    }
333
334    /// Get the latest state
335    pub fn latest(&self) -> Option<&VisualState> {
336        self.states.last()
337    }
338
339    /// Clear the buffer
340    pub fn clear(&mut self) {
341        self.states.clear();
342    }
343
344    /// Number of buffered states
345    pub fn len(&self) -> usize {
346        self.states.len()
347    }
348
349    /// Is buffer empty?
350    pub fn is_empty(&self) -> bool {
351        self.states.is_empty()
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358    use elara_core::NodeId;
359
360    #[test]
361    fn test_predictor_update() {
362        let mut predictor = VisualPredictor::new(PredictionConfig::default());
363
364        let node = NodeId::new(1);
365        let time = StateTime::from_millis(0);
366        let state = VisualState::keyframe(node, time, 1);
367
368        predictor.update(state);
369
370        assert!(predictor.current_state().is_some());
371        assert!(!predictor.is_predicting());
372    }
373
374    #[test]
375    fn test_predictor_predict() {
376        let mut predictor = VisualPredictor::new(PredictionConfig::default());
377
378        let node = NodeId::new(1);
379        let time1 = StateTime::from_millis(0);
380        let time2 = StateTime::from_millis(100);
381
382        predictor.update(VisualState::keyframe(node, time1, 1));
383        predictor.update(VisualState::keyframe(node, time2, 2));
384
385        let predicted = predictor.predict(StateTime::from_millis(200));
386        assert!(predicted.is_some());
387        assert!(predictor.is_predicting());
388    }
389
390    #[test]
391    fn test_state_buffer() {
392        let mut buffer = VisualStateBuffer::new(10, 50);
393
394        let node = NodeId::new(1);
395
396        buffer.push(VisualState::keyframe(node, StateTime::from_millis(0), 1));
397        buffer.push(VisualState::keyframe(node, StateTime::from_millis(100), 2));
398        buffer.push(VisualState::keyframe(node, StateTime::from_millis(200), 3));
399
400        assert_eq!(buffer.len(), 3);
401
402        // Get interpolated state at t=100 (with 50ms delay = t=50)
403        let state = buffer.get_at(StateTime::from_millis(100));
404        assert!(state.is_some());
405    }
406}