use std::collections::VecDeque;
use std::sync::Arc;
use epics_base_rs::server::database::PvDatabase;
use epics_ca_rs::server::{CaServer, ServerConnectionEvent};
use tokio::sync::Mutex;
use tokio::sync::broadcast;
use crate::error::BridgeResult;
pub type SeqConnEvent = (u64, ServerConnectionEvent);
const REPLAY_LOG_CAPACITY: usize = 4096;
const REPLAY_CHANNEL_CAPACITY: usize = 1024;
pub struct ConnEventReplay {
log: parking_lot::Mutex<VecDeque<SeqConnEvent>>,
high_water: std::sync::atomic::AtomicU64,
}
impl ConnEventReplay {
fn new() -> Self {
Self {
log: parking_lot::Mutex::new(VecDeque::with_capacity(REPLAY_LOG_CAPACITY)),
high_water: std::sync::atomic::AtomicU64::new(0),
}
}
fn record(&self, ev: SeqConnEvent) {
let seq = ev.0;
let mut log = self.log.lock();
if log.len() == REPLAY_LOG_CAPACITY {
log.pop_front();
}
log.push_back(ev);
drop(log);
self.advance_high_water(seq);
}
fn advance_high_water(&self, seq: u64) {
self.high_water
.fetch_max(seq, std::sync::atomic::Ordering::SeqCst);
}
pub fn high_water(&self) -> u64 {
self.high_water.load(std::sync::atomic::Ordering::SeqCst)
}
pub fn events_since(&self, after: u64) -> (Vec<SeqConnEvent>, bool) {
let log = self.log.lock();
let oldest = log.front().map(|(s, _)| *s);
let truncated = match oldest {
Some(o) => o > after.saturating_add(1),
None => false,
};
let missed = log
.iter()
.filter(|(seq, _)| *seq > after)
.cloned()
.collect();
(missed, truncated)
}
}
pub struct ReplayingReceiver {
rx: broadcast::Receiver<SeqConnEvent>,
replay: Arc<ConnEventReplay>,
last_seq: u64,
pending: VecDeque<ServerConnectionEvent>,
}
#[derive(Debug)]
pub enum ConnEventRecv {
Event(ServerConnectionEvent),
Closed,
GapTruncated { missed: u64 },
}
impl ReplayingReceiver {
pub async fn recv(&mut self) -> ConnEventRecv {
loop {
if let Some(ev) = self.pending.pop_front() {
return ConnEventRecv::Event(ev);
}
match self.rx.recv().await {
Ok((seq, ev)) => {
if seq <= self.last_seq {
continue;
}
self.last_seq = seq;
return ConnEventRecv::Event(ev);
}
Err(broadcast::error::RecvError::Lagged(_)) => {
let (missed, truncated) = self.replay.events_since(self.last_seq);
if truncated {
let recovered_lo =
missed.first().map(|(s, _)| *s).unwrap_or(self.last_seq + 1);
let lost = recovered_lo.saturating_sub(self.last_seq + 1);
for (seq, ev) in missed {
self.last_seq = seq;
self.pending.push_back(ev);
}
return ConnEventRecv::GapTruncated { missed: lost };
}
for (seq, ev) in missed {
self.last_seq = seq;
self.pending.push_back(ev);
}
continue;
}
Err(broadcast::error::RecvError::Closed) => {
return ConnEventRecv::Closed;
}
}
}
}
}
pub struct DownstreamServer {
server: Mutex<Option<CaServer>>,
shadow_db: Arc<PvDatabase>,
replay_state: Mutex<Option<ReplayState>>,
}
struct ReplayState {
tx: broadcast::Sender<SeqConnEvent>,
replay: Arc<ConnEventReplay>,
forwarder: tokio::task::JoinHandle<()>,
}
impl DownstreamServer {
pub fn new(shadow_db: Arc<PvDatabase>, port: u16) -> Self {
let server = CaServer::from_parts(shadow_db.clone(), port, None, None, None, None);
Self {
server: Mutex::new(Some(server)),
shadow_db,
replay_state: Mutex::new(None),
}
}
#[cfg(feature = "ca-gateway-tls")]
pub fn new_with_tls(
shadow_db: Arc<PvDatabase>,
port: u16,
tls: std::sync::Arc<epics_ca_rs::tls::ServerConfig>,
) -> Self {
let mut server = CaServer::from_parts(shadow_db.clone(), port, None, None, None, None);
server.set_tls(tls);
Self {
server: Mutex::new(Some(server)),
shadow_db,
replay_state: Mutex::new(None),
}
}
pub fn database(&self) -> &Arc<PvDatabase> {
&self.shadow_db
}
pub async fn connection_events(&self) -> Option<ReplayingReceiver> {
let mut replay_guard = self.replay_state.lock().await;
if replay_guard.is_none() {
let raw_rx = {
let mut server_guard = self.server.lock().await;
match server_guard.as_mut() {
Some(s) => s.connection_events(),
None => return None,
}
};
let (tx, _) = broadcast::channel::<SeqConnEvent>(REPLAY_CHANNEL_CAPACITY);
let replay = Arc::new(ConnEventReplay::new());
let forwarder = spawn_conn_event_forwarder(raw_rx, tx.clone(), replay.clone());
*replay_guard = Some(ReplayState {
tx,
replay,
forwarder,
});
}
let state = replay_guard.as_ref().expect("just initialised");
Some(ReplayingReceiver {
rx: state.tx.subscribe(),
replay: state.replay.clone(),
last_seq: state.replay.high_water(),
pending: VecDeque::new(),
})
}
pub async fn stop_connection_events(&self) {
if let Some(state) = self.replay_state.lock().await.take() {
state.forwarder.abort();
}
}
pub async fn beacon_anomaly_handle(&self) -> Option<Arc<tokio::sync::Notify>> {
let guard = self.server.lock().await;
guard.as_ref().map(|s| s.beacon_anomaly_handle())
}
pub async fn run(&self) -> BridgeResult<()> {
let server = {
let mut guard = self.server.lock().await;
match guard.take() {
Some(s) => s,
None => {
return Err(crate::error::BridgeError::PutRejected(
"DownstreamServer already running or consumed".into(),
));
}
}
};
server
.run()
.await
.map_err(|e| crate::error::BridgeError::PutRejected(format!("CaServer run: {e}")))
}
pub async fn reinstall(&self, server: CaServer) -> Option<CaServer> {
let mut guard = self.server.lock().await;
let prev = guard.take();
*guard = Some(server);
prev
}
}
fn spawn_conn_event_forwarder(
mut raw_rx: broadcast::Receiver<ServerConnectionEvent>,
tx: broadcast::Sender<SeqConnEvent>,
replay: Arc<ConnEventReplay>,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let mut seq: u64 = 0;
loop {
match raw_rx.recv().await {
Ok(ev) => {
seq += 1;
let sequenced = (seq, ev);
replay.record(sequenced.clone());
let _ = tx.send(sequenced);
}
Err(broadcast::error::RecvError::Lagged(n)) => {
seq += n;
replay.advance_high_water(seq);
tracing::warn!(
missed = n,
"ca-gateway-rs: raw connection-event broadcast lagged at \
the forwarder — these events cannot be replayed"
);
}
Err(broadcast::error::RecvError::Closed) => {
break;
}
}
}
})
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn construct_downstream() {
let db = Arc::new(PvDatabase::new());
let downstream = DownstreamServer::new(db.clone(), 0);
assert!(Arc::ptr_eq(downstream.database(), &db));
}
#[tokio::test]
async fn connection_events_subscribe() {
let db = Arc::new(PvDatabase::new());
let downstream = DownstreamServer::new(db, 0);
let rx = downstream.connection_events().await;
assert!(rx.is_some(), "expected receiver");
downstream.stop_connection_events().await;
}
fn ev(pv: &str, cid: u32) -> ServerConnectionEvent {
ServerConnectionEvent::ChannelCreated {
peer: "127.0.0.1:5064".parse().unwrap(),
pv_name: pv.to_string(),
cid,
}
}
#[test]
fn replay_log_records_and_queries() {
let log = ConnEventReplay::new();
for i in 1..=5u64 {
log.record((i, ev("PV", i as u32)));
}
let (all, truncated) = log.events_since(0);
assert_eq!(all.len(), 5);
assert!(!truncated);
let (tail, truncated) = log.events_since(3);
assert_eq!(tail.len(), 2);
assert_eq!(tail[0].0, 4);
assert_eq!(tail[1].0, 5);
assert!(!truncated);
let (none, _) = log.events_since(5);
assert!(none.is_empty());
}
#[test]
fn replay_log_is_bounded_and_reports_truncation() {
let log = ConnEventReplay::new();
let total = (REPLAY_LOG_CAPACITY + 100) as u64;
for i in 1..=total {
log.record((i, ev("PV", (i % 1000) as u32)));
}
let (recovered, truncated) = log.events_since(0);
assert_eq!(recovered.len(), REPLAY_LOG_CAPACITY);
assert!(truncated, "gap past ring capacity must report truncation");
assert_eq!(recovered[0].0, total - REPLAY_LOG_CAPACITY as u64 + 1);
let (tail, truncated) = log.events_since(total - 10);
assert_eq!(tail.len(), 10);
assert!(!truncated);
}
#[tokio::test]
async fn replaying_receiver_delivers_events_in_order() {
let replay = Arc::new(ConnEventReplay::new());
let (tx, rx) = broadcast::channel::<SeqConnEvent>(REPLAY_CHANNEL_CAPACITY);
let mut recv = ReplayingReceiver {
rx,
replay: replay.clone(),
last_seq: 0,
pending: VecDeque::new(),
};
for i in 1..=3u64 {
let sequenced = (i, ev("PV", i as u32));
replay.record(sequenced.clone());
tx.send(sequenced).unwrap();
}
for i in 1..=3u32 {
match recv.recv().await {
ConnEventRecv::Event(ServerConnectionEvent::ChannelCreated { cid, .. }) => {
assert_eq!(cid, i);
}
other => panic!("expected event {i}, got {other:?}"),
}
}
drop(tx);
assert!(matches!(recv.recv().await, ConnEventRecv::Closed));
}
#[tokio::test]
async fn replaying_receiver_recovers_lagged_events() {
let replay = Arc::new(ConnEventReplay::new());
let (tx, rx) = broadcast::channel::<SeqConnEvent>(4);
let mut recv = ReplayingReceiver {
rx,
replay: replay.clone(),
last_seq: 0,
pending: VecDeque::new(),
};
const N: u32 = 20;
for i in 1..=N {
let sequenced = (i as u64, ev("PV", i));
replay.record(sequenced.clone());
let _ = tx.send(sequenced);
}
let mut seen = Vec::new();
for _ in 0..N {
match recv.recv().await {
ConnEventRecv::Event(ServerConnectionEvent::ChannelCreated { cid, .. }) => {
seen.push(cid)
}
ConnEventRecv::GapTruncated { missed } => {
panic!("unexpected truncation, missed={missed}")
}
other => panic!("unexpected recv outcome: {other:?}"),
}
}
assert_eq!(
seen,
(1..=N).collect::<Vec<_>>(),
"lagged events not fully replayed"
);
drop(tx);
assert!(matches!(recv.recv().await, ConnEventRecv::Closed));
}
#[tokio::test]
async fn late_subscriber_does_not_replay_pre_subscription_backlog() {
let replay = Arc::new(ConnEventReplay::new());
let (tx, _keepalive) = broadcast::channel::<SeqConnEvent>(4);
const PRE: u32 = 10;
for i in 1..=PRE {
replay.record((i as u64, ev("OLD", i)));
}
assert_eq!(replay.high_water(), PRE as u64);
let mut recv = ReplayingReceiver {
rx: tx.subscribe(),
replay: replay.clone(),
last_seq: replay.high_water(),
pending: VecDeque::new(),
};
const NEW: u32 = 20;
for i in (PRE + 1)..=(PRE + NEW) {
let sequenced = (i as u64, ev("NEW", i));
replay.record(sequenced.clone());
let _ = tx.send(sequenced);
}
let mut seen = Vec::new();
for _ in 0..NEW {
match recv.recv().await {
ConnEventRecv::Event(ServerConnectionEvent::ChannelCreated {
cid,
pv_name,
..
}) => {
assert_eq!(
pv_name, "NEW",
"must not replay pre-subscription `OLD` events"
);
seen.push(cid);
}
ConnEventRecv::GapTruncated { missed } => {
panic!("unexpected truncation, missed={missed}")
}
other => panic!("unexpected recv outcome: {other:?}"),
}
}
assert_eq!(
seen,
((PRE + 1)..=(PRE + NEW)).collect::<Vec<_>>(),
"late subscriber must replay only post-subscription events"
);
drop(tx);
assert!(matches!(recv.recv().await, ConnEventRecv::Closed));
}
#[tokio::test]
async fn connection_events_seeds_late_subscriber_from_high_water() {
let db = Arc::new(PvDatabase::new());
let downstream = DownstreamServer::new(db, 0);
let _first = downstream
.connection_events()
.await
.expect("first receiver");
{
let guard = downstream.replay_state.lock().await;
let state = guard.as_ref().expect("replay state installed");
for i in 1..=7u64 {
state.replay.record((i, ev("PV", i as u32)));
}
}
let late = downstream.connection_events().await.expect("late receiver");
assert_eq!(
late.last_seq, 7,
"late subscriber must be seeded from the forwarder high-water mark"
);
downstream.stop_connection_events().await;
}
#[tokio::test]
async fn forwarder_sequences_and_replays() {
const N: u32 = 30;
let (raw_tx, raw_rx) = broadcast::channel::<ServerConnectionEvent>(N as usize);
let (tx, _keepalive) = broadcast::channel::<SeqConnEvent>(4);
let replay = Arc::new(ConnEventReplay::new());
let forwarder = spawn_conn_event_forwarder(raw_rx, tx.clone(), replay.clone());
let mut recv = ReplayingReceiver {
rx: tx.subscribe(),
replay: replay.clone(),
last_seq: 0,
pending: VecDeque::new(),
};
for i in 1..=N {
raw_tx.send(ev("PV", i)).unwrap();
}
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let mut seen = Vec::new();
for _ in 0..N {
match recv.recv().await {
ConnEventRecv::Event(ServerConnectionEvent::ChannelCreated { cid, .. }) => {
seen.push(cid)
}
other => panic!("unexpected recv outcome: {other:?}"),
}
}
assert_eq!(seen, (1..=N).collect::<Vec<_>>());
forwarder.abort();
}
}