1use std::collections::{HashMap, HashSet};
11
12use bytes::{Bytes, BytesMut};
13use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
14
15pub const STREAM_INIT: u32 = 0; pub const STREAM_CLIENT_SERVER: u32 = 1; pub const STREAM_SERVER_CLIENT: u32 = 3; const FRAME_DATA: u8 = 0x00;
28const FRAME_HEADERS: u8 = 0x01;
29const FRAME_SETTINGS: u8 = 0x04;
30const FRAME_WINDOW_UPDATE: u8 = 0x08;
31
32const FLAG_END_HEADERS: u8 = 0x04;
33const FLAG_SETTINGS_ACK: u8 = 0x01;
34
35const SETTINGS_MAX_CONCURRENT_STREAMS: u16 = 0x03;
38const SETTINGS_INITIAL_WINDOW_SIZE: u16 = 0x04;
39
40pub const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
43
44pub struct H2Framer<S> {
48 stream: S,
49 client_server_buf: BytesMut,
51 server_client_buf: BytesMut,
53 stream_bufs: HashMap<u32, BytesMut>,
55 locally_open_streams: HashSet<u32>,
57 client_server_open: bool,
59 server_client_open: bool,
60}
61
62#[derive(Debug, Clone)]
63pub struct DataFrame {
64 pub stream_id: u32,
65 pub flags: u8,
66 pub payload: Bytes,
67}
68
69impl<S: AsyncRead + AsyncWrite + Unpin> H2Framer<S> {
70 pub async fn connect(mut stream: S) -> Result<Self, H2Error> {
72 stream.write_all(H2_PREFACE).await?;
74
75 let settings = build_settings_frame(&[
78 (SETTINGS_MAX_CONCURRENT_STREAMS, 100),
79 (SETTINGS_INITIAL_WINDOW_SIZE, 1_048_576),
80 ]);
81 stream.write_all(&settings).await?;
82
83 let wupdate = build_window_update_frame(STREAM_INIT, 983_041);
87 stream.write_all(&wupdate).await?;
88 stream.flush().await?;
89
90 let mut framer = Self {
91 stream,
92 client_server_buf: BytesMut::new(),
93 server_client_buf: BytesMut::new(),
94 stream_bufs: HashMap::new(),
95 locally_open_streams: HashSet::new(),
96 client_server_open: false,
97 server_client_open: false,
98 };
99
100 framer.read_until_settings_ack_needed().await?;
102
103 Ok(framer)
104 }
105
106 async fn read_until_settings_ack_needed(&mut self) -> Result<(), H2Error> {
107 loop {
108 let frame = self.read_raw_frame().await?;
109 tracing::trace!(
110 "h2: handshake frame type={} flags=0x{:02x} stream={} len={}",
111 frame_type_name(frame.frame_type),
112 frame.flags,
113 frame.stream_id,
114 frame.payload.len()
115 );
116 match frame.frame_type {
117 FRAME_SETTINGS if frame.flags & FLAG_SETTINGS_ACK == 0 => {
118 let ack = build_settings_ack();
120 self.stream.write_all(&ack).await?;
121 self.stream.flush().await?;
122 return Ok(());
123 }
124 FRAME_SETTINGS => {
125 }
127 FRAME_DATA => {
128 match frame.stream_id {
130 STREAM_CLIENT_SERVER => {
131 self.client_server_buf.extend_from_slice(&frame.payload)
132 }
133 STREAM_SERVER_CLIENT => {
134 self.server_client_buf.extend_from_slice(&frame.payload)
135 }
136 _ => {}
137 }
138 }
139 _ => {}
140 }
141 }
142 }
143
144 async fn read_raw_frame(&mut self) -> Result<RawFrame, H2Error> {
146 let mut header = [0u8; 9];
147 self.stream.read_exact(&mut header).await?;
148
149 let length =
150 ((header[0] as usize) << 16) | ((header[1] as usize) << 8) | (header[2] as usize);
151 let frame_type = header[3];
152 let flags = header[4];
153 let stream_id = u32::from_be_bytes([header[5] & 0x7F, header[6], header[7], header[8]]);
154
155 let mut payload = vec![0u8; length];
156 if length > 0 {
157 self.stream.read_exact(&mut payload).await?;
158 }
159
160 Ok(RawFrame {
161 frame_type,
162 flags,
163 stream_id,
164 payload,
165 })
166 }
167
168 pub async fn read_server_client(&mut self, n: usize) -> Result<Bytes, H2Error> {
171 self.read_stream(STREAM_SERVER_CLIENT, n).await
172 }
173
174 pub async fn read_client_server(&mut self, n: usize) -> Result<Bytes, H2Error> {
176 self.read_stream(STREAM_CLIENT_SERVER, n).await
177 }
178
179 pub async fn read_stream(&mut self, stream_id: u32, n: usize) -> Result<Bytes, H2Error> {
181 while self.stream_buffer_len(stream_id) < n {
182 let frame = self.read_raw_frame().await?;
183 self.dispatch_frame(frame).await?;
184 }
185 self.take_stream_bytes(stream_id, n)
186 }
187
188 async fn dispatch_frame(&mut self, frame: RawFrame) -> Result<(), H2Error> {
189 tracing::trace!(
190 "h2: dispatch frame type={} flags=0x{:02x} stream={} len={}",
191 frame_type_name(frame.frame_type),
192 frame.flags,
193 frame.stream_id,
194 frame.payload.len()
195 );
196 match frame.frame_type {
197 FRAME_DATA => match frame.stream_id {
198 STREAM_CLIENT_SERVER => self.client_server_buf.extend_from_slice(&frame.payload),
199 STREAM_SERVER_CLIENT => self.server_client_buf.extend_from_slice(&frame.payload),
200 other => self
201 .stream_bufs
202 .entry(other)
203 .or_default()
204 .extend_from_slice(&frame.payload),
205 },
206 FRAME_SETTINGS if frame.flags & FLAG_SETTINGS_ACK == 0 => {
207 let ack = build_settings_ack();
208 self.stream.write_all(&ack).await?;
209 self.stream.flush().await?;
210 }
211 _ => {}
212 }
213 if frame.frame_type == FRAME_DATA && frame.stream_id % 2 == 0 && !frame.payload.is_empty() {
214 let conn_window = build_window_update_frame(STREAM_INIT, frame.payload.len() as u32);
215 let stream_window =
216 build_window_update_frame(frame.stream_id, frame.payload.len() as u32);
217 self.stream.write_all(&conn_window).await?;
218 self.stream.write_all(&stream_window).await?;
219 self.stream.flush().await?;
220 }
221 Ok(())
222 }
223
224 pub async fn read_next_data_frame(&mut self) -> Result<DataFrame, H2Error> {
226 loop {
227 let frame = self.read_raw_frame().await?;
228 tracing::trace!(
229 "h2: next data frame type={} flags=0x{:02x} stream={} len={}",
230 frame_type_name(frame.frame_type),
231 frame.flags,
232 frame.stream_id,
233 frame.payload.len()
234 );
235 match frame.frame_type {
236 FRAME_DATA => {
237 if frame.stream_id % 2 == 0 && !frame.payload.is_empty() {
238 let conn_window =
239 build_window_update_frame(STREAM_INIT, frame.payload.len() as u32);
240 let stream_window =
241 build_window_update_frame(frame.stream_id, frame.payload.len() as u32);
242 self.stream.write_all(&conn_window).await?;
243 self.stream.write_all(&stream_window).await?;
244 self.stream.flush().await?;
245 }
246 return Ok(DataFrame {
247 stream_id: frame.stream_id,
248 flags: frame.flags,
249 payload: Bytes::from(frame.payload),
250 });
251 }
252 FRAME_SETTINGS if frame.flags & FLAG_SETTINGS_ACK == 0 => {
253 let ack = build_settings_ack();
254 self.stream.write_all(&ack).await?;
255 self.stream.flush().await?;
256 }
257 _ => {}
258 }
259 }
260 }
261
262 pub async fn write_client_server(&mut self, data: &[u8]) -> Result<(), H2Error> {
264 self.write_stream(STREAM_CLIENT_SERVER, data).await
265 }
266
267 pub async fn write_server_client(&mut self, data: &[u8]) -> Result<(), H2Error> {
269 self.write_stream(STREAM_SERVER_CLIENT, data).await
270 }
271
272 pub async fn write_stream(&mut self, stream_id: u32, data: &[u8]) -> Result<(), H2Error> {
274 self.open_stream(stream_id).await?;
275 let data_frame = build_data_frame(stream_id, data);
276 self.stream.write_all(&data_frame).await?;
277 self.stream.flush().await?;
278 Ok(())
279 }
280
281 pub async fn open_client_server(&mut self) -> Result<(), H2Error> {
283 self.open_stream(STREAM_CLIENT_SERVER).await
284 }
285
286 pub async fn open_server_client(&mut self) -> Result<(), H2Error> {
288 self.open_stream(STREAM_SERVER_CLIENT).await
289 }
290
291 pub async fn open_stream(&mut self, stream_id: u32) -> Result<(), H2Error> {
293 let already_open = match stream_id {
294 STREAM_CLIENT_SERVER => self.client_server_open,
295 STREAM_SERVER_CLIENT => self.server_client_open,
296 _ => self.locally_open_streams.contains(&stream_id),
297 };
298 if !already_open {
299 let headers = build_headers_frame(stream_id);
300 self.stream.write_all(&headers).await?;
301 self.stream.flush().await?;
302 match stream_id {
303 STREAM_CLIENT_SERVER => self.client_server_open = true,
304 STREAM_SERVER_CLIENT => self.server_client_open = true,
305 _ => {
306 self.locally_open_streams.insert(stream_id);
307 self.stream_bufs.entry(stream_id).or_default();
308 }
309 }
310 }
311 Ok(())
312 }
313
314 pub async fn poll_frames(&mut self) -> Result<(), H2Error> {
316 Ok(())
320 }
321
322 fn stream_buffer_len(&self, stream_id: u32) -> usize {
323 match stream_id {
324 STREAM_CLIENT_SERVER => self.client_server_buf.len(),
325 STREAM_SERVER_CLIENT => self.server_client_buf.len(),
326 _ => self.stream_bufs.get(&stream_id).map_or(0, BytesMut::len),
327 }
328 }
329
330 fn take_stream_bytes(&mut self, stream_id: u32, n: usize) -> Result<Bytes, H2Error> {
331 match stream_id {
332 STREAM_CLIENT_SERVER => Ok(self.client_server_buf.split_to(n).freeze()),
333 STREAM_SERVER_CLIENT => Ok(self.server_client_buf.split_to(n).freeze()),
334 _ => self
335 .stream_bufs
336 .get_mut(&stream_id)
337 .map(|buf| buf.split_to(n).freeze())
338 .ok_or_else(|| H2Error::Protocol(format!("stream {stream_id} not open"))),
339 }
340 }
341}
342
343fn frame_type_name(frame_type: u8) -> &'static str {
344 match frame_type {
345 FRAME_DATA => "DATA",
346 FRAME_HEADERS => "HEADERS",
347 FRAME_SETTINGS => "SETTINGS",
348 FRAME_WINDOW_UPDATE => "WINDOW_UPDATE",
349 _ => "OTHER",
350 }
351}
352
353struct RawFrame {
356 frame_type: u8,
357 flags: u8,
358 stream_id: u32,
359 payload: Vec<u8>,
360}
361
362fn build_frame(frame_type: u8, flags: u8, stream_id: u32, payload: &[u8]) -> Vec<u8> {
365 let len = payload.len();
366 let mut out = Vec::with_capacity(9 + len);
367 out.push(((len >> 16) & 0xFF) as u8);
368 out.push(((len >> 8) & 0xFF) as u8);
369 out.push((len & 0xFF) as u8);
370 out.push(frame_type);
371 out.push(flags);
372 out.extend_from_slice(&(stream_id & 0x7FFFFFFF).to_be_bytes());
373 out.extend_from_slice(payload);
374 out
375}
376
377fn build_settings_frame(settings: &[(u16, u32)]) -> Vec<u8> {
378 let mut payload = Vec::new();
379 for (id, val) in settings {
380 payload.extend_from_slice(&id.to_be_bytes());
381 payload.extend_from_slice(&val.to_be_bytes());
382 }
383 build_frame(FRAME_SETTINGS, 0, STREAM_INIT, &payload)
384}
385
386fn build_settings_ack() -> Vec<u8> {
387 build_frame(FRAME_SETTINGS, FLAG_SETTINGS_ACK, STREAM_INIT, &[])
388}
389
390fn build_window_update_frame(stream_id: u32, increment: u32) -> Vec<u8> {
391 build_frame(
392 FRAME_WINDOW_UPDATE,
393 0,
394 stream_id,
395 &(increment & 0x7FFFFFFF).to_be_bytes(),
396 )
397}
398
399fn build_headers_frame(stream_id: u32) -> Vec<u8> {
400 build_frame(FRAME_HEADERS, FLAG_END_HEADERS, stream_id, &[])
402}
403
404fn build_data_frame(stream_id: u32, data: &[u8]) -> Vec<u8> {
405 build_frame(FRAME_DATA, 0, stream_id, data)
406}
407
408#[derive(Debug, thiserror::Error)]
411pub enum H2Error {
412 #[error("IO error: {0}")]
413 Io(#[from] std::io::Error),
414 #[error("H2 protocol error: {0}")]
415 Protocol(String),
416 #[error("GOAWAY received")]
417 GoAway,
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423
424 #[test]
425 fn test_settings_frame_layout() {
426 let frame = build_settings_frame(&[
427 (SETTINGS_MAX_CONCURRENT_STREAMS, 100),
428 (SETTINGS_INITIAL_WINDOW_SIZE, 1_048_576),
429 ]);
430 assert_eq!(frame.len(), 9 + 12);
432 assert_eq!(frame[3], FRAME_SETTINGS); assert_eq!(frame[4], 0); }
435
436 #[test]
437 fn test_window_update_frame() {
438 let frame = build_window_update_frame(0, 983_041);
439 assert_eq!(frame.len(), 9 + 4);
440 assert_eq!(frame[3], FRAME_WINDOW_UPDATE);
441 }
442
443 #[test]
444 fn test_data_frame() {
445 let data = b"hello XPC";
446 let frame = build_data_frame(STREAM_CLIENT_SERVER, data);
447 assert_eq!(frame.len(), 9 + data.len());
448 assert_eq!(frame[3], FRAME_DATA);
449 assert_eq!(&frame[9..], data);
450 let sid = u32::from_be_bytes([frame[5] & 0x7F, frame[6], frame[7], frame[8]]);
452 assert_eq!(sid, STREAM_CLIENT_SERVER);
453 }
454
455 #[tokio::test]
456 async fn test_dispatch_frame_acknowledges_settings_immediately() {
457 let (client, mut server) = tokio::io::duplex(1024);
458 let mut framer = H2Framer {
459 stream: client,
460 client_server_buf: BytesMut::new(),
461 server_client_buf: BytesMut::new(),
462 stream_bufs: HashMap::new(),
463 locally_open_streams: HashSet::new(),
464 client_server_open: false,
465 server_client_open: false,
466 };
467
468 framer
469 .dispatch_frame(RawFrame {
470 frame_type: FRAME_SETTINGS,
471 flags: 0,
472 stream_id: STREAM_INIT,
473 payload: vec![],
474 })
475 .await
476 .unwrap();
477
478 let mut ack = [0u8; 9];
479 server.read_exact(&mut ack).await.unwrap();
480 assert_eq!(ack[3], FRAME_SETTINGS);
481 assert_eq!(ack[4], FLAG_SETTINGS_ACK);
482 }
483
484 #[tokio::test]
485 async fn test_open_stream_still_sends_headers_after_remote_data_buffered() {
486 let (client, mut server) = tokio::io::duplex(1024);
487 let mut framer = H2Framer {
488 stream: client,
489 client_server_buf: BytesMut::new(),
490 server_client_buf: BytesMut::new(),
491 stream_bufs: HashMap::new(),
492 locally_open_streams: HashSet::new(),
493 client_server_open: false,
494 server_client_open: false,
495 };
496
497 framer
498 .dispatch_frame(RawFrame {
499 frame_type: FRAME_DATA,
500 flags: 0,
501 stream_id: 4,
502 payload: vec![1, 2, 3],
503 })
504 .await
505 .unwrap();
506
507 framer.open_stream(4).await.unwrap();
508
509 let mut conn_window = [0u8; 13];
510 server.read_exact(&mut conn_window).await.unwrap();
511 assert_eq!(conn_window[3], FRAME_WINDOW_UPDATE);
512 assert_eq!(
513 u32::from_be_bytes([
514 conn_window[5] & 0x7F,
515 conn_window[6],
516 conn_window[7],
517 conn_window[8]
518 ]),
519 STREAM_INIT
520 );
521
522 let mut stream_window = [0u8; 13];
523 server.read_exact(&mut stream_window).await.unwrap();
524 assert_eq!(stream_window[3], FRAME_WINDOW_UPDATE);
525 assert_eq!(
526 u32::from_be_bytes([
527 stream_window[5] & 0x7F,
528 stream_window[6],
529 stream_window[7],
530 stream_window[8]
531 ]),
532 4
533 );
534
535 let mut headers = [0u8; 9];
536 server.read_exact(&mut headers).await.unwrap();
537 assert_eq!(headers[3], FRAME_HEADERS);
538 assert_eq!(headers[4], FLAG_END_HEADERS);
539 assert_eq!(
540 u32::from_be_bytes([headers[5] & 0x7F, headers[6], headers[7], headers[8]]),
541 4
542 );
543 }
544}