use futures::{Stream, StreamExt};
use parking_lot::Mutex;
use pin_project::pin_project;
use serde::{Deserialize, Serialize};
use std::{
collections::VecDeque,
fmt,
path::Path,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::{Duration, Instant, SystemTime},
};
use tokio::io;
use tokio::time::sleep_until;
#[derive(Debug, Clone, PartialEq)]
pub enum Error {
EmptyRecording,
InvalidSpeed(f64),
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::EmptyRecording => write!(f, "The recording is empty"),
Error::InvalidSpeed(speed) => write!(f, "Invalid speed {}: must be positive", speed),
}
}
}
impl std::error::Error for Error {}
pub type SturgeonResult<T> = std::result::Result<T, Error>;
#[pin_project]
pub struct RecordedStream<S: Stream> {
#[pin]
inner: S,
recording: Recording<S::Item>,
seq: u64,
last_timestamp: Option<Instant>,
start_timestamp: Instant,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Recording<S> {
items: Arc<Mutex<VecDeque<RecordedItem<S>>>>,
capacity: Option<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RecordedItem<T> {
pub seq: u64,
pub timestamp: SystemTime,
pub delta: Duration,
data: Arc<T>,
}
impl<S: Stream> Stream for RecordedStream<S>
where
S::Item: Clone,
{
type Item = S::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
let result = this.inner.poll_next(cx);
if let Poll::Ready(Some(ref item)) = result {
let now = Instant::now();
let delta = this
.last_timestamp
.map(|last| now.duration_since(last))
.unwrap_or_else(|| now.duration_since(*this.start_timestamp));
this.recording.push(RecordedItem {
seq: *this.seq,
timestamp: SystemTime::now(),
delta,
data: Arc::new(item.clone()),
});
*this.seq += 1;
*this.last_timestamp = Some(now);
}
result
}
}
impl<S> Default for Recording<S> {
fn default() -> Self {
Self::new()
}
}
impl<S> Recording<S> {
pub fn new() -> Self {
Recording {
items: Arc::new(Mutex::new(VecDeque::new())),
capacity: None,
}
}
pub fn from_items(items: impl IntoIterator<Item = RecordedItem<S>>) -> Self {
Recording {
items: Arc::new(Mutex::new(VecDeque::from_iter(items))),
capacity: None,
}
}
pub fn from_raw_items(items: impl IntoIterator<Item = S>) -> Self {
let now = SystemTime::now();
let recorded: VecDeque<_> = items
.into_iter()
.enumerate()
.map(|(seq, data)| RecordedItem {
seq: seq as u64,
timestamp: now,
delta: Duration::ZERO,
data: Arc::new(data),
})
.collect();
Recording {
items: Arc::new(Mutex::new(recorded)),
capacity: None,
}
}
pub fn with_capacity(capacity: usize) -> Self {
Recording {
items: Arc::new(Mutex::new(VecDeque::with_capacity(capacity))),
capacity: Some(capacity),
}
}
fn push(&self, item: RecordedItem<S>) {
let mut items = self.items.lock();
if let Some(cap) = self.capacity
&& items.len() >= cap
{
items.remove(0);
}
items.push_back(item);
}
pub async fn save(&self, path: impl AsRef<Path>) -> io::Result<()>
where
S: Serialize,
{
let bytes = bincode::serde::encode_to_vec(self, bincode::config::standard())
.map_err(std::io::Error::other)?;
tokio::fs::write(path, bytes).await
}
pub async fn load(path: impl AsRef<Path>) -> io::Result<Self>
where
S: for<'de> Deserialize<'de>,
{
let bytes = tokio::fs::read(path).await?;
let (rec, _) = bincode::serde::decode_from_slice(&bytes, bincode::config::standard())
.map_err(io::Error::other)?;
Ok(rec)
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Speed(f64);
impl Speed {
pub fn new(value: f64) -> SturgeonResult<Self> {
if value <= 0.0 {
Err(Error::InvalidSpeed(value))
} else {
Ok(Speed(value))
}
}
pub const NORMAL: Speed = Speed(1.0);
pub fn as_f64(&self) -> f64 {
self.0
}
}
impl<S: Clone> Recording<S> {
pub fn items(&self) -> Vec<S> {
self.items
.lock()
.iter()
.map(|i| (*i.data).clone())
.collect()
}
#[must_use = "streams do nothing unless polled"]
fn replay_items(
&self,
items: Vec<RecordedItem<S>>,
speed: Speed,
) -> impl Stream<Item = Arc<S>> {
let start = tokio::time::Instant::now();
let mut cumulative = Duration::ZERO;
let timed_items: Vec<_> = items
.into_iter()
.map(|item| {
let adjusted_delta =
Duration::from_secs_f64(item.delta.as_secs_f64() / speed.as_f64());
cumulative += adjusted_delta;
(start + cumulative, item)
})
.collect();
futures::stream::iter(timed_items).then(|(target, item)| async move {
sleep_until(target).await;
Arc::clone(&item.data)
})
}
#[must_use = "streams do nothing unless polled"]
pub fn replay(&self) -> impl Stream<Item = S> {
let items: Vec<_> = self.items.lock().clone().into_iter().collect();
self.replay_items(items, Speed::NORMAL)
.map(|item| (*item).clone())
}
#[must_use = "streams do nothing unless polled"]
pub fn replay_from(&self, start_seq: u64) -> impl Stream<Item = S> {
let items: Vec<_> = self
.items
.lock()
.iter()
.skip_while(|i| i.seq < start_seq)
.cloned()
.collect();
self.replay_items(items, Speed::NORMAL)
.map(|item| (*item).clone())
}
#[must_use = "streams do nothing unless polled"]
pub fn replay_since(&self, since: SystemTime) -> impl Stream<Item = S> {
let items: Vec<_> = self
.items
.lock()
.iter()
.skip_while(|i| i.timestamp < since)
.cloned()
.collect();
self.replay_items(items, Speed::NORMAL)
.map(|item| (*item).clone())
}
#[must_use = "streams do nothing unless polled"]
pub fn replay_range(&self, start: u64, end: u64) -> impl Stream<Item = S> {
let items: Vec<_> = self
.items
.lock()
.iter()
.filter(|i| i.seq >= start && i.seq <= end)
.cloned()
.collect();
self.replay_items(items, Speed::NORMAL)
.map(|item| (*item).clone())
}
#[must_use = "streams do nothing unless polled"]
pub fn replay_with_speed(&self, speed: Speed) -> impl Stream<Item = S> {
let items: Vec<_> = self.items.lock().iter().cloned().collect();
self.replay_items(items, speed).map(|item| (*item).clone())
}
#[must_use = "streams do nothing unless polled"]
pub fn replay_immediate(&self) -> impl Stream<Item = S> {
let items: Vec<_> = self.items.lock().iter().cloned().collect();
tokio_stream::iter(items).map(|item| (*item.data).clone())
}
}
#[cfg(test)]
impl<S: PartialEq + std::fmt::Debug> Recording<S> {
pub fn assert_count(&self, expected: usize) {
assert_eq!(
self.items.lock().len(),
expected,
"sequence length mismatch"
);
}
pub fn assert_sequence(&self, expected: &[S]) {
let recording = self.items.lock();
let mismatches: Vec<(usize, &S, &S)> = recording
.iter()
.zip(expected)
.enumerate()
.filter_map(|(i, (rec, exp))| (*rec.data != *exp).then_some((i, &*(rec.data), exp)))
.collect();
assert!(
mismatches.is_empty(),
"sequence mismatch at indices {mismatches:?}",
)
}
pub fn assert_timing(&self, min: Duration, max: Duration) {
let recording = self.items.lock();
let violations: Vec<(usize, Duration, Duration)> = recording
.iter()
.enumerate()
.filter_map(|(i, rec)| {
(rec.delta < min || rec.delta > max).then_some((i, rec.delta, min))
})
.collect();
assert!(
violations.is_empty(),
"timing violations at indices {violations:?}",
)
}
}
impl<S: Stream> RecordedStream<S>
where
S::Item: Clone,
{
pub fn recording(&self) -> Recording<S::Item> {
self.recording.clone()
}
}
pub fn record<S: Stream<Item = T>, T: Clone>(s: S) -> RecordedStream<S> {
let now = Instant::now();
RecordedStream {
inner: s,
recording: Recording::new(),
seq: 0,
last_timestamp: None,
start_timestamp: now,
}
}
pub fn record_with_capacity<S: Stream<Item = T>, T: Clone>(
s: S,
capacity: usize,
) -> RecordedStream<S> {
let now = Instant::now();
RecordedStream {
inner: s,
recording: Recording::with_capacity(capacity),
seq: 0,
last_timestamp: None,
start_timestamp: now,
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::stream::{self, StreamExt};
use std::time::Duration;
#[tokio::test]
async fn passthrough_unchanged() {
let input = vec![1, 2, 3, 4, 5];
let recorded = record(stream::iter(input.clone()));
let output: Vec<_> = recorded.collect().await;
assert_eq!(output, input);
}
#[tokio::test]
async fn replay_order() {
for input in [vec![1, 2, 3], vec![5, 4, 3, 2, 1], vec![42], vec![]] {
let mut recorded = record(stream::iter(input.clone()));
while recorded.next().await.is_some() {}
let result: Vec<_> = recorded.recording().replay().collect().await;
assert_eq!(result, input);
}
}
#[tokio::test]
async fn replay_from_seq() {
let mut recorded = record(stream::iter(vec![10, 20, 30, 40, 50]));
while recorded.next().await.is_some() {}
let result: Vec<_> = recorded.recording().replay_from(2).collect().await;
assert_eq!(result, vec![30, 40, 50]);
}
#[tokio::test]
async fn replay_range_bounds() {
let mut recorded = record(stream::iter(vec![10, 20, 30, 40, 50]));
while recorded.next().await.is_some() {}
let result: Vec<_> = recorded.recording().replay_range(1, 3).collect().await;
assert_eq!(result, vec![20, 30, 40]);
}
#[tokio::test]
async fn capacity_bounds_storage() {
let mut recorded = record_with_capacity(stream::iter(1..=5), 3);
while recorded.next().await.is_some() {}
let result: Vec<_> = recorded.recording().replay_immediate().collect().await;
assert_eq!(result, vec![3, 4, 5]);
}
#[tokio::test]
async fn immediate_replay_no_delays() {
let start = tokio::time::Instant::now();
let mut recorded = record(stream::iter(1..=100));
while recorded.next().await.is_some() {}
let _: Vec<_> = recorded.recording().replay_immediate().collect().await;
assert!(start.elapsed() < Duration::from_millis(100));
}
#[tokio::test]
async fn persistence_roundtrip() {
let input = vec!["foo".to_string(), "bar".to_string(), "baz".to_string()];
let mut recorded = record(stream::iter(input.clone()));
while recorded.next().await.is_some() {}
let path = "/tmp/sturgeon_test.bin";
recorded.recording().save(path).await.unwrap();
let loaded: Recording<String> = Recording::load(path).await.unwrap();
let result: Vec<_> = loaded.replay_immediate().collect().await;
assert_eq!(result, input);
let _ = tokio::fs::remove_file(path).await;
}
#[test]
fn speed_validates() {
assert!(Speed::new(1.0).is_ok());
assert!(Speed::new(0.5).is_ok());
assert!(Speed::new(2.0).is_ok());
assert!(Speed::new(0.0).is_err());
assert!(Speed::new(-1.0).is_err());
assert!(Speed::new(-0.1).is_err());
}
#[test]
fn speed_const_normal() {
assert_eq!(Speed::NORMAL.as_f64(), 1.0);
}
}