ant_quic/connection/streams/
mod.rs1use std::{
2 collections::{BinaryHeap, hash_map},
3 io,
4};
5
6use bytes::Bytes;
7use thiserror::Error;
8use tracing::{trace, warn};
9
10use super::spaces::{Retransmits, ThinRetransmits};
11use crate::{
12 Dir, StreamId, VarInt,
13 connection::streams::state::{get_or_insert_recv, get_or_insert_send},
14 frame,
15};
16
17mod recv;
18use recv::Recv;
19pub use recv::{Chunks, ReadError, ReadableError};
20
21mod send;
22pub(crate) use send::{ByteSlice, BytesArray};
23use send::{BytesSource, Send, SendState};
24pub use send::{FinishError, WriteError, Written};
25
26mod state;
27pub use state::StreamsState;
28
29pub struct Streams<'a> {
31 pub(super) state: &'a mut StreamsState,
32 pub(super) conn_state: &'a super::State,
33}
34
35impl<'a> Streams<'a> {
36 #[cfg(fuzzing)]
37 pub fn new(state: &'a mut StreamsState, conn_state: &'a super::State) -> Self {
38 Self { state, conn_state }
39 }
40
41 pub fn open(&mut self, dir: Dir) -> Option<StreamId> {
45 if self.conn_state.is_closed() {
46 return None;
47 }
48
49 if self.state.next[dir as usize] >= self.state.max[dir as usize] {
50 return None;
51 }
52
53 self.state.next[dir as usize] += 1;
54 let id = StreamId::new(self.state.side, dir, self.state.next[dir as usize] - 1);
55 self.state.insert(false, id);
56 self.state.send_streams += 1;
57 Some(id)
58 }
59
60 pub fn accept(&mut self, dir: Dir) -> Option<StreamId> {
64 if self.state.next_remote[dir as usize] == self.state.next_reported_remote[dir as usize] {
65 return None;
66 }
67
68 let x = self.state.next_reported_remote[dir as usize];
69 self.state.next_reported_remote[dir as usize] = x + 1;
70 if dir == Dir::Bi {
71 self.state.send_streams += 1;
72 }
73
74 Some(StreamId::new(!self.state.side, dir, x))
75 }
76
77 #[cfg(fuzzing)]
78 pub fn state(&mut self) -> &mut StreamsState {
79 self.state
80 }
81
82 pub fn send_streams(&self) -> usize {
84 self.state.send_streams
85 }
86
87 pub fn remote_open_streams(&self, dir: Dir) -> u64 {
89 self.state.next_remote[dir as usize]
90 - (self.state.max_remote[dir as usize]
91 - self.state.allocated_remote_count[dir as usize])
92 }
93}
94
95pub struct RecvStream<'a> {
97 pub(super) id: StreamId,
98 pub(super) state: &'a mut StreamsState,
99 pub(super) pending: &'a mut Retransmits,
100}
101
102impl RecvStream<'_> {
103 pub fn read(&mut self, ordered: bool) -> Result<Chunks<'_>, ReadableError> {
114 if self.state.conn_closed() {
115 return Err(ReadableError::ConnectionClosed);
116 }
117
118 Chunks::new(self.id, ordered, self.state, self.pending)
119 }
120
121 pub fn stop(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
125 if self.state.conn_closed() {
126 return Err(ClosedStream { _private: () });
127 }
128
129 let mut entry = match self.state.recv.entry(self.id) {
130 hash_map::Entry::Occupied(s) => s,
131 hash_map::Entry::Vacant(_) => return Err(ClosedStream { _private: () }),
132 };
133 let stream = get_or_insert_recv(self.state.stream_receive_window)(entry.get_mut());
134
135 let (read_credits, stop_sending) = stream.stop()?;
136 if stop_sending.should_transmit() {
137 self.pending.stop_sending.push(frame::StopSending {
138 id: self.id,
139 error_code,
140 });
141 }
142
143 if !stream.final_offset_unknown() {
145 let recv = entry.remove().expect("must have recv when stopping");
146 self.state.stream_recv_freed(self.id, recv);
147 }
148
149 if self.state.add_read_credits(read_credits).should_transmit() {
151 self.pending.max_data = true;
152 }
153
154 Ok(())
155 }
156
157 pub fn received_reset(&mut self) -> Result<Option<VarInt>, ClosedStream> {
161 if self.state.conn_closed() {
162 return Err(ClosedStream { _private: () });
163 }
164
165 let hash_map::Entry::Occupied(entry) = self.state.recv.entry(self.id) else {
166 return Err(ClosedStream { _private: () });
167 };
168
169 let Some(s) = entry.get().as_ref().and_then(|s| s.as_open_recv()) else {
170 return Ok(None);
171 };
172
173 if s.stopped {
174 return Err(ClosedStream { _private: () });
175 }
176
177 let Some(code) = s.reset_code() else {
178 return Ok(None);
179 };
180
181 let (_, recv) = entry.remove_entry();
183 self.state
184 .stream_recv_freed(self.id, recv.expect("must have recv on reset"));
185 self.state.queue_max_stream_id(self.pending);
186
187 Ok(Some(code))
188 }
189}
190
191pub struct SendStream<'a> {
193 pub(super) id: StreamId,
194 pub(super) state: &'a mut StreamsState,
195 pub(super) pending: &'a mut Retransmits,
196 pub(super) conn_state: &'a super::State,
197}
198
199#[allow(clippy::needless_lifetimes)] impl<'a> SendStream<'a> {
201 #[cfg(fuzzing)]
202 pub fn new(
203 id: StreamId,
204 state: &'a mut StreamsState,
205 pending: &'a mut Retransmits,
206 conn_state: &'a super::State,
207 ) -> Self {
208 Self {
209 id,
210 state,
211 pending,
212 conn_state,
213 }
214 }
215
216 pub fn write(&mut self, data: &[u8]) -> Result<usize, WriteError> {
220 Ok(self.write_source(&mut ByteSlice::from_slice(data))?.bytes)
221 }
222
223 pub fn write_chunks(&mut self, data: &mut [Bytes]) -> Result<Written, WriteError> {
230 self.write_source(&mut BytesArray::from_chunks(data))
231 }
232
233 fn write_source<B: BytesSource>(&mut self, source: &mut B) -> Result<Written, WriteError> {
234 if self.conn_state.is_closed() {
235 trace!(%self.id, "write blocked; connection draining");
236 return Err(WriteError::Blocked);
237 }
238
239 let limit = self.state.write_limit();
240
241 let max_send_data = self.state.max_send_data(self.id);
242
243 let stream = self
244 .state
245 .send
246 .get_mut(&self.id)
247 .map(get_or_insert_send(max_send_data))
248 .ok_or(WriteError::ClosedStream)?;
249
250 if limit == 0 {
251 trace!(
252 stream = %self.id, max_data = self.state.max_data, data_sent = self.state.data_sent,
253 "write blocked by connection-level flow control or send window"
254 );
255 if !stream.connection_blocked {
256 stream.connection_blocked = true;
257 self.state.connection_blocked.push(self.id);
258 }
259 return Err(WriteError::Blocked);
260 }
261
262 let was_pending = stream.is_pending();
263 let written = stream.write(source, limit)?;
264 self.state.data_sent += written.bytes as u64;
265 self.state.unacked_data += written.bytes as u64;
266 trace!(stream = %self.id, "wrote {} bytes", written.bytes);
267 if !was_pending {
268 self.state.pending.push_pending(self.id, stream.priority);
269 }
270 Ok(written)
271 }
272
273 pub fn stopped(&self) -> Result<Option<VarInt>, ClosedStream> {
275 match self.state.send.get(&self.id).as_ref() {
276 Some(Some(s)) => Ok(s.stop_reason),
277 Some(None) => Ok(None),
278 None => Err(ClosedStream { _private: () }),
279 }
280 }
281
282 pub fn finish(&mut self) -> Result<(), FinishError> {
288 let max_send_data = self.state.max_send_data(self.id);
289 let stream = self
290 .state
291 .send
292 .get_mut(&self.id)
293 .map(get_or_insert_send(max_send_data))
294 .ok_or(FinishError::ClosedStream)?;
295
296 let was_pending = stream.is_pending();
297 stream.finish()?;
298 if !was_pending {
299 self.state.pending.push_pending(self.id, stream.priority);
300 }
301
302 Ok(())
303 }
304
305 pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
310 let max_send_data = self.state.max_send_data(self.id);
311 let stream = self
312 .state
313 .send
314 .get_mut(&self.id)
315 .map(get_or_insert_send(max_send_data))
316 .ok_or(ClosedStream { _private: () })?;
317
318 if matches!(stream.state, SendState::ResetSent) {
319 return Err(ClosedStream { _private: () });
321 }
322
323 self.state.unacked_data -= stream.pending.unacked();
327 stream.reset();
328 self.pending.reset_stream.push((self.id, error_code));
329
330 Ok(())
332 }
333
334 pub fn set_priority(&mut self, priority: i32) -> Result<(), ClosedStream> {
339 let max_send_data = self.state.max_send_data(self.id);
340 let stream = self
341 .state
342 .send
343 .get_mut(&self.id)
344 .map(get_or_insert_send(max_send_data))
345 .ok_or(ClosedStream { _private: () })?;
346
347 stream.priority = priority;
348 Ok(())
349 }
350
351 pub fn priority(&self) -> Result<i32, ClosedStream> {
356 let stream = self
357 .state
358 .send
359 .get(&self.id)
360 .ok_or(ClosedStream { _private: () })?;
361
362 Ok(stream.as_ref().map(|s| s.priority).unwrap_or_default())
363 }
364}
365
366struct PendingStreamsQueue {
368 streams: BinaryHeap<PendingStream>,
369 next: Option<PendingStream>,
372 recency: u64,
374}
375
376impl PendingStreamsQueue {
377 fn new() -> Self {
378 Self {
379 streams: BinaryHeap::new(),
380 next: None,
381 recency: u64::MAX,
382 }
383 }
384
385 fn reinsert_pending(&mut self, id: StreamId, priority: i32) {
387 if self.next.is_some() {
388 warn!("Attempting to reinsert a pending stream when next is already set");
389 return;
390 }
391
392 self.next = Some(PendingStream {
393 priority,
394 recency: self.recency,
395 id,
396 });
397 }
398
399 fn push_pending(&mut self, id: StreamId, priority: i32) {
401 self.recency = self.recency.saturating_sub(1);
403 self.streams.push(PendingStream {
404 priority,
405 recency: self.recency,
406 id,
407 });
408 }
409
410 fn pop(&mut self) -> Option<PendingStream> {
412 self.next.take().or_else(|| self.streams.pop())
413 }
414
415 fn clear(&mut self) {
417 self.next = None;
418 self.streams.clear();
419 }
420
421 fn iter(&self) -> impl Iterator<Item = &PendingStream> {
423 self.next.iter().chain(self.streams.iter())
424 }
425
426 #[cfg(test)]
427 fn len(&self) -> usize {
428 self.streams.len() + self.next.is_some() as usize
429 }
430}
431
432#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
434struct PendingStream {
435 priority: i32,
439 recency: u64,
445 id: StreamId,
449}
450
451#[derive(Debug, PartialEq, Eq)]
453pub enum StreamEvent {
454 Opened {
456 dir: Dir,
458 },
459 Readable {
461 id: StreamId,
463 },
464 Writable {
468 id: StreamId,
470 },
471 Finished {
473 id: StreamId,
475 },
476 Stopped {
478 id: StreamId,
480 error_code: VarInt,
482 },
483 Available {
485 dir: Dir,
487 },
488}
489
490#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
495#[must_use = "A frame might need to be enqueued"]
496pub struct ShouldTransmit(bool);
497
498impl ShouldTransmit {
499 pub fn should_transmit(self) -> bool {
501 self.0
502 }
503}
504
505#[derive(Debug, Default, Error, Clone, PartialEq, Eq)]
507#[error("closed stream")]
508pub struct ClosedStream {
509 _private: (),
510}
511
512impl From<ClosedStream> for io::Error {
513 fn from(x: ClosedStream) -> Self {
514 Self::new(io::ErrorKind::NotConnected, x)
515 }
516}
517
518#[derive(Debug, Copy, Clone, Eq, PartialEq)]
519enum StreamHalf {
520 Send,
521 Recv,
522}