1use elara_core::StateTime;
7
8use crate::{VisualState, VisualStateId};
9
10#[derive(Debug, Clone)]
12pub struct PredictionConfig {
13 pub max_horizon_ms: u32,
15
16 pub confidence_decay: f32,
18
19 pub min_confidence: f32,
21
22 pub predict_motion: bool,
24
25 pub predict_expression: bool,
27}
28
29impl Default for PredictionConfig {
30 fn default() -> Self {
31 Self {
32 max_horizon_ms: 500,
33 confidence_decay: 0.1,
34 min_confidence: 0.3,
35 predict_motion: true,
36 predict_expression: true,
37 }
38 }
39}
40
41#[derive(Debug)]
43pub struct VisualPredictor {
44 config: PredictionConfig,
46
47 last_state: Option<VisualState>,
49
50 prev_state: Option<VisualState>,
52
53 current_prediction: Option<VisualState>,
55
56 prediction_count: u32,
58}
59
60impl VisualPredictor {
61 pub fn new(config: PredictionConfig) -> Self {
63 Self {
64 config,
65 last_state: None,
66 prev_state: None,
67 current_prediction: None,
68 prediction_count: 0,
69 }
70 }
71
72 pub fn update(&mut self, state: VisualState) {
74 self.prev_state = self.last_state.take();
75 self.last_state = Some(state);
76 self.current_prediction = None;
77 self.prediction_count = 0;
78 }
79
80 pub fn current_state(&self) -> Option<&VisualState> {
82 self.current_prediction
83 .as_ref()
84 .or(self.last_state.as_ref())
85 }
86
87 pub fn predict(&mut self, target_time: StateTime) -> Option<VisualState> {
90 let last = self.last_state.as_ref()?;
91
92 let delta_ms = target_time.as_millis() - last.timestamp.as_millis();
93
94 if delta_ms <= 0 {
96 return Some(last.clone());
97 }
98
99 if delta_ms > self.config.max_horizon_ms as i64 {
101 return None;
102 }
103
104 let decay_steps = delta_ms as f32 / 100.0;
106 let confidence = 1.0 - (decay_steps * self.config.confidence_decay);
107
108 if confidence < self.config.min_confidence {
109 return None;
110 }
111
112 let mut predicted = last.clone();
114 predicted.timestamp = target_time;
115 predicted.sequence = last.sequence + 1;
116 predicted.id = VisualStateId::new(predicted.sequence);
117
118 if self.config.predict_expression {
120 if let (Some(ref prev), Some(ref mut face)) = (&self.prev_state, &mut predicted.face) {
121 if let Some(ref prev_face) = prev.face {
122 if face.speaking {
124 let phase = (delta_ms as f32 / 150.0).sin();
126 face.mouth.openness = 0.3 + 0.2 * phase.abs();
127 }
128
129 let dt = delta_ms as f32 / 1000.0;
131 let head_vel = (
132 (last.face.as_ref().map(|f| f.head_rotation.0).unwrap_or(0.0)
133 - prev_face.head_rotation.0)
134 / 0.1,
135 (last.face.as_ref().map(|f| f.head_rotation.1).unwrap_or(0.0)
136 - prev_face.head_rotation.1)
137 / 0.1,
138 (last.face.as_ref().map(|f| f.head_rotation.2).unwrap_or(0.0)
139 - prev_face.head_rotation.2)
140 / 0.1,
141 );
142
143 face.head_rotation.0 += head_vel.0 * dt * 0.5; face.head_rotation.1 += head_vel.1 * dt * 0.5;
145 face.head_rotation.2 += head_vel.2 * dt * 0.5;
146 }
147
148 face.confidence *= confidence;
150 }
151 }
152
153 if self.config.predict_motion {
155 if let Some(ref mut pose) = predicted.pose {
156 let dt = delta_ms as f32 / 1000.0;
158
159 for joint in &mut pose.joints {
160 joint.position.x += pose.velocity.x * dt;
161 joint.position.y += pose.velocity.y * dt;
162 joint.position.z += pose.velocity.z * dt;
163 }
164
165 pose.confidence *= confidence;
166 }
167 }
168
169 self.current_prediction = Some(predicted.clone());
172 self.prediction_count += 1;
173
174 Some(predicted)
175 }
176
177 pub fn is_predicting(&self) -> bool {
179 self.prediction_count > 0
180 }
181
182 pub fn prediction_count(&self) -> u32 {
184 self.prediction_count
185 }
186
187 pub fn confidence(&self) -> f32 {
189 if let Some(ref pred) = self.current_prediction {
190 pred.face.as_ref().map(|f| f.confidence).unwrap_or(0.5)
192 } else if self.last_state.is_some() {
193 1.0
194 } else {
195 0.0
196 }
197 }
198}
199
200pub struct VisualInterpolator;
202
203impl VisualInterpolator {
204 pub fn interpolate(from: &VisualState, to: &VisualState, t: f32) -> VisualState {
206 let t = t.clamp(0.0, 1.0);
207
208 let mut result = to.clone();
209
210 result.face = match (&from.face, &to.face) {
212 (Some(f1), Some(f2)) => Some(f1.lerp(f2, t)),
213 (None, Some(f)) => Some(f.clone()),
214 (Some(f), None) => {
215 if t < 0.5 {
216 Some(f.clone())
217 } else {
218 None
219 }
220 }
221 (None, None) => None,
222 };
223
224 result.pose = match (&from.pose, &to.pose) {
226 (Some(p1), Some(p2)) => Some(p1.lerp(p2, t)),
227 (None, Some(p)) => Some(p.clone()),
228 (Some(p), None) => {
229 if t < 0.5 {
230 Some(p.clone())
231 } else {
232 None
233 }
234 }
235 (None, None) => None,
236 };
237
238 result.scene = match (&from.scene, &to.scene) {
240 (Some(s1), Some(s2)) => Some(s1.lerp(s2, t)),
241 (None, Some(s)) => Some(s.clone()),
242 (Some(s), None) => {
243 if t < 0.5 {
244 Some(s.clone())
245 } else {
246 None
247 }
248 }
249 (None, None) => None,
250 };
251
252 result
253 }
254}
255
256#[derive(Debug)]
259pub struct VisualStateBuffer {
260 states: Vec<VisualState>,
262
263 max_size: usize,
265
266 target_delay_ms: u32,
268}
269
270impl VisualStateBuffer {
271 pub fn new(max_size: usize, target_delay_ms: u32) -> Self {
273 Self {
274 states: Vec::with_capacity(max_size),
275 max_size,
276 target_delay_ms,
277 }
278 }
279
280 pub fn push(&mut self, state: VisualState) {
282 let pos = self
284 .states
285 .iter()
286 .position(|s| s.timestamp.as_millis() > state.timestamp.as_millis())
287 .unwrap_or(self.states.len());
288
289 self.states.insert(pos, state);
290
291 while self.states.len() > self.max_size {
293 self.states.remove(0);
294 }
295 }
296
297 pub fn get_at(&self, time: StateTime) -> Option<VisualState> {
299 if self.states.is_empty() {
300 return None;
301 }
302
303 let target_ms = time.as_millis() - self.target_delay_ms as i64;
305
306 let mut before: Option<&VisualState> = None;
307 let mut after: Option<&VisualState> = None;
308
309 for state in &self.states {
310 if state.timestamp.as_millis() <= target_ms {
311 before = Some(state);
312 } else {
313 after = Some(state);
314 break;
315 }
316 }
317
318 match (before, after) {
319 (Some(b), Some(a)) => {
320 let range = a.timestamp.as_millis() - b.timestamp.as_millis();
322 if range <= 0 {
323 return Some(b.clone());
324 }
325 let t = (target_ms - b.timestamp.as_millis()) as f32 / range as f32;
326 Some(VisualInterpolator::interpolate(b, a, t))
327 }
328 (Some(b), None) => Some(b.clone()),
329 (None, Some(a)) => Some(a.clone()),
330 (None, None) => None,
331 }
332 }
333
334 pub fn latest(&self) -> Option<&VisualState> {
336 self.states.last()
337 }
338
339 pub fn clear(&mut self) {
341 self.states.clear();
342 }
343
344 pub fn len(&self) -> usize {
346 self.states.len()
347 }
348
349 pub fn is_empty(&self) -> bool {
351 self.states.is_empty()
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use elara_core::NodeId;
359
360 #[test]
361 fn test_predictor_update() {
362 let mut predictor = VisualPredictor::new(PredictionConfig::default());
363
364 let node = NodeId::new(1);
365 let time = StateTime::from_millis(0);
366 let state = VisualState::keyframe(node, time, 1);
367
368 predictor.update(state);
369
370 assert!(predictor.current_state().is_some());
371 assert!(!predictor.is_predicting());
372 }
373
374 #[test]
375 fn test_predictor_predict() {
376 let mut predictor = VisualPredictor::new(PredictionConfig::default());
377
378 let node = NodeId::new(1);
379 let time1 = StateTime::from_millis(0);
380 let time2 = StateTime::from_millis(100);
381
382 predictor.update(VisualState::keyframe(node, time1, 1));
383 predictor.update(VisualState::keyframe(node, time2, 2));
384
385 let predicted = predictor.predict(StateTime::from_millis(200));
386 assert!(predicted.is_some());
387 assert!(predictor.is_predicting());
388 }
389
390 #[test]
391 fn test_state_buffer() {
392 let mut buffer = VisualStateBuffer::new(10, 50);
393
394 let node = NodeId::new(1);
395
396 buffer.push(VisualState::keyframe(node, StateTime::from_millis(0), 1));
397 buffer.push(VisualState::keyframe(node, StateTime::from_millis(100), 2));
398 buffer.push(VisualState::keyframe(node, StateTime::from_millis(200), 3));
399
400 assert_eq!(buffer.len(), 3);
401
402 let state = buffer.get_at(StateTime::from_millis(100));
404 assert!(state.is_some());
405 }
406}