1use futures::future::Either;
12use std::fmt;
13
14#[derive(Clone, Debug, PartialEq, Eq)]
16pub struct Header<T> {
17 version: Version,
18 tag: Tag,
19 flags: Flags,
20 stream_id: StreamId,
21 length: Len,
22 _marker: std::marker::PhantomData<T>,
23}
24
25impl<T> fmt::Display for Header<T> {
26 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
27 write!(
28 f,
29 "(Header {:?} {} (len {}) (flags {:?}))",
30 self.tag,
31 self.stream_id,
32 self.length.val(),
33 self.flags.val()
34 )
35 }
36}
37
38impl<T> Header<T> {
39 pub fn tag(&self) -> Tag {
40 self.tag
41 }
42
43 pub fn flags(&self) -> Flags {
44 self.flags
45 }
46
47 pub fn stream_id(&self) -> StreamId {
48 self.stream_id
49 }
50
51 pub fn len(&self) -> Len {
52 self.length
53 }
54
55 #[cfg(test)]
56 pub fn set_len(&mut self, len: u32) {
57 self.length = Len(len)
58 }
59
60 fn cast<U>(self) -> Header<U> {
62 Header {
63 version: self.version,
64 tag: self.tag,
65 flags: self.flags,
66 stream_id: self.stream_id,
67 length: self.length,
68 _marker: std::marker::PhantomData,
69 }
70 }
71
72 pub(crate) fn right<U>(self) -> Header<Either<U, T>> {
74 self.cast()
75 }
76
77 pub(crate) fn left<U>(self) -> Header<Either<T, U>> {
79 self.cast()
80 }
81}
82
83impl<A: private::Sealed> From<Header<A>> for Header<()> {
84 fn from(h: Header<A>) -> Header<()> {
85 h.cast()
86 }
87}
88
89impl Header<()> {
90 pub(crate) fn into_data(self) -> Header<Data> {
91 debug_assert_eq!(self.tag, Tag::Data);
92 self.cast()
93 }
94
95 pub(crate) fn into_window_update(self) -> Header<WindowUpdate> {
96 debug_assert_eq!(self.tag, Tag::WindowUpdate);
97 self.cast()
98 }
99
100 pub(crate) fn into_ping(self) -> Header<Ping> {
101 debug_assert_eq!(self.tag, Tag::Ping);
102 self.cast()
103 }
104}
105
106impl<T: HasSyn> Header<T> {
107 pub fn syn(&mut self) {
109 self.flags.0 |= SYN.0
110 }
111}
112
113impl<T: HasAck> Header<T> {
114 pub fn ack(&mut self) {
116 self.flags.0 |= ACK.0
117 }
118}
119
120impl<T: HasFin> Header<T> {
121 pub fn fin(&mut self) {
123 self.flags.0 |= FIN.0
124 }
125}
126
127impl<T: HasRst> Header<T> {
128 pub fn rst(&mut self) {
130 self.flags.0 |= RST.0
131 }
132}
133
134impl Header<Data> {
135 pub fn data(id: StreamId, len: u32) -> Self {
137 Header {
138 version: Version(0),
139 tag: Tag::Data,
140 flags: Flags(0),
141 stream_id: id,
142 length: Len(len),
143 _marker: std::marker::PhantomData,
144 }
145 }
146}
147
148impl Header<WindowUpdate> {
149 pub fn window_update(id: StreamId, credit: u32) -> Self {
151 Header {
152 version: Version(0),
153 tag: Tag::WindowUpdate,
154 flags: Flags(0),
155 stream_id: id,
156 length: Len(credit),
157 _marker: std::marker::PhantomData,
158 }
159 }
160
161 pub fn credit(&self) -> u32 {
163 self.length.0
164 }
165}
166
167impl Header<Ping> {
168 pub fn ping(nonce: u32) -> Self {
170 Header {
171 version: Version(0),
172 tag: Tag::Ping,
173 flags: Flags(0),
174 stream_id: StreamId(0),
175 length: Len(nonce),
176 _marker: std::marker::PhantomData,
177 }
178 }
179
180 pub fn nonce(&self) -> u32 {
182 self.length.0
183 }
184}
185
186impl Header<GoAway> {
187 pub fn term() -> Self {
189 Self::go_away(0)
190 }
191
192 pub fn protocol_error() -> Self {
194 Self::go_away(1)
195 }
196
197 pub fn internal_error() -> Self {
199 Self::go_away(2)
200 }
201
202 fn go_away(code: u32) -> Self {
203 Header {
204 version: Version(0),
205 tag: Tag::GoAway,
206 flags: Flags(0),
207 stream_id: StreamId(0),
208 length: Len(code),
209 _marker: std::marker::PhantomData,
210 }
211 }
212}
213
214#[derive(Clone, Debug)]
216pub enum Data {}
217
218#[derive(Clone, Debug)]
220pub enum WindowUpdate {}
221
222#[derive(Clone, Debug)]
224pub enum Ping {}
225
226#[derive(Clone, Debug)]
228pub enum GoAway {}
229
230pub trait HasSyn: private::Sealed {}
232impl HasSyn for Data {}
233impl HasSyn for WindowUpdate {}
234impl HasSyn for Ping {}
235impl<A: HasSyn, B: HasSyn> HasSyn for Either<A, B> {}
236
237pub trait HasAck: private::Sealed {}
239impl HasAck for Data {}
240impl HasAck for WindowUpdate {}
241impl HasAck for Ping {}
242impl<A: HasAck, B: HasAck> HasAck for Either<A, B> {}
243
244pub trait HasFin: private::Sealed {}
246impl HasFin for Data {}
247impl HasFin for WindowUpdate {}
248
249pub trait HasRst: private::Sealed {}
251impl HasRst for Data {}
252impl HasRst for WindowUpdate {}
253
254pub(super) mod private {
255 pub trait Sealed {}
256
257 impl Sealed for super::Data {}
258 impl Sealed for super::WindowUpdate {}
259 impl Sealed for super::Ping {}
260 impl Sealed for super::GoAway {}
261 impl<A: Sealed, B: Sealed> Sealed for super::Either<A, B> {}
262}
263
264#[derive(Copy, Clone, Debug, PartialEq, Eq)]
266pub enum Tag {
267 Data,
268 WindowUpdate,
269 Ping,
270 GoAway,
271}
272
273#[derive(Copy, Clone, Debug, PartialEq, Eq)]
275pub struct Version(u8);
276
277#[derive(Copy, Clone, Debug, PartialEq, Eq)]
279pub struct Len(u32);
280
281impl Len {
282 pub fn val(self) -> u32 {
283 self.0
284 }
285}
286
287pub const CONNECTION_ID: StreamId = StreamId(0);
288
289#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
293pub struct StreamId(u32);
294
295impl StreamId {
296 pub(crate) fn new(val: u32) -> Self {
297 StreamId(val)
298 }
299
300 #[allow(clippy::manual_is_multiple_of)]
302 pub fn is_server(self) -> bool {
303 self.0 % 2 == 0
304 }
305
306 pub fn is_client(self) -> bool {
307 !self.is_server()
308 }
309
310 pub fn is_session(self) -> bool {
311 self == CONNECTION_ID
312 }
313
314 pub fn val(self) -> u32 {
315 self.0
316 }
317}
318
319impl fmt::Display for StreamId {
320 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
321 write!(f, "{}", self.0)
322 }
323}
324
325impl nohash_hasher::IsEnabled for StreamId {}
326
327#[derive(Copy, Clone, Debug, PartialEq, Eq)]
329pub struct Flags(u16);
330
331impl Flags {
332 pub fn contains(self, other: Flags) -> bool {
333 self.0 & other.0 == other.0
334 }
335
336 pub fn val(self) -> u16 {
337 self.0
338 }
339}
340
341pub const SYN: Flags = Flags(1);
343
344pub const ACK: Flags = Flags(2);
346
347pub const FIN: Flags = Flags(4);
349
350pub const RST: Flags = Flags(8);
352
353pub const HEADER_SIZE: usize = 12;
355
356pub fn encode<T>(hdr: &Header<T>) -> [u8; HEADER_SIZE] {
358 let mut buf = [0; HEADER_SIZE];
359 buf[0] = hdr.version.0;
360 buf[1] = hdr.tag as u8;
361 buf[2..4].copy_from_slice(&hdr.flags.0.to_be_bytes());
362 buf[4..8].copy_from_slice(&hdr.stream_id.0.to_be_bytes());
363 buf[8..HEADER_SIZE].copy_from_slice(&hdr.length.0.to_be_bytes());
364 buf
365}
366
367pub fn decode(buf: &[u8; HEADER_SIZE]) -> Result<Header<()>, HeaderDecodeError> {
369 if buf[0] != 0 {
370 return Err(HeaderDecodeError::Version(buf[0]));
371 }
372
373 let hdr = Header {
374 version: Version(buf[0]),
375 tag: match buf[1] {
376 0 => Tag::Data,
377 1 => Tag::WindowUpdate,
378 2 => Tag::Ping,
379 3 => Tag::GoAway,
380 t => return Err(HeaderDecodeError::Type(t)),
381 },
382 flags: Flags(u16::from_be_bytes([buf[2], buf[3]])),
383 stream_id: StreamId(u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]])),
384 length: Len(u32::from_be_bytes([buf[8], buf[9], buf[10], buf[11]])),
385 _marker: std::marker::PhantomData,
386 };
387
388 Ok(hdr)
389}
390
391#[non_exhaustive]
393#[derive(Debug)]
394pub enum HeaderDecodeError {
395 Version(u8),
397 Type(u8),
399}
400
401impl std::fmt::Display for HeaderDecodeError {
402 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
403 match self {
404 HeaderDecodeError::Version(v) => write!(f, "unknown version: {v}"),
405 HeaderDecodeError::Type(t) => write!(f, "unknown frame type: {t}"),
406 }
407 }
408}
409
410impl std::error::Error for HeaderDecodeError {}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415 use quickcheck::{Arbitrary, Gen, QuickCheck};
416
417 impl Arbitrary for Header<()> {
418 fn arbitrary(g: &mut Gen) -> Self {
419 let tag = *g
420 .choose(&[Tag::Data, Tag::WindowUpdate, Tag::Ping, Tag::GoAway])
421 .unwrap();
422
423 Header {
424 version: Version(0),
425 tag,
426 flags: Flags(Arbitrary::arbitrary(g)),
427 stream_id: StreamId(Arbitrary::arbitrary(g)),
428 length: Len(Arbitrary::arbitrary(g)),
429 _marker: std::marker::PhantomData,
430 }
431 }
432 }
433
434 #[test]
435 fn encode_decode_identity() {
436 fn property(hdr: Header<()>) -> bool {
437 match decode(&encode(&hdr)) {
438 Ok(x) => x == hdr,
439 Err(e) => {
440 eprintln!("decode error: {e}");
441 false
442 }
443 }
444 }
445 QuickCheck::new()
446 .tests(10_000)
447 .quickcheck(property as fn(Header<()>) -> bool)
448 }
449}