use std::fmt::Debug;
use std::future::Future;
use std::marker::PhantomData;
use std::time::Duration;
use serde::Serialize;
use serde::de::DeserializeOwned;
use crate::RoplatError;
use crate::rhythm::Rhythm;
use super::frame::FrameRecord;
#[derive(Debug, Clone)]
pub enum FeedCheck {
None,
Exact,
Tolerance(f64),
}
#[derive(Debug, Clone)]
pub enum ReplayMode {
Full {
tolerance: FeedCheck,
},
InputOnly,
OutputOnly,
}
#[derive(Debug)]
pub enum ReplayTiming {
AsFast,
Realtime,
Scaled(f64),
}
pub struct ReplayRhythm<Y, D> {
frames: Vec<FrameRecord>,
cursor: usize,
mode: ReplayMode,
timing: ReplayTiming,
last_timestamp_ns: Option<u64>,
_phantom: PhantomData<(Y, D)>,
}
impl<Y, D> ReplayRhythm<Y, D> {
pub fn new(frames: Vec<FrameRecord>, mode: ReplayMode, timing: ReplayTiming) -> Self {
Self {
frames,
cursor: 0,
mode,
timing,
last_timestamp_ns: None,
_phantom: PhantomData,
}
}
pub fn remaining(&self) -> usize {
self.frames.len().saturating_sub(self.cursor)
}
pub fn is_finished(&self) -> bool {
self.cursor >= self.frames.len()
}
pub fn rewind(&mut self) {
self.cursor = 0;
self.last_timestamp_ns = None;
}
fn frame_delay(&mut self, timestamp_ns: u64) -> Option<Duration> {
let delay = match &self.timing {
ReplayTiming::AsFast => None,
ReplayTiming::Realtime => {
if let Some(last) = self.last_timestamp_ns {
let diff = timestamp_ns.saturating_sub(last);
Some(Duration::from_nanos(diff))
} else {
None
}
}
ReplayTiming::Scaled(factor) => {
if let Some(last) = self.last_timestamp_ns {
let diff = timestamp_ns.saturating_sub(last);
let scaled = (diff as f64 / factor) as u64;
Some(Duration::from_nanos(scaled))
} else {
None
}
}
};
self.last_timestamp_ns = Some(timestamp_ns);
delay
}
}
#[derive(Debug)]
pub struct FeedMismatch {
pub global_seq: u64,
pub rhythm_id: u16,
}
impl<Y, D> Rhythm for ReplayRhythm<Y, D>
where
Y: DeserializeOwned + Send,
D: DeserializeOwned + Serialize + Send,
{
type Yield = Y;
type Feed = D;
type Output = Vec<FeedMismatch>;
type Error = RoplatError;
async fn drive<N, F, Fut>(
&mut self,
mut nodes: N,
mut op_domain: F,
) -> (Self::Output, N)
where
N: Send,
F: FnMut(N, Self::Yield) -> Fut + Send,
Fut: Future<Output = (Self::Feed, N)> + Send,
{
let mut mismatches = Vec::new();
while self.cursor < self.frames.len() {
let frame_seq = self.frames[self.cursor].global_seq;
let frame_rid = self.frames[self.cursor].rhythm_id;
let frame_ts = self.frames[self.cursor].timestamp_ns;
let yield_data = self.frames[self.cursor].yield_data.clone();
let feed_data = self.frames[self.cursor].feed_data.clone();
self.cursor += 1;
if let Some(delay) = self.frame_delay(frame_ts)
&& !delay.is_zero()
{
tokio::time::sleep(delay).await;
}
let yield_val: Y = bincode::deserialize(&yield_data)
.expect("ReplayRhythm: failed to deserialize yield data; type mismatch?");
if let ReplayMode::OutputOnly = &self.mode {
continue;
}
let (feed_val, returned) = op_domain(nodes, yield_val).await;
nodes = returned;
if let ReplayMode::Full { tolerance } = &self.mode {
let matched = match tolerance {
FeedCheck::None => true,
FeedCheck::Exact => {
let actual_bytes = bincode::serialize(&feed_val)
.expect("ReplayRhythm: feed serialize failed");
actual_bytes == feed_data
}
FeedCheck::Tolerance(eps) => {
let actual_bytes = bincode::serialize(&feed_val)
.expect("ReplayRhythm: feed serialize failed");
if actual_bytes.len() != feed_data.len() {
false
} else {
tolerance_check(&actual_bytes, &feed_data, *eps)
}
}
};
if !matched {
mismatches.push(FeedMismatch { global_seq: frame_seq, rhythm_id: frame_rid });
}
}
}
(mismatches, nodes)
}
}
fn tolerance_check(a: &[u8], b: &[u8], eps: f64) -> bool {
if a.len() != b.len() {
return false;
}
if !a.len().is_multiple_of(8) {
return a == b;
}
for chunk_idx in (0..a.len()).step_by(8) {
let va = f64::from_le_bytes(a[chunk_idx..chunk_idx + 8].try_into().unwrap());
let vb = f64::from_le_bytes(b[chunk_idx..chunk_idx + 8].try_into().unwrap());
if (va - vb).abs() > eps {
return false;
}
}
true
}