ant_quic/connection/streams/
mod.rs

1use std::{
2    collections::{BinaryHeap, hash_map},
3    io,
4};
5
6use bytes::Bytes;
7use thiserror::Error;
8use tracing::{trace, warn};
9
10use super::spaces::{Retransmits, ThinRetransmits};
11use crate::{
12    Dir, StreamId, VarInt,
13    connection::streams::state::{get_or_insert_recv, get_or_insert_send},
14    frame,
15};
16
17mod recv;
18use recv::Recv;
19pub use recv::{Chunks, ReadError, ReadableError};
20
21mod send;
22pub(crate) use send::{ByteSlice, BytesArray};
23use send::{BytesSource, Send, SendState};
24pub use send::{FinishError, WriteError, Written};
25
26mod state;
27pub use state::StreamsState;
28
29/// Access to streams
30pub struct Streams<'a> {
31    pub(super) state: &'a mut StreamsState,
32    pub(super) conn_state: &'a super::State,
33}
34
35impl<'a> Streams<'a> {
36    #[cfg(fuzzing)]
37    pub fn new(state: &'a mut StreamsState, conn_state: &'a super::State) -> Self {
38        Self { state, conn_state }
39    }
40
41    /// Open a single stream if possible
42    ///
43    /// Returns `None` if the streams in the given direction are currently exhausted.
44    pub fn open(&mut self, dir: Dir) -> Option<StreamId> {
45        if self.conn_state.is_closed() {
46            return None;
47        }
48
49        if self.state.next[dir as usize] >= self.state.max[dir as usize] {
50            return None;
51        }
52
53        self.state.next[dir as usize] += 1;
54        let id = StreamId::new(self.state.side, dir, self.state.next[dir as usize] - 1);
55        self.state.insert(false, id);
56        self.state.send_streams += 1;
57        Some(id)
58    }
59
60    /// Accept a remotely initiated stream of a certain directionality, if possible
61    ///
62    /// Returns `None` if there are no new incoming streams for this connection.
63    pub fn accept(&mut self, dir: Dir) -> Option<StreamId> {
64        if self.state.next_remote[dir as usize] == self.state.next_reported_remote[dir as usize] {
65            return None;
66        }
67
68        let x = self.state.next_reported_remote[dir as usize];
69        self.state.next_reported_remote[dir as usize] = x + 1;
70        if dir == Dir::Bi {
71            self.state.send_streams += 1;
72        }
73
74        Some(StreamId::new(!self.state.side, dir, x))
75    }
76
77    #[cfg(fuzzing)]
78    pub fn state(&mut self) -> &mut StreamsState {
79        self.state
80    }
81
82    /// The number of streams that may have unacknowledged data.
83    pub fn send_streams(&self) -> usize {
84        self.state.send_streams
85    }
86
87    /// The number of remotely initiated open streams of a certain directionality.
88    pub fn remote_open_streams(&self, dir: Dir) -> u64 {
89        self.state.next_remote[dir as usize]
90            - (self.state.max_remote[dir as usize]
91                - self.state.allocated_remote_count[dir as usize])
92    }
93}
94
95/// Access to streams
96pub struct RecvStream<'a> {
97    pub(super) id: StreamId,
98    pub(super) state: &'a mut StreamsState,
99    pub(super) pending: &'a mut Retransmits,
100}
101
102impl RecvStream<'_> {
103    /// Read from the given recv stream
104    ///
105    /// `max_length` limits the maximum size of the returned `Bytes` value.
106    /// `ordered` ensures the returned chunk's offset is sequential.
107    ///
108    /// Yields `Ok(None)` if the stream was finished. Otherwise, yields a segment of data and its
109    /// offset in the stream.
110    ///
111    /// Unordered reads can improve performance when packet loss occurs, but ordered reads
112    /// on streams that have seen previous unordered reads will return `ReadError::IllegalOrderedRead`.
113    pub fn read(&mut self, ordered: bool) -> Result<Chunks, ReadableError> {
114        if self.state.conn_closed() {
115            return Err(ReadableError::ConnectionClosed);
116        }
117
118        Chunks::new(self.id, ordered, self.state, self.pending)
119    }
120
121    /// Stop accepting data on the given receive stream
122    ///
123    /// Discards unread data and notifies the peer to stop transmitting.
124    pub fn stop(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
125        if self.state.conn_closed() {
126            return Err(ClosedStream { _private: () });
127        }
128
129        let mut entry = match self.state.recv.entry(self.id) {
130            hash_map::Entry::Occupied(s) => s,
131            hash_map::Entry::Vacant(_) => return Err(ClosedStream { _private: () }),
132        };
133        let stream = get_or_insert_recv(self.state.stream_receive_window)(entry.get_mut());
134
135        let (read_credits, stop_sending) = stream.stop()?;
136        if stop_sending.should_transmit() {
137            self.pending.stop_sending.push(frame::StopSending {
138                id: self.id,
139                error_code,
140            });
141        }
142
143        // Clean up stream state if possible
144        if !stream.final_offset_unknown() {
145            let recv = entry.remove().expect("must have recv when stopping");
146            self.state.stream_recv_freed(self.id, recv);
147        }
148
149        // Update flow control if needed
150        if self.state.add_read_credits(read_credits).should_transmit() {
151            self.pending.max_data = true;
152        }
153
154        Ok(())
155    }
156
157    /// Check whether this stream has been reset by the peer
158    ///
159    /// Returns the reset error code if the stream was reset.
160    pub fn received_reset(&mut self) -> Result<Option<VarInt>, ClosedStream> {
161        if self.state.conn_closed() {
162            return Err(ClosedStream { _private: () });
163        }
164
165        let hash_map::Entry::Occupied(entry) = self.state.recv.entry(self.id) else {
166            return Err(ClosedStream { _private: () });
167        };
168
169        let Some(s) = entry.get().as_ref().and_then(|s| s.as_open_recv()) else {
170            return Ok(None);
171        };
172
173        if s.stopped {
174            return Err(ClosedStream { _private: () });
175        }
176
177        let Some(code) = s.reset_code() else {
178            return Ok(None);
179        };
180
181        // Clean up state after application observes the reset
182        let (_, recv) = entry.remove_entry();
183        self.state
184            .stream_recv_freed(self.id, recv.expect("must have recv on reset"));
185        self.state.queue_max_stream_id(self.pending);
186
187        Ok(Some(code))
188    }
189}
190
191/// Access to streams
192pub struct SendStream<'a> {
193    pub(super) id: StreamId,
194    pub(super) state: &'a mut StreamsState,
195    pub(super) pending: &'a mut Retransmits,
196    pub(super) conn_state: &'a super::State,
197}
198
199#[allow(clippy::needless_lifetimes)] // Needed for cfg(fuzzing)
200impl<'a> SendStream<'a> {
201    #[cfg(fuzzing)]
202    pub fn new(
203        id: StreamId,
204        state: &'a mut StreamsState,
205        pending: &'a mut Retransmits,
206        conn_state: &'a super::State,
207    ) -> Self {
208        Self {
209            id,
210            state,
211            pending,
212            conn_state,
213        }
214    }
215
216    /// Send data on the given stream
217    ///
218    /// Returns the number of bytes successfully written.
219    pub fn write(&mut self, data: &[u8]) -> Result<usize, WriteError> {
220        Ok(self.write_source(&mut ByteSlice::from_slice(data))?.bytes)
221    }
222
223    /// Send data on the given stream
224    ///
225    /// Returns the number of bytes and chunks successfully written.
226    /// Note that this method might also write a partial chunk. In this case
227    /// [`Written::chunks`] will not count this chunk as fully written. However
228    /// the chunk will be advanced and contain only non-written data after the call.
229    pub fn write_chunks(&mut self, data: &mut [Bytes]) -> Result<Written, WriteError> {
230        self.write_source(&mut BytesArray::from_chunks(data))
231    }
232
233    fn write_source<B: BytesSource>(&mut self, source: &mut B) -> Result<Written, WriteError> {
234        if self.conn_state.is_closed() {
235            trace!(%self.id, "write blocked; connection draining");
236            return Err(WriteError::Blocked);
237        }
238
239        let limit = self.state.write_limit();
240
241        let max_send_data = self.state.max_send_data(self.id);
242
243        let stream = self
244            .state
245            .send
246            .get_mut(&self.id)
247            .map(get_or_insert_send(max_send_data))
248            .ok_or(WriteError::ClosedStream)?;
249
250        if limit == 0 {
251            trace!(
252                stream = %self.id, max_data = self.state.max_data, data_sent = self.state.data_sent,
253                "write blocked by connection-level flow control or send window"
254            );
255            if !stream.connection_blocked {
256                stream.connection_blocked = true;
257                self.state.connection_blocked.push(self.id);
258            }
259            return Err(WriteError::Blocked);
260        }
261
262        let was_pending = stream.is_pending();
263        let written = stream.write(source, limit)?;
264        self.state.data_sent += written.bytes as u64;
265        self.state.unacked_data += written.bytes as u64;
266        trace!(stream = %self.id, "wrote {} bytes", written.bytes);
267        if !was_pending {
268            self.state.pending.push_pending(self.id, stream.priority);
269        }
270        Ok(written)
271    }
272
273    /// Check if this stream was stopped, get the reason if it was
274    pub fn stopped(&self) -> Result<Option<VarInt>, ClosedStream> {
275        match self.state.send.get(&self.id).as_ref() {
276            Some(Some(s)) => Ok(s.stop_reason),
277            Some(None) => Ok(None),
278            None => Err(ClosedStream { _private: () }),
279        }
280    }
281
282    /// Finish a send stream, signalling that no more data will be sent.
283    ///
284    /// If this fails, no [`StreamEvent::Finished`] will be generated.
285    ///
286    /// [`StreamEvent::Finished`]: crate::StreamEvent::Finished
287    pub fn finish(&mut self) -> Result<(), FinishError> {
288        let max_send_data = self.state.max_send_data(self.id);
289        let stream = self
290            .state
291            .send
292            .get_mut(&self.id)
293            .map(get_or_insert_send(max_send_data))
294            .ok_or(FinishError::ClosedStream)?;
295
296        let was_pending = stream.is_pending();
297        stream.finish()?;
298        if !was_pending {
299            self.state.pending.push_pending(self.id, stream.priority);
300        }
301
302        Ok(())
303    }
304
305    /// Abandon transmitting data on a stream
306    ///
307    /// # Panics
308    /// - when applied to a receive stream
309    pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
310        let max_send_data = self.state.max_send_data(self.id);
311        let stream = self
312            .state
313            .send
314            .get_mut(&self.id)
315            .map(get_or_insert_send(max_send_data))
316            .ok_or(ClosedStream { _private: () })?;
317
318        if matches!(stream.state, SendState::ResetSent) {
319            // Redundant reset call
320            return Err(ClosedStream { _private: () });
321        }
322
323        // Restore the portion of the send window consumed by the data that we aren't about to
324        // send. We leave flow control alone because the peer's responsible for issuing additional
325        // credit based on the final offset communicated in the RESET_STREAM frame we send.
326        self.state.unacked_data -= stream.pending.unacked();
327        stream.reset();
328        self.pending.reset_stream.push((self.id, error_code));
329
330        // Don't reopen an already-closed stream we haven't forgotten yet
331        Ok(())
332    }
333
334    /// Set the priority of a stream
335    ///
336    /// # Panics
337    /// - when applied to a receive stream
338    pub fn set_priority(&mut self, priority: i32) -> Result<(), ClosedStream> {
339        let max_send_data = self.state.max_send_data(self.id);
340        let stream = self
341            .state
342            .send
343            .get_mut(&self.id)
344            .map(get_or_insert_send(max_send_data))
345            .ok_or(ClosedStream { _private: () })?;
346
347        stream.priority = priority;
348        Ok(())
349    }
350
351    /// Get the priority of a stream
352    ///
353    /// # Panics
354    /// - when applied to a receive stream
355    pub fn priority(&self) -> Result<i32, ClosedStream> {
356        let stream = self
357            .state
358            .send
359            .get(&self.id)
360            .ok_or(ClosedStream { _private: () })?;
361
362        Ok(stream.as_ref().map(|s| s.priority).unwrap_or_default())
363    }
364}
365
366/// A queue of streams with pending outgoing data, sorted by priority
367struct PendingStreamsQueue {
368    streams: BinaryHeap<PendingStream>,
369    /// The next stream to write out. This is `Some` when writing a stream is
370    /// interrupted while the stream still has some pending data.
371    next: Option<PendingStream>,
372    /// A monotonically decreasing counter for round-robin scheduling of streams with the same priority
373    recency: u64,
374}
375
376impl PendingStreamsQueue {
377    fn new() -> Self {
378        Self {
379            streams: BinaryHeap::new(),
380            next: None,
381            recency: u64::MAX,
382        }
383    }
384
385    /// Reinsert a stream that was pending and still contains unsent data.
386    fn reinsert_pending(&mut self, id: StreamId, priority: i32) {
387        if self.next.is_some() {
388            warn!("Attempting to reinsert a pending stream when next is already set");
389            return;
390        }
391
392        self.next = Some(PendingStream {
393            priority,
394            recency: self.recency,
395            id,
396        });
397    }
398
399    /// Push a pending stream ID with the given priority
400    fn push_pending(&mut self, id: StreamId, priority: i32) {
401        // Decrement recency to ensure round-robin scheduling for streams of the same priority
402        self.recency = self.recency.saturating_sub(1);
403        self.streams.push(PendingStream {
404            priority,
405            recency: self.recency,
406            id,
407        });
408    }
409
410    /// Pop the highest priority stream
411    fn pop(&mut self) -> Option<PendingStream> {
412        self.next.take().or_else(|| self.streams.pop())
413    }
414
415    /// Clear all pending streams
416    fn clear(&mut self) {
417        self.next = None;
418        self.streams.clear();
419    }
420
421    /// Iterate over all pending streams
422    fn iter(&self) -> impl Iterator<Item = &PendingStream> {
423        self.next.iter().chain(self.streams.iter())
424    }
425
426    #[cfg(test)]
427    fn len(&self) -> usize {
428        self.streams.len() + self.next.is_some() as usize
429    }
430}
431
432/// The [`StreamId`] of a stream with pending data queued, ordered by its priority and recency
433#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
434struct PendingStream {
435    /// The priority of the stream
436    // Note that this field should be kept above the `recency` field, in order for the `Ord` derive to be correct
437    // (See https://doc.rust-lang.org/stable/std/cmp/trait.Ord.html#derivable)
438    priority: i32,
439    /// A tie-breaker for streams of the same priority, used to improve fairness by implementing round-robin scheduling:
440    /// Larger values are prioritized, so it is initialised to `u64::MAX`, and when a stream writes data, we know
441    /// that it currently has the highest recency value, so it is deprioritized by setting its recency to 1 less than the
442    /// previous lowest recency value, such that all other streams of this priority will get processed once before we get back
443    /// round to this one
444    recency: u64,
445    /// The ID of the stream
446    // The way this type is used ensures that every instance has a unique `recency` value, so this field should be kept below
447    // the `priority` and `recency` fields, so that it does not interfere with the behaviour of the `Ord` derive
448    id: StreamId,
449}
450
451/// Application events about streams
452#[derive(Debug, PartialEq, Eq)]
453pub enum StreamEvent {
454    /// One or more new streams has been opened and might be readable
455    Opened {
456        /// Directionality for which streams have been opened
457        dir: Dir,
458    },
459    /// A currently open stream likely has data or errors waiting to be read
460    Readable {
461        /// Which stream is now readable
462        id: StreamId,
463    },
464    /// A formerly write-blocked stream might be ready for a write or have been stopped
465    ///
466    /// Only generated for streams that are currently open.
467    Writable {
468        /// Which stream is now writable
469        id: StreamId,
470    },
471    /// A finished stream has been fully acknowledged or stopped
472    Finished {
473        /// Which stream has been finished
474        id: StreamId,
475    },
476    /// The peer asked us to stop sending on an outgoing stream
477    Stopped {
478        /// Which stream has been stopped
479        id: StreamId,
480        /// Error code supplied by the peer
481        error_code: VarInt,
482    },
483    /// At least one new stream of a certain directionality may be opened
484    Available {
485        /// Directionality for which streams are newly available
486        dir: Dir,
487    },
488}
489
490/// Indicates whether a frame needs to be transmitted
491///
492/// This type wraps around bool and uses the `#[must_use]` attribute in order
493/// to prevent accidental loss of the frame transmission requirement.
494#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
495#[must_use = "A frame might need to be enqueued"]
496pub struct ShouldTransmit(bool);
497
498impl ShouldTransmit {
499    /// Returns whether a frame should be transmitted
500    pub fn should_transmit(self) -> bool {
501        self.0
502    }
503}
504
505/// Error indicating that a stream has not been opened or has already been finished or reset
506#[derive(Debug, Default, Error, Clone, PartialEq, Eq)]
507#[error("closed stream")]
508pub struct ClosedStream {
509    _private: (),
510}
511
512impl From<ClosedStream> for io::Error {
513    fn from(x: ClosedStream) -> Self {
514        Self::new(io::ErrorKind::NotConnected, x)
515    }
516}
517
518#[derive(Debug, Copy, Clone, Eq, PartialEq)]
519enum StreamHalf {
520    Send,
521    Recv,
522}