ant_quic/connection/streams/
mod.rs1use std::{
9 collections::{BinaryHeap, hash_map},
10 io,
11};
12
13use bytes::Bytes;
14use thiserror::Error;
15use tracing::{trace, warn};
16
17use super::spaces::{Retransmits, ThinRetransmits};
18use crate::{
19 Dir, StreamId, VarInt,
20 connection::streams::state::{get_or_insert_recv, get_or_insert_send},
21 frame,
22};
23
24mod recv;
25use recv::Recv;
26pub use recv::{Chunks, ReadError, ReadableError};
27
28mod send;
29pub(crate) use send::{ByteSlice, BytesArray};
30use send::{BytesSource, Send, SendState};
31pub use send::{FinishError, WriteError, Written};
32
33mod state;
34pub use state::StreamsState;
35
36pub struct Streams<'a> {
38 pub(super) state: &'a mut StreamsState,
39 pub(super) conn_state: &'a super::State,
40}
41
42impl<'a> Streams<'a> {
43 #[cfg(fuzzing)]
44 pub fn new(state: &'a mut StreamsState, conn_state: &'a super::State) -> Self {
45 Self { state, conn_state }
46 }
47
48 pub fn open(&mut self, dir: Dir) -> Option<StreamId> {
52 if self.conn_state.is_closed() {
53 return None;
54 }
55
56 if self.state.next[dir as usize] >= self.state.max[dir as usize] {
57 return None;
58 }
59
60 self.state.next[dir as usize] += 1;
61 let id = StreamId::new(self.state.side, dir, self.state.next[dir as usize] - 1);
62 self.state.insert(false, id);
63 self.state.send_streams += 1;
64 Some(id)
65 }
66
67 pub fn accept(&mut self, dir: Dir) -> Option<StreamId> {
71 if self.state.next_remote[dir as usize] == self.state.next_reported_remote[dir as usize] {
72 return None;
73 }
74
75 let x = self.state.next_reported_remote[dir as usize];
76 self.state.next_reported_remote[dir as usize] = x + 1;
77 if dir == Dir::Bi {
78 self.state.send_streams += 1;
79 }
80
81 Some(StreamId::new(!self.state.side, dir, x))
82 }
83
84 #[cfg(fuzzing)]
85 pub fn state(&mut self) -> &mut StreamsState {
86 self.state
87 }
88
89 pub fn send_streams(&self) -> usize {
91 self.state.send_streams
92 }
93
94 pub fn remote_open_streams(&self, dir: Dir) -> u64 {
96 self.state.next_remote[dir as usize]
97 - (self.state.max_remote[dir as usize]
98 - self.state.allocated_remote_count[dir as usize])
99 }
100}
101
102pub struct RecvStream<'a> {
104 pub(super) id: StreamId,
105 pub(super) state: &'a mut StreamsState,
106 pub(super) pending: &'a mut Retransmits,
107}
108
109impl RecvStream<'_> {
110 pub fn read(&mut self, ordered: bool) -> Result<Chunks<'_>, ReadableError> {
121 if self.state.conn_closed() {
122 return Err(ReadableError::ConnectionClosed);
123 }
124
125 Chunks::new(self.id, ordered, self.state, self.pending)
126 }
127
128 pub fn stop(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
132 if self.state.conn_closed() {
133 return Err(ClosedStream { _private: () });
134 }
135
136 let mut entry = match self.state.recv.entry(self.id) {
137 hash_map::Entry::Occupied(s) => s,
138 hash_map::Entry::Vacant(_) => return Err(ClosedStream { _private: () }),
139 };
140 let stream = get_or_insert_recv(self.state.stream_receive_window)(entry.get_mut());
141
142 let (read_credits, stop_sending) = stream.stop()?;
143 if stop_sending.should_transmit() {
144 self.pending.stop_sending.push(frame::StopSending {
145 id: self.id,
146 error_code,
147 });
148 }
149
150 if !stream.final_offset_unknown() {
152 let recv = entry.remove().expect("must have recv when stopping");
153 self.state.stream_recv_freed(self.id, recv);
154 }
155
156 if self.state.add_read_credits(read_credits).should_transmit() {
158 self.pending.max_data = true;
159 }
160
161 Ok(())
162 }
163
164 pub fn received_reset(&mut self) -> Result<Option<VarInt>, ClosedStream> {
168 if self.state.conn_closed() {
169 return Err(ClosedStream { _private: () });
170 }
171
172 let hash_map::Entry::Occupied(entry) = self.state.recv.entry(self.id) else {
173 return Err(ClosedStream { _private: () });
174 };
175
176 let Some(s) = entry.get().as_ref().and_then(|s| s.as_open_recv()) else {
177 return Ok(None);
178 };
179
180 if s.stopped {
181 return Err(ClosedStream { _private: () });
182 }
183
184 let Some(code) = s.reset_code() else {
185 return Ok(None);
186 };
187
188 let (_, recv) = entry.remove_entry();
190 self.state
191 .stream_recv_freed(self.id, recv.expect("must have recv on reset"));
192 self.state.queue_max_stream_id(self.pending);
193
194 Ok(Some(code))
195 }
196}
197
198pub struct SendStream<'a> {
200 pub(super) id: StreamId,
201 pub(super) state: &'a mut StreamsState,
202 pub(super) pending: &'a mut Retransmits,
203 pub(super) conn_state: &'a super::State,
204}
205
206#[allow(clippy::needless_lifetimes)] impl<'a> SendStream<'a> {
208 #[cfg(fuzzing)]
209 pub fn new(
210 id: StreamId,
211 state: &'a mut StreamsState,
212 pending: &'a mut Retransmits,
213 conn_state: &'a super::State,
214 ) -> Self {
215 Self {
216 id,
217 state,
218 pending,
219 conn_state,
220 }
221 }
222
223 pub fn write(&mut self, data: &[u8]) -> Result<usize, WriteError> {
227 Ok(self.write_source(&mut ByteSlice::from_slice(data))?.bytes)
228 }
229
230 pub fn write_chunks(&mut self, data: &mut [Bytes]) -> Result<Written, WriteError> {
237 self.write_source(&mut BytesArray::from_chunks(data))
238 }
239
240 fn write_source<B: BytesSource>(&mut self, source: &mut B) -> Result<Written, WriteError> {
241 if self.conn_state.is_closed() {
242 trace!(%self.id, "write blocked; connection draining");
243 return Err(WriteError::Blocked);
244 }
245
246 let limit = self.state.write_limit();
247
248 let max_send_data = self.state.max_send_data(self.id);
249
250 let stream = self
251 .state
252 .send
253 .get_mut(&self.id)
254 .map(get_or_insert_send(max_send_data))
255 .ok_or(WriteError::ClosedStream)?;
256
257 if limit == 0 {
258 trace!(
259 stream = %self.id, max_data = self.state.max_data, data_sent = self.state.data_sent,
260 "write blocked by connection-level flow control or send window"
261 );
262 if !stream.connection_blocked {
263 stream.connection_blocked = true;
264 self.state.connection_blocked.push(self.id);
265 }
266 return Err(WriteError::Blocked);
267 }
268
269 let was_pending = stream.is_pending();
270 let written = stream.write(source, limit)?;
271 self.state.data_sent += written.bytes as u64;
272 self.state.unacked_data += written.bytes as u64;
273 trace!(stream = %self.id, "wrote {} bytes", written.bytes);
274 if !was_pending {
275 self.state.pending.push_pending(self.id, stream.priority);
276 }
277 Ok(written)
278 }
279
280 pub fn stopped(&self) -> Result<Option<VarInt>, ClosedStream> {
282 match self.state.send.get(&self.id).as_ref() {
283 Some(Some(s)) => Ok(s.stop_reason),
284 Some(None) => Ok(None),
285 None => Err(ClosedStream { _private: () }),
286 }
287 }
288
289 pub fn finish(&mut self) -> Result<(), FinishError> {
295 let max_send_data = self.state.max_send_data(self.id);
296 let stream = self
297 .state
298 .send
299 .get_mut(&self.id)
300 .map(get_or_insert_send(max_send_data))
301 .ok_or(FinishError::ClosedStream)?;
302
303 let was_pending = stream.is_pending();
304 stream.finish()?;
305 if !was_pending {
306 self.state.pending.push_pending(self.id, stream.priority);
307 }
308
309 Ok(())
310 }
311
312 pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
317 let max_send_data = self.state.max_send_data(self.id);
318 let stream = self
319 .state
320 .send
321 .get_mut(&self.id)
322 .map(get_or_insert_send(max_send_data))
323 .ok_or(ClosedStream { _private: () })?;
324
325 if matches!(stream.state, SendState::ResetSent) {
326 return Err(ClosedStream { _private: () });
328 }
329
330 self.state.unacked_data -= stream.pending.unacked();
334 stream.reset();
335 self.pending.reset_stream.push((self.id, error_code));
336
337 Ok(())
339 }
340
341 pub fn set_priority(&mut self, priority: i32) -> Result<(), ClosedStream> {
346 let max_send_data = self.state.max_send_data(self.id);
347 let stream = self
348 .state
349 .send
350 .get_mut(&self.id)
351 .map(get_or_insert_send(max_send_data))
352 .ok_or(ClosedStream { _private: () })?;
353
354 stream.priority = priority;
355 Ok(())
356 }
357
358 pub fn priority(&self) -> Result<i32, ClosedStream> {
363 let stream = self
364 .state
365 .send
366 .get(&self.id)
367 .ok_or(ClosedStream { _private: () })?;
368
369 Ok(stream.as_ref().map(|s| s.priority).unwrap_or_default())
370 }
371}
372
373struct PendingStreamsQueue {
375 streams: BinaryHeap<PendingStream>,
376 next: Option<PendingStream>,
379 recency: u64,
381}
382
383impl PendingStreamsQueue {
384 fn new() -> Self {
385 Self {
386 streams: BinaryHeap::new(),
387 next: None,
388 recency: u64::MAX,
389 }
390 }
391
392 fn reinsert_pending(&mut self, id: StreamId, priority: i32) {
394 if self.next.is_some() {
395 warn!("Attempting to reinsert a pending stream when next is already set");
396 return;
397 }
398
399 self.next = Some(PendingStream {
400 priority,
401 recency: self.recency,
402 id,
403 });
404 }
405
406 fn push_pending(&mut self, id: StreamId, priority: i32) {
408 self.recency = self.recency.saturating_sub(1);
410 self.streams.push(PendingStream {
411 priority,
412 recency: self.recency,
413 id,
414 });
415 }
416
417 fn pop(&mut self) -> Option<PendingStream> {
419 self.next.take().or_else(|| self.streams.pop())
420 }
421
422 fn clear(&mut self) {
424 self.next = None;
425 self.streams.clear();
426 }
427
428 fn iter(&self) -> impl Iterator<Item = &PendingStream> {
430 self.next.iter().chain(self.streams.iter())
431 }
432
433 #[cfg(test)]
434 fn len(&self) -> usize {
435 self.streams.len() + self.next.is_some() as usize
436 }
437}
438
439#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
441struct PendingStream {
442 priority: i32,
446 recency: u64,
452 id: StreamId,
456}
457
458#[derive(Debug, PartialEq, Eq)]
460pub enum StreamEvent {
461 Opened {
463 dir: Dir,
465 },
466 Readable {
468 id: StreamId,
470 },
471 Writable {
475 id: StreamId,
477 },
478 Finished {
480 id: StreamId,
482 },
483 Stopped {
485 id: StreamId,
487 error_code: VarInt,
489 },
490 Available {
492 dir: Dir,
494 },
495}
496
497#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
502#[must_use = "A frame might need to be enqueued"]
503pub struct ShouldTransmit(bool);
504
505impl ShouldTransmit {
506 pub fn should_transmit(self) -> bool {
508 self.0
509 }
510}
511
512#[derive(Debug, Default, Error, Clone, PartialEq, Eq)]
514#[error("closed stream")]
515pub struct ClosedStream {
516 _private: (),
517}
518
519impl From<ClosedStream> for io::Error {
520 fn from(x: ClosedStream) -> Self {
521 Self::new(io::ErrorKind::NotConnected, x)
522 }
523}
524
525#[derive(Debug, Copy, Clone, Eq, PartialEq)]
526enum StreamHalf {
527 Send,
528 Recv,
529}