use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::Notify;
use super::frame::FrameRecord;
use super::rlog::RlogReader;
pub struct ReplayOrchestrator {
current_seq: Arc<AtomicU64>,
advance_notify: Arc<Notify>,
domain_frames: HashMap<u16, Vec<FrameRecord>>,
}
impl ReplayOrchestrator {
pub fn from_rlog(rlog: &RlogReader) -> Self {
let mut domain_frames: HashMap<u16, Vec<FrameRecord>> = HashMap::new();
for frame in &rlog.frames {
domain_frames
.entry(frame.rhythm_id)
.or_default()
.push(frame.clone());
}
for frames in domain_frames.values_mut() {
frames.sort_by_key(|f| f.global_seq);
}
let start_seq = rlog.frames.iter().map(|f| f.global_seq).min().unwrap_or(0);
Self {
current_seq: Arc::new(AtomicU64::new(start_seq)),
advance_notify: Arc::new(Notify::new()),
domain_frames,
}
}
pub fn turn_token(&self, rhythm_id: u16) -> TurnToken {
TurnToken {
current_seq: self.current_seq.clone(),
advance_notify: self.advance_notify.clone(),
rhythm_id,
}
}
pub fn take_frames(&mut self, rhythm_id: u16) -> Vec<FrameRecord> {
self.domain_frames.remove(&rhythm_id).unwrap_or_default()
}
pub fn rhythm_ids(&self) -> Vec<u16> {
self.domain_frames.keys().copied().collect()
}
}
#[derive(Clone)]
pub struct TurnToken {
current_seq: Arc<AtomicU64>,
advance_notify: Arc<Notify>,
#[allow(dead_code)]
rhythm_id: u16,
}
impl TurnToken {
pub async fn wait_turn(&self, expected_seq: u64) {
loop {
let current = self.current_seq.load(Ordering::SeqCst);
if current == expected_seq {
return;
}
self.advance_notify.notified().await;
}
}
pub fn advance(&self) {
self.current_seq.fetch_add(1, Ordering::SeqCst);
self.advance_notify.notify_waiters();
}
}