use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
pub mod flags {
pub const NONE: u64 = 0;
pub const FULL_SNAPSHOT: u64 = 1 << 0;
pub const DRAIN: u64 = 1 << 1;
pub const CANCEL: u64 = 1 << 2;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(C)]
pub struct CheckpointBarrier {
pub checkpoint_id: u64,
pub epoch: u64,
pub flags: u64,
}
const _: () = assert!(std::mem::size_of::<CheckpointBarrier>() == 24);
impl CheckpointBarrier {
#[must_use]
pub const fn new(checkpoint_id: u64, epoch: u64) -> Self {
Self {
checkpoint_id,
epoch,
flags: flags::NONE,
}
}
#[must_use]
pub const fn full_snapshot(checkpoint_id: u64, epoch: u64) -> Self {
Self {
checkpoint_id,
epoch,
flags: flags::FULL_SNAPSHOT,
}
}
#[must_use]
pub const fn is_full_snapshot(&self) -> bool {
self.flags & flags::FULL_SNAPSHOT != 0
}
#[must_use]
pub const fn is_drain(&self) -> bool {
self.flags & flags::DRAIN != 0
}
#[must_use]
pub const fn is_cancel(&self) -> bool {
self.flags & flags::CANCEL != 0
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum StreamMessage<T> {
Event(T),
Watermark(i64),
Barrier(CheckpointBarrier),
}
impl<T> StreamMessage<T> {
#[must_use]
pub const fn is_barrier(&self) -> bool {
matches!(self, Self::Barrier(_))
}
#[must_use]
pub const fn is_watermark(&self) -> bool {
matches!(self, Self::Watermark(_))
}
#[must_use]
pub const fn is_event(&self) -> bool {
matches!(self, Self::Event(_))
}
#[must_use]
pub const fn as_barrier(&self) -> Option<&CheckpointBarrier> {
match self {
Self::Barrier(b) => Some(b),
_ => None,
}
}
}
#[derive(Debug)]
pub struct CheckpointBarrierInjector {
cmd: Arc<AtomicU64>,
epoch: Arc<AtomicU64>,
}
impl CheckpointBarrierInjector {
#[must_use]
pub fn new() -> Self {
Self {
cmd: Arc::new(AtomicU64::new(0)),
epoch: Arc::new(AtomicU64::new(0)),
}
}
#[must_use]
pub fn handle(&self) -> BarrierPollHandle {
BarrierPollHandle {
cmd: Arc::clone(&self.cmd),
}
}
#[allow(clippy::cast_possible_truncation)]
pub fn trigger(&self, checkpoint_id: u64, barrier_flags: u64) {
debug_assert!(
u32::try_from(checkpoint_id).is_ok(),
"checkpoint_id {checkpoint_id} exceeds u32::MAX"
);
debug_assert!(
u32::try_from(barrier_flags).is_ok(),
"barrier_flags {barrier_flags:#x} exceeds u32::MAX"
);
let packed = (u64::from(checkpoint_id as u32) << 32) | u64::from(barrier_flags as u32);
self.cmd.store(packed, Ordering::Release);
self.epoch.fetch_add(1, Ordering::Relaxed);
}
#[must_use]
pub fn epoch(&self) -> u64 {
self.epoch.load(Ordering::Relaxed)
}
}
impl Default for CheckpointBarrierInjector {
fn default() -> Self {
Self::new()
}
}
impl Clone for CheckpointBarrierInjector {
fn clone(&self) -> Self {
Self {
cmd: Arc::clone(&self.cmd),
epoch: Arc::clone(&self.epoch),
}
}
}
#[derive(Debug, Clone)]
pub struct BarrierPollHandle {
cmd: Arc<AtomicU64>,
}
impl BarrierPollHandle {
#[must_use]
pub fn poll(&self, epoch: u64) -> Option<CheckpointBarrier> {
let packed = self.cmd.load(Ordering::Relaxed);
if packed == 0 {
return None;
}
if self
.cmd
.compare_exchange(packed, 0, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
Some(CheckpointBarrier {
checkpoint_id: packed >> 32,
epoch,
flags: packed & 0xFFFF_FFFF,
})
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_barrier_size() {
assert_eq!(std::mem::size_of::<CheckpointBarrier>(), 24);
}
#[test]
fn test_barrier_flags() {
let barrier = CheckpointBarrier::new(1, 1);
assert!(!barrier.is_full_snapshot());
assert!(!barrier.is_drain());
assert!(!barrier.is_cancel());
let full = CheckpointBarrier::full_snapshot(1, 1);
assert!(full.is_full_snapshot());
assert!(!full.is_drain());
let drain = CheckpointBarrier {
checkpoint_id: 1,
epoch: 1,
flags: flags::DRAIN,
};
assert!(drain.is_drain());
}
#[test]
fn test_barrier_roundtrip_via_injector() {
let injector = CheckpointBarrierInjector::new();
let handle = injector.handle();
injector.trigger(42, flags::DRAIN);
let barrier = handle.poll(0).expect("barrier should be available");
assert_eq!(barrier.checkpoint_id, 42);
assert_eq!(barrier.flags, flags::DRAIN);
assert!(handle.poll(1).is_none(), "cleared after one poll");
}
#[test]
fn test_stream_message_variants() {
let event: StreamMessage<String> = StreamMessage::Event("hello".into());
assert!(event.is_event());
assert!(!event.is_barrier());
assert!(!event.is_watermark());
let watermark: StreamMessage<String> = StreamMessage::Watermark(1000);
assert!(watermark.is_watermark());
let barrier: StreamMessage<String> = StreamMessage::Barrier(CheckpointBarrier::new(1, 1));
assert!(barrier.is_barrier());
assert_eq!(barrier.as_barrier().unwrap().checkpoint_id, 1);
}
#[test]
fn test_injector_poll_no_barrier() {
let injector = CheckpointBarrierInjector::new();
let handle = injector.handle();
assert!(handle.poll(0).is_none());
}
#[test]
fn test_injector_trigger_and_poll() {
let injector = CheckpointBarrierInjector::new();
let handle = injector.handle();
injector.trigger(42, flags::FULL_SNAPSHOT);
assert_eq!(injector.epoch(), 1);
let barrier = handle.poll(1).unwrap();
assert_eq!(barrier.checkpoint_id, 42);
assert_eq!(barrier.epoch, 1);
assert!(barrier.is_full_snapshot());
assert!(handle.poll(1).is_none());
}
#[test]
fn test_injector_multiple_handles() {
let injector = CheckpointBarrierInjector::new();
let handle1 = injector.handle();
let handle2 = injector.handle();
injector.trigger(1, flags::NONE);
let r1 = handle1.poll(1);
let r2 = handle2.poll(1);
assert!(r1.is_some() || r2.is_some());
if r1.is_some() {
assert!(r2.is_none());
}
}
}