use std::sync::Arc;
use std::time::Instant;
use async_trait::async_trait;
use atomr_core::actor::{Actor, ActorRef, Context, Props};
use parking_lot::Mutex;
use tokio::sync::oneshot;
#[cfg(feature = "replay")]
use atomr_persistence::{Journal, PersistentRepr};
#[derive(Debug, Clone)]
pub enum ReplayMode {
Off,
Record,
Replay,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "replay", derive(serde::Serialize, serde::Deserialize))]
pub enum JournalEntry {
DeviceCmd {
ts_micros: u64,
name: String,
payload: String,
},
KernelCmd {
ts_micros: u64,
kind: String,
payload: String,
},
RngSeed {
actor_path: String,
seed: u64,
},
BatchSize {
actor_path: String,
size: usize,
},
}
pub trait ReplaySink: Send + 'static {
type Msg: Send + 'static;
fn make_on_entry(entry: JournalEntry, reply: oneshot::Sender<()>) -> Self::Msg;
}
pub enum ReplayMsg {
Record(JournalEntry),
Snapshot {
reply: oneshot::Sender<Vec<JournalEntry>>,
},
SetMode {
mode: ReplayMode,
},
LoadJournal {
entries: Vec<JournalEntry>,
reply: oneshot::Sender<()>,
},
ReplayAll,
#[cfg(feature = "replay")]
LoadFromJournal {
from_sequence_nr: u64,
max: u64,
reply: oneshot::Sender<Result<usize, String>>,
},
}
pub struct ReplayHarness {
mode: ReplayMode,
journal: Arc<Mutex<Vec<JournalEntry>>>,
started_at: Instant,
#[cfg(feature = "replay")]
persistence: Option<PersistenceState>,
}
#[cfg(feature = "replay")]
struct PersistenceState {
journal: Arc<dyn Journal>,
persistence_id: String,
next_seq: Arc<Mutex<u64>>,
}
impl ReplayHarness {
pub fn props(mode: ReplayMode) -> Props<Self> {
Props::create(move || ReplayHarness {
mode: mode.clone(),
journal: Arc::new(Mutex::new(Vec::new())),
started_at: Instant::now(),
#[cfg(feature = "replay")]
persistence: None,
})
}
#[cfg(feature = "replay")]
pub fn with_journal(
mode: ReplayMode,
journal: Arc<dyn Journal>,
persistence_id: impl Into<String>,
) -> Props<Self> {
let pid = persistence_id.into();
Props::create(move || ReplayHarness {
mode: mode.clone(),
journal: Arc::new(Mutex::new(Vec::new())),
started_at: Instant::now(),
persistence: Some(PersistenceState {
journal: journal.clone(),
persistence_id: pid.clone(),
next_seq: Arc::new(Mutex::new(0)),
}),
})
}
pub fn journal(&self) -> Arc<Mutex<Vec<JournalEntry>>> {
self.journal.clone()
}
pub async fn replay_all<F>(&self, mut sink_fn: F)
where
F: FnMut(JournalEntry, oneshot::Sender<()>),
{
if !matches!(self.mode, ReplayMode::Replay) {
return;
}
let entries = self.journal.lock().clone();
for entry in entries {
let (tx, rx) = oneshot::channel::<()>();
sink_fn(entry, tx);
let _ = rx.await;
}
}
}
pub fn replay_via_sink<S: ReplaySink>(
sink: ActorRef<S::Msg>,
) -> impl FnMut(JournalEntry, oneshot::Sender<()>) {
move |entry, reply| {
sink.tell(S::make_on_entry(entry, reply));
}
}
#[async_trait]
impl Actor for ReplayHarness {
type Msg = ReplayMsg;
async fn handle(&mut self, _ctx: &mut Context<Self>, msg: ReplayMsg) {
match msg {
ReplayMsg::Record(mut entry) => {
if matches!(self.mode, ReplayMode::Record) {
let ts = self.started_at.elapsed().as_micros() as u64;
if let JournalEntry::DeviceCmd { ts_micros, .. }
| JournalEntry::KernelCmd { ts_micros, .. } = &mut entry
{
*ts_micros = ts;
}
self.journal.lock().push(entry.clone());
#[cfg(feature = "replay")]
if let Some(p) = &self.persistence {
match write_to_journal(p, &entry).await {
Ok(()) => {}
Err(e) => {
tracing::warn!(
error = %e,
persistence_id = %p.persistence_id,
"ReplayHarness: persistence write failed"
);
}
}
}
}
}
ReplayMsg::Snapshot { reply } => {
let _ = reply.send(self.journal.lock().clone());
}
ReplayMsg::SetMode { mode } => {
self.mode = mode;
}
ReplayMsg::LoadJournal { entries, reply } => {
*self.journal.lock() = entries;
let _ = reply.send(());
}
ReplayMsg::ReplayAll => {
if !matches!(self.mode, ReplayMode::Replay) {
return;
}
let entries = self.journal.lock().clone();
for _ in entries {
}
}
#[cfg(feature = "replay")]
ReplayMsg::LoadFromJournal {
from_sequence_nr,
max,
reply,
} => {
let p = match &self.persistence {
Some(p) => p,
None => {
let _ = reply.send(Err("no persistence backend attached".into()));
return;
}
};
match p
.journal
.replay_messages(&p.persistence_id, from_sequence_nr, u64::MAX, max)
.await
{
Ok(reprs) => {
let mut decoded = Vec::with_capacity(reprs.len());
for r in &reprs {
match serde_json::from_slice::<JournalEntry>(&r.payload) {
Ok(e) => decoded.push(e),
Err(e) => {
let _ = reply
.send(Err(format!("decode seq={}: {e}", r.sequence_nr)));
return;
}
}
}
let n = decoded.len();
*self.journal.lock() = decoded;
let _ = reply.send(Ok(n));
}
Err(e) => {
let _ = reply.send(Err(format!("journal replay failed: {e}")));
}
}
}
}
}
}
#[cfg(feature = "replay")]
async fn write_to_journal(p: &PersistenceState, entry: &JournalEntry) -> Result<(), String> {
let payload = serde_json::to_vec(entry).map_err(|e| format!("serde: {e}"))?;
let needs_init = { *p.next_seq.lock() == 0 };
if needs_init {
let highest = p
.journal
.highest_sequence_nr(&p.persistence_id, 0)
.await
.map_err(|e| format!("highest_seq: {e}"))?;
let mut s = p.next_seq.lock();
if *s == 0 {
*s = highest;
}
}
let seq = {
let mut s = p.next_seq.lock();
*s += 1;
*s
};
let repr = PersistentRepr {
persistence_id: p.persistence_id.clone(),
sequence_nr: seq,
payload,
manifest: "atomr_accel_cuda::replay::JournalEntry".into(),
writer_uuid: "atomr-accel-cuda".into(),
deleted: false,
tags: Vec::new(),
};
p.journal
.write_messages(vec![repr])
.await
.map_err(|e| format!("write_messages: {e}"))
}
#[cfg(test)]
mod tests {
use super::*;
use atomr_config::Config;
use atomr_core::actor::ActorSystem;
use std::time::Duration;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn record_then_snapshot() {
let sys = ActorSystem::create("replay-test", Config::empty())
.await
.unwrap();
let actor = sys
.actor_of(ReplayHarness::props(ReplayMode::Record), "replay")
.unwrap();
actor.tell(ReplayMsg::Record(JournalEntry::RngSeed {
actor_path: "test/rng".into(),
seed: 42,
}));
tokio::time::sleep(Duration::from_millis(50)).await;
let (tx, rx) = oneshot::channel();
actor.tell(ReplayMsg::Snapshot { reply: tx });
let entries = tokio::time::timeout(Duration::from_secs(2), rx)
.await
.unwrap()
.unwrap();
assert_eq!(entries.len(), 1);
sys.terminate().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn off_mode_drops_records() {
let sys = ActorSystem::create("replay-off", Config::empty())
.await
.unwrap();
let actor = sys
.actor_of(ReplayHarness::props(ReplayMode::Off), "replay")
.unwrap();
actor.tell(ReplayMsg::Record(JournalEntry::RngSeed {
actor_path: "test".into(),
seed: 1,
}));
tokio::time::sleep(Duration::from_millis(50)).await;
let (tx, rx) = oneshot::channel();
actor.tell(ReplayMsg::Snapshot { reply: tx });
let entries = tokio::time::timeout(Duration::from_secs(2), rx)
.await
.unwrap()
.unwrap();
assert_eq!(entries.len(), 0);
sys.terminate().await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn load_then_replay_via_sink() {
let sys = ActorSystem::create("replay-load", Config::empty())
.await
.unwrap();
let actor = sys
.actor_of(ReplayHarness::props(ReplayMode::Replay), "replay")
.unwrap();
let journal = vec![
JournalEntry::RngSeed {
actor_path: "a".into(),
seed: 1,
},
JournalEntry::RngSeed {
actor_path: "b".into(),
seed: 2,
},
];
let (tx, rx) = oneshot::channel();
actor.tell(ReplayMsg::LoadJournal {
entries: journal.clone(),
reply: tx,
});
tokio::time::timeout(Duration::from_secs(2), rx)
.await
.unwrap()
.unwrap();
let (tx_done, rx_done) = oneshot::channel::<Vec<JournalEntry>>();
actor.tell(ReplayMsg::Snapshot { reply: tx_done });
let snap = tokio::time::timeout(Duration::from_secs(2), rx_done)
.await
.unwrap()
.unwrap();
assert_eq!(snap.len(), 2);
match (&snap[0], &snap[1]) {
(JournalEntry::RngSeed { seed: s0, .. }, JournalEntry::RngSeed { seed: s1, .. }) => {
assert_eq!(*s0, 1);
assert_eq!(*s1, 2);
}
_ => panic!("unexpected journal contents"),
}
sys.terminate().await;
}
}