rust_http2/
frame.rs

1use bitflags::bitflags;
2use hpack::Encoder;
3
4pub enum FrameType {
5    DATA = 0x00,
6    HEADERS = 0x01,
7    PRIORITY = 0x02,
8    RSTSTREAM = 0x03,
9    SETTINGS = 0x04,
10    PUSHPROMISE = 0x05,
11    PING = 0x06,
12    GOAWAY = 0x07,
13    WINDOWUPDATE = 0x08,
14    CONTINUATION = 0x09,
15}
16
17impl TryFrom<u8> for FrameType {
18    type Error = &'static str;
19
20    fn try_from(value: u8) -> Result<Self, Self::Error> {
21        match value {
22            0x00 => Ok(FrameType::DATA),
23            0x01 => Ok(FrameType::HEADERS),
24            0x02 => Ok(FrameType::PRIORITY),
25            0x03 => Ok(FrameType::RSTSTREAM),
26            0x04 => Ok(FrameType::SETTINGS),
27            0x05 => Ok(FrameType::PUSHPROMISE),
28            0x06 => Ok(FrameType::PING),
29            0x07 => Ok(FrameType::GOAWAY),
30            0x08 => Ok(FrameType::WINDOWUPDATE),
31            0x09 => Ok(FrameType::CONTINUATION),
32            _ => Err("Invalid frame type"),
33        }
34    }
35}
36
37impl TryInto<u8> for FrameType {
38    type Error = &'static str;
39
40    fn try_into(self) -> Result<u8, Self::Error> {
41        match self {
42            FrameType::DATA => Ok(0x00),
43            FrameType::HEADERS => Ok(0x01),
44            FrameType::PRIORITY => Ok(0x02),
45            FrameType::RSTSTREAM => Ok(0x03),
46            FrameType::SETTINGS => Ok(0x04),
47            FrameType::PUSHPROMISE => Ok(0x05),
48            FrameType::PING => Ok(0x06),
49            FrameType::GOAWAY => Ok(0x07),
50            FrameType::WINDOWUPDATE => Ok(0x08),
51            FrameType::CONTINUATION => Ok(0x09),
52        }
53    }
54}
55
56bitflags! {
57    #[derive(PartialEq)]
58    pub struct FrameFlags: u8 {
59        const ENDSTREAM = 0x01;
60        const ENDHEADERS = 0x04;
61        const PADDED = 0x08;
62        const PRIORITY = 0x20;
63        const NONE = 0x00;
64    }
65}
66
67impl std::fmt::Display for FrameType {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        match self {
70            FrameType::DATA => write!(f, "DATA"),
71            FrameType::HEADERS => write!(f, "HEADERS"),
72            FrameType::PRIORITY => write!(f, "PRIORITY"),
73            FrameType::RSTSTREAM => write!(f, "RSTSTREAM"),
74            FrameType::SETTINGS => write!(f, "SETTINGS"),
75            FrameType::PUSHPROMISE => write!(f, "PUSHPROMISE"),
76            FrameType::PING => write!(f, "PING"),
77            FrameType::GOAWAY => write!(f, "GOAWAY"),
78            FrameType::WINDOWUPDATE => write!(f, "WINDOWUPDATE"),
79            FrameType::CONTINUATION => write!(f, "CONTINUATION"),
80        }
81    }
82}
83
84// impl TryFrom<u8> for FrameFlags {
85//     type Error = &'static str;
86
87//     fn try_from(value: u8) -> Result<Self, Self::Error> {
88//         match value {
89//             0x01 => Ok(FrameFlags::ENDSTREAM),
90//             0x04 => Ok(FrameFlags::ENDHEADERS),
91//             0x08 => Ok(FrameFlags::PADDED),
92//             0x20 => Ok(FrameFlags::PRIORITY),
93//             _ => Ok(FrameFlags::NONE),
94//         }
95//     }
96// }
97
98// impl TryInto<u8> for FrameFlags {
99//     type Error = &'static str;
100
101//     fn try_into(self) -> Result<u8, Self::Error> {
102//         match self {
103//             FrameFlags::ENDSTREAM => Ok(0x01),
104//             FrameFlags::ENDHEADERS => Ok(0x04),
105//             FrameFlags::PADDED => Ok(0x08),
106//             FrameFlags::PRIORITY => Ok(0x20),
107//             FrameFlags::NONE => Ok(0x00),
108//         }
109//     }
110// }
111
112pub struct FrameWriter {
113    pub frame_type: FrameType,
114    pub flags: FrameFlags,
115    pub stream_id: u32,
116    pub payload_len: u32,
117    pub payload: Vec<u8>,
118}
119
120impl FrameWriter {
121    pub fn new(frame_type: FrameType, flags: FrameFlags, stream_id: u32, payload: Vec<u8>) -> Self {
122        FrameWriter {
123            frame_type,
124            flags,
125            stream_id,
126            payload_len: payload.len() as u32,
127            payload,
128        }
129    }
130
131    pub fn new_frame_writer(
132        frame_type: FrameType,
133        flags: FrameFlags,
134        stream_id: u32
135    ) -> Result<Self, Box<dyn std::error::Error>> {
136        let mut headers = FrameHeaders::new();
137        headers.add_header(":method".to_string(), "GET".to_string());
138        headers.add_header(":path".to_string(), "/".to_string());
139        headers.add_header(":scheme".to_string(), "https".to_string());
140        headers.add_header(":authority".to_string(), "example.com".to_string());
141
142        let payload = headers.serialize();
143
144        // println!("[*] Serialized headers: {:?}", payload);
145
146        Ok(FrameWriter {
147            frame_type,
148            flags,
149            stream_id,
150            payload_len: payload.len() as u32,
151            payload: payload,
152        })
153    }
154
155    // pub fn payload_from_frame_type(
156    //     frame_type: FrameType
157    // ) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
158    //     let payload = match frame_type {
159    //         FrameType::HEADERS => {
160    //             let mut headers = FrameHeaders::new();
161    //             headers.add_header(":method".to_string(), "GET".to_string());
162    //             headers.add_header(":path".to_string(), "/".to_string());
163    //             headers.add_header(":scheme".to_string(), "https".to_string());
164    //             headers.add_header(":authority".to_string(), "example.com".to_string());
165
166    //             headers.serialize()
167    //         }
168
169    //         _ => Vec::new(),
170    //     };
171
172    //     Ok(payload)
173    // }
174
175    pub fn serialize(self) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
176        let mut serialized = Vec::with_capacity(9 + (self.payload_len as usize));
177
178        serialized.extend_from_slice(&self.payload_len.to_be_bytes()[1..]);
179        serialized.push(FrameType::try_into(self.frame_type)?);
180        serialized.push(self.flags.bits());
181        serialized.extend_from_slice(
182            &[
183                (self.stream_id >> 24) as u8,
184                (self.stream_id >> 16) as u8,
185                (self.stream_id >> 8) as u8,
186                self.stream_id as u8,
187            ]
188        );
189        // serialized.extend_from_slice(&self.stream_id.to_be_bytes()[1..]);
190        serialized.extend_from_slice(&self.payload);
191
192        Ok(serialized)
193    }
194}
195
196pub struct FrameReader {
197    pub frame_type: FrameType,
198    pub flags: FrameFlags,
199    pub stream_id: u32,
200    pub payload_len: u32,
201    pub payload: Vec<u8>,
202}
203
204impl FrameReader {
205    // pub async fn parse_frame(buf: &[u8]) -> Result<FrameReader, Box<dyn std::error::Error>> {
206    //     FrameReader::deserialize(buf).await
207    // }
208
209    pub async fn read_frame(buf: &[u8]) -> Result<Self, Box<dyn std::error::Error>> {
210        if buf.len() < 9 {
211            return Err("Buffer too small".into());
212        }
213
214        let payload_len = ((buf[0] as u32) << 16) | ((buf[1] as u32) << 8) | (buf[2] as u32);
215
216        let frame_type = FrameType::try_from(buf[3])?;
217        let flags = FrameFlags::from_bits_truncate(buf[4]);
218        let stream_id = u32::from_be_bytes([buf[5], buf[6], buf[7], buf[8]]) & 0x7fff_ffff;
219        let payload = buf[9..9 + (payload_len as usize)].to_vec();
220
221        Ok(FrameReader {
222            frame_type,
223            flags,
224            stream_id,
225            payload_len,
226            payload,
227        })
228    }
229
230    pub async fn parse_settings_payload(&self) -> Vec<Setting> {
231        let payload = &self.payload;
232
233        if self.payload_len == 0 {
234            return Vec::new();
235        }
236
237        let mut i = 0;
238        let mut settings = Vec::new();
239        while i < payload.len() {
240            if i + 6 > payload.len() {
241                break;
242            }
243
244            let id = u16::from_be_bytes([payload[i], payload[i + 1]]);
245            let value = u32::from_be_bytes([
246                payload[i + 2],
247                payload[i + 3],
248                payload[i + 4],
249                payload[i + 5],
250            ]);
251
252            settings.push(Setting::new(id, value));
253
254            i += 6;
255        }
256
257        settings
258    }
259}
260
261pub struct HeadersBuilder {
262    headers: Vec<(String, String)>,
263}
264
265impl HeadersBuilder {
266    pub fn new() -> Self {
267        HeadersBuilder { headers: Vec::new() }
268    }
269
270    pub fn add_header(mut self, name: String, value: String) -> Self {
271        self.headers.push((name, value));
272        self
273    }
274
275    pub fn build(
276        self,
277        stream_id: u32
278        // flags: FrameFlags
279    ) -> Result<FrameWriter, Box<dyn std::error::Error>> {
280        let headers = FrameHeaders::from_pairs(self.headers);
281        let payload = headers.serialize();
282
283        Ok(
284            FrameWriter::new(
285                FrameType::HEADERS,
286                FrameFlags::ENDHEADERS | FrameFlags::ENDSTREAM,
287                stream_id,
288                payload
289            )
290        )
291    }
292}
293
294pub struct FrameHeaders {
295    headers: Vec<(String, String)>,
296}
297
298impl FrameHeaders {
299    pub fn new() -> Self {
300        FrameHeaders { headers: Vec::new() }
301    }
302
303    pub fn from_pairs(headers: Vec<(String, String)>) -> Self {
304        FrameHeaders { headers }
305    }
306
307    pub fn add_header(&mut self, name: String, value: String) {
308        self.headers.push((name, value));
309    }
310
311    pub fn serialize(&self) -> Vec<u8> {
312        let mut encoder = Encoder::new();
313        let headers: Vec<(&[u8], &[u8])> = self.headers
314            .iter()
315            .map(|(name, value)| (name.as_bytes(), value.as_bytes()))
316            .collect();
317        encoder.encode(headers)
318    }
319}
320
321pub struct Setting {
322    pub id: u16,
323    pub value: u32,
324}
325
326impl Setting {
327    pub fn new(id: u16, value: u32) -> Self {
328        Setting { id, value }
329    }
330}
331
332pub struct Settings {
333    pub settings: Vec<Setting>,
334}
335
336impl Default for Settings {
337    fn default() -> Self {
338        Settings {
339            settings: vec![
340                Setting::new(0x03, 100), // Max Concurrent Streams
341                Setting::new(0x04, 65535), // Initial Window Size
342                Setting::new(0x05, 16384), // Max Frame Size
343                Setting::new(0x06, 0) // Max Header List Size
344            ],
345        }
346    }
347}
348
349impl Settings {
350    pub async fn new() -> Self {
351        Settings {
352            settings: Vec::new(),
353        }
354    }
355
356    pub async fn write_settings_ack() -> Vec<u8> {
357        vec![0, 0, 0, 4, 1, 0, 0, 0, 0]
358    }
359
360    pub async fn from_pairs(pairs: Vec<(u16, u32)>) -> Self {
361        let mut settings = Settings::new().await;
362
363        for (id, value) in pairs {
364            settings.settings.push(Setting::new(id, value));
365        }
366
367        settings
368    }
369
370    pub async fn serialize(self) -> Vec<u8> {
371        let mut payload = Vec::new();
372
373        for setting in self.settings {
374            payload.extend_from_slice(&setting.id.to_be_bytes());
375            payload.extend_from_slice(&setting.value.to_be_bytes());
376        }
377
378        payload
379    }
380}
381
382pub struct SettingsBuilder {
383    settings: Vec<(u16, u32)>,
384}
385impl SettingsBuilder {
386    pub fn new() -> Self {
387        SettingsBuilder { settings: Vec::new() }
388    }
389
390    pub fn add_setting(mut self, id: u16, value: u32) -> Self {
391        self.settings.push((id, value));
392        self
393    }
394
395    pub async fn build(self) -> FrameWriter {
396        let settings = Settings::from_pairs(self.settings).await;
397        let payload = settings.serialize().await;
398
399        FrameWriter::new(FrameType::SETTINGS, FrameFlags::NONE, 0, payload)
400    }
401}