use alloc::vec::Vec;
use core::cmp::Ordering;
use bytes::BufMut;
pub trait BroadcastHandler<T> {
type Key: Invalidates;
type Error: core::error::Error + Send + Sync + 'static;
fn receive_item(
&mut self,
data: &[u8],
sender: Option<&T>,
) -> Result<Option<Self::Key>, Self::Error>;
fn should_add_broadcast_data(&self, _member: &T) -> bool {
true
}
}
pub trait Invalidates {
fn invalidates(&self, other: &Self) -> bool;
}
impl Invalidates for &[u8] {
fn invalidates(&self, other: &Self) -> bool {
self.eq(other)
}
}
#[allow(dead_code)]
pub(crate) struct Broadcasts<V> {
flip: alloc::collections::BinaryHeap<Entry<V>>,
flop: alloc::collections::BinaryHeap<Entry<V>>,
}
impl<T> Broadcasts<T>
where
T: Invalidates,
{
pub(crate) fn new() -> Self {
Self {
flip: Default::default(),
flop: Default::default(),
}
}
pub(crate) fn len(&self) -> usize {
self.flip.len()
}
pub(crate) fn is_empty(&self) -> bool {
self.flip.is_empty()
}
pub(crate) fn add_or_replace(&mut self, item: T, data: Vec<u8>, max_tx: usize) {
debug_assert!(max_tx > 0);
self.flip.retain(|node| !item.invalidates(&node.item));
self.flip.push(Entry {
remaining_tx: max_tx,
item,
data,
});
}
pub(crate) fn fill(&mut self, mut buffer: impl BufMut, max_items: usize) -> usize {
if self.flip.is_empty() {
return 0;
}
debug_assert!(self.flop.is_empty());
let mut num_taken = 0;
let mut remaining = max_items;
while buffer.has_remaining_mut() && remaining > 0 {
let Some(mut node) = self.flip.pop() else {
break;
};
debug_assert!(node.remaining_tx > 0);
if buffer.remaining_mut() >= node.data.len() {
num_taken += 1;
remaining -= 1;
buffer.put_slice(&node.data);
node.remaining_tx -= 1;
}
if node.remaining_tx > 0 {
self.flop.push(node);
}
}
self.flip.append(&mut self.flop);
num_taken
}
pub(crate) fn fill_with_len_prefix(
&mut self,
mut buffer: impl BufMut,
max_items: usize,
) -> usize {
if self.flip.is_empty() {
return 0;
}
debug_assert!(self.flop.is_empty());
let mut num_taken = 0;
let mut remaining = max_items;
while buffer.has_remaining_mut() && remaining > 0 {
let Some(mut node) = self.flip.pop() else {
break;
};
debug_assert!(node.remaining_tx > 0);
if buffer.remaining_mut() >= node.data.len() + 2 {
num_taken += 1;
remaining -= 1;
debug_assert!(u16::try_from(node.data.len()).is_ok());
buffer.put_u16(node.data.len() as u16);
buffer.put_slice(&node.data);
node.remaining_tx -= 1;
}
if node.remaining_tx > 0 {
self.flop.push(node);
}
}
self.flip.append(&mut self.flop);
num_taken
}
}
#[derive(Debug, Clone)]
struct Entry<T> {
remaining_tx: usize,
data: Vec<u8>,
item: T,
}
impl<T> PartialEq for Entry<T> {
fn eq(&self, other: &Self) -> bool {
self.cmp(other).is_eq()
}
}
impl<T> Eq for Entry<T> {}
impl<T> PartialOrd for Entry<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<T> Ord for Entry<T> {
fn cmp(&self, other: &Self) -> Ordering {
self.remaining_tx
.cmp(&other.remaining_tx)
.then_with(|| self.data.len().cmp(&other.data.len()))
}
}
#[cfg(test)]
mod tests {
use super::*;
struct Key(&'static str);
impl Invalidates for Key {
fn invalidates(&self, other: &Self) -> bool {
self.0 == other.0
}
}
#[test]
fn piggyback_behaviour() {
let max_tx = 5;
let mut piggyback = Broadcasts::new();
assert!(piggyback.is_empty(), "Piggyback starts empty");
piggyback.add_or_replace(Key("AA"), b"AAabc".to_vec(), max_tx);
assert_eq!(1, piggyback.len());
piggyback.add_or_replace(Key("AA"), b"AAcba".to_vec(), max_tx);
assert_eq!(
1,
piggyback.len(),
"add_or_replace with same key should replace"
);
let mut buf = Vec::new();
for _i in 0..max_tx {
buf.clear();
let num_items = piggyback.fill(&mut buf, usize::MAX);
assert_eq!(1, num_items);
assert_eq!(
b"AAcba",
&buf[..],
"Should transmit an item at most max_tx times"
);
}
assert!(
piggyback.is_empty(),
"Should remove item after being used max_tx times"
);
}
#[test]
fn fill_does_nothing_if_buffer_full() {
let mut piggyback = Broadcasts::new();
piggyback.add_or_replace(Key("a "), b"a super long value".to_vec(), 1);
let buf = bytes::BytesMut::new();
let mut limited = buf.limit(5);
let num_items = piggyback.fill(&mut limited, usize::MAX);
assert_eq!(0, num_items);
assert_eq!(5, limited.remaining_mut());
assert_eq!(1, piggyback.len());
}
#[test]
fn piggyback_consumes_largest_first() {
let max_tx = 10;
let mut piggyback = Broadcasts::new();
piggyback.add_or_replace(Key("00"), b"00hi".to_vec(), max_tx);
piggyback.add_or_replace(Key("01"), b"01hello".to_vec(), max_tx);
piggyback.add_or_replace(Key("02"), b"02hey".to_vec(), max_tx);
let mut buf = Vec::new();
let num_items = piggyback.fill(&mut buf, usize::MAX);
assert_eq!(3, num_items);
assert_eq!(b"01hello02hey00hi", &buf[..]);
}
#[test]
fn highest_max_tx_is_consumed_first() {
let mut piggyback = Broadcasts::new();
piggyback.add_or_replace(Key("10"), b"100".to_vec(), 1);
piggyback.add_or_replace(Key("20"), b"200".to_vec(), 2);
piggyback.add_or_replace(Key("30"), b"300".to_vec(), 3);
let mut buf = Vec::new();
piggyback.fill(&mut buf, usize::MAX);
assert_eq!(b"300200100", &buf[..]);
buf.clear();
piggyback.fill(&mut buf, usize::MAX);
assert_eq!(b"300200", &buf[..]);
buf.clear();
piggyback.fill(&mut buf, usize::MAX);
assert_eq!(b"300", &buf[..]);
assert_eq!(0, piggyback.len());
}
#[test]
fn piggyback_respects_limit() {
let max_tx = 10;
let mut piggyback = Broadcasts::new();
piggyback.add_or_replace(Key("fo"), b"foo".to_vec(), max_tx);
piggyback.add_or_replace(Key("ba"), b"bar".to_vec(), max_tx);
piggyback.add_or_replace(Key("ba"), b"baz".to_vec(), max_tx);
let mut buf = Vec::new();
let num_items = piggyback.fill(&mut buf, 0);
assert_eq!(0, num_items);
assert!(buf.is_empty());
let num_items = piggyback.fill(&mut buf, 2);
assert_eq!(2, num_items);
}
#[test]
fn fill_with_len_prefix() {
let mut bcs = Broadcasts::new();
bcs.add_or_replace(Key("fo"), b"foo".to_vec(), 10);
bcs.add_or_replace(Key("ba"), b"barr".to_vec(), 10);
bcs.add_or_replace(Key("ba"), b"bazz".to_vec(), 10);
let mut buf = Vec::new();
let num_items = bcs.fill_with_len_prefix(&mut buf, 0);
assert_eq!(0, num_items);
assert!(buf.is_empty());
let num_items = bcs.fill_with_len_prefix(&mut buf, 2);
assert_eq!(2, num_items);
use bytes::Buf;
let mut buf = &buf[..];
assert_eq!(4, buf.get_u16());
assert_eq!(&b"bazz"[..], &buf[..4]);
buf.advance(4);
assert_eq!(3, buf.get_u16());
assert_eq!(&b"foo"[..], &buf[..3]);
buf.advance(3);
assert!(buf.is_empty());
}
}