1use std::{
2 collections::{HashMap, hash_map::Entry},
3 fmt::Debug,
4 ops::{Deref, DerefMut},
5};
6
7use super::ProtocolError;
8
9#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
14pub struct StreamId(pub u32);
15
16impl StreamId {
17 pub const CONTROL: Self = Self(0);
19
20 #[must_use]
22 pub const fn is_control(self) -> bool {
23 self.0 == Self::CONTROL.0
24 }
25
26 #[must_use]
28 pub const fn is_application(self) -> bool {
29 self.0 >= 1
30 }
31}
32
33#[derive(Clone, Copy, Debug, PartialEq, Eq)]
35pub enum StreamState {
36 Subscribing,
38 Active,
40 Closing,
42}
43
44#[derive(Debug, Default)]
46pub struct StreamTable {
47 streams: HashMap<StreamId, StreamState>,
48}
49
50impl StreamTable {
51 #[must_use]
53 pub fn new() -> Self {
54 Self::default()
55 }
56
57 pub fn insert(&mut self, stream_id: StreamId, state: StreamState) -> Result<(), ProtocolError> {
64 match self.streams.entry(stream_id) {
65 Entry::Vacant(entry) => {
66 entry.insert(state);
67 Ok(())
68 }
69 Entry::Occupied(_) => Err(ProtocolError::codec(format!(
70 "stream {stream_id:?} already exists"
71 ))),
72 }
73 }
74
75 pub fn transition(
82 &mut self,
83 stream_id: StreamId,
84 new_state: StreamState,
85 ) -> Result<(), ProtocolError> {
86 let Some(state) = self.streams.get_mut(&stream_id) else {
87 return Err(ProtocolError::codec(format!(
88 "stream {stream_id:?} is not active"
89 )));
90 };
91
92 *state = new_state;
93 Ok(())
94 }
95
96 #[must_use]
98 pub fn remove(&mut self, stream_id: StreamId) -> Option<StreamState> {
99 self.streams.remove(&stream_id)
100 }
101
102 #[must_use]
104 pub fn get(&self, stream_id: StreamId) -> Option<StreamState> {
105 self.streams.get(&stream_id).copied()
106 }
107
108 #[must_use]
110 pub fn active_count(&self) -> usize {
111 self.streams.len()
112 }
113}
114
115#[derive(Clone, Debug, PartialEq, Eq)]
117pub struct StreamAllocator {
118 next_id: Option<u32>,
119}
120
121pub trait AllocateStreamId: Debug {
123 fn next(&mut self) -> Result<StreamId, ProtocolError>;
130}
131
132impl StreamAllocator {
133 #[must_use]
135 pub const fn client() -> Self {
136 Self { next_id: Some(1) }
137 }
138
139 #[must_use]
141 pub const fn server() -> Self {
142 Self { next_id: Some(2) }
143 }
144
145 fn allocate_next(&mut self) -> Result<StreamId, ProtocolError> {
146 let stream_id = self
147 .next_id
148 .ok_or_else(|| ProtocolError::codec("stream id space exhausted"))?;
149 self.next_id = stream_id.checked_add(2);
150 Ok(StreamId(stream_id))
151 }
152}
153
154impl AllocateStreamId for StreamAllocator {
155 fn next(&mut self) -> Result<StreamId, ProtocolError> {
156 self.allocate_next()
157 }
158}
159
160impl Deref for StreamAllocator {
161 type Target = dyn AllocateStreamId;
162
163 fn deref(&self) -> &Self::Target {
164 self
165 }
166}
167
168impl DerefMut for StreamAllocator {
169 fn deref_mut(&mut self) -> &mut Self::Target {
170 self
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use std::{fmt::Debug, hash::Hash};
177
178 use super::{StreamAllocator, StreamId, StreamState, StreamTable};
179 use crate::protocol::ProtocolError;
180
181 #[test]
182 fn stream_id_trait_bounds_are_available() {
183 fn assert_traits<T: Debug + Clone + Copy + PartialEq + Eq + Hash + PartialOrd + Ord>() {}
184
185 assert_traits::<StreamId>();
186 }
187
188 #[test]
189 fn stream_zero_is_control_and_not_application() {
190 let stream_id = StreamId(0);
191
192 assert!(stream_id.is_control());
193 assert!(!stream_id.is_application());
194 }
195
196 #[test]
197 fn stream_one_is_application_and_not_control() {
198 let stream_id = StreamId(1);
199
200 assert!(stream_id.is_application());
201 assert!(!stream_id.is_control());
202 }
203
204 #[test]
205 fn stream_state_has_exact_required_variants() {
206 fn state_name(state: StreamState) -> &'static str {
207 match state {
208 StreamState::Subscribing => "subscribing",
209 StreamState::Active => "active",
210 StreamState::Closing => "closing",
211 }
212 }
213
214 let variants = [
215 StreamState::Subscribing,
216 StreamState::Active,
217 StreamState::Closing,
218 ];
219
220 assert_eq!(variants.len(), 3);
221 assert_eq!(state_name(StreamState::Subscribing), "subscribing");
222 assert_eq!(state_name(StreamState::Active), "active");
223 assert_eq!(state_name(StreamState::Closing), "closing");
224 }
225
226 #[test]
227 fn insert_adds_stream_and_counts_it() -> Result<(), ProtocolError> {
228 let mut table = StreamTable::new();
229
230 table.insert(StreamId(1), StreamState::Subscribing)?;
231
232 assert_eq!(table.get(StreamId(1)), Some(StreamState::Subscribing));
233 assert_eq!(table.active_count(), 1);
234 Ok(())
235 }
236
237 #[test]
238 fn duplicate_insert_returns_error_and_preserves_state() -> Result<(), ProtocolError> {
239 let mut table = StreamTable::new();
240 table.insert(StreamId(1), StreamState::Subscribing)?;
241
242 let result = table.insert(StreamId(1), StreamState::Active);
243
244 assert!(matches!(result, Err(ProtocolError::CodecError { .. })));
245 assert_eq!(table.get(StreamId(1)), Some(StreamState::Subscribing));
246 assert_eq!(table.active_count(), 1);
247 Ok(())
248 }
249
250 #[test]
251 fn transition_updates_existing_stream_state() -> Result<(), ProtocolError> {
252 let mut table = StreamTable::new();
253 table.insert(StreamId(1), StreamState::Subscribing)?;
254
255 table.transition(StreamId(1), StreamState::Active)?;
256
257 assert_eq!(table.get(StreamId(1)), Some(StreamState::Active));
258 Ok(())
259 }
260
261 #[test]
262 fn transition_on_missing_stream_returns_protocol_error() {
263 let mut table = StreamTable::new();
264
265 let result = table.transition(StreamId(1), StreamState::Active);
266
267 assert!(matches!(result, Err(ProtocolError::CodecError { .. })));
268 }
269
270 #[test]
271 fn remove_deletes_stream_and_updates_count() -> Result<(), ProtocolError> {
272 let mut table = StreamTable::new();
273 table.insert(StreamId(1), StreamState::Active)?;
274 table.insert(StreamId(3), StreamState::Closing)?;
275
276 assert_eq!(table.remove(StreamId(1)), Some(StreamState::Active));
277
278 assert_eq!(table.get(StreamId(1)), None);
279 assert_eq!(table.active_count(), 1);
280 Ok(())
281 }
282
283 #[test]
284 fn client_allocator_produces_odd_ids() -> Result<(), ProtocolError> {
285 let mut allocator = StreamAllocator::client();
286
287 assert_eq!(allocator.next()?, StreamId(1));
288 assert_eq!(allocator.next()?, StreamId(3));
289 assert_eq!(allocator.next()?, StreamId(5));
290 assert_eq!(allocator.next()?, StreamId(7));
291 Ok(())
292 }
293
294 #[test]
295 fn server_allocator_produces_even_ids() -> Result<(), ProtocolError> {
296 let mut allocator = StreamAllocator::server();
297
298 assert_eq!(allocator.next()?, StreamId(2));
299 assert_eq!(allocator.next()?, StreamId(4));
300 assert_eq!(allocator.next()?, StreamId(6));
301 assert_eq!(allocator.next()?, StreamId(8));
302 Ok(())
303 }
304
305 #[test]
306 fn client_allocator_errors_after_final_odd_stream_id() -> Result<(), ProtocolError> {
307 let mut allocator = StreamAllocator {
308 next_id: Some(u32::MAX),
309 };
310
311 assert_eq!(allocator.next()?, StreamId(u32::MAX));
312 assert!(matches!(
313 allocator.next(),
314 Err(ProtocolError::CodecError { .. })
315 ));
316 Ok(())
317 }
318
319 #[test]
320 fn server_allocator_errors_after_final_even_stream_id() -> Result<(), ProtocolError> {
321 let mut allocator = StreamAllocator {
322 next_id: Some(u32::MAX - 1),
323 };
324
325 assert_eq!(allocator.next()?, StreamId(u32::MAX - 1));
326 assert!(matches!(
327 allocator.next(),
328 Err(ProtocolError::CodecError { .. })
329 ));
330 Ok(())
331 }
332
333 #[test]
334 fn allocator_never_recycles_ids() -> Result<(), ProtocolError> {
335 let mut allocator = StreamAllocator::client();
336 let first = allocator.next()?;
337 let second = allocator.next()?;
338 let third = allocator.next()?;
339
340 assert!(first < second);
341 assert!(second < third);
342 Ok(())
343 }
344}