1use crate::{Error, Result};
16
17use super::{Group, GroupConsumer, GroupProducer};
18
19use std::{
20 collections::{HashSet, VecDeque},
21 task::{Poll, ready},
22 time::Duration,
23};
24
25const MAX_GROUP_AGE: Duration = Duration::from_secs(30);
28
29#[derive(Clone, Debug, PartialEq, Eq)]
31#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
32pub struct Track {
33 pub name: String,
34 pub priority: u8,
35}
36
37impl Track {
38 pub fn new<T: Into<String>>(name: T) -> Self {
39 Self {
40 name: name.into(),
41 priority: 0,
42 }
43 }
44
45 pub fn produce(self) -> TrackProducer {
46 TrackProducer::new(self)
47 }
48}
49
50#[derive(Default)]
51struct State {
52 groups: VecDeque<Option<(GroupProducer, tokio::time::Instant)>>,
54 duplicates: HashSet<u64>,
55 offset: usize,
56 max_sequence: Option<u64>,
57 final_sequence: Option<u64>,
58 abort: Option<Error>,
59}
60
61impl State {
62 fn poll_next_group(&self, index: usize, min_sequence: u64) -> Poll<Result<Option<(GroupConsumer, usize)>>> {
66 let start = index.saturating_sub(self.offset);
67 for (i, slot) in self.groups.iter().enumerate().skip(start) {
68 if let Some((group, _)) = slot
69 && group.info.sequence >= min_sequence
70 {
71 return Poll::Ready(Ok(Some((group.consume(), self.offset + i))));
72 }
73 }
74
75 if self.final_sequence.is_some() {
77 Poll::Ready(Ok(None))
78 } else if let Some(err) = &self.abort {
79 Poll::Ready(Err(err.clone()))
80 } else {
81 Poll::Pending
82 }
83 }
84
85 fn poll_get_group(&self, sequence: u64) -> Poll<Result<Option<GroupConsumer>>> {
86 for (group, _) in self.groups.iter().flatten() {
88 if group.info.sequence == sequence {
89 return Poll::Ready(Ok(Some(group.consume())));
90 }
91 }
92
93 if let Some(fin) = self.final_sequence
95 && sequence >= fin
96 {
97 return Poll::Ready(Ok(None));
98 }
99
100 if let Some(err) = &self.abort {
101 return Poll::Ready(Err(err.clone()));
102 }
103
104 Poll::Pending
105 }
106
107 fn poll_closed(&self) -> Poll<Result<()>> {
108 if self.final_sequence.is_some() {
109 Poll::Ready(Ok(()))
110 } else if let Some(err) = &self.abort {
111 Poll::Ready(Err(err.clone()))
112 } else {
113 Poll::Pending
114 }
115 }
116
117 fn evict_expired(&mut self, now: tokio::time::Instant) {
124 for slot in self.groups.iter_mut() {
125 let Some((group, created_at)) = slot else { continue };
126
127 if Some(group.info.sequence) == self.max_sequence {
128 continue;
129 }
130
131 if now.duration_since(*created_at) <= MAX_GROUP_AGE {
132 break;
133 }
134
135 self.duplicates.remove(&group.info.sequence);
136 *slot = None;
137 }
138
139 while let Some(None) = self.groups.front() {
141 self.groups.pop_front();
142 self.offset += 1;
143 }
144 }
145
146 fn poll_finished(&self) -> Poll<Result<u64>> {
147 if let Some(fin) = self.final_sequence {
148 Poll::Ready(Ok(fin))
149 } else if let Some(err) = &self.abort {
150 Poll::Ready(Err(err.clone()))
151 } else {
152 Poll::Pending
153 }
154 }
155}
156
157pub struct TrackProducer {
159 pub info: Track,
160 state: conducer::Producer<State>,
161}
162
163impl TrackProducer {
164 pub fn new(info: Track) -> Self {
165 Self {
166 info,
167 state: conducer::Producer::default(),
168 }
169 }
170
171 pub fn create_group(&mut self, info: Group) -> Result<GroupProducer> {
173 let group = info.produce();
174
175 let mut state = self.modify()?;
176 if let Some(fin) = state.final_sequence
177 && group.info.sequence >= fin
178 {
179 return Err(Error::Closed);
180 }
181
182 if !state.duplicates.insert(group.info.sequence) {
183 return Err(Error::Duplicate);
184 }
185
186 let now = tokio::time::Instant::now();
187 state.max_sequence = Some(state.max_sequence.unwrap_or(0).max(group.info.sequence));
188 state.groups.push_back(Some((group.clone(), now)));
189 state.evict_expired(now);
190
191 Ok(group)
192 }
193
194 pub fn append_group(&mut self) -> Result<GroupProducer> {
196 let mut state = self.modify()?;
197 let sequence = match state.max_sequence {
198 Some(s) => s.checked_add(1).ok_or(Error::BoundsExceeded)?,
199 None => 0,
200 };
201 if let Some(fin) = state.final_sequence
202 && sequence >= fin
203 {
204 return Err(Error::Closed);
205 }
206
207 let group = Group { sequence }.produce();
208
209 let now = tokio::time::Instant::now();
210 state.duplicates.insert(sequence);
211 state.max_sequence = Some(sequence);
212 state.groups.push_back(Some((group.clone(), now)));
213 state.evict_expired(now);
214
215 Ok(group)
216 }
217
218 pub fn write_frame<B: Into<bytes::Bytes>>(&mut self, frame: B) -> Result<()> {
220 let mut group = self.append_group()?;
221 group.write_frame(frame.into())?;
222 group.finish()?;
223 Ok(())
224 }
225
226 pub fn finish(&mut self) -> Result<()> {
232 let mut state = self.modify()?;
233 if state.final_sequence.is_some() {
234 return Err(Error::Closed);
235 }
236 state.final_sequence = Some(match state.max_sequence {
237 Some(max) => max.checked_add(1).ok_or(Error::BoundsExceeded)?,
238 None => 0,
239 });
240 Ok(())
241 }
242
243 #[deprecated(note = "use finish() or finish_at(sequence) instead")]
248 pub fn close(&mut self) -> Result<()> {
249 self.finish()
250 }
251
252 pub fn finish_at(&mut self, sequence: u64) -> Result<()> {
259 let mut state = self.modify()?;
260 let max = state.max_sequence.ok_or(Error::Closed)?;
261 if state.final_sequence.is_some() || sequence != max {
262 return Err(Error::Closed);
263 }
264 state.final_sequence = Some(max.checked_add(1).ok_or(Error::BoundsExceeded)?);
265 Ok(())
266 }
267
268 pub fn abort(&mut self, err: Error) -> Result<()> {
270 let mut guard = self.modify()?;
271
272 for (group, _) in guard.groups.iter_mut().flatten() {
274 group.abort(err.clone()).ok();
276 }
277
278 guard.abort = Some(err);
279 guard.close();
280 Ok(())
281 }
282
283 pub fn consume(&self) -> TrackConsumer {
285 TrackConsumer {
286 info: self.info.clone(),
287 state: self.state.consume(),
288 index: 0,
289 min_sequence: 0,
290 }
291 }
292
293 pub async fn unused(&self) -> Result<()> {
295 self.state
296 .unused()
297 .await
298 .map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
299 }
300
301 pub async fn used(&self) -> Result<()> {
303 self.state
304 .used()
305 .await
306 .map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
307 }
308
309 pub fn is_closed(&self) -> bool {
311 self.state.read().is_closed()
312 }
313
314 pub fn is_clone(&self, other: &Self) -> bool {
316 self.state.same_channel(&other.state)
317 }
318
319 pub(crate) fn weak(&self) -> TrackWeak {
321 TrackWeak {
322 info: self.info.clone(),
323 state: self.state.weak(),
324 }
325 }
326
327 fn modify(&self) -> Result<conducer::Mut<'_, State>> {
328 self.state
329 .write()
330 .map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
331 }
332}
333
334impl Clone for TrackProducer {
335 fn clone(&self) -> Self {
336 Self {
337 info: self.info.clone(),
338 state: self.state.clone(),
339 }
340 }
341}
342
343impl From<Track> for TrackProducer {
344 fn from(info: Track) -> Self {
345 TrackProducer::new(info)
346 }
347}
348
349#[derive(Clone)]
351pub(crate) struct TrackWeak {
352 pub info: Track,
353 state: conducer::Weak<State>,
354}
355
356impl TrackWeak {
357 pub fn abort(&self, err: Error) {
358 let Ok(mut guard) = self.state.write() else { return };
359
360 for (group, _) in guard.groups.iter_mut().flatten() {
362 group.abort(err.clone()).ok();
363 }
364
365 guard.abort = Some(err);
366 guard.close();
367 }
368
369 pub fn is_closed(&self) -> bool {
370 self.state.is_closed()
371 }
372
373 pub fn consume(&self) -> TrackConsumer {
374 TrackConsumer {
375 info: self.info.clone(),
376 state: self.state.consume(),
377 index: 0,
378 min_sequence: 0,
379 }
380 }
381
382 pub async fn unused(&self) -> crate::Result<()> {
383 self.state
384 .unused()
385 .await
386 .map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
387 }
388
389 pub fn is_clone(&self, other: &Self) -> bool {
390 self.state.same_channel(&other.state)
391 }
392}
393
394#[derive(Clone)]
396pub struct TrackConsumer {
397 pub info: Track,
398 state: conducer::Consumer<State>,
399 index: usize,
400
401 min_sequence: u64,
402}
403
404impl TrackConsumer {
405 fn poll<F, R>(&self, waiter: &conducer::Waiter, f: F) -> Poll<Result<R>>
407 where
408 F: Fn(&conducer::Ref<'_, State>) -> Poll<Result<R>>,
409 {
410 Poll::Ready(match ready!(self.state.poll(waiter, f)) {
411 Ok(res) => res,
412 Err(state) => Err(state.abort.clone().unwrap_or(Error::Dropped)),
414 })
415 }
416
417 pub fn poll_next_group(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Option<GroupConsumer>>> {
424 let Some((consumer, found_index)) =
425 ready!(self.poll(waiter, |state| state.poll_next_group(self.index, self.min_sequence))?)
426 else {
427 return Poll::Ready(Ok(None));
428 };
429
430 self.index = found_index + 1;
431 Poll::Ready(Ok(Some(consumer)))
432 }
433
434 pub async fn next_group(&mut self) -> Result<Option<GroupConsumer>> {
438 conducer::wait(|waiter| self.poll_next_group(waiter)).await
439 }
440
441 pub fn poll_get_group(&self, waiter: &conducer::Waiter, sequence: u64) -> Poll<Result<Option<GroupConsumer>>> {
443 self.poll(waiter, |state| state.poll_get_group(sequence))
444 }
445
446 pub async fn get_group(&self, sequence: u64) -> Result<Option<GroupConsumer>> {
450 conducer::wait(|waiter| self.poll_get_group(waiter, sequence)).await
451 }
452
453 pub fn poll_closed(&self, waiter: &conducer::Waiter) -> Poll<Result<()>> {
455 self.poll(waiter, |state| state.poll_closed())
456 }
457
458 pub async fn closed(&self) -> Result<()> {
462 conducer::wait(|waiter| self.poll_closed(waiter)).await
463 }
464
465 pub fn is_clone(&self, other: &Self) -> bool {
466 self.state.same_channel(&other.state)
467 }
468
469 pub fn poll_finished(&mut self, waiter: &conducer::Waiter) -> Poll<Result<u64>> {
471 self.poll(waiter, |state| state.poll_finished())
472 }
473
474 pub async fn finished(&mut self) -> Result<u64> {
476 conducer::wait(|waiter| self.poll_finished(waiter)).await
477 }
478
479 pub fn start_at(&mut self, sequence: u64) {
481 self.min_sequence = sequence;
482 }
483
484 pub fn latest(&self) -> Option<u64> {
486 self.state.read().max_sequence
487 }
488}
489
490#[cfg(test)]
491use futures::FutureExt;
492
493#[cfg(test)]
494impl TrackConsumer {
495 pub fn assert_group(&mut self) -> GroupConsumer {
496 self.next_group()
497 .now_or_never()
498 .expect("group would have blocked")
499 .expect("would have errored")
500 .expect("track was closed")
501 }
502
503 pub fn assert_no_group(&mut self) {
504 assert!(
505 self.next_group().now_or_never().is_none(),
506 "next group would not have blocked"
507 );
508 }
509
510 pub fn assert_not_closed(&self) {
511 assert!(self.closed().now_or_never().is_none(), "should not be closed");
512 }
513
514 pub fn assert_closed(&self) {
515 assert!(self.closed().now_or_never().is_some(), "should be closed");
516 }
517
518 pub fn assert_error(&self) {
520 assert!(
521 self.closed().now_or_never().expect("should not block").is_err(),
522 "should be error"
523 );
524 }
525
526 pub fn assert_is_clone(&self, other: &Self) {
527 assert!(self.is_clone(other), "should be clone");
528 }
529
530 pub fn assert_not_clone(&self, other: &Self) {
531 assert!(!self.is_clone(other), "should not be clone");
532 }
533}
534
535#[cfg(test)]
536mod test {
537 use super::*;
538
539 fn live_groups(state: &State) -> usize {
541 state.groups.iter().flatten().count()
542 }
543
544 fn first_live_sequence(state: &State) -> u64 {
546 state.groups.iter().flatten().next().unwrap().0.info.sequence
547 }
548
549 #[tokio::test]
550 async fn evict_expired_groups() {
551 tokio::time::pause();
552
553 let mut producer = Track::new("test").produce();
554
555 producer.append_group().unwrap(); producer.append_group().unwrap(); producer.append_group().unwrap(); {
561 let state = producer.state.read();
562 assert_eq!(live_groups(&state), 3);
563 assert_eq!(state.offset, 0);
564 }
565
566 tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
568
569 producer.append_group().unwrap(); {
575 let state = producer.state.read();
576 assert_eq!(live_groups(&state), 1);
577 assert_eq!(first_live_sequence(&state), 3);
578 assert_eq!(state.offset, 3);
579 assert!(!state.duplicates.contains(&0));
580 assert!(!state.duplicates.contains(&1));
581 assert!(!state.duplicates.contains(&2));
582 assert!(state.duplicates.contains(&3));
583 }
584 }
585
586 #[tokio::test]
587 async fn evict_keeps_max_sequence() {
588 tokio::time::pause();
589
590 let mut producer = Track::new("test").produce();
591 producer.append_group().unwrap(); tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
595
596 producer.append_group().unwrap(); {
600 let state = producer.state.read();
601 assert_eq!(live_groups(&state), 1);
602 assert_eq!(first_live_sequence(&state), 1);
603 assert_eq!(state.offset, 1);
604 }
605 }
606
607 #[tokio::test]
608 async fn no_eviction_when_fresh() {
609 tokio::time::pause();
610
611 let mut producer = Track::new("test").produce();
612 producer.append_group().unwrap(); producer.append_group().unwrap(); producer.append_group().unwrap(); {
617 let state = producer.state.read();
618 assert_eq!(live_groups(&state), 3);
619 assert_eq!(state.offset, 0);
620 }
621 }
622
623 #[tokio::test]
624 async fn consumer_skips_evicted_groups() {
625 tokio::time::pause();
626
627 let mut producer = Track::new("test").produce();
628 producer.append_group().unwrap(); let mut consumer = producer.consume();
631
632 tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
633 producer.append_group().unwrap(); let group = consumer.assert_group();
637 assert_eq!(group.info.sequence, 1);
638 }
639
640 #[tokio::test]
641 async fn out_of_order_max_sequence_at_front() {
642 tokio::time::pause();
643
644 let mut producer = Track::new("test").produce();
645
646 producer.create_group(Group { sequence: 5 }).unwrap();
648 producer.create_group(Group { sequence: 3 }).unwrap();
649 producer.create_group(Group { sequence: 4 }).unwrap();
650
651 {
653 let state = producer.state.read();
654 assert_eq!(state.max_sequence, Some(5));
655 }
656
657 tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
659
660 producer.append_group().unwrap(); {
666 let state = producer.state.read();
667 assert_eq!(live_groups(&state), 1);
668 assert_eq!(first_live_sequence(&state), 6);
669 assert!(!state.duplicates.contains(&3));
670 assert!(!state.duplicates.contains(&4));
671 assert!(!state.duplicates.contains(&5));
672 assert!(state.duplicates.contains(&6));
673 }
674 }
675
676 #[tokio::test]
677 async fn max_sequence_at_front_blocks_trim() {
678 tokio::time::pause();
679
680 let mut producer = Track::new("test").produce();
681
682 producer.create_group(Group { sequence: 5 }).unwrap();
684
685 tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
686
687 producer.create_group(Group { sequence: 3 }).unwrap();
689
690 {
693 let state = producer.state.read();
694 assert_eq!(live_groups(&state), 2);
695 assert_eq!(state.offset, 0);
696 }
697
698 tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
700
701 producer.create_group(Group { sequence: 2 }).unwrap();
703
704 {
709 let state = producer.state.read();
710 assert_eq!(live_groups(&state), 2);
711 assert_eq!(state.offset, 0);
712 assert!(state.duplicates.contains(&5));
713 assert!(!state.duplicates.contains(&3));
714 assert!(state.duplicates.contains(&2));
715 }
716
717 let mut consumer = producer.consume();
719 let group = consumer.assert_group();
720 assert_eq!(group.info.sequence, 5);
722 }
723
724 #[test]
725 fn append_finish_cannot_be_rewritten() {
726 let mut producer = Track::new("test").produce();
727
728 assert!(producer.finish().is_ok());
730 assert!(producer.finish().is_err());
731 assert!(producer.append_group().is_err());
732 }
733
734 #[test]
735 fn finish_after_groups() {
736 let mut producer = Track::new("test").produce();
737
738 producer.append_group().unwrap();
739 assert!(producer.finish().is_ok());
740 assert!(producer.finish().is_err());
741 assert!(producer.append_group().is_err());
742 }
743
744 #[test]
745 fn insert_finish_validates_sequence_and_freezes_to_max() {
746 let mut producer = Track::new("test").produce();
747 producer.create_group(Group { sequence: 5 }).unwrap();
748
749 assert!(producer.finish_at(4).is_err());
750 assert!(producer.finish_at(10).is_err());
751 assert!(producer.finish_at(5).is_ok());
752
753 {
754 let state = producer.state.read();
755 assert_eq!(state.final_sequence, Some(6));
756 }
757
758 assert!(producer.finish_at(5).is_err());
759 assert!(producer.create_group(Group { sequence: 4 }).is_ok());
760 assert!(producer.create_group(Group { sequence: 5 }).is_err());
761 }
762
763 #[tokio::test]
764 async fn next_group_finishes_without_waiting_for_gaps() {
765 let mut producer = Track::new("test").produce();
766 producer.create_group(Group { sequence: 1 }).unwrap();
767 producer.finish_at(1).unwrap();
768
769 let mut consumer = producer.consume();
770 assert_eq!(consumer.assert_group().info.sequence, 1);
771
772 let done = consumer
773 .next_group()
774 .now_or_never()
775 .expect("should not block")
776 .expect("would have errored");
777 assert!(done.is_none(), "track should finish without waiting for gaps");
778 }
779
780 #[tokio::test]
781 async fn get_group_finishes_without_waiting_for_gaps() {
782 let mut producer = Track::new("test").produce();
783 producer.create_group(Group { sequence: 1 }).unwrap();
784 producer.finish_at(1).unwrap();
785
786 let consumer = producer.consume();
787 assert!(
789 consumer.get_group(0).now_or_never().is_none(),
790 "sequence below fin should block (group could still arrive)"
791 );
792 assert!(
793 consumer
794 .get_group(2)
795 .now_or_never()
796 .expect("sequence at-or-after fin should resolve")
797 .expect("should not error")
798 .is_none(),
799 "sequence at-or-after fin should not exist"
800 );
801 }
802
803 #[test]
804 fn append_group_returns_bounds_exceeded_on_sequence_overflow() {
805 let mut producer = Track::new("test").produce();
806 {
807 let mut state = producer.state.write().ok().unwrap();
808 state.max_sequence = Some(u64::MAX);
809 }
810
811 assert!(matches!(producer.append_group(), Err(Error::BoundsExceeded)));
812 }
813}