h3/proto/
stream.rs

1use bytes::{Buf, BufMut};
2use std::{
3    convert::TryFrom,
4    fmt::{self, Display},
5    ops::Add,
6};
7
8use crate::webtransport::SessionId;
9
10use super::{
11    coding::{BufExt, BufMutExt, Decode, Encode, UnexpectedEnd},
12    varint::VarInt,
13};
14
15#[derive(Debug, PartialEq, Eq, Clone)]
16pub struct StreamType(u64);
17
18macro_rules! stream_types {
19    {$($name:ident = $val:expr,)*} => {
20        impl StreamType {
21            $(pub const $name: StreamType = StreamType($val);)*
22        }
23    }
24}
25
26stream_types! {
27    CONTROL = 0x00,
28    PUSH = 0x01,
29    ENCODER = 0x02,
30    DECODER = 0x03,
31    WEBTRANSPORT_BIDI = 0x41,
32    WEBTRANSPORT_UNI = 0x54,
33}
34
35impl StreamType {
36    pub const MAX_ENCODED_SIZE: usize = VarInt::MAX_SIZE;
37
38    pub fn value(&self) -> u64 {
39        self.0
40    }
41    /// returns a StreamType type with random number of the 0x1f * N + 0x21
42    /// format within the range of the Varint implementation
43    pub fn grease() -> Self {
44        StreamType(fastrand::u64(0..0x210842108421083) * 0x1f + 0x21)
45    }
46
47    pub fn from_value(value: u64) -> Self {
48        StreamType(value)
49    }
50}
51
52impl Decode for StreamType {
53    fn decode<B: Buf>(buf: &mut B) -> Result<Self, UnexpectedEnd> {
54        Ok(StreamType(buf.get_var()?))
55    }
56}
57
58impl Encode for StreamType {
59    fn encode<W: BufMut>(&self, buf: &mut W) {
60        buf.write_var(self.0);
61    }
62}
63
64impl fmt::Display for StreamType {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        match self {
67            &StreamType::CONTROL => write!(f, "Control"),
68            &StreamType::ENCODER => write!(f, "Encoder"),
69            &StreamType::DECODER => write!(f, "Decoder"),
70            &StreamType::WEBTRANSPORT_UNI => write!(f, "WebTransportUni"),
71            x => write!(f, "StreamType({})", x.0),
72        }
73    }
74}
75
76/// Identifier for a stream
77#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
78pub struct StreamId(#[cfg(not(test))] u64, #[cfg(test)] pub(crate) u64);
79
80impl fmt::Display for StreamId {
81    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82        let initiator = match self.initiator() {
83            Side::Client => "client",
84            Side::Server => "server",
85        };
86        let dir = match self.dir() {
87            Dir::Uni => "uni",
88            Dir::Bi => "bi",
89        };
90        write!(
91            f,
92            "{} {}directional stream {}",
93            initiator,
94            dir,
95            self.index()
96        )
97    }
98}
99
100impl StreamId {
101    pub(crate) const FIRST_REQUEST: Self = Self::new(0, Dir::Bi, Side::Client);
102
103    /// Is this a client-initiated request?
104    pub fn is_request(&self) -> bool {
105        self.dir() == Dir::Bi && self.initiator() == Side::Client
106    }
107
108    /// Is this a server push?
109    pub fn is_push(&self) -> bool {
110        self.dir() == Dir::Uni && self.initiator() == Side::Server
111    }
112
113    /// Which side of a connection initiated the stream
114    pub(crate) fn initiator(self) -> Side {
115        if self.0 & 0x1 == 0 {
116            Side::Client
117        } else {
118            Side::Server
119        }
120    }
121
122    /// Create a new StreamId
123    const fn new(index: u64, dir: Dir, initiator: Side) -> Self {
124        StreamId((index) << 2 | (dir as u64) << 1 | initiator as u64)
125    }
126
127    /// Distinguishes streams of the same initiator and directionality
128    pub fn index(self) -> u64 {
129        self.0 >> 2
130    }
131
132    /// Which directions data flows in
133    fn dir(self) -> Dir {
134        if self.0 & 0x2 == 0 {
135            Dir::Bi
136        } else {
137            Dir::Uni
138        }
139    }
140
141    #[allow(missing_docs)]
142    pub fn into_inner(self) -> u64 {
143        self.0
144    }
145}
146
147impl TryFrom<u64> for StreamId {
148    type Error = InvalidStreamId;
149    fn try_from(v: u64) -> Result<Self, Self::Error> {
150        if v > VarInt::MAX.0 {
151            return Err(InvalidStreamId(v));
152        }
153        Ok(Self(v))
154    }
155}
156
157impl From<VarInt> for StreamId {
158    fn from(v: VarInt) -> Self {
159        Self(v.0)
160    }
161}
162
163impl From<StreamId> for VarInt {
164    fn from(v: StreamId) -> Self {
165        Self(v.0)
166    }
167}
168
169/// Invalid StreamId, for example because it's too large
170#[derive(Debug, PartialEq)]
171pub struct InvalidStreamId(pub(crate) u64);
172
173impl Display for InvalidStreamId {
174    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175        write!(f, "invalid stream id: {:x}", self.0)
176    }
177}
178
179impl Encode for StreamId {
180    fn encode<B: bytes::BufMut>(&self, buf: &mut B) {
181        VarInt::from_u64(self.0).unwrap().encode(buf);
182    }
183}
184
185impl Add<usize> for StreamId {
186    type Output = StreamId;
187
188    #[allow(clippy::suspicious_arithmetic_impl)]
189    fn add(self, rhs: usize) -> Self::Output {
190        let index = u64::min(
191            u64::saturating_add(self.index(), rhs as u64),
192            VarInt::MAX.0 >> 2,
193        );
194        Self::new(index, self.dir(), self.initiator())
195    }
196}
197
198impl From<SessionId> for StreamId {
199    fn from(value: SessionId) -> Self {
200        Self(value.into_inner())
201    }
202}
203
204#[derive(Debug, Copy, Clone, Eq, PartialEq)]
205pub enum Side {
206    /// The initiator of a connection
207    Client = 0,
208    /// The acceptor of a connection
209    Server = 1,
210}
211
212/// Whether a stream communicates data in both directions or only from the initiator
213#[derive(Debug, Copy, Clone, Eq, PartialEq)]
214enum Dir {
215    /// Data flows in both directions
216    Bi = 0,
217    /// Data flows only from the stream's initiator
218    Uni = 1,
219}