1use crate::{Error, Result, coding};
16
17use super::{Group, GroupConsumer, GroupProducer};
18
19use std::{
20 collections::{HashSet, VecDeque},
21 task::{Poll, ready},
22 time::Duration,
23};
24
25const MAX_GROUP_AGE: Duration = Duration::from_secs(30);
28
29#[derive(Clone, Debug, PartialEq, Eq)]
31#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
32pub struct Track {
33 pub name: String,
34 pub priority: u8,
35}
36
37impl Track {
38 pub fn new<T: Into<String>>(name: T) -> Self {
39 Self {
40 name: name.into(),
41 priority: 0,
42 }
43 }
44
45 pub fn produce(self) -> TrackProducer {
46 TrackProducer::new(self)
47 }
48}
49
50#[derive(Default)]
51struct State {
52 groups: VecDeque<Option<(GroupProducer, tokio::time::Instant)>>,
54 duplicates: HashSet<u64>,
55 offset: usize,
56 max_sequence: Option<u64>,
57 final_sequence: Option<u64>,
58 abort: Option<Error>,
59}
60
61impl State {
62 fn poll_next_group(&self, index: usize, min_sequence: u64) -> Poll<Result<Option<(GroupConsumer, usize)>>> {
66 let start = index.saturating_sub(self.offset);
67 for (i, slot) in self.groups.iter().enumerate().skip(start) {
68 if let Some((group, _)) = slot
69 && group.info.sequence >= min_sequence
70 {
71 return Poll::Ready(Ok(Some((group.consume(), self.offset + i))));
72 }
73 }
74
75 if self.final_sequence.is_some() {
77 Poll::Ready(Ok(None))
78 } else if let Some(err) = &self.abort {
79 Poll::Ready(Err(err.clone()))
80 } else {
81 Poll::Pending
82 }
83 }
84
85 fn poll_get_group(&self, sequence: u64) -> Poll<Result<Option<GroupConsumer>>> {
86 for (group, _) in self.groups.iter().flatten() {
88 if group.info.sequence == sequence {
89 return Poll::Ready(Ok(Some(group.consume())));
90 }
91 }
92
93 if let Some(fin) = self.final_sequence
95 && sequence >= fin
96 {
97 return Poll::Ready(Ok(None));
98 }
99
100 if let Some(err) = &self.abort {
101 return Poll::Ready(Err(err.clone()));
102 }
103
104 Poll::Pending
105 }
106
107 fn poll_closed(&self) -> Poll<Result<()>> {
108 if self.final_sequence.is_some() {
109 Poll::Ready(Ok(()))
110 } else if let Some(err) = &self.abort {
111 Poll::Ready(Err(err.clone()))
112 } else {
113 Poll::Pending
114 }
115 }
116
117 fn evict_expired(&mut self, now: tokio::time::Instant) {
124 for slot in self.groups.iter_mut() {
125 let Some((group, created_at)) = slot else { continue };
126
127 if Some(group.info.sequence) == self.max_sequence {
128 continue;
129 }
130
131 if now.duration_since(*created_at) <= MAX_GROUP_AGE {
132 break;
133 }
134
135 self.duplicates.remove(&group.info.sequence);
136 *slot = None;
137 }
138
139 while let Some(None) = self.groups.front() {
141 self.groups.pop_front();
142 self.offset += 1;
143 }
144 }
145
146 fn poll_finished(&self) -> Poll<Result<u64>> {
147 if let Some(fin) = self.final_sequence {
148 Poll::Ready(Ok(fin))
149 } else if let Some(err) = &self.abort {
150 Poll::Ready(Err(err.clone()))
151 } else {
152 Poll::Pending
153 }
154 }
155}
156
157pub struct TrackProducer {
159 pub info: Track,
160 state: conducer::Producer<State>,
161}
162
163impl TrackProducer {
164 pub fn new(info: Track) -> Self {
165 Self {
166 info,
167 state: conducer::Producer::default(),
168 }
169 }
170
171 pub fn create_group(&mut self, info: Group) -> Result<GroupProducer> {
173 let group = info.produce();
174
175 let mut state = self.modify()?;
176 if let Some(fin) = state.final_sequence
177 && group.info.sequence >= fin
178 {
179 return Err(Error::Closed);
180 }
181
182 if !state.duplicates.insert(group.info.sequence) {
183 return Err(Error::Duplicate);
184 }
185
186 let now = tokio::time::Instant::now();
187 state.max_sequence = Some(state.max_sequence.unwrap_or(0).max(group.info.sequence));
188 state.groups.push_back(Some((group.clone(), now)));
189 state.evict_expired(now);
190
191 Ok(group)
192 }
193
194 pub fn append_group(&mut self) -> Result<GroupProducer> {
196 let mut state = self.modify()?;
197 let sequence = match state.max_sequence {
198 Some(s) => s.checked_add(1).ok_or(coding::BoundsExceeded)?,
199 None => 0,
200 };
201 if let Some(fin) = state.final_sequence
202 && sequence >= fin
203 {
204 return Err(Error::Closed);
205 }
206
207 let group = Group { sequence }.produce();
208
209 let now = tokio::time::Instant::now();
210 state.duplicates.insert(sequence);
211 state.max_sequence = Some(sequence);
212 state.groups.push_back(Some((group.clone(), now)));
213 state.evict_expired(now);
214
215 Ok(group)
216 }
217
218 pub fn write_frame<B: Into<bytes::Bytes>>(&mut self, frame: B) -> Result<()> {
220 let mut group = self.append_group()?;
221 group.write_frame(frame.into())?;
222 group.finish()?;
223 Ok(())
224 }
225
226 pub fn finish(&mut self) -> Result<()> {
232 let mut state = self.modify()?;
233 if state.final_sequence.is_some() {
234 return Err(Error::Closed);
235 }
236 state.final_sequence = Some(match state.max_sequence {
237 Some(max) => max.checked_add(1).ok_or(coding::BoundsExceeded)?,
238 None => 0,
239 });
240 Ok(())
241 }
242
243 #[deprecated(note = "use finish() or finish_at(sequence) instead")]
248 pub fn close(&mut self) -> Result<()> {
249 self.finish()
250 }
251
252 pub fn finish_at(&mut self, sequence: u64) -> Result<()> {
259 let mut state = self.modify()?;
260 let max = state.max_sequence.ok_or(Error::Closed)?;
261 if state.final_sequence.is_some() || sequence != max {
262 return Err(Error::Closed);
263 }
264 state.final_sequence = Some(max.checked_add(1).ok_or(coding::BoundsExceeded)?);
265 Ok(())
266 }
267
268 pub fn abort(&mut self, err: Error) -> Result<()> {
270 let mut guard = self.modify()?;
271
272 for (group, _) in guard.groups.iter_mut().flatten() {
274 group.abort(err.clone()).ok();
276 }
277
278 guard.abort = Some(err);
279 guard.close();
280 Ok(())
281 }
282
283 pub fn consume(&self) -> TrackConsumer {
285 TrackConsumer {
286 info: self.info.clone(),
287 state: self.state.consume(),
288 index: 0,
289 min_sequence: 0,
290 }
291 }
292
293 pub async fn unused(&self) -> Result<()> {
295 self.state
296 .unused()
297 .await
298 .map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
299 }
300
301 pub async fn used(&self) -> Result<()> {
303 self.state
304 .used()
305 .await
306 .map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
307 }
308
309 pub fn is_closed(&self) -> bool {
311 self.state.read().is_closed()
312 }
313
314 pub fn is_clone(&self, other: &Self) -> bool {
316 self.state.same_channel(&other.state)
317 }
318
319 pub(crate) fn weak(&self) -> TrackWeak {
321 TrackWeak {
322 info: self.info.clone(),
323 state: self.state.weak(),
324 }
325 }
326
327 fn modify(&self) -> Result<conducer::Mut<'_, State>> {
328 self.state
329 .write()
330 .map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
331 }
332}
333
334impl Clone for TrackProducer {
335 fn clone(&self) -> Self {
336 Self {
337 info: self.info.clone(),
338 state: self.state.clone(),
339 }
340 }
341}
342
343impl From<Track> for TrackProducer {
344 fn from(info: Track) -> Self {
345 TrackProducer::new(info)
346 }
347}
348
349#[derive(Clone)]
351pub(crate) struct TrackWeak {
352 pub info: Track,
353 state: conducer::Weak<State>,
354}
355
356impl TrackWeak {
357 pub fn abort(&self, err: Error) {
358 let Ok(mut guard) = self.state.write() else { return };
359
360 for (group, _) in guard.groups.iter_mut().flatten() {
362 group.abort(err.clone()).ok();
363 }
364
365 guard.abort = Some(err);
366 guard.close();
367 }
368
369 pub fn is_closed(&self) -> bool {
370 self.state.is_closed()
371 }
372
373 pub fn consume(&self) -> TrackConsumer {
374 TrackConsumer {
375 info: self.info.clone(),
376 state: self.state.consume(),
377 index: 0,
378 min_sequence: 0,
379 }
380 }
381
382 pub async fn unused(&self) -> crate::Result<()> {
383 self.state
384 .unused()
385 .await
386 .map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
387 }
388
389 pub fn is_clone(&self, other: &Self) -> bool {
390 self.state.same_channel(&other.state)
391 }
392}
393
394#[derive(Clone)]
396pub struct TrackConsumer {
397 pub info: Track,
398 state: conducer::Consumer<State>,
399 index: usize,
400
401 min_sequence: u64,
402}
403
404impl TrackConsumer {
405 fn poll<F, R>(&self, waiter: &conducer::Waiter, f: F) -> Poll<Result<R>>
407 where
408 F: Fn(&conducer::Ref<'_, State>) -> Poll<Result<R>>,
409 {
410 Poll::Ready(match ready!(self.state.poll(waiter, f)) {
411 Ok(res) => res,
412 Err(state) => Err(state.abort.clone().unwrap_or(Error::Dropped)),
414 })
415 }
416
417 pub fn poll_next_group(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Option<GroupConsumer>>> {
424 let Some((consumer, found_index)) =
425 ready!(self.poll(waiter, |state| state.poll_next_group(self.index, self.min_sequence))?)
426 else {
427 return Poll::Ready(Ok(None));
428 };
429
430 self.index = found_index + 1;
431 Poll::Ready(Ok(Some(consumer)))
432 }
433
434 pub async fn next_group(&mut self) -> Result<Option<GroupConsumer>> {
438 conducer::wait(|waiter| self.poll_next_group(waiter)).await
439 }
440
441 pub fn poll_get_group(&self, waiter: &conducer::Waiter, sequence: u64) -> Poll<Result<Option<GroupConsumer>>> {
443 self.poll(waiter, |state| state.poll_get_group(sequence))
444 }
445
446 pub async fn get_group(&self, sequence: u64) -> Result<Option<GroupConsumer>> {
450 conducer::wait(|waiter| self.poll_get_group(waiter, sequence)).await
451 }
452
453 pub fn poll_closed(&self, waiter: &conducer::Waiter) -> Poll<Result<()>> {
455 self.poll(waiter, |state| state.poll_closed())
456 }
457
458 pub async fn closed(&self) -> Result<()> {
462 conducer::wait(|waiter| self.poll_closed(waiter)).await
463 }
464
465 pub fn is_clone(&self, other: &Self) -> bool {
466 self.state.same_channel(&other.state)
467 }
468
469 pub fn poll_finished(&mut self, waiter: &conducer::Waiter) -> Poll<Result<u64>> {
471 self.poll(waiter, |state| state.poll_finished())
472 }
473
474 pub async fn finished(&mut self) -> Result<u64> {
476 conducer::wait(|waiter| self.poll_finished(waiter)).await
477 }
478
479 pub fn start_at(&mut self, sequence: u64) {
481 self.min_sequence = sequence;
482 }
483
484 pub fn latest(&self) -> Option<u64> {
486 self.state.read().max_sequence
487 }
488
489 pub fn produce(&self) -> Result<TrackProducer> {
504 let state = self
505 .state
506 .produce()
507 .ok_or_else(|| self.state.read().abort.clone().unwrap_or(Error::Dropped))?;
508 Ok(TrackProducer {
509 info: self.info.clone(),
510 state,
511 })
512 }
513}
514
515#[cfg(test)]
516use futures::FutureExt;
517
518#[cfg(test)]
519impl TrackConsumer {
520 pub fn assert_group(&mut self) -> GroupConsumer {
521 self.next_group()
522 .now_or_never()
523 .expect("group would have blocked")
524 .expect("would have errored")
525 .expect("track was closed")
526 }
527
528 pub fn assert_no_group(&mut self) {
529 assert!(
530 self.next_group().now_or_never().is_none(),
531 "next group would not have blocked"
532 );
533 }
534
535 pub fn assert_not_closed(&self) {
536 assert!(self.closed().now_or_never().is_none(), "should not be closed");
537 }
538
539 pub fn assert_closed(&self) {
540 assert!(self.closed().now_or_never().is_some(), "should be closed");
541 }
542
543 pub fn assert_error(&self) {
545 assert!(
546 self.closed().now_or_never().expect("should not block").is_err(),
547 "should be error"
548 );
549 }
550
551 pub fn assert_is_clone(&self, other: &Self) {
552 assert!(self.is_clone(other), "should be clone");
553 }
554
555 pub fn assert_not_clone(&self, other: &Self) {
556 assert!(!self.is_clone(other), "should not be clone");
557 }
558}
559
560#[cfg(test)]
561mod test {
562 use super::*;
563
564 fn live_groups(state: &State) -> usize {
566 state.groups.iter().flatten().count()
567 }
568
569 fn first_live_sequence(state: &State) -> u64 {
571 state.groups.iter().flatten().next().unwrap().0.info.sequence
572 }
573
574 #[tokio::test]
575 async fn evict_expired_groups() {
576 tokio::time::pause();
577
578 let mut producer = Track::new("test").produce();
579
580 producer.append_group().unwrap(); producer.append_group().unwrap(); producer.append_group().unwrap(); {
586 let state = producer.state.read();
587 assert_eq!(live_groups(&state), 3);
588 assert_eq!(state.offset, 0);
589 }
590
591 tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
593
594 producer.append_group().unwrap(); {
600 let state = producer.state.read();
601 assert_eq!(live_groups(&state), 1);
602 assert_eq!(first_live_sequence(&state), 3);
603 assert_eq!(state.offset, 3);
604 assert!(!state.duplicates.contains(&0));
605 assert!(!state.duplicates.contains(&1));
606 assert!(!state.duplicates.contains(&2));
607 assert!(state.duplicates.contains(&3));
608 }
609 }
610
611 #[tokio::test]
612 async fn evict_keeps_max_sequence() {
613 tokio::time::pause();
614
615 let mut producer = Track::new("test").produce();
616 producer.append_group().unwrap(); tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
620
621 producer.append_group().unwrap(); {
625 let state = producer.state.read();
626 assert_eq!(live_groups(&state), 1);
627 assert_eq!(first_live_sequence(&state), 1);
628 assert_eq!(state.offset, 1);
629 }
630 }
631
632 #[tokio::test]
633 async fn no_eviction_when_fresh() {
634 tokio::time::pause();
635
636 let mut producer = Track::new("test").produce();
637 producer.append_group().unwrap(); producer.append_group().unwrap(); producer.append_group().unwrap(); {
642 let state = producer.state.read();
643 assert_eq!(live_groups(&state), 3);
644 assert_eq!(state.offset, 0);
645 }
646 }
647
648 #[tokio::test]
649 async fn consumer_skips_evicted_groups() {
650 tokio::time::pause();
651
652 let mut producer = Track::new("test").produce();
653 producer.append_group().unwrap(); let mut consumer = producer.consume();
656
657 tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
658 producer.append_group().unwrap(); let group = consumer.assert_group();
662 assert_eq!(group.info.sequence, 1);
663 }
664
665 #[tokio::test]
666 async fn out_of_order_max_sequence_at_front() {
667 tokio::time::pause();
668
669 let mut producer = Track::new("test").produce();
670
671 producer.create_group(Group { sequence: 5 }).unwrap();
673 producer.create_group(Group { sequence: 3 }).unwrap();
674 producer.create_group(Group { sequence: 4 }).unwrap();
675
676 {
678 let state = producer.state.read();
679 assert_eq!(state.max_sequence, Some(5));
680 }
681
682 tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
684
685 producer.append_group().unwrap(); {
691 let state = producer.state.read();
692 assert_eq!(live_groups(&state), 1);
693 assert_eq!(first_live_sequence(&state), 6);
694 assert!(!state.duplicates.contains(&3));
695 assert!(!state.duplicates.contains(&4));
696 assert!(!state.duplicates.contains(&5));
697 assert!(state.duplicates.contains(&6));
698 }
699 }
700
701 #[tokio::test]
702 async fn max_sequence_at_front_blocks_trim() {
703 tokio::time::pause();
704
705 let mut producer = Track::new("test").produce();
706
707 producer.create_group(Group { sequence: 5 }).unwrap();
709
710 tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
711
712 producer.create_group(Group { sequence: 3 }).unwrap();
714
715 {
718 let state = producer.state.read();
719 assert_eq!(live_groups(&state), 2);
720 assert_eq!(state.offset, 0);
721 }
722
723 tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
725
726 producer.create_group(Group { sequence: 2 }).unwrap();
728
729 {
734 let state = producer.state.read();
735 assert_eq!(live_groups(&state), 2);
736 assert_eq!(state.offset, 0);
737 assert!(state.duplicates.contains(&5));
738 assert!(!state.duplicates.contains(&3));
739 assert!(state.duplicates.contains(&2));
740 }
741
742 let mut consumer = producer.consume();
744 let group = consumer.assert_group();
745 assert_eq!(group.info.sequence, 5);
747 }
748
749 #[test]
750 fn append_finish_cannot_be_rewritten() {
751 let mut producer = Track::new("test").produce();
752
753 assert!(producer.finish().is_ok());
755 assert!(producer.finish().is_err());
756 assert!(producer.append_group().is_err());
757 }
758
759 #[test]
760 fn finish_after_groups() {
761 let mut producer = Track::new("test").produce();
762
763 producer.append_group().unwrap();
764 assert!(producer.finish().is_ok());
765 assert!(producer.finish().is_err());
766 assert!(producer.append_group().is_err());
767 }
768
769 #[test]
770 fn insert_finish_validates_sequence_and_freezes_to_max() {
771 let mut producer = Track::new("test").produce();
772 producer.create_group(Group { sequence: 5 }).unwrap();
773
774 assert!(producer.finish_at(4).is_err());
775 assert!(producer.finish_at(10).is_err());
776 assert!(producer.finish_at(5).is_ok());
777
778 {
779 let state = producer.state.read();
780 assert_eq!(state.final_sequence, Some(6));
781 }
782
783 assert!(producer.finish_at(5).is_err());
784 assert!(producer.create_group(Group { sequence: 4 }).is_ok());
785 assert!(producer.create_group(Group { sequence: 5 }).is_err());
786 }
787
788 #[tokio::test]
789 async fn next_group_finishes_without_waiting_for_gaps() {
790 let mut producer = Track::new("test").produce();
791 producer.create_group(Group { sequence: 1 }).unwrap();
792 producer.finish_at(1).unwrap();
793
794 let mut consumer = producer.consume();
795 assert_eq!(consumer.assert_group().info.sequence, 1);
796
797 let done = consumer
798 .next_group()
799 .now_or_never()
800 .expect("should not block")
801 .expect("would have errored");
802 assert!(done.is_none(), "track should finish without waiting for gaps");
803 }
804
805 #[tokio::test]
806 async fn get_group_finishes_without_waiting_for_gaps() {
807 let mut producer = Track::new("test").produce();
808 producer.create_group(Group { sequence: 1 }).unwrap();
809 producer.finish_at(1).unwrap();
810
811 let consumer = producer.consume();
812 assert!(
814 consumer.get_group(0).now_or_never().is_none(),
815 "sequence below fin should block (group could still arrive)"
816 );
817 assert!(
818 consumer
819 .get_group(2)
820 .now_or_never()
821 .expect("sequence at-or-after fin should resolve")
822 .expect("should not error")
823 .is_none(),
824 "sequence at-or-after fin should not exist"
825 );
826 }
827
828 #[test]
829 fn append_group_returns_bounds_exceeded_on_sequence_overflow() {
830 let mut producer = Track::new("test").produce();
831 {
832 let mut state = producer.state.write().ok().unwrap();
833 state.max_sequence = Some(u64::MAX);
834 }
835
836 assert!(matches!(producer.append_group(), Err(Error::BoundsExceeded(_))));
837 }
838
839 #[tokio::test]
840 async fn consumer_produce() {
841 let mut producer = Track::new("test").produce();
842 producer.append_group().unwrap();
843
844 let consumer = producer.consume();
845
846 let got = consumer.produce().expect("should produce");
848 assert!(got.is_clone(&producer), "should be the same track");
849
850 got.clone().append_group().unwrap();
852 let mut sub = producer.consume();
853 sub.assert_group(); sub.assert_group(); }
856
857 #[tokio::test]
858 async fn consumer_produce_after_drop() {
859 let producer = Track::new("test").produce();
860 let consumer = producer.consume();
861 drop(producer);
862
863 let err = consumer.produce();
866 assert!(matches!(err, Err(Error::Dropped)), "expected Dropped");
867 }
868
869 #[tokio::test]
870 async fn consumer_produce_after_abort() {
871 let mut producer = Track::new("test").produce();
872 let consumer = producer.consume();
873 producer.abort(Error::Cancel).unwrap();
874 drop(producer);
875
876 let err = consumer.produce();
878 assert!(matches!(err, Err(Error::Cancel)), "expected Cancel");
879 }
880
881 #[tokio::test]
882 async fn consumer_produce_keeps_alive() {
883 let producer = Track::new("test").produce();
884 let consumer = producer.consume();
885 let upgraded = consumer.produce().expect("should produce");
886 drop(producer);
887
888 assert!(consumer.closed().now_or_never().is_none(), "should not be closed");
890 drop(upgraded);
891
892 assert!(consumer.closed().now_or_never().is_some(), "should be closed");
894 }
895}