yamux/frame/
header.rs

1// Copyright (c) 2018-2019 Parity Technologies (UK) Ltd.
2//
3// Licensed under the Apache License, Version 2.0 or MIT license, at your option.
4//
5// A copy of the Apache License, Version 2.0 is included in the software as
6// LICENSE-APACHE and a copy of the MIT license is included in the software
7// as LICENSE-MIT. You may also obtain a copy of the Apache License, Version 2.0
8// at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license
9// at https://opensource.org/licenses/MIT.
10
11use futures::future::Either;
12use std::fmt;
13
14/// The message frame header.
15#[derive(Clone, Debug, PartialEq, Eq)]
16pub struct Header<T> {
17    version: Version,
18    tag: Tag,
19    flags: Flags,
20    stream_id: StreamId,
21    length: Len,
22    _marker: std::marker::PhantomData<T>,
23}
24
25impl<T> fmt::Display for Header<T> {
26    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
27        write!(
28            f,
29            "(Header {:?} {} (len {}) (flags {:?}))",
30            self.tag,
31            self.stream_id,
32            self.length.val(),
33            self.flags.val()
34        )
35    }
36}
37
38impl<T> Header<T> {
39    pub fn tag(&self) -> Tag {
40        self.tag
41    }
42
43    pub fn flags(&self) -> Flags {
44        self.flags
45    }
46
47    pub fn stream_id(&self) -> StreamId {
48        self.stream_id
49    }
50
51    pub fn len(&self) -> Len {
52        self.length
53    }
54
55    #[cfg(test)]
56    pub fn set_len(&mut self, len: u32) {
57        self.length = Len(len)
58    }
59
60    /// Arbitrary type cast, use with caution.
61    fn cast<U>(self) -> Header<U> {
62        Header {
63            version: self.version,
64            tag: self.tag,
65            flags: self.flags,
66            stream_id: self.stream_id,
67            length: self.length,
68            _marker: std::marker::PhantomData,
69        }
70    }
71
72    /// Introduce this header to the right of a binary header type.
73    pub(crate) fn right<U>(self) -> Header<Either<U, T>> {
74        self.cast()
75    }
76
77    /// Introduce this header to the left of a binary header type.
78    pub(crate) fn left<U>(self) -> Header<Either<T, U>> {
79        self.cast()
80    }
81}
82
83impl<A: private::Sealed> From<Header<A>> for Header<()> {
84    fn from(h: Header<A>) -> Header<()> {
85        h.cast()
86    }
87}
88
89impl Header<()> {
90    pub(crate) fn into_data(self) -> Header<Data> {
91        debug_assert_eq!(self.tag, Tag::Data);
92        self.cast()
93    }
94
95    pub(crate) fn into_window_update(self) -> Header<WindowUpdate> {
96        debug_assert_eq!(self.tag, Tag::WindowUpdate);
97        self.cast()
98    }
99
100    pub(crate) fn into_ping(self) -> Header<Ping> {
101        debug_assert_eq!(self.tag, Tag::Ping);
102        self.cast()
103    }
104}
105
106impl<T: HasSyn> Header<T> {
107    /// Set the [`SYN`] flag.
108    pub fn syn(&mut self) {
109        self.flags.0 |= SYN.0
110    }
111}
112
113impl<T: HasAck> Header<T> {
114    /// Set the [`ACK`] flag.
115    pub fn ack(&mut self) {
116        self.flags.0 |= ACK.0
117    }
118}
119
120impl<T: HasFin> Header<T> {
121    /// Set the [`FIN`] flag.
122    pub fn fin(&mut self) {
123        self.flags.0 |= FIN.0
124    }
125}
126
127impl<T: HasRst> Header<T> {
128    /// Set the [`RST`] flag.
129    pub fn rst(&mut self) {
130        self.flags.0 |= RST.0
131    }
132}
133
134impl Header<Data> {
135    /// Create a new data frame header.
136    pub fn data(id: StreamId, len: u32) -> Self {
137        Header {
138            version: Version(0),
139            tag: Tag::Data,
140            flags: Flags(0),
141            stream_id: id,
142            length: Len(len),
143            _marker: std::marker::PhantomData,
144        }
145    }
146}
147
148impl Header<WindowUpdate> {
149    /// Create a new window update frame header.
150    pub fn window_update(id: StreamId, credit: u32) -> Self {
151        Header {
152            version: Version(0),
153            tag: Tag::WindowUpdate,
154            flags: Flags(0),
155            stream_id: id,
156            length: Len(credit),
157            _marker: std::marker::PhantomData,
158        }
159    }
160
161    /// The credit this window update grants to the remote.
162    pub fn credit(&self) -> u32 {
163        self.length.0
164    }
165}
166
167impl Header<Ping> {
168    /// Create a new ping frame header.
169    pub fn ping(nonce: u32) -> Self {
170        Header {
171            version: Version(0),
172            tag: Tag::Ping,
173            flags: Flags(0),
174            stream_id: StreamId(0),
175            length: Len(nonce),
176            _marker: std::marker::PhantomData,
177        }
178    }
179
180    /// The nonce of this ping.
181    pub fn nonce(&self) -> u32 {
182        self.length.0
183    }
184}
185
186impl Header<GoAway> {
187    /// Terminate the session without indicating an error to the remote.
188    pub fn term() -> Self {
189        Self::go_away(0)
190    }
191
192    /// Terminate the session indicating a protocol error to the remote.
193    pub fn protocol_error() -> Self {
194        Self::go_away(1)
195    }
196
197    /// Terminate the session indicating an internal error to the remote.
198    pub fn internal_error() -> Self {
199        Self::go_away(2)
200    }
201
202    fn go_away(code: u32) -> Self {
203        Header {
204            version: Version(0),
205            tag: Tag::GoAway,
206            flags: Flags(0),
207            stream_id: StreamId(0),
208            length: Len(code),
209            _marker: std::marker::PhantomData,
210        }
211    }
212}
213
214/// Data message type.
215#[derive(Clone, Debug)]
216pub enum Data {}
217
218/// Window update message type.
219#[derive(Clone, Debug)]
220pub enum WindowUpdate {}
221
222/// Ping message type.
223#[derive(Clone, Debug)]
224pub enum Ping {}
225
226/// Go Away message type.
227#[derive(Clone, Debug)]
228pub enum GoAway {}
229
230/// Types which have a `syn` method.
231pub trait HasSyn: private::Sealed {}
232impl HasSyn for Data {}
233impl HasSyn for WindowUpdate {}
234impl HasSyn for Ping {}
235impl<A: HasSyn, B: HasSyn> HasSyn for Either<A, B> {}
236
237/// Types which have an `ack` method.
238pub trait HasAck: private::Sealed {}
239impl HasAck for Data {}
240impl HasAck for WindowUpdate {}
241impl HasAck for Ping {}
242impl<A: HasAck, B: HasAck> HasAck for Either<A, B> {}
243
244/// Types which have a `fin` method.
245pub trait HasFin: private::Sealed {}
246impl HasFin for Data {}
247impl HasFin for WindowUpdate {}
248
249/// Types which have a `rst` method.
250pub trait HasRst: private::Sealed {}
251impl HasRst for Data {}
252impl HasRst for WindowUpdate {}
253
254pub(super) mod private {
255    pub trait Sealed {}
256
257    impl Sealed for super::Data {}
258    impl Sealed for super::WindowUpdate {}
259    impl Sealed for super::Ping {}
260    impl Sealed for super::GoAway {}
261    impl<A: Sealed, B: Sealed> Sealed for super::Either<A, B> {}
262}
263
264/// A tag is the runtime representation of a message type.
265#[derive(Copy, Clone, Debug, PartialEq, Eq)]
266pub enum Tag {
267    Data,
268    WindowUpdate,
269    Ping,
270    GoAway,
271}
272
273/// The protocol version a message corresponds to.
274#[derive(Copy, Clone, Debug, PartialEq, Eq)]
275pub struct Version(u8);
276
277/// The message length.
278#[derive(Copy, Clone, Debug, PartialEq, Eq)]
279pub struct Len(u32);
280
281impl Len {
282    pub fn val(self) -> u32 {
283        self.0
284    }
285}
286
287pub const CONNECTION_ID: StreamId = StreamId(0);
288
289/// The ID of a stream.
290///
291/// The value 0 denotes no particular stream but the whole session.
292#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
293pub struct StreamId(u32);
294
295impl StreamId {
296    pub(crate) fn new(val: u32) -> Self {
297        StreamId(val)
298    }
299
300    // TODO: remove and use is multiple_of() on the next minor release.
301    #[allow(clippy::manual_is_multiple_of)]
302    pub fn is_server(self) -> bool {
303        self.0 % 2 == 0
304    }
305
306    pub fn is_client(self) -> bool {
307        !self.is_server()
308    }
309
310    pub fn is_session(self) -> bool {
311        self == CONNECTION_ID
312    }
313
314    pub fn val(self) -> u32 {
315        self.0
316    }
317}
318
319impl fmt::Display for StreamId {
320    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
321        write!(f, "{}", self.0)
322    }
323}
324
325impl nohash_hasher::IsEnabled for StreamId {}
326
327/// Possible flags set on a message.
328#[derive(Copy, Clone, Debug, PartialEq, Eq)]
329pub struct Flags(u16);
330
331impl Flags {
332    pub fn contains(self, other: Flags) -> bool {
333        self.0 & other.0 == other.0
334    }
335
336    pub fn val(self) -> u16 {
337        self.0
338    }
339}
340
341/// Indicates the start of a new stream.
342pub const SYN: Flags = Flags(1);
343
344/// Acknowledges the start of a new stream.
345pub const ACK: Flags = Flags(2);
346
347/// Indicates the half-closing of a stream.
348pub const FIN: Flags = Flags(4);
349
350/// Indicates an immediate stream reset.
351pub const RST: Flags = Flags(8);
352
353/// The serialised header size in bytes.
354pub const HEADER_SIZE: usize = 12;
355
356/// Encode a [`Header`] value.
357pub fn encode<T>(hdr: &Header<T>) -> [u8; HEADER_SIZE] {
358    let mut buf = [0; HEADER_SIZE];
359    buf[0] = hdr.version.0;
360    buf[1] = hdr.tag as u8;
361    buf[2..4].copy_from_slice(&hdr.flags.0.to_be_bytes());
362    buf[4..8].copy_from_slice(&hdr.stream_id.0.to_be_bytes());
363    buf[8..HEADER_SIZE].copy_from_slice(&hdr.length.0.to_be_bytes());
364    buf
365}
366
367/// Decode a [`Header`] value.
368pub fn decode(buf: &[u8; HEADER_SIZE]) -> Result<Header<()>, HeaderDecodeError> {
369    if buf[0] != 0 {
370        return Err(HeaderDecodeError::Version(buf[0]));
371    }
372
373    let hdr = Header {
374        version: Version(buf[0]),
375        tag: match buf[1] {
376            0 => Tag::Data,
377            1 => Tag::WindowUpdate,
378            2 => Tag::Ping,
379            3 => Tag::GoAway,
380            t => return Err(HeaderDecodeError::Type(t)),
381        },
382        flags: Flags(u16::from_be_bytes([buf[2], buf[3]])),
383        stream_id: StreamId(u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]])),
384        length: Len(u32::from_be_bytes([buf[8], buf[9], buf[10], buf[11]])),
385        _marker: std::marker::PhantomData,
386    };
387
388    Ok(hdr)
389}
390
391/// Possible errors while decoding a message frame header.
392#[non_exhaustive]
393#[derive(Debug)]
394pub enum HeaderDecodeError {
395    /// Unknown version.
396    Version(u8),
397    /// An unknown frame type.
398    Type(u8),
399}
400
401impl std::fmt::Display for HeaderDecodeError {
402    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
403        match self {
404            HeaderDecodeError::Version(v) => write!(f, "unknown version: {v}"),
405            HeaderDecodeError::Type(t) => write!(f, "unknown frame type: {t}"),
406        }
407    }
408}
409
410impl std::error::Error for HeaderDecodeError {}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415    use quickcheck::{Arbitrary, Gen, QuickCheck};
416
417    impl Arbitrary for Header<()> {
418        fn arbitrary(g: &mut Gen) -> Self {
419            let tag = *g
420                .choose(&[Tag::Data, Tag::WindowUpdate, Tag::Ping, Tag::GoAway])
421                .unwrap();
422
423            Header {
424                version: Version(0),
425                tag,
426                flags: Flags(Arbitrary::arbitrary(g)),
427                stream_id: StreamId(Arbitrary::arbitrary(g)),
428                length: Len(Arbitrary::arbitrary(g)),
429                _marker: std::marker::PhantomData,
430            }
431        }
432    }
433
434    #[test]
435    fn encode_decode_identity() {
436        fn property(hdr: Header<()>) -> bool {
437            match decode(&encode(&hdr)) {
438                Ok(x) => x == hdr,
439                Err(e) => {
440                    eprintln!("decode error: {e}");
441                    false
442                }
443            }
444        }
445        QuickCheck::new()
446            .tests(10_000)
447            .quickcheck(property as fn(Header<()>) -> bool)
448    }
449}