1use crate::{Error, Result};
16
17use super::{Group, GroupConsumer, GroupProducer};
18
19use std::{
20 collections::{HashSet, VecDeque},
21 task::{Poll, ready},
22 time::Duration,
23};
24
25const MAX_GROUP_AGE: Duration = Duration::from_secs(30);
28
29#[derive(Clone, Debug, PartialEq, Eq)]
31#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
32pub struct Track {
33 pub name: String,
34 pub priority: u8,
35}
36
37impl Track {
38 pub fn new<T: Into<String>>(name: T) -> Self {
39 Self {
40 name: name.into(),
41 priority: 0,
42 }
43 }
44
45 pub fn produce(self) -> TrackProducer {
46 TrackProducer::new(self)
47 }
48}
49
50#[derive(Default)]
51struct State {
52 groups: VecDeque<Option<(GroupProducer, tokio::time::Instant)>>,
54 duplicates: HashSet<u64>,
55 offset: usize,
56 max_sequence: Option<u64>,
57 final_sequence: Option<u64>,
58 abort: Option<Error>,
59}
60
61impl State {
62 fn poll_next_group(&self, index: usize, min_sequence: u64) -> Poll<Result<Option<(GroupConsumer, usize)>>> {
66 let start = index.saturating_sub(self.offset);
67 for (i, slot) in self.groups.iter().enumerate().skip(start) {
68 if let Some((group, _)) = slot
69 && group.info.sequence >= min_sequence
70 {
71 return Poll::Ready(Ok(Some((group.consume(), self.offset + i))));
72 }
73 }
74
75 if self.final_sequence.is_some() {
77 Poll::Ready(Ok(None))
78 } else if let Some(err) = &self.abort {
79 Poll::Ready(Err(err.clone()))
80 } else {
81 Poll::Pending
82 }
83 }
84
85 fn poll_get_group(&self, sequence: u64) -> Poll<Result<Option<GroupConsumer>>> {
86 for (group, _) in self.groups.iter().flatten() {
88 if group.info.sequence == sequence {
89 return Poll::Ready(Ok(Some(group.consume())));
90 }
91 }
92
93 if let Some(fin) = self.final_sequence
95 && sequence >= fin
96 {
97 return Poll::Ready(Ok(None));
98 }
99
100 if let Some(err) = &self.abort {
101 return Poll::Ready(Err(err.clone()));
102 }
103
104 Poll::Pending
105 }
106
107 fn poll_closed(&self) -> Poll<Result<()>> {
108 if self.final_sequence.is_some() {
109 Poll::Ready(Ok(()))
110 } else if let Some(err) = &self.abort {
111 Poll::Ready(Err(err.clone()))
112 } else {
113 Poll::Pending
114 }
115 }
116
117 fn evict_expired(&mut self, now: tokio::time::Instant) {
124 for slot in self.groups.iter_mut() {
125 let Some((group, created_at)) = slot else { continue };
126
127 if Some(group.info.sequence) == self.max_sequence {
128 continue;
129 }
130
131 if now.duration_since(*created_at) <= MAX_GROUP_AGE {
132 break;
133 }
134
135 self.duplicates.remove(&group.info.sequence);
136 *slot = None;
137 }
138
139 while let Some(None) = self.groups.front() {
141 self.groups.pop_front();
142 self.offset += 1;
143 }
144 }
145
146 fn poll_finished(&self) -> Poll<Result<u64>> {
147 if let Some(fin) = self.final_sequence {
148 Poll::Ready(Ok(fin))
149 } else if let Some(err) = &self.abort {
150 Poll::Ready(Err(err.clone()))
151 } else {
152 Poll::Pending
153 }
154 }
155}
156
157pub struct TrackProducer {
159 pub info: Track,
160 state: conducer::Producer<State>,
161}
162
163impl TrackProducer {
164 pub fn new(info: Track) -> Self {
165 Self {
166 info,
167 state: conducer::Producer::default(),
168 }
169 }
170
171 pub fn create_group(&mut self, info: Group) -> Result<GroupProducer> {
173 let group = info.produce();
174
175 let mut state = self.modify()?;
176 if let Some(fin) = state.final_sequence
177 && group.info.sequence >= fin
178 {
179 return Err(Error::Closed);
180 }
181
182 if !state.duplicates.insert(group.info.sequence) {
183 return Err(Error::Duplicate);
184 }
185
186 let now = tokio::time::Instant::now();
187 state.max_sequence = Some(state.max_sequence.unwrap_or(0).max(group.info.sequence));
188 state.groups.push_back(Some((group.clone(), now)));
189 state.evict_expired(now);
190
191 Ok(group)
192 }
193
194 pub fn append_group(&mut self) -> Result<GroupProducer> {
196 let mut state = self.modify()?;
197 let sequence = match state.max_sequence {
198 Some(s) => s.checked_add(1).ok_or(Error::BoundsExceeded)?,
199 None => 0,
200 };
201 if let Some(fin) = state.final_sequence
202 && sequence >= fin
203 {
204 return Err(Error::Closed);
205 }
206
207 let group = Group { sequence }.produce();
208
209 let now = tokio::time::Instant::now();
210 state.duplicates.insert(sequence);
211 state.max_sequence = Some(sequence);
212 state.groups.push_back(Some((group.clone(), now)));
213 state.evict_expired(now);
214
215 Ok(group)
216 }
217
218 pub fn write_frame<B: Into<bytes::Bytes>>(&mut self, frame: B) -> Result<()> {
220 let mut group = self.append_group()?;
221 group.write_frame(frame.into())?;
222 group.finish()?;
223 Ok(())
224 }
225
226 pub fn finish(&mut self) -> Result<()> {
232 let mut state = self.modify()?;
233 if state.final_sequence.is_some() {
234 return Err(Error::Closed);
235 }
236 state.final_sequence = Some(match state.max_sequence {
237 Some(max) => max.checked_add(1).ok_or(Error::BoundsExceeded)?,
238 None => 0,
239 });
240 Ok(())
241 }
242
243 #[deprecated(note = "use finish() or finish_at(sequence) instead")]
248 pub fn close(&mut self) -> Result<()> {
249 self.finish()
250 }
251
252 pub fn finish_at(&mut self, sequence: u64) -> Result<()> {
259 let mut state = self.modify()?;
260 let max = state.max_sequence.ok_or(Error::Closed)?;
261 if state.final_sequence.is_some() || sequence != max {
262 return Err(Error::Closed);
263 }
264 state.final_sequence = Some(max.checked_add(1).ok_or(Error::BoundsExceeded)?);
265 Ok(())
266 }
267
268 pub fn abort(&mut self, err: Error) -> Result<()> {
270 let mut guard = self.modify()?;
271
272 for (group, _) in guard.groups.iter_mut().flatten() {
274 group.abort(err.clone()).ok();
276 }
277
278 guard.abort = Some(err);
279 guard.close();
280 Ok(())
281 }
282
283 pub fn consume(&self) -> TrackConsumer {
285 TrackConsumer {
286 info: self.info.clone(),
287 state: self.state.consume(),
288 index: 0,
289 min_sequence: 0,
290 }
291 }
292
293 pub async fn unused(&self) -> Result<()> {
295 self.state
296 .unused()
297 .await
298 .map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
299 }
300
301 pub fn is_closed(&self) -> bool {
303 self.state.read().is_closed()
304 }
305
306 pub fn is_clone(&self, other: &Self) -> bool {
308 self.state.same_channel(&other.state)
309 }
310
311 pub(crate) fn weak(&self) -> TrackWeak {
313 TrackWeak {
314 info: self.info.clone(),
315 state: self.state.weak(),
316 }
317 }
318
319 fn modify(&self) -> Result<conducer::Mut<'_, State>> {
320 self.state
321 .write()
322 .map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
323 }
324}
325
326impl Clone for TrackProducer {
327 fn clone(&self) -> Self {
328 Self {
329 info: self.info.clone(),
330 state: self.state.clone(),
331 }
332 }
333}
334
335impl From<Track> for TrackProducer {
336 fn from(info: Track) -> Self {
337 TrackProducer::new(info)
338 }
339}
340
341#[derive(Clone)]
343pub(crate) struct TrackWeak {
344 pub info: Track,
345 state: conducer::Weak<State>,
346}
347
348impl TrackWeak {
349 pub fn abort(&self, err: Error) {
350 let Ok(mut guard) = self.state.write() else { return };
351
352 for (group, _) in guard.groups.iter_mut().flatten() {
354 group.abort(err.clone()).ok();
355 }
356
357 guard.abort = Some(err);
358 guard.close();
359 }
360
361 pub fn is_closed(&self) -> bool {
362 self.state.is_closed()
363 }
364
365 pub fn consume(&self) -> TrackConsumer {
366 TrackConsumer {
367 info: self.info.clone(),
368 state: self.state.consume(),
369 index: 0,
370 min_sequence: 0,
371 }
372 }
373
374 pub async fn unused(&self) -> crate::Result<()> {
375 self.state
376 .unused()
377 .await
378 .map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
379 }
380
381 pub fn is_clone(&self, other: &Self) -> bool {
382 self.state.same_channel(&other.state)
383 }
384}
385
386#[derive(Clone)]
388pub struct TrackConsumer {
389 pub info: Track,
390 state: conducer::Consumer<State>,
391 index: usize,
392
393 min_sequence: u64,
394}
395
396impl TrackConsumer {
397 fn poll<F, R>(&self, waiter: &conducer::Waiter, f: F) -> Poll<Result<R>>
399 where
400 F: Fn(&conducer::Ref<'_, State>) -> Poll<Result<R>>,
401 {
402 Poll::Ready(match ready!(self.state.poll(waiter, f)) {
403 Ok(res) => res,
404 Err(state) => Err(state.abort.clone().unwrap_or(Error::Dropped)),
406 })
407 }
408
409 pub fn poll_next_group(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Option<GroupConsumer>>> {
416 let Some((consumer, found_index)) =
417 ready!(self.poll(waiter, |state| state.poll_next_group(self.index, self.min_sequence))?)
418 else {
419 return Poll::Ready(Ok(None));
420 };
421
422 self.index = found_index + 1;
423 Poll::Ready(Ok(Some(consumer)))
424 }
425
426 pub async fn next_group(&mut self) -> Result<Option<GroupConsumer>> {
430 conducer::wait(|waiter| self.poll_next_group(waiter)).await
431 }
432
433 pub fn poll_get_group(&self, waiter: &conducer::Waiter, sequence: u64) -> Poll<Result<Option<GroupConsumer>>> {
435 self.poll(waiter, |state| state.poll_get_group(sequence))
436 }
437
438 pub async fn get_group(&self, sequence: u64) -> Result<Option<GroupConsumer>> {
442 conducer::wait(|waiter| self.poll_get_group(waiter, sequence)).await
443 }
444
445 pub fn poll_closed(&self, waiter: &conducer::Waiter) -> Poll<Result<()>> {
447 self.poll(waiter, |state| state.poll_closed())
448 }
449
450 pub async fn closed(&self) -> Result<()> {
454 conducer::wait(|waiter| self.poll_closed(waiter)).await
455 }
456
457 pub fn is_clone(&self, other: &Self) -> bool {
458 self.state.same_channel(&other.state)
459 }
460
461 pub fn poll_finished(&mut self, waiter: &conducer::Waiter) -> Poll<Result<u64>> {
463 self.poll(waiter, |state| state.poll_finished())
464 }
465
466 pub async fn finished(&mut self) -> Result<u64> {
468 conducer::wait(|waiter| self.poll_finished(waiter)).await
469 }
470
471 pub fn start_at(&mut self, sequence: u64) {
473 self.min_sequence = sequence;
474 }
475
476 pub fn latest(&self) -> Option<u64> {
478 self.state.read().max_sequence
479 }
480}
481
482#[cfg(test)]
483use futures::FutureExt;
484
485#[cfg(test)]
486impl TrackConsumer {
487 pub fn assert_group(&mut self) -> GroupConsumer {
488 self.next_group()
489 .now_or_never()
490 .expect("group would have blocked")
491 .expect("would have errored")
492 .expect("track was closed")
493 }
494
495 pub fn assert_no_group(&mut self) {
496 assert!(
497 self.next_group().now_or_never().is_none(),
498 "next group would not have blocked"
499 );
500 }
501
502 pub fn assert_not_closed(&self) {
503 assert!(self.closed().now_or_never().is_none(), "should not be closed");
504 }
505
506 pub fn assert_closed(&self) {
507 assert!(self.closed().now_or_never().is_some(), "should be closed");
508 }
509
510 pub fn assert_error(&self) {
512 assert!(
513 self.closed().now_or_never().expect("should not block").is_err(),
514 "should be error"
515 );
516 }
517
518 pub fn assert_is_clone(&self, other: &Self) {
519 assert!(self.is_clone(other), "should be clone");
520 }
521
522 pub fn assert_not_clone(&self, other: &Self) {
523 assert!(!self.is_clone(other), "should not be clone");
524 }
525}
526
527#[cfg(test)]
528mod test {
529 use super::*;
530
531 fn live_groups(state: &State) -> usize {
533 state.groups.iter().flatten().count()
534 }
535
536 fn first_live_sequence(state: &State) -> u64 {
538 state.groups.iter().flatten().next().unwrap().0.info.sequence
539 }
540
541 #[tokio::test]
542 async fn evict_expired_groups() {
543 tokio::time::pause();
544
545 let mut producer = Track::new("test").produce();
546
547 producer.append_group().unwrap(); producer.append_group().unwrap(); producer.append_group().unwrap(); {
553 let state = producer.state.read();
554 assert_eq!(live_groups(&state), 3);
555 assert_eq!(state.offset, 0);
556 }
557
558 tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
560
561 producer.append_group().unwrap(); {
567 let state = producer.state.read();
568 assert_eq!(live_groups(&state), 1);
569 assert_eq!(first_live_sequence(&state), 3);
570 assert_eq!(state.offset, 3);
571 assert!(!state.duplicates.contains(&0));
572 assert!(!state.duplicates.contains(&1));
573 assert!(!state.duplicates.contains(&2));
574 assert!(state.duplicates.contains(&3));
575 }
576 }
577
578 #[tokio::test]
579 async fn evict_keeps_max_sequence() {
580 tokio::time::pause();
581
582 let mut producer = Track::new("test").produce();
583 producer.append_group().unwrap(); tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
587
588 producer.append_group().unwrap(); {
592 let state = producer.state.read();
593 assert_eq!(live_groups(&state), 1);
594 assert_eq!(first_live_sequence(&state), 1);
595 assert_eq!(state.offset, 1);
596 }
597 }
598
599 #[tokio::test]
600 async fn no_eviction_when_fresh() {
601 tokio::time::pause();
602
603 let mut producer = Track::new("test").produce();
604 producer.append_group().unwrap(); producer.append_group().unwrap(); producer.append_group().unwrap(); {
609 let state = producer.state.read();
610 assert_eq!(live_groups(&state), 3);
611 assert_eq!(state.offset, 0);
612 }
613 }
614
615 #[tokio::test]
616 async fn consumer_skips_evicted_groups() {
617 tokio::time::pause();
618
619 let mut producer = Track::new("test").produce();
620 producer.append_group().unwrap(); let mut consumer = producer.consume();
623
624 tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
625 producer.append_group().unwrap(); let group = consumer.assert_group();
629 assert_eq!(group.info.sequence, 1);
630 }
631
632 #[tokio::test]
633 async fn out_of_order_max_sequence_at_front() {
634 tokio::time::pause();
635
636 let mut producer = Track::new("test").produce();
637
638 producer.create_group(Group { sequence: 5 }).unwrap();
640 producer.create_group(Group { sequence: 3 }).unwrap();
641 producer.create_group(Group { sequence: 4 }).unwrap();
642
643 {
645 let state = producer.state.read();
646 assert_eq!(state.max_sequence, Some(5));
647 }
648
649 tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
651
652 producer.append_group().unwrap(); {
658 let state = producer.state.read();
659 assert_eq!(live_groups(&state), 1);
660 assert_eq!(first_live_sequence(&state), 6);
661 assert!(!state.duplicates.contains(&3));
662 assert!(!state.duplicates.contains(&4));
663 assert!(!state.duplicates.contains(&5));
664 assert!(state.duplicates.contains(&6));
665 }
666 }
667
668 #[tokio::test]
669 async fn max_sequence_at_front_blocks_trim() {
670 tokio::time::pause();
671
672 let mut producer = Track::new("test").produce();
673
674 producer.create_group(Group { sequence: 5 }).unwrap();
676
677 tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
678
679 producer.create_group(Group { sequence: 3 }).unwrap();
681
682 {
685 let state = producer.state.read();
686 assert_eq!(live_groups(&state), 2);
687 assert_eq!(state.offset, 0);
688 }
689
690 tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
692
693 producer.create_group(Group { sequence: 2 }).unwrap();
695
696 {
701 let state = producer.state.read();
702 assert_eq!(live_groups(&state), 2);
703 assert_eq!(state.offset, 0);
704 assert!(state.duplicates.contains(&5));
705 assert!(!state.duplicates.contains(&3));
706 assert!(state.duplicates.contains(&2));
707 }
708
709 let mut consumer = producer.consume();
711 let group = consumer.assert_group();
712 assert_eq!(group.info.sequence, 5);
714 }
715
716 #[test]
717 fn append_finish_cannot_be_rewritten() {
718 let mut producer = Track::new("test").produce();
719
720 assert!(producer.finish().is_ok());
722 assert!(producer.finish().is_err());
723 assert!(producer.append_group().is_err());
724 }
725
726 #[test]
727 fn finish_after_groups() {
728 let mut producer = Track::new("test").produce();
729
730 producer.append_group().unwrap();
731 assert!(producer.finish().is_ok());
732 assert!(producer.finish().is_err());
733 assert!(producer.append_group().is_err());
734 }
735
736 #[test]
737 fn insert_finish_validates_sequence_and_freezes_to_max() {
738 let mut producer = Track::new("test").produce();
739 producer.create_group(Group { sequence: 5 }).unwrap();
740
741 assert!(producer.finish_at(4).is_err());
742 assert!(producer.finish_at(10).is_err());
743 assert!(producer.finish_at(5).is_ok());
744
745 {
746 let state = producer.state.read();
747 assert_eq!(state.final_sequence, Some(6));
748 }
749
750 assert!(producer.finish_at(5).is_err());
751 assert!(producer.create_group(Group { sequence: 4 }).is_ok());
752 assert!(producer.create_group(Group { sequence: 5 }).is_err());
753 }
754
755 #[tokio::test]
756 async fn next_group_finishes_without_waiting_for_gaps() {
757 let mut producer = Track::new("test").produce();
758 producer.create_group(Group { sequence: 1 }).unwrap();
759 producer.finish_at(1).unwrap();
760
761 let mut consumer = producer.consume();
762 assert_eq!(consumer.assert_group().info.sequence, 1);
763
764 let done = consumer
765 .next_group()
766 .now_or_never()
767 .expect("should not block")
768 .expect("would have errored");
769 assert!(done.is_none(), "track should finish without waiting for gaps");
770 }
771
772 #[tokio::test]
773 async fn get_group_finishes_without_waiting_for_gaps() {
774 let mut producer = Track::new("test").produce();
775 producer.create_group(Group { sequence: 1 }).unwrap();
776 producer.finish_at(1).unwrap();
777
778 let consumer = producer.consume();
779 assert!(
781 consumer.get_group(0).now_or_never().is_none(),
782 "sequence below fin should block (group could still arrive)"
783 );
784 assert!(
785 consumer
786 .get_group(2)
787 .now_or_never()
788 .expect("sequence at-or-after fin should resolve")
789 .expect("should not error")
790 .is_none(),
791 "sequence at-or-after fin should not exist"
792 );
793 }
794
795 #[test]
796 fn append_group_returns_bounds_exceeded_on_sequence_overflow() {
797 let mut producer = Track::new("test").produce();
798 {
799 let mut state = producer.state.write().ok().unwrap();
800 state.max_sequence = Some(u64::MAX);
801 }
802
803 assert!(matches!(producer.append_group(), Err(Error::BoundsExceeded)));
804 }
805}