1use bytes::{Buf, Bytes, BytesMut};
2use http::{HeaderName, HeaderValue, Method, Request, Response, StatusCode};
3use std::collections::HashSet;
4
5use crate::{
6 HeaderField, HttpVersion, PackedRequest, PackedResponse, MAX_HEADERS, DecodeError, EncodeError,
7};
8
9const STREAM_MAGIC: [u8; 4] = *b"HPKS";
10const STREAM_VERSION: u8 = 1;
11const FRAME_HEADERS: u8 = 1;
12const FRAME_BODY: u8 = 2;
13const FRAME_END: u8 = 3;
14const END_FLAGS_NONE: u8 = 0;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum StreamKind {
18 Request,
19 Response,
20}
21
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct StreamRequestHeaders {
24 pub stream_id: u64,
25 pub version: HttpVersion,
26 pub method: Vec<u8>,
27 pub scheme: Option<Vec<u8>>,
28 pub authority: Option<Vec<u8>>,
29 pub path: Vec<u8>,
30 pub headers: Vec<HeaderField>,
31}
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct StreamResponseHeaders {
35 pub stream_id: u64,
36 pub version: HttpVersion,
37 pub status: u16,
38 pub headers: Vec<HeaderField>,
39}
40
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub enum StreamHeaders {
43 Request(StreamRequestHeaders),
44 Response(StreamResponseHeaders),
45}
46
47#[derive(Debug, Clone, PartialEq, Eq)]
48pub struct StreamBody {
49 pub stream_id: u64,
50 pub data: Bytes,
51}
52
53#[derive(Debug, Clone, PartialEq, Eq)]
54pub struct StreamEnd {
55 pub stream_id: u64,
56}
57
58#[derive(Debug, Clone, PartialEq, Eq)]
59pub enum StreamFrame {
60 Headers(StreamHeaders),
61 Body(StreamBody),
62 End(StreamEnd),
63}
64
65impl StreamFrame {
66 pub fn stream_id(&self) -> u64 {
67 match self {
68 StreamFrame::Headers(headers) => match headers {
69 StreamHeaders::Request(req) => req.stream_id,
70 StreamHeaders::Response(resp) => resp.stream_id,
71 },
72 StreamFrame::Body(body) => body.stream_id,
73 StreamFrame::End(end) => end.stream_id,
74 }
75 }
76}
77
78impl StreamHeaders {
79 pub fn from_request<B>(stream_id: u64, req: &Request<B>) -> Result<Self, EncodeError> {
80 let version = HttpVersion::from_http(req.version())?;
81 let method = req.method().as_str().as_bytes().to_vec();
82
83 let uri = req.uri();
84 let scheme = uri.scheme_str().map(|s| s.as_bytes().to_vec());
85 let authority = uri
86 .authority()
87 .map(|a| a.as_str().as_bytes().to_vec())
88 .or_else(|| req.headers().get("host").map(|v| v.as_bytes().to_vec()));
89 let path = uri
90 .path_and_query()
91 .map(|pq| pq.as_str())
92 .unwrap_or("/");
93 let path = if path.is_empty() { "/" } else { path };
94 let headers = collect_headers(req.headers());
95
96 Ok(StreamHeaders::Request(StreamRequestHeaders {
97 stream_id,
98 version,
99 method,
100 scheme,
101 authority,
102 path: path.as_bytes().to_vec(),
103 headers,
104 }))
105 }
106
107 pub fn from_response<B>(stream_id: u64, resp: &Response<B>) -> Result<Self, EncodeError> {
108 let version = HttpVersion::from_http(resp.version())?;
109 let status = resp.status().as_u16();
110 let headers = collect_headers(resp.headers());
111
112 Ok(StreamHeaders::Response(StreamResponseHeaders {
113 stream_id,
114 version,
115 status,
116 headers,
117 }))
118 }
119
120 pub fn from_packed_request(stream_id: u64, req: PackedRequest) -> Self {
121 StreamHeaders::Request(StreamRequestHeaders {
122 stream_id,
123 version: req.version,
124 method: req.method,
125 scheme: req.scheme,
126 authority: req.authority,
127 path: req.path,
128 headers: req.headers,
129 })
130 }
131
132 pub fn from_packed_response(stream_id: u64, resp: PackedResponse) -> Self {
133 StreamHeaders::Response(StreamResponseHeaders {
134 stream_id,
135 version: resp.version,
136 status: resp.status,
137 headers: resp.headers,
138 })
139 }
140}
141
142pub fn encode_frame(frame: &StreamFrame) -> Vec<u8> {
143 let mut buf = Vec::new();
144 buf.extend_from_slice(&STREAM_MAGIC);
145 buf.push(STREAM_VERSION);
146
147 match frame {
148 StreamFrame::Headers(headers) => {
149 buf.push(FRAME_HEADERS);
150 match headers {
151 StreamHeaders::Request(req) => {
152 buf.extend_from_slice(&req.stream_id.to_be_bytes());
153 buf.push(StreamKind::Request.to_byte());
154 buf.push(req.version.to_byte());
155 encode_request_fields(req, &mut buf);
156 }
157 StreamHeaders::Response(resp) => {
158 buf.extend_from_slice(&resp.stream_id.to_be_bytes());
159 buf.push(StreamKind::Response.to_byte());
160 buf.push(resp.version.to_byte());
161 encode_response_fields(resp, &mut buf);
162 }
163 }
164 }
165 StreamFrame::Body(body) => {
166 buf.push(FRAME_BODY);
167 buf.extend_from_slice(&body.stream_id.to_be_bytes());
168 crate::put_varint(&mut buf, body.data.len() as u64);
169 buf.extend_from_slice(&body.data);
170 }
171 StreamFrame::End(end) => {
172 buf.push(FRAME_END);
173 buf.extend_from_slice(&end.stream_id.to_be_bytes());
174 buf.push(END_FLAGS_NONE);
175 }
176 }
177
178 buf
179}
180
181#[derive(Debug)]
182pub enum StreamDecodeError {
183 InvalidMagic,
184 UnsupportedVersion(u8),
185 InvalidFrameType(u8),
186 TrailingBytes(usize),
187 UnsupportedHttpVersion(u8),
188 InvalidKind(u8),
189 InvalidVarint,
190 LengthOverflow,
191 TooManyHeaders(u64),
192 InvalidMethod,
193 InvalidPath,
194 InvalidHeaderName,
195 InvalidHeaderValue,
196 InvalidStatus,
197 InvalidEndFlags(u8),
198}
199
200impl std::fmt::Display for StreamDecodeError {
201 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202 match self {
203 StreamDecodeError::InvalidMagic => write!(f, "invalid magic"),
204 StreamDecodeError::UnsupportedVersion(version) => {
205 write!(f, "unsupported format version: {}", version)
206 }
207 StreamDecodeError::InvalidFrameType(frame) => {
208 write!(f, "invalid frame type: {}", frame)
209 }
210 StreamDecodeError::TrailingBytes(remaining) => {
211 write!(f, "trailing bytes: {}", remaining)
212 }
213 StreamDecodeError::UnsupportedHttpVersion(version) => {
214 write!(f, "unsupported http version: {}", version)
215 }
216 StreamDecodeError::InvalidKind(kind) => write!(f, "invalid message kind: {}", kind),
217 StreamDecodeError::InvalidVarint => write!(f, "invalid varint"),
218 StreamDecodeError::LengthOverflow => write!(f, "length overflow"),
219 StreamDecodeError::TooManyHeaders(count) => write!(f, "too many headers: {}", count),
220 StreamDecodeError::InvalidMethod => write!(f, "invalid method"),
221 StreamDecodeError::InvalidPath => write!(f, "invalid path"),
222 StreamDecodeError::InvalidHeaderName => write!(f, "invalid header name"),
223 StreamDecodeError::InvalidHeaderValue => write!(f, "invalid header value"),
224 StreamDecodeError::InvalidStatus => write!(f, "invalid status"),
225 StreamDecodeError::InvalidEndFlags(flags) => write!(f, "invalid end flags: {}", flags),
226 }
227 }
228}
229
230impl std::error::Error for StreamDecodeError {}
231
232pub fn decode_frame(bytes: &[u8]) -> Result<StreamFrame, StreamDecodeError> {
233 match decode_frame_from_prefix(bytes)? {
234 Some((frame, consumed)) => {
235 if consumed != bytes.len() {
236 return Err(StreamDecodeError::TrailingBytes(bytes.len() - consumed));
237 }
238 Ok(frame)
239 }
240 None => Err(StreamDecodeError::InvalidMagic),
241 }
242}
243
244pub fn decode_frame_from_prefix(
245 bytes: &[u8],
246) -> Result<Option<(StreamFrame, usize)>, StreamDecodeError> {
247 let mut offset = 0usize;
248
249 if bytes.len() < STREAM_MAGIC.len() {
250 return Ok(None);
251 }
252 if &bytes[..STREAM_MAGIC.len()] != STREAM_MAGIC {
253 return Err(StreamDecodeError::InvalidMagic);
254 }
255 offset += STREAM_MAGIC.len();
256
257 if bytes.len() < offset + 2 {
258 return Ok(None);
259 }
260 let version = bytes[offset];
261 offset += 1;
262 if version != STREAM_VERSION {
263 return Err(StreamDecodeError::UnsupportedVersion(version));
264 }
265
266 let frame_type = bytes[offset];
267 offset += 1;
268
269 if bytes.len() < offset + 8 {
270 return Ok(None);
271 }
272 let stream_id = u64::from_be_bytes([
273 bytes[offset],
274 bytes[offset + 1],
275 bytes[offset + 2],
276 bytes[offset + 3],
277 bytes[offset + 4],
278 bytes[offset + 5],
279 bytes[offset + 6],
280 bytes[offset + 7],
281 ]);
282 offset += 8;
283
284 match frame_type {
285 FRAME_HEADERS => {
286 if bytes.len() < offset + 2 {
287 return Ok(None);
288 }
289 let kind = StreamKind::from_byte(bytes[offset])?;
290 offset += 1;
291 let http_version = HttpVersion::from_byte(bytes[offset])
292 .map_err(|err| match err {
293 DecodeError::UnsupportedHttpVersion(v) => {
294 StreamDecodeError::UnsupportedHttpVersion(v)
295 }
296 _ => StreamDecodeError::UnsupportedHttpVersion(0),
297 })?;
298 offset += 1;
299
300 match kind {
301 StreamKind::Request => {
302 let method = match read_bytes(bytes, &mut offset)? {
303 Some(value) if !value.is_empty() => value,
304 Some(_) => return Err(StreamDecodeError::InvalidMethod),
305 None => return Ok(None),
306 };
307
308 let scheme = match read_bytes(bytes, &mut offset)? {
309 Some(value) if value.is_empty() => None,
310 Some(value) => Some(value),
311 None => return Ok(None),
312 };
313
314 let authority = match read_bytes(bytes, &mut offset)? {
315 Some(value) if value.is_empty() => None,
316 Some(value) => Some(value),
317 None => return Ok(None),
318 };
319
320 let path = match read_bytes(bytes, &mut offset)? {
321 Some(value) if value.is_empty() => b"/".to_vec(),
322 Some(value) => value,
323 None => return Ok(None),
324 };
325
326 validate_method(&method)?;
327 validate_path(&path)?;
328
329 let headers = read_headers(bytes, &mut offset)?;
330 let headers = match headers {
331 Some(value) => value,
332 None => return Ok(None),
333 };
334
335 Ok(Some((
336 StreamFrame::Headers(StreamHeaders::Request(StreamRequestHeaders {
337 stream_id,
338 version: http_version,
339 method,
340 scheme,
341 authority,
342 path,
343 headers,
344 })),
345 offset,
346 )))
347 }
348 StreamKind::Response => {
349 if bytes.len() < offset + 2 {
350 return Ok(None);
351 }
352 let status = u16::from_be_bytes([bytes[offset], bytes[offset + 1]]);
353 offset += 2;
354 if StatusCode::from_u16(status).is_err() {
355 return Err(StreamDecodeError::InvalidStatus);
356 }
357
358 let headers = read_headers(bytes, &mut offset)?;
359 let headers = match headers {
360 Some(value) => value,
361 None => return Ok(None),
362 };
363
364 Ok(Some((
365 StreamFrame::Headers(StreamHeaders::Response(StreamResponseHeaders {
366 stream_id,
367 version: http_version,
368 status,
369 headers,
370 })),
371 offset,
372 )))
373 }
374 }
375 }
376 FRAME_BODY => {
377 let body_len = match read_varint(bytes, &mut offset)? {
378 Some(value) => value,
379 None => return Ok(None),
380 };
381 let len = usize::try_from(body_len).map_err(|_| StreamDecodeError::LengthOverflow)?;
382 if bytes.len() < offset + len {
383 return Ok(None);
384 }
385 let data = Bytes::copy_from_slice(&bytes[offset..offset + len]);
387 offset += len;
388
389 Ok(Some((StreamFrame::Body(StreamBody { stream_id, data }), offset)))
390 }
391 FRAME_END => {
392 if bytes.len() < offset + 1 {
393 return Ok(None);
394 }
395 let flags = bytes[offset];
396 offset += 1;
397 if flags != END_FLAGS_NONE {
398 return Err(StreamDecodeError::InvalidEndFlags(flags));
399 }
400 Ok(Some((StreamFrame::End(StreamEnd { stream_id }), offset)))
401 }
402 other => Err(StreamDecodeError::InvalidFrameType(other)),
403 }
404}
405
406pub struct StreamDecoder {
407 buf: BytesMut,
408}
409
410impl StreamDecoder {
411 pub fn new() -> Self {
412 Self { buf: BytesMut::new() }
413 }
414
415 pub fn push(&mut self, data: &[u8]) {
416 self.buf.extend_from_slice(data);
417 }
418
419 pub fn try_decode(&mut self) -> Result<Option<StreamFrame>, StreamDecodeError> {
420 match decode_frame_from_prefix(&self.buf)? {
421 Some((frame, consumed)) => {
422 self.buf.advance(consumed);
423 Ok(Some(frame))
424 }
425 None => Ok(None),
426 }
427 }
428
429 pub fn buffer_len(&self) -> usize {
430 self.buf.len()
431 }
432}
433
434#[derive(Debug)]
435pub enum StreamRebuildError {
436 MissingHeaders(u64),
437 DuplicateHeaders(u64),
438 InvalidFrame,
439 InvalidMethod,
440 InvalidPath,
441 InvalidHeaderName,
442 InvalidHeaderValue,
443 InvalidStatus,
444}
445
446impl std::fmt::Display for StreamRebuildError {
447 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
448 match self {
449 StreamRebuildError::MissingHeaders(id) => write!(f, "missing headers for stream {}", id),
450 StreamRebuildError::DuplicateHeaders(id) => write!(f, "duplicate headers for stream {}", id),
451 StreamRebuildError::InvalidFrame => write!(f, "invalid frame order"),
452 StreamRebuildError::InvalidMethod => write!(f, "invalid method"),
453 StreamRebuildError::InvalidPath => write!(f, "invalid path"),
454 StreamRebuildError::InvalidHeaderName => write!(f, "invalid header name"),
455 StreamRebuildError::InvalidHeaderValue => write!(f, "invalid header value"),
456 StreamRebuildError::InvalidStatus => write!(f, "invalid status"),
457 }
458 }
459}
460
461impl std::error::Error for StreamRebuildError {}
462
463pub struct Http1StreamRebuilder {
464 streams: HashSet<u64>,
465}
466
467impl Http1StreamRebuilder {
468 pub fn new() -> Self {
469 Self { streams: HashSet::new() }
470 }
471
472 pub fn push_frame(&mut self, frame: StreamFrame) -> Result<Vec<Bytes>, StreamRebuildError> {
473 match frame {
474 StreamFrame::Headers(headers) => self.handle_headers(headers),
475 StreamFrame::Body(body) => self.handle_body(body),
476 StreamFrame::End(end) => self.handle_end(end),
477 }
478 }
479
480 fn handle_headers(&mut self, headers: StreamHeaders) -> Result<Vec<Bytes>, StreamRebuildError> {
481 let stream_id = match &headers {
482 StreamHeaders::Request(req) => req.stream_id,
483 StreamHeaders::Response(resp) => resp.stream_id,
484 };
485 if self.streams.contains(&stream_id) {
486 return Err(StreamRebuildError::DuplicateHeaders(stream_id));
487 }
488
489 let mut out = Vec::new();
490 let bytes = match headers {
491 StreamHeaders::Request(req) => {
492 self.streams.insert(stream_id);
493 build_http1_request_headers(&req)?
494 }
495 StreamHeaders::Response(resp) => {
496 self.streams.insert(stream_id);
497 build_http1_response_headers(&resp)?
498 }
499 };
500
501 out.push(Bytes::from(bytes));
502 Ok(out)
503 }
504
505 fn handle_body(&mut self, body: StreamBody) -> Result<Vec<Bytes>, StreamRebuildError> {
506 if !self.streams.contains(&body.stream_id) {
507 return Err(StreamRebuildError::MissingHeaders(body.stream_id));
508 }
509 if body.data.is_empty() {
510 return Ok(Vec::new());
511 }
512 let mut chunk = Vec::new();
513 write_chunk_size(body.data.len(), &mut chunk);
514 chunk.extend_from_slice(&body.data);
515 chunk.extend_from_slice(b"\r\n");
516 Ok(vec![Bytes::from(chunk)])
517 }
518
519 fn handle_end(&mut self, end: StreamEnd) -> Result<Vec<Bytes>, StreamRebuildError> {
520 if !self.streams.remove(&end.stream_id) {
521 return Err(StreamRebuildError::MissingHeaders(end.stream_id));
522 }
523 Ok(vec![Bytes::from_static(b"0\r\n\r\n")])
524 }
525}
526
527fn build_http1_request_headers(req: &StreamRequestHeaders) -> Result<Vec<u8>, StreamRebuildError> {
528 validate_method(&req.method).map_err(|_| StreamRebuildError::InvalidMethod)?;
529 validate_path(&req.path).map_err(|_| StreamRebuildError::InvalidPath)?;
530
531 let mut out = Vec::new();
532 out.extend_from_slice(&req.method);
533 out.extend_from_slice(b" ");
534 out.extend_from_slice(&req.path);
535 out.extend_from_slice(b" HTTP/1.1\r\n");
536
537 let mut has_host = false;
538 for header in &req.headers {
539 if crate::eq_ignore_ascii_case(&header.name, b"transfer-encoding") {
540 continue;
541 }
542 if crate::eq_ignore_ascii_case(&header.name, b"content-length") {
543 continue;
544 }
545 if crate::eq_ignore_ascii_case(&header.name, b"host") {
546 has_host = true;
547 }
548 validate_header_field(header).map_err(map_header_error)?;
549 out.extend_from_slice(&header.name);
550 out.extend_from_slice(b": ");
551 out.extend_from_slice(&header.value);
552 out.extend_from_slice(b"\r\n");
553 }
554
555 if !has_host {
556 if let Some(authority) = &req.authority {
557 if crate::has_crlf(authority) {
558 return Err(StreamRebuildError::InvalidHeaderValue);
559 }
560 out.extend_from_slice(b"host: ");
561 out.extend_from_slice(authority);
562 out.extend_from_slice(b"\r\n");
563 }
564 }
565
566 out.extend_from_slice(b"transfer-encoding: chunked\r\n\r\n");
567 Ok(out)
568}
569
570fn build_http1_response_headers(
571 resp: &StreamResponseHeaders,
572) -> Result<Vec<u8>, StreamRebuildError> {
573 let status = StatusCode::from_u16(resp.status).map_err(|_| StreamRebuildError::InvalidStatus)?;
574 let reason = status.canonical_reason().unwrap_or("");
575
576 let mut out = Vec::new();
577 out.extend_from_slice(b"HTTP/1.1 ");
578 out.extend_from_slice(status.as_str().as_bytes());
579 if !reason.is_empty() {
580 out.extend_from_slice(b" ");
581 out.extend_from_slice(reason.as_bytes());
582 }
583 out.extend_from_slice(b"\r\n");
584
585 for header in &resp.headers {
586 if crate::eq_ignore_ascii_case(&header.name, b"transfer-encoding") {
587 continue;
588 }
589 if crate::eq_ignore_ascii_case(&header.name, b"content-length") {
590 continue;
591 }
592 validate_header_field(header).map_err(map_header_error)?;
593 out.extend_from_slice(&header.name);
594 out.extend_from_slice(b": ");
595 out.extend_from_slice(&header.value);
596 out.extend_from_slice(b"\r\n");
597 }
598
599 out.extend_from_slice(b"transfer-encoding: chunked\r\n\r\n");
600 Ok(out)
601}
602
603fn map_header_error(err: DecodeError) -> StreamRebuildError {
604 match err {
605 DecodeError::InvalidHeaderName => StreamRebuildError::InvalidHeaderName,
606 DecodeError::InvalidHeaderValue => StreamRebuildError::InvalidHeaderValue,
607 _ => StreamRebuildError::InvalidFrame,
608 }
609}
610
611fn write_chunk_size(len: usize, out: &mut Vec<u8>) {
612 let mut buf = [0u8; 16];
613 let mut idx = buf.len();
614 let mut value = len;
615 if value == 0 {
616 out.extend_from_slice(b"0\r\n");
617 return;
618 }
619 while value > 0 {
620 let digit = (value & 0xF) as u8;
621 let ch = if digit < 10 { b'0' + digit } else { b'A' + (digit - 10) };
622 idx -= 1;
623 buf[idx] = ch;
624 value >>= 4;
625 }
626 out.extend_from_slice(&buf[idx..]);
627 out.extend_from_slice(b"\r\n");
628}
629
630fn encode_request_fields(req: &StreamRequestHeaders, buf: &mut Vec<u8>) {
631 crate::put_varint(buf, req.method.len() as u64);
632 buf.extend_from_slice(&req.method);
633
634 if let Some(scheme) = &req.scheme {
635 crate::put_varint(buf, scheme.len() as u64);
636 buf.extend_from_slice(scheme);
637 } else {
638 crate::put_varint(buf, 0);
639 }
640
641 if let Some(authority) = &req.authority {
642 crate::put_varint(buf, authority.len() as u64);
643 buf.extend_from_slice(authority);
644 } else {
645 crate::put_varint(buf, 0);
646 }
647
648 crate::put_varint(buf, req.path.len() as u64);
649 buf.extend_from_slice(&req.path);
650
651 crate::put_varint(buf, req.headers.len() as u64);
652 for header in &req.headers {
653 crate::put_varint(buf, header.name.len() as u64);
654 buf.extend_from_slice(&header.name);
655 crate::put_varint(buf, header.value.len() as u64);
656 buf.extend_from_slice(&header.value);
657 }
658}
659
660fn encode_response_fields(resp: &StreamResponseHeaders, buf: &mut Vec<u8>) {
661 buf.extend_from_slice(&resp.status.to_be_bytes());
662
663 crate::put_varint(buf, resp.headers.len() as u64);
664 for header in &resp.headers {
665 crate::put_varint(buf, header.name.len() as u64);
666 buf.extend_from_slice(&header.name);
667 crate::put_varint(buf, header.value.len() as u64);
668 buf.extend_from_slice(&header.value);
669 }
670}
671
672fn read_headers(
673 bytes: &[u8],
674 offset: &mut usize,
675) -> Result<Option<Vec<HeaderField>>, StreamDecodeError> {
676 let header_count = match read_varint(bytes, offset)? {
677 Some(value) => value,
678 None => return Ok(None),
679 };
680 if header_count > MAX_HEADERS {
681 return Err(StreamDecodeError::TooManyHeaders(header_count));
682 }
683
684 let mut headers = Vec::with_capacity(header_count as usize);
685 for _ in 0..header_count {
686 let name = match read_bytes(bytes, offset)? {
687 Some(value) => value,
688 None => return Ok(None),
689 };
690 let value = match read_bytes(bytes, offset)? {
691 Some(value) => value,
692 None => return Ok(None),
693 };
694 validate_header_name(&name)?;
695 validate_header_value(&value)?;
696 headers.push(HeaderField { name, value });
697 }
698
699 Ok(Some(headers))
700}
701
702fn read_varint(bytes: &[u8], offset: &mut usize) -> Result<Option<u64>, StreamDecodeError> {
703 let mut value: u64 = 0;
704 let mut shift = 0;
705
706 for _ in 0..10 {
707 if *offset >= bytes.len() {
708 return Ok(None);
709 }
710 let byte = bytes[*offset];
711 *offset += 1;
712 value |= ((byte & 0x7f) as u64) << shift;
713 if (byte & 0x80) == 0 {
714 return Ok(Some(value));
715 }
716 shift += 7;
717 }
718
719 Err(StreamDecodeError::InvalidVarint)
720}
721
722fn read_bytes(bytes: &[u8], offset: &mut usize) -> Result<Option<Vec<u8>>, StreamDecodeError> {
723 let len = match read_varint(bytes, offset)? {
724 Some(value) => value,
725 None => return Ok(None),
726 };
727 read_raw(bytes, offset, len)
728}
729
730fn read_raw(
731 bytes: &[u8],
732 offset: &mut usize,
733 len: u64,
734) -> Result<Option<Vec<u8>>, StreamDecodeError> {
735 let len = usize::try_from(len).map_err(|_| StreamDecodeError::LengthOverflow)?;
736 if bytes.len() < *offset + len {
737 return Ok(None);
738 }
739 let data = bytes[*offset..*offset + len].to_vec();
740 *offset += len;
741 Ok(Some(data))
742}
743
744fn validate_method(method: &[u8]) -> Result<(), StreamDecodeError> {
745 Method::from_bytes(method).map_err(|_| StreamDecodeError::InvalidMethod)?;
746 Ok(())
747}
748
749fn validate_path(path: &[u8]) -> Result<(), StreamDecodeError> {
750 if path.is_empty() || crate::has_crlf(path) {
751 return Err(StreamDecodeError::InvalidPath);
752 }
753 Ok(())
754}
755
756fn validate_header_name(name: &[u8]) -> Result<(), StreamDecodeError> {
757 HeaderName::from_bytes(name).map_err(|_| StreamDecodeError::InvalidHeaderName)?;
758 Ok(())
759}
760
761fn validate_header_value(value: &[u8]) -> Result<(), StreamDecodeError> {
762 HeaderValue::from_bytes(value).map_err(|_| StreamDecodeError::InvalidHeaderValue)?;
763 Ok(())
764}
765
766fn validate_header_field(field: &HeaderField) -> Result<(), DecodeError> {
767 crate::validate_header_name(&field.name)?;
768 crate::validate_header_value(&field.value)?;
769 Ok(())
770}
771
772fn collect_headers(headers: &http::HeaderMap) -> Vec<HeaderField> {
773 headers
774 .iter()
775 .map(|(name, value)| HeaderField {
776 name: name.as_str().as_bytes().to_vec(),
777 value: value.as_bytes().to_vec(),
778 })
779 .collect()
780}
781
782impl StreamKind {
783 fn to_byte(self) -> u8 {
784 match self {
785 StreamKind::Request => 1,
786 StreamKind::Response => 2,
787 }
788 }
789
790 fn from_byte(byte: u8) -> Result<Self, StreamDecodeError> {
791 match byte {
792 1 => Ok(StreamKind::Request),
793 2 => Ok(StreamKind::Response),
794 other => Err(StreamDecodeError::InvalidKind(other)),
795 }
796 }
797}
798
799#[cfg(feature = "body")]
800pub mod body {
801 use super::{StreamFrame, StreamHeaders, StreamBody, StreamEnd, StreamEncodeError};
802 use bytes::Buf;
803 use http::{Request, Response};
804 use http_body::Body;
805 use http_body_util::BodyExt;
806
807 pub async fn encode_request<B, F, E>(
808 req: Request<B>,
809 stream_id: u64,
810 mut emit: F,
811 ) -> Result<(), StreamEncodeError<E>>
812 where
813 B: Body + Unpin,
814 B::Data: Buf,
815 B::Error: std::error::Error + Send + Sync + 'static,
816 F: FnMut(StreamFrame) -> Result<(), E>,
817 {
818 let (parts, mut body) = req.into_parts();
819 let request = Request::from_parts(parts, ());
820 let headers = StreamHeaders::from_request(stream_id, &request)
821 .map_err(StreamEncodeError::Encode)?;
822 emit(StreamFrame::Headers(headers)).map_err(StreamEncodeError::Emit)?;
823
824 while let Some(frame) = body
825 .frame()
826 .await
827 .transpose()
828 .map_err(|err| StreamEncodeError::Body(Box::new(err)))?
829 {
830 if let Ok(mut data) = frame.into_data() {
831 if data.remaining() == 0 {
832 continue;
833 }
834 let bytes = data.copy_to_bytes(data.remaining());
835 emit(StreamFrame::Body(StreamBody {
836 stream_id,
837 data: bytes,
838 }))
839 .map_err(StreamEncodeError::Emit)?;
840 }
841 }
842
843 emit(StreamFrame::End(StreamEnd { stream_id })).map_err(StreamEncodeError::Emit)?;
844 Ok(())
845 }
846
847 pub async fn encode_response<B, F, E>(
848 resp: Response<B>,
849 stream_id: u64,
850 mut emit: F,
851 ) -> Result<(), StreamEncodeError<E>>
852 where
853 B: Body + Unpin,
854 B::Data: Buf,
855 B::Error: std::error::Error + Send + Sync + 'static,
856 F: FnMut(StreamFrame) -> Result<(), E>,
857 {
858 let (parts, mut body) = resp.into_parts();
859 let response = Response::from_parts(parts, ());
860 let headers = StreamHeaders::from_response(stream_id, &response)
861 .map_err(StreamEncodeError::Encode)?;
862 emit(StreamFrame::Headers(headers)).map_err(StreamEncodeError::Emit)?;
863
864 while let Some(frame) = body
865 .frame()
866 .await
867 .transpose()
868 .map_err(|err| StreamEncodeError::Body(Box::new(err)))?
869 {
870 if let Ok(mut data) = frame.into_data() {
871 if data.remaining() == 0 {
872 continue;
873 }
874 let bytes = data.copy_to_bytes(data.remaining());
875 emit(StreamFrame::Body(StreamBody {
876 stream_id,
877 data: bytes,
878 }))
879 .map_err(StreamEncodeError::Emit)?;
880 }
881 }
882
883 emit(StreamFrame::End(StreamEnd { stream_id })).map_err(StreamEncodeError::Emit)?;
884 Ok(())
885 }
886}
887
888#[cfg(feature = "h3")]
889pub mod h3 {
890 use super::{StreamFrame, StreamHeaders, StreamBody, StreamEnd, StreamEncodeError};
891 use bytes::Buf;
892 use h3::quic::RecvStream;
893
894 pub async fn encode_server_request<S, B, F, E>(
895 req: http::Request<()>,
896 stream_id: u64,
897 stream: &mut h3::server::RequestStream<S, B>,
898 mut emit: F,
899 ) -> Result<(), StreamEncodeError<E>>
900 where
901 S: RecvStream,
902 B: Buf,
903 F: FnMut(StreamFrame) -> Result<(), E>,
904 {
905 let headers = StreamHeaders::from_request(stream_id, &req)
906 .map_err(StreamEncodeError::Encode)?;
907 emit(StreamFrame::Headers(headers)).map_err(StreamEncodeError::Emit)?;
908
909 loop {
910 match stream.recv_data().await.map_err(StreamEncodeError::H3Stream)? {
911 Some(mut chunk) => {
912 let remaining = chunk.remaining();
913 if remaining == 0 {
914 continue;
915 }
916 let bytes = chunk.copy_to_bytes(remaining);
917 emit(StreamFrame::Body(StreamBody {
918 stream_id,
919 data: bytes,
920 }))
921 .map_err(StreamEncodeError::Emit)?;
922 }
923 None => break,
924 }
925 }
926
927 emit(StreamFrame::End(StreamEnd { stream_id })).map_err(StreamEncodeError::Emit)?;
928 Ok(())
929 }
930
931 pub async fn encode_client_response<S, B, F, E>(
932 resp: http::Response<()>,
933 stream_id: u64,
934 stream: &mut h3::client::RequestStream<S, B>,
935 mut emit: F,
936 ) -> Result<(), StreamEncodeError<E>>
937 where
938 S: RecvStream,
939 B: Buf,
940 F: FnMut(StreamFrame) -> Result<(), E>,
941 {
942 let headers = StreamHeaders::from_response(stream_id, &resp)
943 .map_err(StreamEncodeError::Encode)?;
944 emit(StreamFrame::Headers(headers)).map_err(StreamEncodeError::Emit)?;
945
946 loop {
947 match stream.recv_data().await.map_err(StreamEncodeError::H3Stream)? {
948 Some(mut chunk) => {
949 let remaining = chunk.remaining();
950 if remaining == 0 {
951 continue;
952 }
953 let bytes = chunk.copy_to_bytes(remaining);
954 emit(StreamFrame::Body(StreamBody {
955 stream_id,
956 data: bytes,
957 }))
958 .map_err(StreamEncodeError::Emit)?;
959 }
960 None => break,
961 }
962 }
963
964 emit(StreamFrame::End(StreamEnd { stream_id })).map_err(StreamEncodeError::Emit)?;
965 Ok(())
966 }
967}
968
969#[derive(Debug)]
970pub enum StreamEncodeError<E> {
971 Encode(EncodeError),
972 Body(Box<dyn std::error::Error + Send + Sync>),
973 #[cfg(feature = "h3")]
974 H3Stream(::h3::error::StreamError),
975 Emit(E),
976}
977
978impl<E: std::fmt::Display> std::fmt::Display for StreamEncodeError<E> {
979 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
980 match self {
981 StreamEncodeError::Encode(err) => write!(f, "encode error: {}", err),
982 StreamEncodeError::Body(err) => write!(f, "body error: {}", err),
983 #[cfg(feature = "h3")]
984 StreamEncodeError::H3Stream(err) => write!(f, "h3 stream error: {}", err),
985 StreamEncodeError::Emit(err) => write!(f, "emit error: {}", err),
986 }
987 }
988}
989
990impl<E: std::fmt::Debug + std::fmt::Display> std::error::Error for StreamEncodeError<E> {}
991
992#[cfg(test)]
993mod tests {
994 use super::*;
995
996 #[test]
997 fn frame_roundtrip_request_headers() {
998 let headers = StreamHeaders::Request(StreamRequestHeaders {
999 stream_id: 7,
1000 version: HttpVersion::Http11,
1001 method: b"GET".to_vec(),
1002 scheme: None,
1003 authority: Some(b"example.com".to_vec()),
1004 path: b"/".to_vec(),
1005 headers: vec![HeaderField {
1006 name: b"x-test".to_vec(),
1007 value: b"ok".to_vec(),
1008 }],
1009 });
1010 let frame = StreamFrame::Headers(headers);
1011 let encoded = encode_frame(&frame);
1012 let decoded = decode_frame(&encoded).unwrap();
1013 assert_eq!(frame, decoded);
1014 }
1015
1016 #[test]
1017 fn http1_rebuild_request() {
1018 let headers = StreamHeaders::Request(StreamRequestHeaders {
1019 stream_id: 1,
1020 version: HttpVersion::Http11,
1021 method: b"POST".to_vec(),
1022 scheme: None,
1023 authority: Some(b"example.com".to_vec()),
1024 path: b"/upload".to_vec(),
1025 headers: vec![HeaderField {
1026 name: b"x-test".to_vec(),
1027 value: b"ok".to_vec(),
1028 }],
1029 });
1030
1031 let mut rebuilder = Http1StreamRebuilder::new();
1032 let head = rebuilder
1033 .push_frame(StreamFrame::Headers(headers))
1034 .unwrap();
1035 let head_str = String::from_utf8(head[0].to_vec()).unwrap();
1036 assert!(head_str.starts_with("POST /upload HTTP/1.1\r\n"));
1037 assert!(head_str.contains("transfer-encoding: chunked\r\n"));
1038
1039 let body = rebuilder
1040 .push_frame(StreamFrame::Body(StreamBody {
1041 stream_id: 1,
1042 data: Bytes::from_static(b"hello"),
1043 }))
1044 .unwrap();
1045 assert_eq!(body[0].as_ref(), b"5\r\nhello\r\n");
1046
1047 let end = rebuilder
1048 .push_frame(StreamFrame::End(StreamEnd { stream_id: 1 }))
1049 .unwrap();
1050 assert_eq!(end[0].as_ref(), b"0\r\n\r\n");
1051 }
1052}