Skip to main content

liminal/protocol/
multiplex.rs

1use std::{
2    collections::{HashMap, hash_map::Entry},
3    fmt::Debug,
4    ops::{Deref, DerefMut},
5};
6
7use super::ProtocolError;
8
9/// Stream identifier carried in every frame header.
10///
11/// Stream 0 is reserved for connection-level control frames. Application
12/// streams use identifiers 1 and above.
13#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
14pub struct StreamId(pub u32);
15
16impl StreamId {
17    /// Reserved connection-level control stream.
18    pub const CONTROL: Self = Self(0);
19
20    /// Return true when this stream is the reserved control stream.
21    #[must_use]
22    pub const fn is_control(self) -> bool {
23        self.0 == Self::CONTROL.0
24    }
25
26    /// Return true when this stream is available for application traffic.
27    #[must_use]
28    pub const fn is_application(self) -> bool {
29        self.0 >= 1
30    }
31}
32
33/// Lifecycle state tracked for an active stream on a connection.
34#[derive(Clone, Copy, Debug, PartialEq, Eq)]
35pub enum StreamState {
36    /// A stream has been opened and is awaiting subscription confirmation.
37    Subscribing,
38    /// A stream is active and can carry application frames.
39    Active,
40    /// A stream is closing and awaiting final cleanup.
41    Closing,
42}
43
44/// Tracks active streams and their current state for a single connection.
45#[derive(Debug, Default)]
46pub struct StreamTable {
47    streams: HashMap<StreamId, StreamState>,
48}
49
50impl StreamTable {
51    /// Create an empty stream table.
52    #[must_use]
53    pub fn new() -> Self {
54        Self::default()
55    }
56
57    /// Insert a new active stream in the requested state.
58    ///
59    /// # Errors
60    ///
61    /// Returns [`ProtocolError::CodecError`] when the stream already exists in
62    /// the table.
63    pub fn insert(&mut self, stream_id: StreamId, state: StreamState) -> Result<(), ProtocolError> {
64        match self.streams.entry(stream_id) {
65            Entry::Vacant(entry) => {
66                entry.insert(state);
67                Ok(())
68            }
69            Entry::Occupied(_) => Err(ProtocolError::codec(format!(
70                "stream {stream_id:?} already exists"
71            ))),
72        }
73    }
74
75    /// Transition an existing stream to a new state.
76    ///
77    /// # Errors
78    ///
79    /// Returns [`ProtocolError::CodecError`] when the stream is not present in
80    /// the table.
81    pub fn transition(
82        &mut self,
83        stream_id: StreamId,
84        new_state: StreamState,
85    ) -> Result<(), ProtocolError> {
86        let Some(state) = self.streams.get_mut(&stream_id) else {
87            return Err(ProtocolError::codec(format!(
88                "stream {stream_id:?} is not active"
89            )));
90        };
91
92        *state = new_state;
93        Ok(())
94    }
95
96    /// Remove a stream from the table and return its previous state.
97    #[must_use]
98    pub fn remove(&mut self, stream_id: StreamId) -> Option<StreamState> {
99        self.streams.remove(&stream_id)
100    }
101
102    /// Return the tracked state for a stream.
103    #[must_use]
104    pub fn get(&self, stream_id: StreamId) -> Option<StreamState> {
105        self.streams.get(&stream_id).copied()
106    }
107
108    /// Return the number of streams currently tracked in the table.
109    #[must_use]
110    pub fn active_count(&self) -> usize {
111        self.streams.len()
112    }
113}
114
115/// Monotonic stream ID allocator for one side of a connection.
116#[derive(Clone, Debug, PartialEq, Eq)]
117pub struct StreamAllocator {
118    next_id: Option<u32>,
119}
120
121/// Fallible stream ID allocation behavior.
122pub trait AllocateStreamId: Debug {
123    /// Allocate the next stream ID and advance the allocator.
124    ///
125    /// # Errors
126    ///
127    /// Returns [`ProtocolError::CodecError`] when this allocator has exhausted
128    /// the `u32` stream ID space for its parity.
129    fn next(&mut self) -> Result<StreamId, ProtocolError>;
130}
131
132impl StreamAllocator {
133    /// Construct a client-side allocator that produces odd stream IDs.
134    #[must_use]
135    pub const fn client() -> Self {
136        Self { next_id: Some(1) }
137    }
138
139    /// Construct a server-side allocator that produces even stream IDs.
140    #[must_use]
141    pub const fn server() -> Self {
142        Self { next_id: Some(2) }
143    }
144
145    fn allocate_next(&mut self) -> Result<StreamId, ProtocolError> {
146        let stream_id = self
147            .next_id
148            .ok_or_else(|| ProtocolError::codec("stream id space exhausted"))?;
149        self.next_id = stream_id.checked_add(2);
150        Ok(StreamId(stream_id))
151    }
152}
153
154impl AllocateStreamId for StreamAllocator {
155    fn next(&mut self) -> Result<StreamId, ProtocolError> {
156        self.allocate_next()
157    }
158}
159
160impl Deref for StreamAllocator {
161    type Target = dyn AllocateStreamId;
162
163    fn deref(&self) -> &Self::Target {
164        self
165    }
166}
167
168impl DerefMut for StreamAllocator {
169    fn deref_mut(&mut self) -> &mut Self::Target {
170        self
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use std::{fmt::Debug, hash::Hash};
177
178    use super::{StreamAllocator, StreamId, StreamState, StreamTable};
179    use crate::protocol::ProtocolError;
180
181    #[test]
182    fn stream_id_trait_bounds_are_available() {
183        fn assert_traits<T: Debug + Clone + Copy + PartialEq + Eq + Hash + PartialOrd + Ord>() {}
184
185        assert_traits::<StreamId>();
186    }
187
188    #[test]
189    fn stream_zero_is_control_and_not_application() {
190        let stream_id = StreamId(0);
191
192        assert!(stream_id.is_control());
193        assert!(!stream_id.is_application());
194    }
195
196    #[test]
197    fn stream_one_is_application_and_not_control() {
198        let stream_id = StreamId(1);
199
200        assert!(stream_id.is_application());
201        assert!(!stream_id.is_control());
202    }
203
204    #[test]
205    fn stream_state_has_exact_required_variants() {
206        fn state_name(state: StreamState) -> &'static str {
207            match state {
208                StreamState::Subscribing => "subscribing",
209                StreamState::Active => "active",
210                StreamState::Closing => "closing",
211            }
212        }
213
214        let variants = [
215            StreamState::Subscribing,
216            StreamState::Active,
217            StreamState::Closing,
218        ];
219
220        assert_eq!(variants.len(), 3);
221        assert_eq!(state_name(StreamState::Subscribing), "subscribing");
222        assert_eq!(state_name(StreamState::Active), "active");
223        assert_eq!(state_name(StreamState::Closing), "closing");
224    }
225
226    #[test]
227    fn insert_adds_stream_and_counts_it() -> Result<(), ProtocolError> {
228        let mut table = StreamTable::new();
229
230        table.insert(StreamId(1), StreamState::Subscribing)?;
231
232        assert_eq!(table.get(StreamId(1)), Some(StreamState::Subscribing));
233        assert_eq!(table.active_count(), 1);
234        Ok(())
235    }
236
237    #[test]
238    fn duplicate_insert_returns_error_and_preserves_state() -> Result<(), ProtocolError> {
239        let mut table = StreamTable::new();
240        table.insert(StreamId(1), StreamState::Subscribing)?;
241
242        let result = table.insert(StreamId(1), StreamState::Active);
243
244        assert!(matches!(result, Err(ProtocolError::CodecError { .. })));
245        assert_eq!(table.get(StreamId(1)), Some(StreamState::Subscribing));
246        assert_eq!(table.active_count(), 1);
247        Ok(())
248    }
249
250    #[test]
251    fn transition_updates_existing_stream_state() -> Result<(), ProtocolError> {
252        let mut table = StreamTable::new();
253        table.insert(StreamId(1), StreamState::Subscribing)?;
254
255        table.transition(StreamId(1), StreamState::Active)?;
256
257        assert_eq!(table.get(StreamId(1)), Some(StreamState::Active));
258        Ok(())
259    }
260
261    #[test]
262    fn transition_on_missing_stream_returns_protocol_error() {
263        let mut table = StreamTable::new();
264
265        let result = table.transition(StreamId(1), StreamState::Active);
266
267        assert!(matches!(result, Err(ProtocolError::CodecError { .. })));
268    }
269
270    #[test]
271    fn remove_deletes_stream_and_updates_count() -> Result<(), ProtocolError> {
272        let mut table = StreamTable::new();
273        table.insert(StreamId(1), StreamState::Active)?;
274        table.insert(StreamId(3), StreamState::Closing)?;
275
276        assert_eq!(table.remove(StreamId(1)), Some(StreamState::Active));
277
278        assert_eq!(table.get(StreamId(1)), None);
279        assert_eq!(table.active_count(), 1);
280        Ok(())
281    }
282
283    #[test]
284    fn client_allocator_produces_odd_ids() -> Result<(), ProtocolError> {
285        let mut allocator = StreamAllocator::client();
286
287        assert_eq!(allocator.next()?, StreamId(1));
288        assert_eq!(allocator.next()?, StreamId(3));
289        assert_eq!(allocator.next()?, StreamId(5));
290        assert_eq!(allocator.next()?, StreamId(7));
291        Ok(())
292    }
293
294    #[test]
295    fn server_allocator_produces_even_ids() -> Result<(), ProtocolError> {
296        let mut allocator = StreamAllocator::server();
297
298        assert_eq!(allocator.next()?, StreamId(2));
299        assert_eq!(allocator.next()?, StreamId(4));
300        assert_eq!(allocator.next()?, StreamId(6));
301        assert_eq!(allocator.next()?, StreamId(8));
302        Ok(())
303    }
304
305    #[test]
306    fn client_allocator_errors_after_final_odd_stream_id() -> Result<(), ProtocolError> {
307        let mut allocator = StreamAllocator {
308            next_id: Some(u32::MAX),
309        };
310
311        assert_eq!(allocator.next()?, StreamId(u32::MAX));
312        assert!(matches!(
313            allocator.next(),
314            Err(ProtocolError::CodecError { .. })
315        ));
316        Ok(())
317    }
318
319    #[test]
320    fn server_allocator_errors_after_final_even_stream_id() -> Result<(), ProtocolError> {
321        let mut allocator = StreamAllocator {
322            next_id: Some(u32::MAX - 1),
323        };
324
325        assert_eq!(allocator.next()?, StreamId(u32::MAX - 1));
326        assert!(matches!(
327            allocator.next(),
328            Err(ProtocolError::CodecError { .. })
329        ));
330        Ok(())
331    }
332
333    #[test]
334    fn allocator_never_recycles_ids() -> Result<(), ProtocolError> {
335        let mut allocator = StreamAllocator::client();
336        let first = allocator.next()?;
337        let second = allocator.next()?;
338        let third = allocator.next()?;
339
340        assert!(first < second);
341        assert!(second < third);
342        Ok(())
343    }
344}