h3_datagram/
datagram.rs

1use bytes::Buf;
2use h3::{
3    error::{internal_error::InternalConnectionError, Code},
4    proto::varint::VarInt,
5    quic::StreamId,
6};
7
8/// HTTP datagram frames
9/// See: <https://www.rfc-editor.org/rfc/rfc9297#section-2.1>
10#[derive(Debug, Clone)]
11pub struct Datagram<B> {
12    /// Stream id divided by 4
13    stream_id: StreamId,
14    /// The data contained in the datagram
15    payload: B,
16}
17
18impl<B> Datagram<B>
19where
20    B: Buf,
21{
22    /// Creates a new datagram frame
23    pub fn new(stream_id: StreamId, payload: B) -> Self {
24        assert!(
25            stream_id.into_inner() % 4 == 0,
26            "StreamId is not divisible by 4"
27        );
28        // StreamId will be divided by 4 when encoding the Datagram header
29        Self { stream_id, payload }
30    }
31
32    /// Decodes a datagram frame from the QUIC datagram
33    pub fn decode(mut buf: B) -> Result<Self, InternalConnectionError> {
34        let q_stream_id = VarInt::decode(&mut buf).map_err(|_| {
35            InternalConnectionError::new(Code::H3_DATAGRAM_ERROR, "invalid stream id".to_string())
36        })?;
37
38        //= https://www.rfc-editor.org/rfc/rfc9297#section-2.1
39        // Quarter Stream ID: A variable-length integer that contains the value of the client-initiated bidirectional
40        // stream that this datagram is associated with divided by four (the division by four stems
41        // from the fact that HTTP requests are sent on client-initiated bidirectional streams,
42        // which have stream IDs that are divisible by four). The largest legal QUIC stream ID
43        // value is 262-1, so the largest legal value of the Quarter Stream ID field is 260-1.
44        // Receipt of an HTTP/3 Datagram that includes a larger value MUST be treated as an HTTP/3
45        // connection error of type H3_DATAGRAM_ERROR (0x33).
46        let stream_id = StreamId::try_from(u64::from(q_stream_id) * 4).map_err(|_| {
47            InternalConnectionError::new(Code::H3_DATAGRAM_ERROR, "invalid stream id".to_string())
48        })?;
49
50        let payload = buf;
51
52        Ok(Self { stream_id, payload })
53    }
54
55    #[inline]
56    /// Returns the associated stream id of the datagram
57    pub fn stream_id(&self) -> StreamId {
58        self.stream_id
59    }
60
61    #[inline]
62    /// Returns the datagram payload
63    pub fn payload(&self) -> &B {
64        &self.payload
65    }
66
67    /// Encode the datagram to wire format
68    pub fn encode(self) -> EncodedDatagram<B> {
69        let mut buffer = [0; VarInt::MAX_SIZE];
70        let varint = VarInt::from(self.stream_id) / 4;
71        varint.encode(&mut buffer.as_mut_slice());
72        EncodedDatagram {
73            stream_id: [0; VarInt::MAX_SIZE],
74            len: varint.size(),
75            pos: 0,
76            payload: self.payload,
77        }
78    }
79
80    /// Returns the datagram payload
81    pub fn into_payload(self) -> B {
82        self.payload
83    }
84}
85
86#[derive(Debug)]
87pub struct EncodedDatagram<B: Buf> {
88    /// Encoded datagram stream ID as Varint
89    stream_id: [u8; VarInt::MAX_SIZE],
90    /// Length of the varint
91    len: usize,
92    /// Position of the stream_id buffer
93    pos: usize,
94    /// Datagram Payload
95    payload: B,
96}
97
98/// Implementation of [`Buf`] for [`Datagram`]
99impl<B> Buf for EncodedDatagram<B>
100where
101    B: Buf,
102{
103    fn remaining(&self) -> usize {
104        self.len - self.pos + self.payload.remaining()
105    }
106
107    fn chunk(&self) -> &[u8] {
108        if self.len - self.pos > 0 {
109            return &self.stream_id[self.pos..self.len];
110        } else {
111            self.payload.chunk()
112        }
113    }
114
115    fn advance(&mut self, mut cnt: usize) {
116        let remaining_header = self.len - self.pos;
117        if remaining_header > 0 {
118            let advanced = usize::min(cnt, remaining_header);
119            self.pos += advanced;
120            cnt -= advanced;
121        }
122        self.payload.advance(cnt);
123    }
124}