moq_dir/
listing.rs

1use anyhow::Context;
2use bytes::BytesMut;
3use std::collections::{HashSet, VecDeque};
4
5use moq_transport::serve::{
6    ServeError, SubgroupReader, SubgroupWriter, SubgroupsReader, SubgroupsWriter, TrackReader,
7    TrackReaderMode, TrackWriter,
8};
9
10pub struct ListingWriter {
11    track: Option<TrackWriter>,
12    subgroups: Option<SubgroupsWriter>,
13    subgroup: Option<SubgroupWriter>,
14
15    current: HashSet<String>,
16}
17
18impl ListingWriter {
19    pub fn new(track: TrackWriter) -> Self {
20        Self {
21            track: Some(track),
22            subgroups: None,
23            subgroup: None,
24            current: HashSet::new(),
25        }
26    }
27
28    pub fn insert(&mut self, name: String) -> Result<(), ServeError> {
29        if !self.current.insert(name.clone()) {
30            return Err(ServeError::Duplicate);
31        }
32
33        match self.subgroup {
34            // Create a delta if the current subgroup is small enough.
35            Some(ref mut subgroup) if self.current.len() < 2 * subgroup.len() => {
36                let msg = format!("+{name}");
37                subgroup.write(msg.into())?;
38            }
39            // Otherwise create a snapshot with every element.
40            _ => self.subgroup = Some(self.snapshot()?),
41        }
42
43        Ok(())
44    }
45
46    pub fn remove(&mut self, name: &str) -> Result<(), ServeError> {
47        if !self.current.remove(name) {
48            return Err(ServeError::NotFound);
49        }
50
51        match self.subgroup {
52            // Create a delta if the current subgroup is small enough.
53            Some(ref mut subgroup) if self.current.len() < 2 * subgroup.len() => {
54                let msg = format!("-{name}");
55                subgroup.write(msg.into())?;
56            }
57            // Otherwise create a snapshot with every element.
58            _ => self.subgroup = Some(self.snapshot()?),
59        }
60
61        Ok(())
62    }
63
64    fn snapshot(&mut self) -> Result<SubgroupWriter, ServeError> {
65        let mut subgroups = match self.subgroups.take() {
66            Some(subgroups) => subgroups,
67            None => self.track.take().unwrap().groups()?,
68        };
69
70        let priority = 127;
71        let mut subgroup = subgroups.append(priority)?;
72
73        let mut msg = BytesMut::new();
74        for name in &self.current {
75            msg.extend_from_slice(name.as_bytes());
76            msg.extend_from_slice(b"\n");
77        }
78
79        subgroup.write(msg.freeze())?;
80        self.subgroups = Some(subgroups);
81
82        Ok(subgroup)
83    }
84
85    pub fn len(&self) -> usize {
86        self.current.len()
87    }
88
89    pub fn is_empty(&self) -> bool {
90        self.current.is_empty()
91    }
92}
93
94#[derive(Clone)]
95pub enum ListingDelta {
96    Add(String),
97    Rem(String),
98}
99
100#[derive(Clone)]
101pub struct ListingReader {
102    track: TrackReader,
103
104    // Keep track of the current subgroup.
105    subgroups: Option<SubgroupsReader>,
106    subgroup: Option<SubgroupReader>,
107
108    // The current state of the listing.
109    current: HashSet<String>,
110
111    // A list of deltas we need to return
112    deltas: VecDeque<ListingDelta>,
113}
114
115impl ListingReader {
116    pub fn new(track: TrackReader) -> Self {
117        Self {
118            track,
119            subgroups: None,
120            subgroup: None,
121
122            current: HashSet::new(),
123            deltas: VecDeque::new(),
124        }
125    }
126
127    pub async fn next(&mut self) -> anyhow::Result<Option<ListingDelta>> {
128        if let Some(delta) = self.deltas.pop_front() {
129            return Ok(Some(delta));
130        }
131
132        if self.subgroups.is_none() {
133            self.subgroups = match self.track.mode().await? {
134                TrackReaderMode::Subgroups(subgroups) => Some(subgroups),
135                _ => anyhow::bail!("expected subgroups mode"),
136            };
137        };
138
139        if self.subgroup.is_none() {
140            self.subgroup = Some(
141                self.subgroups
142                    .as_mut()
143                    .unwrap()
144                    .next()
145                    .await?
146                    .context("empty track")?,
147            );
148        }
149
150        let mut subgroup_done = false;
151        let mut subgroups_done = false;
152
153        loop {
154            tokio::select! {
155                next = self.subgroups.as_mut().unwrap().next(), if !subgroups_done => {
156                    if let Some(next) = next? {
157                        self.subgroup = Some(next);
158                        subgroup_done = false;
159                    } else {
160                        subgroups_done = true;
161                    }
162                },
163                object = self.subgroup.as_mut().unwrap().read_next(), if !subgroup_done => {
164                    let payload = match object? {
165                        Some(object) => object,
166                        None => {
167                            subgroup_done = true;
168                            continue;
169                        }
170                    };
171
172                    if payload.is_empty() {
173                        anyhow::bail!("empty payload");
174                    } else if self.subgroup.as_mut().unwrap().pos() == 1 {
175                        // This is a full snapshot, not a delta
176                        let set = HashSet::from_iter(payload.split(|&b| b == b'\n').map(|s| String::from_utf8_lossy(s).to_string()));
177
178                        for name in set.difference(&self.current) {
179                            self.deltas.push_back(ListingDelta::Add(name.clone()));
180                        }
181
182                        for name in self.current.difference(&set) {
183                            self.deltas.push_back(ListingDelta::Rem(name.clone()));
184                        }
185
186                        self.current = set;
187
188                        if let Some(delta) = self.deltas.pop_front() {
189                            return Ok(Some(delta));
190                        }
191                    } else if payload[0] == b'+' {
192                        return Ok(Some(ListingDelta::Add(String::from_utf8_lossy(&payload[1..]).to_string())));
193                    } else if payload[0] == b'-' {
194                        return Ok(Some(ListingDelta::Rem(String::from_utf8_lossy(&payload[1..]).to_string())));
195                    } else {
196                        anyhow::bail!("invalid delta: {:?}", payload);
197                    }
198                }
199                else => return Ok(None),
200            }
201        }
202    }
203
204    // If you just want to proxy the track
205    pub fn into_inner(self) -> TrackReader {
206        self.track
207    }
208}