use std::collections::HashMap;
use std::pin::Pin;
use std::sync::atomic::AtomicU64;
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use bytes::Bytes;
use futures_util::stream::{Stream, unfold};
use futures_util::{StreamExt, stream};
use tokio::sync::{broadcast, mpsc};
use crate::error::MetadataLogError;
#[derive(Debug, Clone)]
pub struct MetadataEventRecord {
pub partition: i32,
pub offset: i64,
pub payload: Bytes,
}
pub type MetadataEventStream = Pin<Box<dyn Stream<Item = MetadataEventRecord> + Send + 'static>>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PartitionStart {
pub partition: i32,
pub start_offset: i64,
}
pub trait AssignmentHandle: Send + Sync {
fn add(&self, start: PartitionStart);
fn remove(&self, partition: i32);
fn assigned(&self) -> Vec<i32>;
}
#[async_trait]
pub trait MetadataEventLog: Send + Sync {
fn partition_count(&self) -> i32;
async fn publish(&self, partition: i32, event: Bytes) -> Result<i64, MetadataLogError>;
fn subscribe(
&self,
assignment: Vec<PartitionStart>,
) -> (MetadataEventStream, Arc<dyn AssignmentHandle>);
async fn high_water_marks(&self) -> Result<Vec<i64>, MetadataLogError>;
}
pub struct InProcessMetadataEventLog {
inner: Arc<InProcessInner>,
}
#[derive(Debug, Clone, Copy)]
struct PartitionCursor {
next: i64,
via_inject: bool,
}
struct SubscriptionState {
assigned: Mutex<HashMap<i32, PartitionCursor>>,
inject: mpsc::UnboundedSender<MetadataEventRecord>,
}
struct InProcessInner {
log: Mutex<Vec<Vec<Bytes>>>,
tx: broadcast::Sender<MetadataEventRecord>,
partition_count: i32,
subscriptions: Mutex<HashMap<u64, Arc<SubscriptionState>>>,
next_sub_id: AtomicU64,
}
impl InProcessMetadataEventLog {
#[must_use]
pub fn new(partition_count: i32) -> Arc<Self> {
assert!(partition_count > 0, "partition_count must be positive");
let cap = usize::try_from(partition_count).expect("partition_count fits in usize");
let (tx, _rx) = broadcast::channel(1024);
Arc::new(Self {
inner: Arc::new(InProcessInner {
log: Mutex::new(vec![Vec::new(); cap]),
tx,
partition_count,
subscriptions: Mutex::new(HashMap::new()),
next_sub_id: AtomicU64::new(0),
}),
})
}
}
#[async_trait]
impl MetadataEventLog for InProcessMetadataEventLog {
fn partition_count(&self) -> i32 {
self.inner.partition_count
}
async fn publish(&self, partition: i32, event: Bytes) -> Result<i64, MetadataLogError> {
if partition < 0 || partition >= self.inner.partition_count {
return Err(MetadataLogError::PartitionOutOfRange {
partition,
count: self.inner.partition_count,
});
}
let mut guard = self.inner.log.lock().expect("metadata-log mutex poisoned");
let idx = usize::try_from(partition).expect("partition non-negative");
let log_for_p = &mut guard[idx];
let offset = i64::try_from(log_for_p.len()).expect("offset fits in i64");
log_for_p.push(event.clone());
let record = MetadataEventRecord {
partition,
offset,
payload: event,
};
let _ = self.inner.tx.send(record);
Ok(offset)
}
fn subscribe(
&self,
assignment: Vec<PartitionStart>,
) -> (MetadataEventStream, Arc<dyn AssignmentHandle>) {
use std::sync::atomic::Ordering;
let guard = self.inner.log.lock().expect("metadata-log mutex poisoned");
let rx = self.inner.tx.subscribe();
let mut assigned: HashMap<i32, PartitionCursor> = HashMap::new();
let mut snapshot: Vec<MetadataEventRecord> = Vec::new();
for ps in &assignment {
let Ok(idx) = usize::try_from(ps.partition) else {
continue;
};
if idx >= guard.len() {
continue;
}
let records = &guard[idx];
let begin = usize::try_from(ps.start_offset.max(0)).unwrap_or(usize::MAX);
for (offset, payload) in records.iter().enumerate().skip(begin) {
snapshot.push(MetadataEventRecord {
partition: ps.partition,
offset: i64::try_from(offset).expect("offset fits in i64"),
payload: payload.clone(),
});
}
assigned.insert(
ps.partition,
PartitionCursor {
next: i64::try_from(records.len()).expect("len fits in i64"),
via_inject: false,
},
);
}
let (inject_tx, inject_rx) = mpsc::unbounded_channel::<MetadataEventRecord>();
let state = Arc::new(SubscriptionState {
assigned: Mutex::new(assigned),
inject: inject_tx,
});
let sub_id = self.inner.next_sub_id.fetch_add(1, Ordering::Relaxed);
self.inner
.subscriptions
.lock()
.expect("metadata-log subscriptions mutex poisoned")
.insert(sub_id, state.clone());
drop(guard);
let snapshot_stream = stream::iter(snapshot);
let inject_stream = unfold(inject_rx, |mut rx| async move {
rx.recv().await.map(|r| (r, rx))
});
let live = filtered_broadcast(rx, state.clone());
let merged = stream::select(inject_stream, live);
let stream = snapshot_stream.chain(merged).boxed();
let handle: Arc<dyn AssignmentHandle> = Arc::new(InProcessAssignmentHandle {
inner: self.inner.clone(),
sub_id,
});
(stream, handle)
}
async fn high_water_marks(&self) -> Result<Vec<i64>, MetadataLogError> {
let guard = self.inner.log.lock().expect("metadata-log mutex poisoned");
Ok(guard
.iter()
.map(|v| i64::try_from(v.len()).expect("hwm fits in i64"))
.collect())
}
}
struct InProcessAssignmentHandle {
inner: Arc<InProcessInner>,
sub_id: u64,
}
impl Drop for InProcessAssignmentHandle {
fn drop(&mut self) {
if let Ok(mut subs) = self.inner.subscriptions.lock() {
subs.remove(&self.sub_id);
}
}
}
impl AssignmentHandle for InProcessAssignmentHandle {
fn add(&self, start: PartitionStart) {
let subs = self
.inner
.subscriptions
.lock()
.expect("metadata-log subscriptions mutex poisoned");
let Some(state) = subs.get(&self.sub_id).cloned() else {
return;
};
drop(subs);
let log = self.inner.log.lock().expect("metadata-log mutex poisoned");
let mut assigned = state.assigned.lock().expect("assigned mutex poisoned");
if assigned.contains_key(&start.partition) {
return; }
let idx = match usize::try_from(start.partition) {
Ok(i) if i < log.len() => i,
_ => return, };
let records = &log[idx];
let begin = usize::try_from(start.start_offset.max(0)).unwrap_or(usize::MAX);
for (offset, payload) in records.iter().enumerate().skip(begin) {
let _ = state.inject.send(MetadataEventRecord {
partition: start.partition,
offset: i64::try_from(offset).expect("offset fits in i64"),
payload: payload.clone(),
});
}
let next_live = i64::try_from(records.len()).expect("len fits in i64");
assigned.insert(
start.partition,
PartitionCursor {
next: next_live,
via_inject: true,
},
);
}
fn remove(&self, partition: i32) {
let subs = self
.inner
.subscriptions
.lock()
.expect("metadata-log subscriptions mutex poisoned");
if let Some(state) = subs.get(&self.sub_id) {
state
.assigned
.lock()
.expect("assigned mutex poisoned")
.remove(&partition);
}
}
fn assigned(&self) -> Vec<i32> {
let subs = self
.inner
.subscriptions
.lock()
.expect("metadata-log subscriptions mutex poisoned");
let Some(state) = subs.get(&self.sub_id) else {
return Vec::new();
};
let mut v: Vec<i32> = state
.assigned
.lock()
.expect("assigned mutex poisoned")
.keys()
.copied()
.collect();
v.sort_unstable();
v
}
}
enum Forward {
Emit,
Inject,
Drop,
}
fn filtered_broadcast(
rx: broadcast::Receiver<MetadataEventRecord>,
state: Arc<SubscriptionState>,
) -> MetadataEventStream {
unfold((rx, state), |(mut rx, state)| async move {
loop {
match rx.recv().await {
Ok(record) => {
let action = {
let assigned = state.assigned.lock().expect("assigned mutex poisoned");
match assigned.get(&record.partition) {
Some(cur) if record.offset >= cur.next => {
if cur.via_inject {
Forward::Inject
} else {
Forward::Emit
}
}
_ => Forward::Drop,
}
};
match action {
Forward::Emit => return Some((record, (rx, state))),
Forward::Inject => {
let _ = state.inject.send(record);
}
Forward::Drop => {}
}
}
Err(broadcast::error::RecvError::Lagged(_)) => {
}
Err(broadcast::error::RecvError::Closed) => return None,
}
}
})
.boxed()
}
#[cfg(test)]
mod tests {
use super::*;
use assert2::assert;
use futures_util::StreamExt;
#[tokio::test]
async fn publish_assigns_monotonic_offsets() {
let log = InProcessMetadataEventLog::new(2);
assert!(log.publish(0, Bytes::from_static(b"a")).await.unwrap() == 0);
assert!(log.publish(0, Bytes::from_static(b"b")).await.unwrap() == 1);
assert!(log.publish(1, Bytes::from_static(b"c")).await.unwrap() == 0);
let hwms = log.high_water_marks().await.unwrap();
assert!(hwms == vec![2, 1]);
}
#[tokio::test]
async fn subscribe_replays_history_then_forwards_new_writes() {
let log = InProcessMetadataEventLog::new(1);
log.publish(0, Bytes::from_static(b"a")).await.unwrap();
log.publish(0, Bytes::from_static(b"b")).await.unwrap();
let (mut stream, _h) = log.subscribe(vec![PartitionStart {
partition: 0,
start_offset: 0,
}]);
let a = stream.next().await.unwrap();
let b = stream.next().await.unwrap();
assert!(a.payload.as_ref() == b"a");
assert!(b.payload.as_ref() == b"b");
log.publish(0, Bytes::from_static(b"c")).await.unwrap();
let c = stream.next().await.unwrap();
assert!(c.payload.as_ref() == b"c");
assert!((c.partition, c.offset) == (0, 2));
}
#[tokio::test]
async fn subscribe_attached_after_history_still_sees_history() {
let log = InProcessMetadataEventLog::new(1);
for i in 0..5 {
log.publish(0, Bytes::copy_from_slice(&[i])).await.unwrap();
}
let (mut stream, _h) = log.subscribe(vec![PartitionStart {
partition: 0,
start_offset: 0,
}]);
for i in 0..5 {
let r = stream.next().await.unwrap();
assert!(r.payload.as_ref() == &[i]);
assert!(r.offset == i64::from(i));
}
}
#[tokio::test]
async fn publish_out_of_range_is_rejected() {
let log = InProcessMetadataEventLog::new(2);
let err = log.publish(5, Bytes::from_static(b"x")).await.unwrap_err();
assert!(matches!(err, MetadataLogError::PartitionOutOfRange { .. }));
}
#[tokio::test]
async fn two_subscribers_see_the_same_history() {
let log = InProcessMetadataEventLog::new(1);
log.publish(0, Bytes::from_static(b"a")).await.unwrap();
let (mut s1, _h1) = log.subscribe(vec![PartitionStart {
partition: 0,
start_offset: 0,
}]);
let (mut s2, _h2) = log.subscribe(vec![PartitionStart {
partition: 0,
start_offset: 0,
}]);
log.publish(0, Bytes::from_static(b"b")).await.unwrap();
for s in [&mut s1, &mut s2] {
assert!(s.next().await.unwrap().payload.as_ref() == b"a");
assert!(s.next().await.unwrap().payload.as_ref() == b"b");
}
}
#[tokio::test]
async fn subscribe_delivers_only_assigned_partitions_from_start_offset() {
let log = InProcessMetadataEventLog::new(3);
for p0 in [b"a".as_slice(), b"b", b"c"] {
log.publish(0, Bytes::copy_from_slice(p0)).await.unwrap();
}
for p1 in [b"x".as_slice(), b"y"] {
log.publish(1, Bytes::copy_from_slice(p1)).await.unwrap();
}
log.publish(2, Bytes::from_static(b"z")).await.unwrap();
let (mut stream, _h) = log.subscribe(vec![
PartitionStart {
partition: 0,
start_offset: 1,
},
PartitionStart {
partition: 1,
start_offset: 0,
},
]);
let mut got: Vec<(i32, i64, Vec<u8>)> = Vec::new();
for _ in 0..3 {
let r = stream.next().await.unwrap();
got.push((r.partition, r.offset, r.payload.to_vec()));
}
got.sort();
assert!(
got == vec![
(0, 1, b"b".to_vec()),
(0, 2, b"c".to_vec()),
(1, 0, b"x".to_vec()),
]
);
let r = stream.next().await.unwrap();
assert!((r.partition, r.offset, r.payload.as_ref()) == (1, 1, b"y".as_ref()));
}
#[tokio::test]
async fn live_appends_only_for_assigned_partitions() {
let log = InProcessMetadataEventLog::new(2);
let (mut stream, _h) = log.subscribe(vec![PartitionStart {
partition: 0,
start_offset: 0,
}]);
log.publish(1, Bytes::from_static(b"skip")).await.unwrap();
log.publish(0, Bytes::from_static(b"keep")).await.unwrap();
let r = stream.next().await.unwrap();
assert!((r.partition, r.payload.as_ref()) == (0, b"keep".as_ref()));
}
#[tokio::test]
async fn add_mid_stream_delivers_backlog_then_live() {
let log = InProcessMetadataEventLog::new(2);
for v in [b"old0".as_slice(), b"old1", b"old2"] {
log.publish(1, Bytes::copy_from_slice(v)).await.unwrap();
}
let (mut stream, handle) = log.subscribe(vec![PartitionStart {
partition: 0,
start_offset: 0,
}]);
handle.add(PartitionStart {
partition: 1,
start_offset: 0,
});
log.publish(1, Bytes::from_static(b"new")).await.unwrap();
let mut got = Vec::new();
for _ in 0..4 {
let r = stream.next().await.unwrap();
got.push((r.partition, r.offset, r.payload.to_vec()));
}
assert!(
got == vec![
(1, 0, b"old0".to_vec()),
(1, 1, b"old1".to_vec()),
(1, 2, b"old2".to_vec()),
(1, 3, b"new".to_vec()),
],
"backlog must drain fully (in offset order) before the live append"
);
assert!(handle.assigned().contains(&1));
}
#[tokio::test]
async fn dropping_handle_evicts_subscription_state() {
let log = InProcessMetadataEventLog::new(1);
let (_stream, handle) = log.subscribe(vec![PartitionStart {
partition: 0,
start_offset: 0,
}]);
assert!(log.inner.subscriptions.lock().unwrap().len() == 1);
drop(handle);
assert!(
log.inner.subscriptions.lock().unwrap().len() == 0,
"subscription state must be evicted when the handle drops"
);
}
#[tokio::test]
async fn remove_stops_delivery() {
let log = InProcessMetadataEventLog::new(2);
let (mut stream, handle) = log.subscribe(vec![
PartitionStart {
partition: 0,
start_offset: 0,
},
PartitionStart {
partition: 1,
start_offset: 0,
},
]);
handle.remove(1);
assert!(handle.assigned() == vec![0]);
log.publish(1, Bytes::from_static(b"gone")).await.unwrap();
log.publish(0, Bytes::from_static(b"here")).await.unwrap();
let r = stream.next().await.unwrap();
assert!((r.partition, r.payload.as_ref()) == (0, b"here".as_ref()));
}
}