use crate::{Error, Result};
use super::{Group, GroupConsumer, GroupProducer};
use std::{
collections::{HashSet, VecDeque},
task::{Poll, ready},
time::Duration,
};
const MAX_GROUP_AGE: Duration = Duration::from_secs(30);
#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Track {
pub name: String,
pub priority: u8,
}
impl Track {
pub fn new<T: Into<String>>(name: T) -> Self {
Self {
name: name.into(),
priority: 0,
}
}
pub fn produce(self) -> TrackProducer {
TrackProducer::new(self)
}
}
#[derive(Default)]
struct State {
groups: VecDeque<Option<(GroupProducer, tokio::time::Instant)>>,
duplicates: HashSet<u64>,
offset: usize,
max_sequence: Option<u64>,
final_sequence: Option<u64>,
abort: Option<Error>,
}
impl State {
fn poll_next_group(&self, index: usize, min_sequence: u64) -> Poll<Result<Option<(GroupConsumer, usize)>>> {
let start = index.saturating_sub(self.offset);
for (i, slot) in self.groups.iter().enumerate().skip(start) {
if let Some((group, _)) = slot
&& group.info.sequence >= min_sequence
{
return Poll::Ready(Ok(Some((group.consume(), self.offset + i))));
}
}
if self.final_sequence.is_some() {
Poll::Ready(Ok(None))
} else if let Some(err) = &self.abort {
Poll::Ready(Err(err.clone()))
} else {
Poll::Pending
}
}
fn poll_get_group(&self, sequence: u64) -> Poll<Result<Option<GroupConsumer>>> {
for (group, _) in self.groups.iter().flatten() {
if group.info.sequence == sequence {
return Poll::Ready(Ok(Some(group.consume())));
}
}
if let Some(fin) = self.final_sequence
&& sequence >= fin
{
return Poll::Ready(Ok(None));
}
if let Some(err) = &self.abort {
return Poll::Ready(Err(err.clone()));
}
Poll::Pending
}
fn poll_closed(&self) -> Poll<Result<()>> {
if self.final_sequence.is_some() {
Poll::Ready(Ok(()))
} else if let Some(err) = &self.abort {
Poll::Ready(Err(err.clone()))
} else {
Poll::Pending
}
}
fn evict_expired(&mut self, now: tokio::time::Instant) {
for slot in self.groups.iter_mut() {
let Some((group, created_at)) = slot else { continue };
if Some(group.info.sequence) == self.max_sequence {
continue;
}
if now.duration_since(*created_at) <= MAX_GROUP_AGE {
break;
}
self.duplicates.remove(&group.info.sequence);
*slot = None;
}
while let Some(None) = self.groups.front() {
self.groups.pop_front();
self.offset += 1;
}
}
fn poll_finished(&self) -> Poll<Result<u64>> {
if let Some(fin) = self.final_sequence {
Poll::Ready(Ok(fin))
} else if let Some(err) = &self.abort {
Poll::Ready(Err(err.clone()))
} else {
Poll::Pending
}
}
}
pub struct TrackProducer {
pub info: Track,
state: conducer::Producer<State>,
}
impl TrackProducer {
pub fn new(info: Track) -> Self {
Self {
info,
state: conducer::Producer::default(),
}
}
pub fn create_group(&mut self, info: Group) -> Result<GroupProducer> {
let group = info.produce();
let mut state = self.modify()?;
if let Some(fin) = state.final_sequence
&& group.info.sequence >= fin
{
return Err(Error::Closed);
}
if !state.duplicates.insert(group.info.sequence) {
return Err(Error::Duplicate);
}
let now = tokio::time::Instant::now();
state.max_sequence = Some(state.max_sequence.unwrap_or(0).max(group.info.sequence));
state.groups.push_back(Some((group.clone(), now)));
state.evict_expired(now);
Ok(group)
}
pub fn append_group(&mut self) -> Result<GroupProducer> {
let mut state = self.modify()?;
let sequence = match state.max_sequence {
Some(s) => s.checked_add(1).ok_or(Error::BoundsExceeded)?,
None => 0,
};
if let Some(fin) = state.final_sequence
&& sequence >= fin
{
return Err(Error::Closed);
}
let group = Group { sequence }.produce();
let now = tokio::time::Instant::now();
state.duplicates.insert(sequence);
state.max_sequence = Some(sequence);
state.groups.push_back(Some((group.clone(), now)));
state.evict_expired(now);
Ok(group)
}
pub fn write_frame<B: Into<bytes::Bytes>>(&mut self, frame: B) -> Result<()> {
let mut group = self.append_group()?;
group.write_frame(frame.into())?;
group.finish()?;
Ok(())
}
pub fn finish(&mut self) -> Result<()> {
let mut state = self.modify()?;
if state.final_sequence.is_some() {
return Err(Error::Closed);
}
state.final_sequence = Some(match state.max_sequence {
Some(max) => max.checked_add(1).ok_or(Error::BoundsExceeded)?,
None => 0,
});
Ok(())
}
#[deprecated(note = "use finish() or finish_at(sequence) instead")]
pub fn close(&mut self) -> Result<()> {
self.finish()
}
pub fn finish_at(&mut self, sequence: u64) -> Result<()> {
let mut state = self.modify()?;
let max = state.max_sequence.ok_or(Error::Closed)?;
if state.final_sequence.is_some() || sequence != max {
return Err(Error::Closed);
}
state.final_sequence = Some(max.checked_add(1).ok_or(Error::BoundsExceeded)?);
Ok(())
}
pub fn abort(&mut self, err: Error) -> Result<()> {
let mut guard = self.modify()?;
for (group, _) in guard.groups.iter_mut().flatten() {
group.abort(err.clone()).ok();
}
guard.abort = Some(err);
guard.close();
Ok(())
}
pub fn consume(&self) -> TrackConsumer {
TrackConsumer {
info: self.info.clone(),
state: self.state.consume(),
index: 0,
min_sequence: 0,
}
}
pub async fn unused(&self) -> Result<()> {
self.state
.unused()
.await
.map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
}
pub fn is_closed(&self) -> bool {
self.state.read().is_closed()
}
pub fn is_clone(&self, other: &Self) -> bool {
self.state.same_channel(&other.state)
}
pub(crate) fn weak(&self) -> TrackWeak {
TrackWeak {
info: self.info.clone(),
state: self.state.weak(),
}
}
fn modify(&self) -> Result<conducer::Mut<'_, State>> {
self.state
.write()
.map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
}
}
impl Clone for TrackProducer {
fn clone(&self) -> Self {
Self {
info: self.info.clone(),
state: self.state.clone(),
}
}
}
impl From<Track> for TrackProducer {
fn from(info: Track) -> Self {
TrackProducer::new(info)
}
}
#[derive(Clone)]
pub(crate) struct TrackWeak {
pub info: Track,
state: conducer::Weak<State>,
}
impl TrackWeak {
pub fn abort(&self, err: Error) {
let Ok(mut guard) = self.state.write() else { return };
for (group, _) in guard.groups.iter_mut().flatten() {
group.abort(err.clone()).ok();
}
guard.abort = Some(err);
guard.close();
}
pub fn is_closed(&self) -> bool {
self.state.is_closed()
}
pub fn consume(&self) -> TrackConsumer {
TrackConsumer {
info: self.info.clone(),
state: self.state.consume(),
index: 0,
min_sequence: 0,
}
}
pub async fn unused(&self) -> crate::Result<()> {
self.state
.unused()
.await
.map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
}
pub fn is_clone(&self, other: &Self) -> bool {
self.state.same_channel(&other.state)
}
}
#[derive(Clone)]
pub struct TrackConsumer {
pub info: Track,
state: conducer::Consumer<State>,
index: usize,
min_sequence: u64,
}
impl TrackConsumer {
fn poll<F, R>(&self, waiter: &conducer::Waiter, f: F) -> Poll<Result<R>>
where
F: Fn(&conducer::Ref<'_, State>) -> Poll<Result<R>>,
{
Poll::Ready(match ready!(self.state.poll(waiter, f)) {
Ok(res) => res,
Err(state) => Err(state.abort.clone().unwrap_or(Error::Dropped)),
})
}
pub fn poll_next_group(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Option<GroupConsumer>>> {
let Some((consumer, found_index)) =
ready!(self.poll(waiter, |state| state.poll_next_group(self.index, self.min_sequence))?)
else {
return Poll::Ready(Ok(None));
};
self.index = found_index + 1;
Poll::Ready(Ok(Some(consumer)))
}
pub async fn next_group(&mut self) -> Result<Option<GroupConsumer>> {
conducer::wait(|waiter| self.poll_next_group(waiter)).await
}
pub fn poll_get_group(&self, waiter: &conducer::Waiter, sequence: u64) -> Poll<Result<Option<GroupConsumer>>> {
self.poll(waiter, |state| state.poll_get_group(sequence))
}
pub async fn get_group(&self, sequence: u64) -> Result<Option<GroupConsumer>> {
conducer::wait(|waiter| self.poll_get_group(waiter, sequence)).await
}
pub fn poll_closed(&self, waiter: &conducer::Waiter) -> Poll<Result<()>> {
self.poll(waiter, |state| state.poll_closed())
}
pub async fn closed(&self) -> Result<()> {
conducer::wait(|waiter| self.poll_closed(waiter)).await
}
pub fn is_clone(&self, other: &Self) -> bool {
self.state.same_channel(&other.state)
}
pub fn poll_finished(&mut self, waiter: &conducer::Waiter) -> Poll<Result<u64>> {
self.poll(waiter, |state| state.poll_finished())
}
pub async fn finished(&mut self) -> Result<u64> {
conducer::wait(|waiter| self.poll_finished(waiter)).await
}
pub fn start_at(&mut self, sequence: u64) {
self.min_sequence = sequence;
}
pub fn latest(&self) -> Option<u64> {
self.state.read().max_sequence
}
}
#[cfg(test)]
use futures::FutureExt;
#[cfg(test)]
impl TrackConsumer {
pub fn assert_group(&mut self) -> GroupConsumer {
self.next_group()
.now_or_never()
.expect("group would have blocked")
.expect("would have errored")
.expect("track was closed")
}
pub fn assert_no_group(&mut self) {
assert!(
self.next_group().now_or_never().is_none(),
"next group would not have blocked"
);
}
pub fn assert_not_closed(&self) {
assert!(self.closed().now_or_never().is_none(), "should not be closed");
}
pub fn assert_closed(&self) {
assert!(self.closed().now_or_never().is_some(), "should be closed");
}
pub fn assert_error(&self) {
assert!(
self.closed().now_or_never().expect("should not block").is_err(),
"should be error"
);
}
pub fn assert_is_clone(&self, other: &Self) {
assert!(self.is_clone(other), "should be clone");
}
pub fn assert_not_clone(&self, other: &Self) {
assert!(!self.is_clone(other), "should not be clone");
}
}
#[cfg(test)]
mod test {
use super::*;
fn live_groups(state: &State) -> usize {
state.groups.iter().flatten().count()
}
fn first_live_sequence(state: &State) -> u64 {
state.groups.iter().flatten().next().unwrap().0.info.sequence
}
#[tokio::test]
async fn evict_expired_groups() {
tokio::time::pause();
let mut producer = Track::new("test").produce();
producer.append_group().unwrap(); producer.append_group().unwrap(); producer.append_group().unwrap();
{
let state = producer.state.read();
assert_eq!(live_groups(&state), 3);
assert_eq!(state.offset, 0);
}
tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
producer.append_group().unwrap();
{
let state = producer.state.read();
assert_eq!(live_groups(&state), 1);
assert_eq!(first_live_sequence(&state), 3);
assert_eq!(state.offset, 3);
assert!(!state.duplicates.contains(&0));
assert!(!state.duplicates.contains(&1));
assert!(!state.duplicates.contains(&2));
assert!(state.duplicates.contains(&3));
}
}
#[tokio::test]
async fn evict_keeps_max_sequence() {
tokio::time::pause();
let mut producer = Track::new("test").produce();
producer.append_group().unwrap();
tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
producer.append_group().unwrap();
{
let state = producer.state.read();
assert_eq!(live_groups(&state), 1);
assert_eq!(first_live_sequence(&state), 1);
assert_eq!(state.offset, 1);
}
}
#[tokio::test]
async fn no_eviction_when_fresh() {
tokio::time::pause();
let mut producer = Track::new("test").produce();
producer.append_group().unwrap(); producer.append_group().unwrap(); producer.append_group().unwrap();
{
let state = producer.state.read();
assert_eq!(live_groups(&state), 3);
assert_eq!(state.offset, 0);
}
}
#[tokio::test]
async fn consumer_skips_evicted_groups() {
tokio::time::pause();
let mut producer = Track::new("test").produce();
producer.append_group().unwrap();
let mut consumer = producer.consume();
tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
producer.append_group().unwrap();
let group = consumer.assert_group();
assert_eq!(group.info.sequence, 1);
}
#[tokio::test]
async fn out_of_order_max_sequence_at_front() {
tokio::time::pause();
let mut producer = Track::new("test").produce();
producer.create_group(Group { sequence: 5 }).unwrap();
producer.create_group(Group { sequence: 3 }).unwrap();
producer.create_group(Group { sequence: 4 }).unwrap();
{
let state = producer.state.read();
assert_eq!(state.max_sequence, Some(5));
}
tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
producer.append_group().unwrap();
{
let state = producer.state.read();
assert_eq!(live_groups(&state), 1);
assert_eq!(first_live_sequence(&state), 6);
assert!(!state.duplicates.contains(&3));
assert!(!state.duplicates.contains(&4));
assert!(!state.duplicates.contains(&5));
assert!(state.duplicates.contains(&6));
}
}
#[tokio::test]
async fn max_sequence_at_front_blocks_trim() {
tokio::time::pause();
let mut producer = Track::new("test").produce();
producer.create_group(Group { sequence: 5 }).unwrap();
tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
producer.create_group(Group { sequence: 3 }).unwrap();
{
let state = producer.state.read();
assert_eq!(live_groups(&state), 2);
assert_eq!(state.offset, 0);
}
tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
producer.create_group(Group { sequence: 2 }).unwrap();
{
let state = producer.state.read();
assert_eq!(live_groups(&state), 2);
assert_eq!(state.offset, 0);
assert!(state.duplicates.contains(&5));
assert!(!state.duplicates.contains(&3));
assert!(state.duplicates.contains(&2));
}
let mut consumer = producer.consume();
let group = consumer.assert_group();
assert_eq!(group.info.sequence, 5);
}
#[test]
fn append_finish_cannot_be_rewritten() {
let mut producer = Track::new("test").produce();
assert!(producer.finish().is_ok());
assert!(producer.finish().is_err());
assert!(producer.append_group().is_err());
}
#[test]
fn finish_after_groups() {
let mut producer = Track::new("test").produce();
producer.append_group().unwrap();
assert!(producer.finish().is_ok());
assert!(producer.finish().is_err());
assert!(producer.append_group().is_err());
}
#[test]
fn insert_finish_validates_sequence_and_freezes_to_max() {
let mut producer = Track::new("test").produce();
producer.create_group(Group { sequence: 5 }).unwrap();
assert!(producer.finish_at(4).is_err());
assert!(producer.finish_at(10).is_err());
assert!(producer.finish_at(5).is_ok());
{
let state = producer.state.read();
assert_eq!(state.final_sequence, Some(6));
}
assert!(producer.finish_at(5).is_err());
assert!(producer.create_group(Group { sequence: 4 }).is_ok());
assert!(producer.create_group(Group { sequence: 5 }).is_err());
}
#[tokio::test]
async fn next_group_finishes_without_waiting_for_gaps() {
let mut producer = Track::new("test").produce();
producer.create_group(Group { sequence: 1 }).unwrap();
producer.finish_at(1).unwrap();
let mut consumer = producer.consume();
assert_eq!(consumer.assert_group().info.sequence, 1);
let done = consumer
.next_group()
.now_or_never()
.expect("should not block")
.expect("would have errored");
assert!(done.is_none(), "track should finish without waiting for gaps");
}
#[tokio::test]
async fn get_group_finishes_without_waiting_for_gaps() {
let mut producer = Track::new("test").produce();
producer.create_group(Group { sequence: 1 }).unwrap();
producer.finish_at(1).unwrap();
let consumer = producer.consume();
assert!(
consumer.get_group(0).now_or_never().is_none(),
"sequence below fin should block (group could still arrive)"
);
assert!(
consumer
.get_group(2)
.now_or_never()
.expect("sequence at-or-after fin should resolve")
.expect("should not error")
.is_none(),
"sequence at-or-after fin should not exist"
);
}
#[test]
fn append_group_returns_bounds_exceeded_on_sequence_overflow() {
let mut producer = Track::new("test").produce();
{
let mut state = producer.state.write().ok().unwrap();
state.max_sequence = Some(u64::MAX);
}
assert!(matches!(producer.append_group(), Err(Error::BoundsExceeded)));
}
}