ant_quic/connection/streams/
mod.rs1use std::{
2 collections::{BinaryHeap, hash_map},
3 io,
4};
5
6use bytes::Bytes;
7use thiserror::Error;
8use tracing::{trace, warn};
9
10use super::spaces::{Retransmits, ThinRetransmits};
11use crate::{
12 Dir, StreamId, VarInt,
13 connection::streams::state::{get_or_insert_recv, get_or_insert_send},
14 frame,
15};
16
17mod recv;
18use recv::Recv;
19pub use recv::{Chunks, ReadError, ReadableError};
20
21mod send;
22pub(crate) use send::{ByteSlice, BytesArray};
23use send::{BytesSource, Send, SendState};
24pub use send::{FinishError, WriteError, Written};
25
26mod state;
27#[allow(unreachable_pub)] pub use state::StreamsState;
29
30pub struct Streams<'a> {
32 pub(super) state: &'a mut StreamsState,
33 pub(super) conn_state: &'a super::State,
34}
35
36impl<'a> Streams<'a> {
37 #[cfg(fuzzing)]
38 pub fn new(state: &'a mut StreamsState, conn_state: &'a super::State) -> Self {
39 Self { state, conn_state }
40 }
41
42 pub fn open(&mut self, dir: Dir) -> Option<StreamId> {
46 if self.conn_state.is_closed() {
47 return None;
48 }
49
50 if self.state.next[dir as usize] >= self.state.max[dir as usize] {
51 return None;
52 }
53
54 self.state.next[dir as usize] += 1;
55 let id = StreamId::new(self.state.side, dir, self.state.next[dir as usize] - 1);
56 self.state.insert(false, id);
57 self.state.send_streams += 1;
58 Some(id)
59 }
60
61 pub fn accept(&mut self, dir: Dir) -> Option<StreamId> {
65 if self.state.next_remote[dir as usize] == self.state.next_reported_remote[dir as usize] {
66 return None;
67 }
68
69 let x = self.state.next_reported_remote[dir as usize];
70 self.state.next_reported_remote[dir as usize] = x + 1;
71 if dir == Dir::Bi {
72 self.state.send_streams += 1;
73 }
74
75 Some(StreamId::new(!self.state.side, dir, x))
76 }
77
78 #[cfg(fuzzing)]
79 pub fn state(&mut self) -> &mut StreamsState {
80 self.state
81 }
82
83 pub fn send_streams(&self) -> usize {
85 self.state.send_streams
86 }
87
88 pub fn remote_open_streams(&self, dir: Dir) -> u64 {
90 self.state.next_remote[dir as usize]
91 - (self.state.max_remote[dir as usize]
92 - self.state.allocated_remote_count[dir as usize])
93 }
94}
95
96pub struct RecvStream<'a> {
98 pub(super) id: StreamId,
99 pub(super) state: &'a mut StreamsState,
100 pub(super) pending: &'a mut Retransmits,
101}
102
103impl RecvStream<'_> {
104 pub fn read(&mut self, ordered: bool) -> Result<Chunks, ReadableError> {
115 if self.state.conn_closed() {
116 return Err(ReadableError::ConnectionClosed);
117 }
118
119 Chunks::new(self.id, ordered, self.state, self.pending)
120 }
121
122 pub fn stop(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
126 if self.state.conn_closed() {
127 return Err(ClosedStream { _private: () });
128 }
129
130 let mut entry = match self.state.recv.entry(self.id) {
131 hash_map::Entry::Occupied(s) => s,
132 hash_map::Entry::Vacant(_) => return Err(ClosedStream { _private: () }),
133 };
134 let stream = get_or_insert_recv(self.state.stream_receive_window)(entry.get_mut());
135
136 let (read_credits, stop_sending) = stream.stop()?;
137 if stop_sending.should_transmit() {
138 self.pending.stop_sending.push(frame::StopSending {
139 id: self.id,
140 error_code,
141 });
142 }
143
144 if !stream.final_offset_unknown() {
146 let recv = entry.remove().expect("must have recv when stopping");
147 self.state.stream_recv_freed(self.id, recv);
148 }
149
150 if self.state.add_read_credits(read_credits).should_transmit() {
152 self.pending.max_data = true;
153 }
154
155 Ok(())
156 }
157
158 pub fn received_reset(&mut self) -> Result<Option<VarInt>, ClosedStream> {
162 if self.state.conn_closed() {
163 return Err(ClosedStream { _private: () });
164 }
165
166 let hash_map::Entry::Occupied(entry) = self.state.recv.entry(self.id) else {
167 return Err(ClosedStream { _private: () });
168 };
169
170 let Some(s) = entry.get().as_ref().and_then(|s| s.as_open_recv()) else {
171 return Ok(None);
172 };
173
174 if s.stopped {
175 return Err(ClosedStream { _private: () });
176 }
177
178 let Some(code) = s.reset_code() else {
179 return Ok(None);
180 };
181
182 let (_, recv) = entry.remove_entry();
184 self.state
185 .stream_recv_freed(self.id, recv.expect("must have recv on reset"));
186 self.state.queue_max_stream_id(self.pending);
187
188 Ok(Some(code))
189 }
190}
191
192pub struct SendStream<'a> {
194 pub(super) id: StreamId,
195 pub(super) state: &'a mut StreamsState,
196 pub(super) pending: &'a mut Retransmits,
197 pub(super) conn_state: &'a super::State,
198}
199
200#[allow(clippy::needless_lifetimes)] impl<'a> SendStream<'a> {
202 #[cfg(fuzzing)]
203 pub fn new(
204 id: StreamId,
205 state: &'a mut StreamsState,
206 pending: &'a mut Retransmits,
207 conn_state: &'a super::State,
208 ) -> Self {
209 Self {
210 id,
211 state,
212 pending,
213 conn_state,
214 }
215 }
216
217 pub fn write(&mut self, data: &[u8]) -> Result<usize, WriteError> {
221 Ok(self.write_source(&mut ByteSlice::from_slice(data))?.bytes)
222 }
223
224 pub fn write_chunks(&mut self, data: &mut [Bytes]) -> Result<Written, WriteError> {
231 self.write_source(&mut BytesArray::from_chunks(data))
232 }
233
234 fn write_source<B: BytesSource>(&mut self, source: &mut B) -> Result<Written, WriteError> {
235 if self.conn_state.is_closed() {
236 trace!(%self.id, "write blocked; connection draining");
237 return Err(WriteError::Blocked);
238 }
239
240 let limit = self.state.write_limit();
241
242 let max_send_data = self.state.max_send_data(self.id);
243
244 let stream = self
245 .state
246 .send
247 .get_mut(&self.id)
248 .map(get_or_insert_send(max_send_data))
249 .ok_or(WriteError::ClosedStream)?;
250
251 if limit == 0 {
252 trace!(
253 stream = %self.id, max_data = self.state.max_data, data_sent = self.state.data_sent,
254 "write blocked by connection-level flow control or send window"
255 );
256 if !stream.connection_blocked {
257 stream.connection_blocked = true;
258 self.state.connection_blocked.push(self.id);
259 }
260 return Err(WriteError::Blocked);
261 }
262
263 let was_pending = stream.is_pending();
264 let written = stream.write(source, limit)?;
265 self.state.data_sent += written.bytes as u64;
266 self.state.unacked_data += written.bytes as u64;
267 trace!(stream = %self.id, "wrote {} bytes", written.bytes);
268 if !was_pending {
269 self.state.pending.push_pending(self.id, stream.priority);
270 }
271 Ok(written)
272 }
273
274 pub fn stopped(&self) -> Result<Option<VarInt>, ClosedStream> {
276 match self.state.send.get(&self.id).as_ref() {
277 Some(Some(s)) => Ok(s.stop_reason),
278 Some(None) => Ok(None),
279 None => Err(ClosedStream { _private: () }),
280 }
281 }
282
283 pub fn finish(&mut self) -> Result<(), FinishError> {
289 let max_send_data = self.state.max_send_data(self.id);
290 let stream = self
291 .state
292 .send
293 .get_mut(&self.id)
294 .map(get_or_insert_send(max_send_data))
295 .ok_or(FinishError::ClosedStream)?;
296
297 let was_pending = stream.is_pending();
298 stream.finish()?;
299 if !was_pending {
300 self.state.pending.push_pending(self.id, stream.priority);
301 }
302
303 Ok(())
304 }
305
306 pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
311 let max_send_data = self.state.max_send_data(self.id);
312 let stream = self
313 .state
314 .send
315 .get_mut(&self.id)
316 .map(get_or_insert_send(max_send_data))
317 .ok_or(ClosedStream { _private: () })?;
318
319 if matches!(stream.state, SendState::ResetSent) {
320 return Err(ClosedStream { _private: () });
322 }
323
324 self.state.unacked_data -= stream.pending.unacked();
328 stream.reset();
329 self.pending.reset_stream.push((self.id, error_code));
330
331 Ok(())
333 }
334
335 pub fn set_priority(&mut self, priority: i32) -> Result<(), ClosedStream> {
340 let max_send_data = self.state.max_send_data(self.id);
341 let stream = self
342 .state
343 .send
344 .get_mut(&self.id)
345 .map(get_or_insert_send(max_send_data))
346 .ok_or(ClosedStream { _private: () })?;
347
348 stream.priority = priority;
349 Ok(())
350 }
351
352 pub fn priority(&self) -> Result<i32, ClosedStream> {
357 let stream = self
358 .state
359 .send
360 .get(&self.id)
361 .ok_or(ClosedStream { _private: () })?;
362
363 Ok(stream.as_ref().map(|s| s.priority).unwrap_or_default())
364 }
365}
366
367struct PendingStreamsQueue {
369 streams: BinaryHeap<PendingStream>,
370 next: Option<PendingStream>,
373 recency: u64,
375}
376
377impl PendingStreamsQueue {
378 fn new() -> Self {
379 Self {
380 streams: BinaryHeap::new(),
381 next: None,
382 recency: u64::MAX,
383 }
384 }
385
386 fn reinsert_pending(&mut self, id: StreamId, priority: i32) {
388 if self.next.is_some() {
389 warn!("Attempting to reinsert a pending stream when next is already set");
390 return;
391 }
392
393 self.next = Some(PendingStream {
394 priority,
395 recency: self.recency,
396 id,
397 });
398 }
399
400 fn push_pending(&mut self, id: StreamId, priority: i32) {
402 self.recency = self.recency.saturating_sub(1);
404 self.streams.push(PendingStream {
405 priority,
406 recency: self.recency,
407 id,
408 });
409 }
410
411 fn pop(&mut self) -> Option<PendingStream> {
413 self.next.take().or_else(|| self.streams.pop())
414 }
415
416 fn clear(&mut self) {
418 self.next = None;
419 self.streams.clear();
420 }
421
422 fn iter(&self) -> impl Iterator<Item = &PendingStream> {
424 self.next.iter().chain(self.streams.iter())
425 }
426
427 #[cfg(test)]
428 fn len(&self) -> usize {
429 self.streams.len() + self.next.is_some() as usize
430 }
431}
432
433#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
435struct PendingStream {
436 priority: i32,
440 recency: u64,
446 id: StreamId,
450}
451
452#[derive(Debug, PartialEq, Eq)]
454pub enum StreamEvent {
455 Opened {
457 dir: Dir,
459 },
460 Readable {
462 id: StreamId,
464 },
465 Writable {
469 id: StreamId,
471 },
472 Finished {
474 id: StreamId,
476 },
477 Stopped {
479 id: StreamId,
481 error_code: VarInt,
483 },
484 Available {
486 dir: Dir,
488 },
489}
490
491#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
496#[must_use = "A frame might need to be enqueued"]
497pub struct ShouldTransmit(bool);
498
499impl ShouldTransmit {
500 pub fn should_transmit(self) -> bool {
502 self.0
503 }
504}
505
506#[derive(Debug, Default, Error, Clone, PartialEq, Eq)]
508#[error("closed stream")]
509pub struct ClosedStream {
510 _private: (),
511}
512
513impl From<ClosedStream> for io::Error {
514 fn from(x: ClosedStream) -> Self {
515 Self::new(io::ErrorKind::NotConnected, x)
516 }
517}
518
519#[derive(Debug, Copy, Clone, Eq, PartialEq)]
520enum StreamHalf {
521 Send,
522 Recv,
523}