1use crate::{Error, Result};
16
17use super::state::{Consumer, Producer, Weak};
18use super::waiter::waiter_fn;
19use super::{Group, GroupConsumer, GroupProducer};
20
21use std::{
22 collections::{HashSet, VecDeque},
23 task::Poll,
24 time::Duration,
25};
26
27const MAX_GROUP_AGE: Duration = Duration::from_secs(30);
30
31#[derive(Clone, Debug, PartialEq, Eq)]
33#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
34pub struct Track {
35 pub name: String,
36 pub priority: u8,
37}
38
39impl Track {
40 pub fn new<T: Into<String>>(name: T) -> Self {
41 Self {
42 name: name.into(),
43 priority: 0,
44 }
45 }
46
47 pub fn produce(self) -> TrackProducer {
48 TrackProducer::new(self)
49 }
50}
51
52#[derive(Default)]
53struct State {
54 groups: VecDeque<Option<(GroupProducer, tokio::time::Instant)>>,
56 duplicates: HashSet<u64>,
57 offset: usize,
58 max_sequence: Option<u64>,
59 final_sequence: Option<u64>,
60}
61
62impl State {
63 fn poll_next_group(&self, index: usize) -> Poll<Option<(GroupProducer, usize)>> {
67 let start = index.saturating_sub(self.offset);
68 for (i, slot) in self.groups.iter().enumerate().skip(start) {
69 if let Some((group, _)) = slot {
70 return Poll::Ready(Some((group.clone(), self.offset + i)));
71 }
72 }
73
74 if self.final_sequence.is_some() {
75 Poll::Ready(None)
76 } else {
77 Poll::Pending
78 }
79 }
80
81 fn poll_get_group(&self, sequence: u64) -> Poll<Option<GroupProducer>> {
82 for (group, _) in self.groups.iter().flatten() {
84 if group.info.sequence == sequence {
85 return Poll::Ready(Some(group.clone()));
86 }
87 }
88
89 if let Some(fin) = self.final_sequence
91 && sequence >= fin
92 {
93 return Poll::Ready(None);
94 }
95
96 if self.final_sequence.is_some() {
97 return Poll::Ready(None);
98 }
99
100 Poll::Pending
101 }
102
103 fn evict_expired(&mut self, now: tokio::time::Instant) {
110 for slot in self.groups.iter_mut() {
111 let Some((group, created_at)) = slot else { continue };
112
113 if Some(group.info.sequence) == self.max_sequence {
114 continue;
115 }
116
117 if now.duration_since(*created_at) <= MAX_GROUP_AGE {
118 break;
119 }
120
121 self.duplicates.remove(&group.info.sequence);
122 *slot = None;
123 }
124
125 while let Some(None) = self.groups.front() {
127 self.groups.pop_front();
128 self.offset += 1;
129 }
130 }
131}
132
133pub struct TrackProducer {
135 pub info: Track,
136 state: Producer<State>,
137}
138
139impl TrackProducer {
140 pub fn new(info: Track) -> Self {
141 Self {
142 info,
143 state: Producer::default(),
144 }
145 }
146
147 pub fn create_group(&mut self, info: Group) -> Result<GroupProducer> {
149 let group = info.produce();
150
151 let mut state = self.state.modify()?;
152 if let Some(fin) = state.final_sequence
153 && group.info.sequence >= fin
154 {
155 return Err(Error::Closed);
156 }
157
158 if !state.duplicates.insert(group.info.sequence) {
159 return Err(Error::Duplicate);
160 }
161
162 let now = tokio::time::Instant::now();
163 state.max_sequence = Some(state.max_sequence.unwrap_or(0).max(group.info.sequence));
164 state.groups.push_back(Some((group.clone(), now)));
165 state.evict_expired(now);
166
167 Ok(group)
168 }
169
170 pub fn append_group(&mut self) -> Result<GroupProducer> {
172 let mut state = self.state.modify()?;
173 let sequence = match state.max_sequence {
174 Some(s) => s.checked_add(1).ok_or(Error::BoundsExceeded)?,
175 None => 0,
176 };
177 if let Some(fin) = state.final_sequence
178 && sequence >= fin
179 {
180 return Err(Error::Closed);
181 }
182
183 let group = Group { sequence }.produce();
184
185 let now = tokio::time::Instant::now();
186 state.duplicates.insert(sequence);
187 state.max_sequence = Some(sequence);
188 state.groups.push_back(Some((group.clone(), now)));
189 state.evict_expired(now);
190
191 Ok(group)
192 }
193
194 pub fn write_frame<B: Into<bytes::Bytes>>(&mut self, frame: B) -> Result<()> {
196 let mut group = self.append_group()?;
197 group.write_frame(frame.into())?;
198 group.finish()?;
199 Ok(())
200 }
201
202 pub fn finish(&mut self) -> Result<()> {
208 let mut state = self.state.modify()?;
209 if state.final_sequence.is_some() {
210 return Err(Error::Closed);
211 }
212 state.final_sequence = Some(match state.max_sequence {
213 Some(max) => max.checked_add(1).ok_or(Error::BoundsExceeded)?,
214 None => 0,
215 });
216 Ok(())
217 }
218
219 #[deprecated(note = "use finish() or finish_at(sequence) instead")]
224 pub fn close(&mut self) -> Result<()> {
225 self.finish()
226 }
227
228 pub fn finish_at(&mut self, sequence: u64) -> Result<()> {
235 let mut state = self.state.modify()?;
236 let max = state.max_sequence.ok_or(Error::Closed)?;
237 if state.final_sequence.is_some() || sequence != max {
238 return Err(Error::Closed);
239 }
240 state.final_sequence = Some(max.checked_add(1).ok_or(Error::BoundsExceeded)?);
241 Ok(())
242 }
243
244 pub fn abort(&mut self, err: Error) -> Result<()> {
246 let mut state = self.state.modify()?;
247
248 for (group, _) in state.groups.iter_mut().flatten() {
250 group.abort(err.clone()).ok();
252 }
253
254 state.abort(err);
255 Ok(())
256 }
257
258 pub fn consume(&self) -> TrackConsumer {
260 let state = self.state.borrow();
261 let index = state.offset + state.groups.len().saturating_sub(1);
262
263 TrackConsumer {
264 info: self.info.clone(),
265 state: self.state.consume(),
266 index,
267 }
268 }
269
270 pub async fn unused(&self) -> Result<()> {
272 self.state.unused().await
273 }
274
275 pub fn is_closed(&self) -> bool {
277 self.state.borrow().is_closed()
278 }
279
280 pub fn is_clone(&self, other: &Self) -> bool {
282 self.state.is_clone(&other.state)
283 }
284
285 pub(crate) fn weak(&self) -> TrackWeak {
287 TrackWeak {
288 info: self.info.clone(),
289 state: self.state.weak(),
290 }
291 }
292}
293
294impl Clone for TrackProducer {
295 fn clone(&self) -> Self {
296 Self {
297 info: self.info.clone(),
298 state: self.state.clone(),
299 }
300 }
301}
302
303impl From<Track> for TrackProducer {
304 fn from(info: Track) -> Self {
305 TrackProducer::new(info)
306 }
307}
308
309#[derive(Clone)]
311pub(crate) struct TrackWeak {
312 pub info: Track,
313 state: Weak<State>,
314}
315
316impl TrackWeak {
317 pub fn abort(&self, err: Error) {
318 let Ok(producer) = self.state.produce() else { return };
320 let Ok(mut state) = producer.modify() else { return };
321
322 for (group, _) in state.groups.iter_mut().flatten() {
324 group.abort(err.clone()).ok();
325 }
326
327 state.abort(err);
328 }
329
330 pub fn is_closed(&self) -> bool {
331 self.state.is_closed()
332 }
333
334 pub fn consume(&self) -> TrackConsumer {
335 let state = self.state.borrow();
336 let index = state.offset + state.groups.len().saturating_sub(1);
337
338 TrackConsumer {
339 info: self.info.clone(),
340 state: self.state.consume(),
341 index,
342 }
343 }
344
345 pub async fn unused(&self) -> crate::Result<()> {
346 self.state.unused().await
347 }
348
349 pub fn is_clone(&self, other: &Self) -> bool {
350 self.state.is_clone(&other.state)
351 }
352}
353
354#[derive(Clone)]
356pub struct TrackConsumer {
357 pub info: Track,
358 state: Consumer<State>,
359 index: usize,
360}
361
362impl TrackConsumer {
363 pub async fn next_group(&mut self) -> Result<Option<GroupConsumer>> {
367 let index = self.index;
368 let res = waiter_fn(|waiter| self.state.poll(waiter, |state| state.poll_next_group(index))).await?;
369 let consumer = res.map(|(producer, found_index)| {
370 self.index = found_index + 1;
371 producer.consume()
372 });
373 Ok(consumer)
374 }
375
376 pub async fn get_group(&self, sequence: u64) -> Result<Option<GroupConsumer>> {
380 let res = waiter_fn(|waiter| self.state.poll(waiter, |state| state.poll_get_group(sequence))).await?;
381 Ok(res.map(|producer| producer.consume()))
382 }
383
384 pub async fn closed(&self) -> Result<()> {
386 let err = self.state.closed().await;
387 match err {
388 Error::Closed | Error::Dropped => Ok(()),
389 err => Err(err),
390 }
391 }
392
393 pub fn is_clone(&self, other: &Self) -> bool {
394 self.state.is_clone(&other.state)
395 }
396}
397
398#[cfg(test)]
399use futures::FutureExt;
400
401#[cfg(test)]
402impl TrackConsumer {
403 pub fn assert_group(&mut self) -> GroupConsumer {
404 self.next_group()
405 .now_or_never()
406 .expect("group would have blocked")
407 .expect("would have errored")
408 .expect("track was closed")
409 }
410
411 pub fn assert_no_group(&mut self) {
412 assert!(
413 self.next_group().now_or_never().is_none(),
414 "next group would not have blocked"
415 );
416 }
417
418 pub fn assert_not_closed(&self) {
419 assert!(self.closed().now_or_never().is_none(), "should not be closed");
420 }
421
422 pub fn assert_closed(&self) {
423 assert!(self.closed().now_or_never().is_some(), "should be closed");
424 }
425
426 pub fn assert_error(&self) {
428 assert!(
429 self.closed().now_or_never().expect("should not block").is_err(),
430 "should be error"
431 );
432 }
433
434 pub fn assert_is_clone(&self, other: &Self) {
435 assert!(self.is_clone(other), "should be clone");
436 }
437
438 pub fn assert_not_clone(&self, other: &Self) {
439 assert!(!self.is_clone(other), "should not be clone");
440 }
441}
442
443#[cfg(test)]
444mod test {
445 use super::*;
446
447 fn live_groups(state: &State) -> usize {
449 state.groups.iter().flatten().count()
450 }
451
452 fn first_live_sequence(state: &State) -> u64 {
454 state.groups.iter().flatten().next().unwrap().0.info.sequence
455 }
456
457 #[tokio::test]
458 async fn evict_expired_groups() {
459 tokio::time::pause();
460
461 let mut producer = Track::new("test").produce();
462
463 producer.append_group().unwrap(); producer.append_group().unwrap(); producer.append_group().unwrap(); {
469 let state = producer.state.borrow();
470 assert_eq!(live_groups(&state), 3);
471 assert_eq!(state.offset, 0);
472 }
473
474 tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
476
477 producer.append_group().unwrap(); {
483 let state = producer.state.borrow();
484 assert_eq!(live_groups(&state), 1);
485 assert_eq!(first_live_sequence(&state), 3);
486 assert_eq!(state.offset, 3);
487 assert!(!state.duplicates.contains(&0));
488 assert!(!state.duplicates.contains(&1));
489 assert!(!state.duplicates.contains(&2));
490 assert!(state.duplicates.contains(&3));
491 }
492 }
493
494 #[tokio::test]
495 async fn evict_keeps_max_sequence() {
496 tokio::time::pause();
497
498 let mut producer = Track::new("test").produce();
499 producer.append_group().unwrap(); tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
503
504 producer.append_group().unwrap(); {
508 let state = producer.state.borrow();
509 assert_eq!(live_groups(&state), 1);
510 assert_eq!(first_live_sequence(&state), 1);
511 assert_eq!(state.offset, 1);
512 }
513 }
514
515 #[tokio::test]
516 async fn no_eviction_when_fresh() {
517 tokio::time::pause();
518
519 let mut producer = Track::new("test").produce();
520 producer.append_group().unwrap(); producer.append_group().unwrap(); producer.append_group().unwrap(); {
525 let state = producer.state.borrow();
526 assert_eq!(live_groups(&state), 3);
527 assert_eq!(state.offset, 0);
528 }
529 }
530
531 #[tokio::test]
532 async fn consumer_skips_evicted_groups() {
533 tokio::time::pause();
534
535 let mut producer = Track::new("test").produce();
536 producer.append_group().unwrap(); let mut consumer = producer.consume();
539
540 tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
541 producer.append_group().unwrap(); let group = consumer.assert_group();
545 assert_eq!(group.info.sequence, 1);
546 }
547
548 #[tokio::test]
549 async fn out_of_order_max_sequence_at_front() {
550 tokio::time::pause();
551
552 let mut producer = Track::new("test").produce();
553
554 producer.create_group(Group { sequence: 5 }).unwrap();
556 producer.create_group(Group { sequence: 3 }).unwrap();
557 producer.create_group(Group { sequence: 4 }).unwrap();
558
559 {
561 let state = producer.state.borrow();
562 assert_eq!(state.max_sequence, Some(5));
563 }
564
565 tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
567
568 producer.append_group().unwrap(); {
574 let state = producer.state.borrow();
575 assert_eq!(live_groups(&state), 1);
576 assert_eq!(first_live_sequence(&state), 6);
577 assert!(!state.duplicates.contains(&3));
578 assert!(!state.duplicates.contains(&4));
579 assert!(!state.duplicates.contains(&5));
580 assert!(state.duplicates.contains(&6));
581 }
582 }
583
584 #[tokio::test]
585 async fn max_sequence_at_front_blocks_trim() {
586 tokio::time::pause();
587
588 let mut producer = Track::new("test").produce();
589
590 producer.create_group(Group { sequence: 5 }).unwrap();
592
593 tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
594
595 producer.create_group(Group { sequence: 3 }).unwrap();
597
598 {
601 let state = producer.state.borrow();
602 assert_eq!(live_groups(&state), 2);
603 assert_eq!(state.offset, 0);
604 }
605
606 tokio::time::advance(MAX_GROUP_AGE + Duration::from_secs(1)).await;
608
609 producer.create_group(Group { sequence: 2 }).unwrap();
611
612 {
617 let state = producer.state.borrow();
618 assert_eq!(live_groups(&state), 2);
619 assert_eq!(state.offset, 0);
620 assert!(state.duplicates.contains(&5));
621 assert!(!state.duplicates.contains(&3));
622 assert!(state.duplicates.contains(&2));
623 }
624
625 let mut consumer = producer.consume();
627 let group = consumer.assert_group();
628 assert_eq!(group.info.sequence, 2);
630 }
631
632 #[test]
633 fn append_finish_cannot_be_rewritten() {
634 let mut producer = Track::new("test").produce();
635
636 assert!(producer.finish().is_ok());
638 assert!(producer.finish().is_err());
639 assert!(producer.append_group().is_err());
640 }
641
642 #[test]
643 fn finish_after_groups() {
644 let mut producer = Track::new("test").produce();
645
646 producer.append_group().unwrap();
647 assert!(producer.finish().is_ok());
648 assert!(producer.finish().is_err());
649 assert!(producer.append_group().is_err());
650 }
651
652 #[test]
653 fn insert_finish_validates_sequence_and_freezes_to_max() {
654 let mut producer = Track::new("test").produce();
655 producer.create_group(Group { sequence: 5 }).unwrap();
656
657 assert!(producer.finish_at(4).is_err());
658 assert!(producer.finish_at(10).is_err());
659 assert!(producer.finish_at(5).is_ok());
660
661 {
662 let state = producer.state.borrow();
663 assert_eq!(state.final_sequence, Some(6));
664 }
665
666 assert!(producer.finish_at(5).is_err());
667 assert!(producer.create_group(Group { sequence: 4 }).is_ok());
668 assert!(producer.create_group(Group { sequence: 5 }).is_err());
669 }
670
671 #[tokio::test]
672 async fn next_group_finishes_without_waiting_for_gaps() {
673 let mut producer = Track::new("test").produce();
674 producer.create_group(Group { sequence: 1 }).unwrap();
675 producer.finish_at(1).unwrap();
676
677 let mut consumer = producer.consume();
678 assert_eq!(consumer.assert_group().info.sequence, 1);
679
680 let done = consumer
681 .next_group()
682 .now_or_never()
683 .expect("should not block")
684 .expect("would have errored");
685 assert!(done.is_none(), "track should finish without waiting for gaps");
686 }
687
688 #[tokio::test]
689 async fn get_group_finishes_without_waiting_for_gaps() {
690 let mut producer = Track::new("test").produce();
691 producer.create_group(Group { sequence: 1 }).unwrap();
692 producer.finish_at(1).unwrap();
693
694 let consumer = producer.consume();
695 assert!(
696 consumer
697 .get_group(0)
698 .now_or_never()
699 .expect("should not block")
700 .expect("would have errored")
701 .is_none(),
702 "sequence below fin should not block forever"
703 );
704 assert!(
705 consumer
706 .get_group(2)
707 .now_or_never()
708 .expect("sequence at-or-after fin should resolve")
709 .expect("should not error")
710 .is_none(),
711 "sequence at-or-after fin should not exist"
712 );
713 }
714
715 #[test]
716 fn append_group_returns_bounds_exceeded_on_sequence_overflow() {
717 let mut producer = Track::new("test").produce();
718 {
719 let mut state = producer.state.modify().unwrap();
720 state.max_sequence = Some(u64::MAX);
721 }
722
723 assert!(matches!(producer.append_group(), Err(Error::BoundsExceeded)));
724 }
725}