Skip to main content

moq_transport/serve/
stream.rs

1use bytes::Bytes;
2use std::{ops::Deref, sync::Arc};
3
4use crate::data::ObjectStatus;
5use crate::watch::State;
6
7use super::{ServeError, Track};
8
9#[derive(Debug, PartialEq, Clone)]
10pub struct Stream {
11    pub track: Arc<Track>,
12    pub priority: u8,
13}
14
15impl Stream {
16    pub fn produce(self) -> (StreamWriter, StreamReader) {
17        let (writer, reader) = State::default().split();
18        let info = Arc::new(self);
19
20        let writer = StreamWriter::new(writer, info.clone());
21        let reader = StreamReader::new(reader, info);
22
23        (writer, reader)
24    }
25}
26
27impl Deref for Stream {
28    type Target = Track;
29
30    fn deref(&self) -> &Self::Target {
31        &self.track
32    }
33}
34
35struct StreamState {
36    // The latest group.
37    latest_group_reader: Option<StreamGroupReader>,
38
39    // Updated each time object changes.
40    epoch: usize,
41
42    // Set when the writer is dropped.
43    closed: Result<(), ServeError>,
44}
45
46impl Default for StreamState {
47    fn default() -> Self {
48        Self {
49            latest_group_reader: None,
50            epoch: 0,
51            closed: Ok(()),
52        }
53    }
54}
55
56/// Used to write data to a stream and notify readers.
57///
58/// This is Clone as a work-around, but be very careful because it's meant to be sequential.
59#[derive(Clone)]
60pub struct StreamWriter {
61    // Mutable stream state.
62    state: State<StreamState>,
63
64    // Immutable stream state.
65    pub info: Arc<Stream>,
66}
67
68impl StreamWriter {
69    fn new(state: State<StreamState>, info: Arc<Stream>) -> Self {
70        Self { state, info }
71    }
72
73    /// Create a new group with the given group_id for the stream
74    pub fn create(&mut self, group_id: u64) -> Result<StreamGroupWriter, ServeError> {
75        let mut state = self.state.lock_mut().ok_or(ServeError::Cancel)?;
76
77        // Ensure group_id is larger than (or equal to) the latest
78        if let Some(latest_group_reader) = &state.latest_group_reader {
79            if latest_group_reader.group_id > group_id {
80                return Err(ServeError::Duplicate);
81            }
82        }
83
84        // Create new StreamGroup
85        let group = Arc::new(StreamGroup {
86            stream: self.info.clone(),
87            group_id,
88        });
89
90        let (writer_state, reader_state) = State::default().split();
91
92        // Create StreamGroupWriter/StreamGroupReader pair
93        let stream_group_reader = StreamGroupReader::new(reader_state, group.clone());
94        let stream_group_writer = StreamGroupWriter::new(writer_state, group);
95
96        state.latest_group_reader = Some(stream_group_reader);
97        state.epoch += 1;
98
99        Ok(stream_group_writer)
100    }
101
102    /// Create a new group with the next sequential group_id for the stream.
103    pub fn append(&mut self) -> Result<StreamGroupWriter, ServeError> {
104        let next = self
105            .state
106            .lock()
107            .latest_group_reader
108            .as_ref()
109            .map(|g| g.group_id + 1)
110            .unwrap_or_default();
111        self.create(next)
112    }
113
114    /// Close the stream with an error.
115    pub fn close(self, err: ServeError) -> Result<(), ServeError> {
116        let state = self.state.lock();
117        state.closed.clone()?;
118
119        let mut state = state.into_mut().ok_or(ServeError::Cancel)?;
120        state.closed = Err(err);
121
122        Ok(())
123    }
124}
125
126impl Deref for StreamWriter {
127    type Target = Stream;
128
129    fn deref(&self) -> &Self::Target {
130        &self.info
131    }
132}
133
134/// Notified when a stream has new data available.
135#[derive(Clone)]
136pub struct StreamReader {
137    // Mutable stream state.
138    state: State<StreamState>,
139
140    // Immutable stream state.
141    pub info: Arc<Stream>,
142
143    // The number of chunks that we've read.
144    // NOTE: Cloned readers inherit this index, but then run in parallel.
145    epoch: usize,
146}
147
148impl StreamReader {
149    fn new(state: State<StreamState>, info: Arc<Stream>) -> Self {
150        Self {
151            state,
152            info,
153            epoch: 0,
154        }
155    }
156
157    /// Block until the next group is available.
158    pub async fn next(&mut self) -> Result<Option<StreamGroupReader>, ServeError> {
159        loop {
160            {
161                let state = self.state.lock();
162                if self.epoch != state.epoch {
163                    self.epoch = state.epoch;
164                    let latest = state.latest_group_reader.clone().unwrap();
165                    return Ok(Some(latest));
166                }
167
168                state.closed.clone()?;
169                match state.modified() {
170                    Some(notify) => notify,
171                    None => return Ok(None),
172                }
173            }
174            .await; // Try again when the state changes
175        }
176    }
177
178    /// Returns the largest group_id/object_id
179    pub fn latest(&self) -> Option<(u64, u64)> {
180        let state = self.state.lock();
181        state
182            .latest_group_reader
183            .as_ref()
184            .map(|stream_group_reader| {
185                (
186                    stream_group_reader.group_id,
187                    stream_group_reader.latest_object_id(),
188                )
189            })
190    }
191
192    /// Check if the stream writer has been closed or dropped.
193    pub fn is_closed(&self) -> bool {
194        let state = self.state.lock();
195        state.closed.is_err() || state.modified().is_none()
196    }
197}
198
199impl Deref for StreamReader {
200    type Target = Stream;
201
202    fn deref(&self) -> &Self::Target {
203        &self.info
204    }
205}
206
207#[derive(Clone, PartialEq, Debug)]
208pub struct StreamGroup {
209    pub stream: Arc<Stream>,
210    pub group_id: u64,
211}
212
213impl Deref for StreamGroup {
214    type Target = Stream;
215
216    fn deref(&self) -> &Self::Target {
217        &self.stream
218    }
219}
220
221struct StreamGroupState {
222    // The objects that have been received thus far.
223    objects: Vec<StreamObjectReader>,
224    closed: Result<(), ServeError>,
225}
226
227impl Default for StreamGroupState {
228    fn default() -> Self {
229        Self {
230            objects: Vec::new(),
231            closed: Ok(()),
232        }
233    }
234}
235
236pub struct StreamGroupWriter {
237    state: State<StreamGroupState>,
238    pub info: Arc<StreamGroup>,
239    next_object_id: u64,
240}
241
242impl StreamGroupWriter {
243    fn new(state: State<StreamGroupState>, info: Arc<StreamGroup>) -> Self {
244        Self {
245            state,
246            info,
247            next_object_id: 0,
248        }
249    }
250
251    /// Add a new object to the group.
252    pub fn write(&mut self, payload: Bytes) -> Result<(), ServeError> {
253        let mut writer = self.create(payload.len())?;
254        writer.write(payload)?;
255        Ok(())
256    }
257
258    /// Create a new object in the group with the given size.
259    pub fn create(&mut self, size: usize) -> Result<StreamObjectWriter, ServeError> {
260        let mut state = self.state.lock_mut().ok_or(ServeError::Cancel)?;
261
262        let (writer, reader) = StreamObject {
263            group: self.info.clone(),
264            object_id: self.next_object_id,
265            status: ObjectStatus::NormalObject,
266            size,
267        }
268        .produce();
269
270        state.objects.push(reader);
271
272        Ok(writer)
273    }
274
275    /// Close the stream with an error.
276    pub fn close(self, err: ServeError) -> Result<(), ServeError> {
277        let state = self.state.lock();
278        state.closed.clone()?;
279
280        let mut state = state.into_mut().ok_or(ServeError::Cancel)?;
281        state.closed = Err(err);
282
283        Ok(())
284    }
285}
286
287impl Deref for StreamGroupWriter {
288    type Target = StreamGroup;
289
290    fn deref(&self) -> &Self::Target {
291        &self.info
292    }
293}
294
295#[derive(Clone)]
296pub struct StreamGroupReader {
297    pub info: Arc<StreamGroup>,
298    state: State<StreamGroupState>,
299    index: usize,
300}
301
302impl StreamGroupReader {
303    fn new(state: State<StreamGroupState>, info: Arc<StreamGroup>) -> Self {
304        Self {
305            state,
306            info,
307            index: 0,
308        }
309    }
310
311    /// Read all remaining data from the current object, if any.
312    pub async fn read_next(&mut self) -> Result<Option<Bytes>, ServeError> {
313        if let Some(mut reader) = self.next().await? {
314            Ok(Some(reader.read_all().await?))
315        } else {
316            Ok(None)
317        }
318    }
319
320    /// Block until the next object is available.
321    pub async fn next(&mut self) -> Result<Option<StreamObjectReader>, ServeError> {
322        loop {
323            {
324                let state = self.state.lock();
325                if self.index < state.objects.len() {
326                    self.index += 1;
327                    return Ok(Some(state.objects[self.index].clone()));
328                }
329
330                state.closed.clone()?;
331                match state.modified() {
332                    Some(notify) => notify,
333                    None => return Ok(None),
334                }
335            }
336            .await;
337        }
338    }
339
340    /// Returns the largest object_id
341    pub fn latest_object_id(&self) -> u64 {
342        let state = self.state.lock();
343        state
344            .objects
345            .last()
346            .map(|o| o.object_id)
347            .unwrap_or_default()
348    }
349}
350
351impl Deref for StreamGroupReader {
352    type Target = StreamGroup;
353
354    fn deref(&self) -> &Self::Target {
355        &self.info
356    }
357}
358
359/// A subset of Object, since we use the group's info.
360#[derive(Clone, PartialEq, Debug)]
361pub struct StreamObject {
362    // The group this belongs to.
363    pub group: Arc<StreamGroup>,
364
365    pub object_id: u64,
366
367    // The size of the object.
368    pub size: usize,
369
370    // Object status
371    pub status: ObjectStatus,
372}
373
374impl StreamObject {
375    pub fn produce(self) -> (StreamObjectWriter, StreamObjectReader) {
376        let (writer_state, reader_state) = State::default().split();
377        let info = Arc::new(self);
378
379        let writer = StreamObjectWriter::new(writer_state, info.clone());
380        let reader = StreamObjectReader::new(reader_state, info);
381
382        (writer, reader)
383    }
384}
385
386impl Deref for StreamObject {
387    type Target = StreamGroup;
388
389    fn deref(&self) -> &Self::Target {
390        &self.group
391    }
392}
393
394struct StreamObjectState {
395    // The data that has been received thus far.
396    chunks: Vec<Bytes>,
397
398    closed: Result<(), ServeError>,
399}
400
401impl Default for StreamObjectState {
402    fn default() -> Self {
403        Self {
404            chunks: Vec::new(),
405            closed: Ok(()),
406        }
407    }
408}
409
410/// Used to write data to a segment and notify readers.
411pub struct StreamObjectWriter {
412    // Mutable segment state.
413    state: State<StreamObjectState>,
414
415    // Immutable segment state.
416    pub info: Arc<StreamObject>,
417
418    // The amount of promised data that has yet to be written.
419    remaining_write_bytes: usize,
420}
421
422impl StreamObjectWriter {
423    /// Create a new segment with the given info.
424    fn new(state: State<StreamObjectState>, info: Arc<StreamObject>) -> Self {
425        Self {
426            state,
427            remaining_write_bytes: info.size,
428            info,
429        }
430    }
431
432    /// Write a new chunk of bytes.
433    pub fn write(&mut self, chunk: Bytes) -> Result<(), ServeError> {
434        if chunk.len() > self.remaining_write_bytes {
435            return Err(ServeError::Size);
436        }
437        self.remaining_write_bytes -= chunk.len();
438
439        let mut state = self.state.lock_mut().ok_or(ServeError::Cancel)?;
440        state.chunks.push(chunk);
441
442        Ok(())
443    }
444
445    /// Close the stream with an error.
446    pub fn close(self, err: ServeError) -> Result<(), ServeError> {
447        let state = self.state.lock();
448        state.closed.clone()?;
449
450        let mut state = state.into_mut().ok_or(ServeError::Cancel)?;
451        state.closed = Err(err);
452
453        Ok(())
454    }
455}
456
457impl Drop for StreamObjectWriter {
458    // Make sure we fully write the segment, otherwise close it with an error.
459    fn drop(&mut self) {
460        if self.remaining_write_bytes == 0 {
461            return;
462        }
463
464        let state = self.state.lock();
465        if state.closed.is_err() {
466            return;
467        }
468
469        if let Some(mut state) = state.into_mut() {
470            state.closed = Err(ServeError::Size);
471        }
472    }
473}
474
475impl Deref for StreamObjectWriter {
476    type Target = StreamObject;
477
478    fn deref(&self) -> &Self::Target {
479        &self.info
480    }
481}
482
483/// Notified when a segment has new data available.
484#[derive(Clone)]
485pub struct StreamObjectReader {
486    // Modify the segment state.
487    state: State<StreamObjectState>,
488
489    // Immutable segment state.
490    pub info: Arc<StreamObject>,
491
492    // The number of chunks that we've read.
493    // NOTE: Cloned readers inherit this index, but then run in parallel.
494    index: usize,
495}
496
497impl StreamObjectReader {
498    fn new(state: State<StreamObjectState>, info: Arc<StreamObject>) -> Self {
499        Self {
500            state,
501            info,
502            index: 0,
503        }
504    }
505
506    /// Block until the next chunk of bytes is available.
507    pub async fn read(&mut self) -> Result<Option<Bytes>, ServeError> {
508        loop {
509            {
510                let state = self.state.lock();
511
512                if self.index < state.chunks.len() {
513                    let chunk = state.chunks[self.index].clone();
514                    self.index += 1;
515                    return Ok(Some(chunk));
516                }
517
518                state.closed.clone()?;
519                match state.modified() {
520                    Some(notify) => notify,
521                    None => return Ok(None),
522                }
523            }
524            .await; // Try again when the state changes
525        }
526    }
527
528    /// Read all remaining data from the current object, if any.
529    pub async fn read_all(&mut self) -> Result<Bytes, ServeError> {
530        let mut chunks = Vec::new();
531        while let Some(chunk) = self.read().await? {
532            chunks.push(chunk);
533        }
534
535        Ok(Bytes::from(chunks.concat()))
536    }
537}
538
539impl Deref for StreamObjectReader {
540    type Target = StreamObject;
541
542    fn deref(&self) -> &Self::Target {
543        &self.info
544    }
545}