http_pack/
stream.rs

1use bytes::{Buf, Bytes, BytesMut};
2use http::{HeaderName, HeaderValue, Method, Request, Response, StatusCode};
3use std::collections::HashSet;
4
5use crate::{
6    HeaderField, HttpVersion, PackedRequest, PackedResponse, MAX_HEADERS, DecodeError, EncodeError,
7};
8
9const STREAM_MAGIC: [u8; 4] = *b"HPKS";
10const STREAM_VERSION: u8 = 1;
11const FRAME_HEADERS: u8 = 1;
12const FRAME_BODY: u8 = 2;
13const FRAME_END: u8 = 3;
14const END_FLAGS_NONE: u8 = 0;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum StreamKind {
18    Request,
19    Response,
20}
21
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct StreamRequestHeaders {
24    pub stream_id: u64,
25    pub version: HttpVersion,
26    pub method: Vec<u8>,
27    pub scheme: Option<Vec<u8>>,
28    pub authority: Option<Vec<u8>>,
29    pub path: Vec<u8>,
30    pub headers: Vec<HeaderField>,
31}
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct StreamResponseHeaders {
35    pub stream_id: u64,
36    pub version: HttpVersion,
37    pub status: u16,
38    pub headers: Vec<HeaderField>,
39}
40
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub enum StreamHeaders {
43    Request(StreamRequestHeaders),
44    Response(StreamResponseHeaders),
45}
46
47#[derive(Debug, Clone, PartialEq, Eq)]
48pub struct StreamBody {
49    pub stream_id: u64,
50    pub data: Bytes,
51}
52
53#[derive(Debug, Clone, PartialEq, Eq)]
54pub struct StreamEnd {
55    pub stream_id: u64,
56}
57
58#[derive(Debug, Clone, PartialEq, Eq)]
59pub enum StreamFrame {
60    Headers(StreamHeaders),
61    Body(StreamBody),
62    End(StreamEnd),
63}
64
65impl StreamFrame {
66    pub fn stream_id(&self) -> u64 {
67        match self {
68            StreamFrame::Headers(headers) => match headers {
69                StreamHeaders::Request(req) => req.stream_id,
70                StreamHeaders::Response(resp) => resp.stream_id,
71            },
72            StreamFrame::Body(body) => body.stream_id,
73            StreamFrame::End(end) => end.stream_id,
74        }
75    }
76}
77
78impl StreamHeaders {
79    pub fn from_request<B>(stream_id: u64, req: &Request<B>) -> Result<Self, EncodeError> {
80        let version = HttpVersion::from_http(req.version())?;
81        let method = req.method().as_str().as_bytes().to_vec();
82
83        let uri = req.uri();
84        let scheme = uri.scheme_str().map(|s| s.as_bytes().to_vec());
85        let authority = uri
86            .authority()
87            .map(|a| a.as_str().as_bytes().to_vec())
88            .or_else(|| req.headers().get("host").map(|v| v.as_bytes().to_vec()));
89        let path = uri
90            .path_and_query()
91            .map(|pq| pq.as_str())
92            .unwrap_or("/");
93        let path = if path.is_empty() { "/" } else { path };
94        let headers = collect_headers(req.headers());
95
96        Ok(StreamHeaders::Request(StreamRequestHeaders {
97            stream_id,
98            version,
99            method,
100            scheme,
101            authority,
102            path: path.as_bytes().to_vec(),
103            headers,
104        }))
105    }
106
107    pub fn from_response<B>(stream_id: u64, resp: &Response<B>) -> Result<Self, EncodeError> {
108        let version = HttpVersion::from_http(resp.version())?;
109        let status = resp.status().as_u16();
110        let headers = collect_headers(resp.headers());
111
112        Ok(StreamHeaders::Response(StreamResponseHeaders {
113            stream_id,
114            version,
115            status,
116            headers,
117        }))
118    }
119
120    pub fn from_packed_request(stream_id: u64, req: PackedRequest) -> Self {
121        StreamHeaders::Request(StreamRequestHeaders {
122            stream_id,
123            version: req.version,
124            method: req.method,
125            scheme: req.scheme,
126            authority: req.authority,
127            path: req.path,
128            headers: req.headers,
129        })
130    }
131
132    pub fn from_packed_response(stream_id: u64, resp: PackedResponse) -> Self {
133        StreamHeaders::Response(StreamResponseHeaders {
134            stream_id,
135            version: resp.version,
136            status: resp.status,
137            headers: resp.headers,
138        })
139    }
140}
141
142pub fn encode_frame(frame: &StreamFrame) -> Vec<u8> {
143    let mut buf = Vec::new();
144    buf.extend_from_slice(&STREAM_MAGIC);
145    buf.push(STREAM_VERSION);
146
147    match frame {
148        StreamFrame::Headers(headers) => {
149            buf.push(FRAME_HEADERS);
150            match headers {
151                StreamHeaders::Request(req) => {
152                    buf.extend_from_slice(&req.stream_id.to_be_bytes());
153                    buf.push(StreamKind::Request.to_byte());
154                    buf.push(req.version.to_byte());
155                    encode_request_fields(req, &mut buf);
156                }
157                StreamHeaders::Response(resp) => {
158                    buf.extend_from_slice(&resp.stream_id.to_be_bytes());
159                    buf.push(StreamKind::Response.to_byte());
160                    buf.push(resp.version.to_byte());
161                    encode_response_fields(resp, &mut buf);
162                }
163            }
164        }
165        StreamFrame::Body(body) => {
166            buf.push(FRAME_BODY);
167            buf.extend_from_slice(&body.stream_id.to_be_bytes());
168            crate::put_varint(&mut buf, body.data.len() as u64);
169            buf.extend_from_slice(&body.data);
170        }
171        StreamFrame::End(end) => {
172            buf.push(FRAME_END);
173            buf.extend_from_slice(&end.stream_id.to_be_bytes());
174            buf.push(END_FLAGS_NONE);
175        }
176    }
177
178    buf
179}
180
181#[derive(Debug)]
182pub enum StreamDecodeError {
183    InvalidMagic,
184    UnsupportedVersion(u8),
185    InvalidFrameType(u8),
186    TrailingBytes(usize),
187    UnsupportedHttpVersion(u8),
188    InvalidKind(u8),
189    InvalidVarint,
190    LengthOverflow,
191    TooManyHeaders(u64),
192    InvalidMethod,
193    InvalidPath,
194    InvalidHeaderName,
195    InvalidHeaderValue,
196    InvalidStatus,
197    InvalidEndFlags(u8),
198}
199
200impl std::fmt::Display for StreamDecodeError {
201    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202        match self {
203            StreamDecodeError::InvalidMagic => write!(f, "invalid magic"),
204            StreamDecodeError::UnsupportedVersion(version) => {
205                write!(f, "unsupported format version: {}", version)
206            }
207            StreamDecodeError::InvalidFrameType(frame) => {
208                write!(f, "invalid frame type: {}", frame)
209            }
210            StreamDecodeError::TrailingBytes(remaining) => {
211                write!(f, "trailing bytes: {}", remaining)
212            }
213            StreamDecodeError::UnsupportedHttpVersion(version) => {
214                write!(f, "unsupported http version: {}", version)
215            }
216            StreamDecodeError::InvalidKind(kind) => write!(f, "invalid message kind: {}", kind),
217            StreamDecodeError::InvalidVarint => write!(f, "invalid varint"),
218            StreamDecodeError::LengthOverflow => write!(f, "length overflow"),
219            StreamDecodeError::TooManyHeaders(count) => write!(f, "too many headers: {}", count),
220            StreamDecodeError::InvalidMethod => write!(f, "invalid method"),
221            StreamDecodeError::InvalidPath => write!(f, "invalid path"),
222            StreamDecodeError::InvalidHeaderName => write!(f, "invalid header name"),
223            StreamDecodeError::InvalidHeaderValue => write!(f, "invalid header value"),
224            StreamDecodeError::InvalidStatus => write!(f, "invalid status"),
225            StreamDecodeError::InvalidEndFlags(flags) => write!(f, "invalid end flags: {}", flags),
226        }
227    }
228}
229
230impl std::error::Error for StreamDecodeError {}
231
232pub fn decode_frame(bytes: &[u8]) -> Result<StreamFrame, StreamDecodeError> {
233    match decode_frame_from_prefix(bytes)? {
234        Some((frame, consumed)) => {
235            if consumed != bytes.len() {
236                return Err(StreamDecodeError::TrailingBytes(bytes.len() - consumed));
237            }
238            Ok(frame)
239        }
240        None => Err(StreamDecodeError::InvalidMagic),
241    }
242}
243
244pub fn decode_frame_from_prefix(
245    bytes: &[u8],
246) -> Result<Option<(StreamFrame, usize)>, StreamDecodeError> {
247    let mut offset = 0usize;
248
249    if bytes.len() < STREAM_MAGIC.len() {
250        return Ok(None);
251    }
252    if &bytes[..STREAM_MAGIC.len()] != STREAM_MAGIC {
253        return Err(StreamDecodeError::InvalidMagic);
254    }
255    offset += STREAM_MAGIC.len();
256
257    if bytes.len() < offset + 2 {
258        return Ok(None);
259    }
260    let version = bytes[offset];
261    offset += 1;
262    if version != STREAM_VERSION {
263        return Err(StreamDecodeError::UnsupportedVersion(version));
264    }
265
266    let frame_type = bytes[offset];
267    offset += 1;
268
269    if bytes.len() < offset + 8 {
270        return Ok(None);
271    }
272    let stream_id = u64::from_be_bytes([
273        bytes[offset],
274        bytes[offset + 1],
275        bytes[offset + 2],
276        bytes[offset + 3],
277        bytes[offset + 4],
278        bytes[offset + 5],
279        bytes[offset + 6],
280        bytes[offset + 7],
281    ]);
282    offset += 8;
283
284    match frame_type {
285        FRAME_HEADERS => {
286            if bytes.len() < offset + 2 {
287                return Ok(None);
288            }
289            let kind = StreamKind::from_byte(bytes[offset])?;
290            offset += 1;
291            let http_version = HttpVersion::from_byte(bytes[offset])
292                .map_err(|err| match err {
293                    DecodeError::UnsupportedHttpVersion(v) => {
294                        StreamDecodeError::UnsupportedHttpVersion(v)
295                    }
296                    _ => StreamDecodeError::UnsupportedHttpVersion(0),
297                })?;
298            offset += 1;
299
300            match kind {
301                StreamKind::Request => {
302                    let method = match read_bytes(bytes, &mut offset)? {
303                        Some(value) if !value.is_empty() => value,
304                        Some(_) => return Err(StreamDecodeError::InvalidMethod),
305                        None => return Ok(None),
306                    };
307
308                    let scheme = match read_bytes(bytes, &mut offset)? {
309                        Some(value) if value.is_empty() => None,
310                        Some(value) => Some(value),
311                        None => return Ok(None),
312                    };
313
314                    let authority = match read_bytes(bytes, &mut offset)? {
315                        Some(value) if value.is_empty() => None,
316                        Some(value) => Some(value),
317                        None => return Ok(None),
318                    };
319
320                    let path = match read_bytes(bytes, &mut offset)? {
321                        Some(value) if value.is_empty() => b"/".to_vec(),
322                        Some(value) => value,
323                        None => return Ok(None),
324                    };
325
326                    validate_method(&method)?;
327                    validate_path(&path)?;
328
329                    let headers = read_headers(bytes, &mut offset)?;
330                    let headers = match headers {
331                        Some(value) => value,
332                        None => return Ok(None),
333                    };
334
335                    Ok(Some((
336                        StreamFrame::Headers(StreamHeaders::Request(StreamRequestHeaders {
337                            stream_id,
338                            version: http_version,
339                            method,
340                            scheme,
341                            authority,
342                            path,
343                            headers,
344                        })),
345                        offset,
346                    )))
347                }
348                StreamKind::Response => {
349                    if bytes.len() < offset + 2 {
350                        return Ok(None);
351                    }
352                    let status = u16::from_be_bytes([bytes[offset], bytes[offset + 1]]);
353                    offset += 2;
354                    if StatusCode::from_u16(status).is_err() {
355                        return Err(StreamDecodeError::InvalidStatus);
356                    }
357
358                    let headers = read_headers(bytes, &mut offset)?;
359                    let headers = match headers {
360                        Some(value) => value,
361                        None => return Ok(None),
362                    };
363
364                    Ok(Some((
365                        StreamFrame::Headers(StreamHeaders::Response(StreamResponseHeaders {
366                            stream_id,
367                            version: http_version,
368                            status,
369                            headers,
370                        })),
371                        offset,
372                    )))
373                }
374            }
375        }
376        FRAME_BODY => {
377            let body_len = match read_varint(bytes, &mut offset)? {
378                Some(value) => value,
379                None => return Ok(None),
380            };
381            let len = usize::try_from(body_len).map_err(|_| StreamDecodeError::LengthOverflow)?;
382            if bytes.len() < offset + len {
383                return Ok(None);
384            }
385            // Use Bytes::copy_from_slice for body data to enable efficient handling
386            let data = Bytes::copy_from_slice(&bytes[offset..offset + len]);
387            offset += len;
388
389            Ok(Some((StreamFrame::Body(StreamBody { stream_id, data }), offset)))
390        }
391        FRAME_END => {
392            if bytes.len() < offset + 1 {
393                return Ok(None);
394            }
395            let flags = bytes[offset];
396            offset += 1;
397            if flags != END_FLAGS_NONE {
398                return Err(StreamDecodeError::InvalidEndFlags(flags));
399            }
400            Ok(Some((StreamFrame::End(StreamEnd { stream_id }), offset)))
401        }
402        other => Err(StreamDecodeError::InvalidFrameType(other)),
403    }
404}
405
406pub struct StreamDecoder {
407    buf: BytesMut,
408}
409
410impl StreamDecoder {
411    pub fn new() -> Self {
412        Self { buf: BytesMut::new() }
413    }
414
415    pub fn push(&mut self, data: &[u8]) {
416        self.buf.extend_from_slice(data);
417    }
418
419    pub fn try_decode(&mut self) -> Result<Option<StreamFrame>, StreamDecodeError> {
420        match decode_frame_from_prefix(&self.buf)? {
421            Some((frame, consumed)) => {
422                self.buf.advance(consumed);
423                Ok(Some(frame))
424            }
425            None => Ok(None),
426        }
427    }
428
429    pub fn buffer_len(&self) -> usize {
430        self.buf.len()
431    }
432}
433
434#[derive(Debug)]
435pub enum StreamRebuildError {
436    MissingHeaders(u64),
437    DuplicateHeaders(u64),
438    InvalidFrame,
439    InvalidMethod,
440    InvalidPath,
441    InvalidHeaderName,
442    InvalidHeaderValue,
443    InvalidStatus,
444}
445
446impl std::fmt::Display for StreamRebuildError {
447    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
448        match self {
449            StreamRebuildError::MissingHeaders(id) => write!(f, "missing headers for stream {}", id),
450            StreamRebuildError::DuplicateHeaders(id) => write!(f, "duplicate headers for stream {}", id),
451            StreamRebuildError::InvalidFrame => write!(f, "invalid frame order"),
452            StreamRebuildError::InvalidMethod => write!(f, "invalid method"),
453            StreamRebuildError::InvalidPath => write!(f, "invalid path"),
454            StreamRebuildError::InvalidHeaderName => write!(f, "invalid header name"),
455            StreamRebuildError::InvalidHeaderValue => write!(f, "invalid header value"),
456            StreamRebuildError::InvalidStatus => write!(f, "invalid status"),
457        }
458    }
459}
460
461impl std::error::Error for StreamRebuildError {}
462
463pub struct Http1StreamRebuilder {
464    streams: HashSet<u64>,
465}
466
467impl Http1StreamRebuilder {
468    pub fn new() -> Self {
469        Self { streams: HashSet::new() }
470    }
471
472    pub fn push_frame(&mut self, frame: StreamFrame) -> Result<Vec<Bytes>, StreamRebuildError> {
473        match frame {
474            StreamFrame::Headers(headers) => self.handle_headers(headers),
475            StreamFrame::Body(body) => self.handle_body(body),
476            StreamFrame::End(end) => self.handle_end(end),
477        }
478    }
479
480    fn handle_headers(&mut self, headers: StreamHeaders) -> Result<Vec<Bytes>, StreamRebuildError> {
481        let stream_id = match &headers {
482            StreamHeaders::Request(req) => req.stream_id,
483            StreamHeaders::Response(resp) => resp.stream_id,
484        };
485        if self.streams.contains(&stream_id) {
486            return Err(StreamRebuildError::DuplicateHeaders(stream_id));
487        }
488
489        let mut out = Vec::new();
490        let bytes = match headers {
491            StreamHeaders::Request(req) => {
492                self.streams.insert(stream_id);
493                build_http1_request_headers(&req)?
494            }
495            StreamHeaders::Response(resp) => {
496                self.streams.insert(stream_id);
497                build_http1_response_headers(&resp)?
498            }
499        };
500
501        out.push(Bytes::from(bytes));
502        Ok(out)
503    }
504
505    fn handle_body(&mut self, body: StreamBody) -> Result<Vec<Bytes>, StreamRebuildError> {
506        if !self.streams.contains(&body.stream_id) {
507            return Err(StreamRebuildError::MissingHeaders(body.stream_id));
508        }
509        if body.data.is_empty() {
510            return Ok(Vec::new());
511        }
512        let mut chunk = Vec::new();
513        write_chunk_size(body.data.len(), &mut chunk);
514        chunk.extend_from_slice(&body.data);
515        chunk.extend_from_slice(b"\r\n");
516        Ok(vec![Bytes::from(chunk)])
517    }
518
519    fn handle_end(&mut self, end: StreamEnd) -> Result<Vec<Bytes>, StreamRebuildError> {
520        if !self.streams.remove(&end.stream_id) {
521            return Err(StreamRebuildError::MissingHeaders(end.stream_id));
522        }
523        Ok(vec![Bytes::from_static(b"0\r\n\r\n")])
524    }
525}
526
527fn build_http1_request_headers(req: &StreamRequestHeaders) -> Result<Vec<u8>, StreamRebuildError> {
528    validate_method(&req.method).map_err(|_| StreamRebuildError::InvalidMethod)?;
529    validate_path(&req.path).map_err(|_| StreamRebuildError::InvalidPath)?;
530
531    let mut out = Vec::new();
532    out.extend_from_slice(&req.method);
533    out.extend_from_slice(b" ");
534    out.extend_from_slice(&req.path);
535    out.extend_from_slice(b" HTTP/1.1\r\n");
536
537    let mut has_host = false;
538    for header in &req.headers {
539        if crate::eq_ignore_ascii_case(&header.name, b"transfer-encoding") {
540            continue;
541        }
542        if crate::eq_ignore_ascii_case(&header.name, b"content-length") {
543            continue;
544        }
545        if crate::eq_ignore_ascii_case(&header.name, b"host") {
546            has_host = true;
547        }
548        validate_header_field(header).map_err(map_header_error)?;
549        out.extend_from_slice(&header.name);
550        out.extend_from_slice(b": ");
551        out.extend_from_slice(&header.value);
552        out.extend_from_slice(b"\r\n");
553    }
554
555    if !has_host {
556        if let Some(authority) = &req.authority {
557            if crate::has_crlf(authority) {
558                return Err(StreamRebuildError::InvalidHeaderValue);
559            }
560            out.extend_from_slice(b"host: ");
561            out.extend_from_slice(authority);
562            out.extend_from_slice(b"\r\n");
563        }
564    }
565
566    out.extend_from_slice(b"transfer-encoding: chunked\r\n\r\n");
567    Ok(out)
568}
569
570fn build_http1_response_headers(
571    resp: &StreamResponseHeaders,
572) -> Result<Vec<u8>, StreamRebuildError> {
573    let status = StatusCode::from_u16(resp.status).map_err(|_| StreamRebuildError::InvalidStatus)?;
574    let reason = status.canonical_reason().unwrap_or("");
575
576    let mut out = Vec::new();
577    out.extend_from_slice(b"HTTP/1.1 ");
578    out.extend_from_slice(status.as_str().as_bytes());
579    if !reason.is_empty() {
580        out.extend_from_slice(b" ");
581        out.extend_from_slice(reason.as_bytes());
582    }
583    out.extend_from_slice(b"\r\n");
584
585    for header in &resp.headers {
586        if crate::eq_ignore_ascii_case(&header.name, b"transfer-encoding") {
587            continue;
588        }
589        if crate::eq_ignore_ascii_case(&header.name, b"content-length") {
590            continue;
591        }
592        validate_header_field(header).map_err(map_header_error)?;
593        out.extend_from_slice(&header.name);
594        out.extend_from_slice(b": ");
595        out.extend_from_slice(&header.value);
596        out.extend_from_slice(b"\r\n");
597    }
598
599    out.extend_from_slice(b"transfer-encoding: chunked\r\n\r\n");
600    Ok(out)
601}
602
603fn map_header_error(err: DecodeError) -> StreamRebuildError {
604    match err {
605        DecodeError::InvalidHeaderName => StreamRebuildError::InvalidHeaderName,
606        DecodeError::InvalidHeaderValue => StreamRebuildError::InvalidHeaderValue,
607        _ => StreamRebuildError::InvalidFrame,
608    }
609}
610
611fn write_chunk_size(len: usize, out: &mut Vec<u8>) {
612    let mut buf = [0u8; 16];
613    let mut idx = buf.len();
614    let mut value = len;
615    if value == 0 {
616        out.extend_from_slice(b"0\r\n");
617        return;
618    }
619    while value > 0 {
620        let digit = (value & 0xF) as u8;
621        let ch = if digit < 10 { b'0' + digit } else { b'A' + (digit - 10) };
622        idx -= 1;
623        buf[idx] = ch;
624        value >>= 4;
625    }
626    out.extend_from_slice(&buf[idx..]);
627    out.extend_from_slice(b"\r\n");
628}
629
630fn encode_request_fields(req: &StreamRequestHeaders, buf: &mut Vec<u8>) {
631    crate::put_varint(buf, req.method.len() as u64);
632    buf.extend_from_slice(&req.method);
633
634    if let Some(scheme) = &req.scheme {
635        crate::put_varint(buf, scheme.len() as u64);
636        buf.extend_from_slice(scheme);
637    } else {
638        crate::put_varint(buf, 0);
639    }
640
641    if let Some(authority) = &req.authority {
642        crate::put_varint(buf, authority.len() as u64);
643        buf.extend_from_slice(authority);
644    } else {
645        crate::put_varint(buf, 0);
646    }
647
648    crate::put_varint(buf, req.path.len() as u64);
649    buf.extend_from_slice(&req.path);
650
651    crate::put_varint(buf, req.headers.len() as u64);
652    for header in &req.headers {
653        crate::put_varint(buf, header.name.len() as u64);
654        buf.extend_from_slice(&header.name);
655        crate::put_varint(buf, header.value.len() as u64);
656        buf.extend_from_slice(&header.value);
657    }
658}
659
660fn encode_response_fields(resp: &StreamResponseHeaders, buf: &mut Vec<u8>) {
661    buf.extend_from_slice(&resp.status.to_be_bytes());
662
663    crate::put_varint(buf, resp.headers.len() as u64);
664    for header in &resp.headers {
665        crate::put_varint(buf, header.name.len() as u64);
666        buf.extend_from_slice(&header.name);
667        crate::put_varint(buf, header.value.len() as u64);
668        buf.extend_from_slice(&header.value);
669    }
670}
671
672fn read_headers(
673    bytes: &[u8],
674    offset: &mut usize,
675) -> Result<Option<Vec<HeaderField>>, StreamDecodeError> {
676    let header_count = match read_varint(bytes, offset)? {
677        Some(value) => value,
678        None => return Ok(None),
679    };
680    if header_count > MAX_HEADERS {
681        return Err(StreamDecodeError::TooManyHeaders(header_count));
682    }
683
684    let mut headers = Vec::with_capacity(header_count as usize);
685    for _ in 0..header_count {
686        let name = match read_bytes(bytes, offset)? {
687            Some(value) => value,
688            None => return Ok(None),
689        };
690        let value = match read_bytes(bytes, offset)? {
691            Some(value) => value,
692            None => return Ok(None),
693        };
694        validate_header_name(&name)?;
695        validate_header_value(&value)?;
696        headers.push(HeaderField { name, value });
697    }
698
699    Ok(Some(headers))
700}
701
702fn read_varint(bytes: &[u8], offset: &mut usize) -> Result<Option<u64>, StreamDecodeError> {
703    let mut value: u64 = 0;
704    let mut shift = 0;
705
706    for _ in 0..10 {
707        if *offset >= bytes.len() {
708            return Ok(None);
709        }
710        let byte = bytes[*offset];
711        *offset += 1;
712        value |= ((byte & 0x7f) as u64) << shift;
713        if (byte & 0x80) == 0 {
714            return Ok(Some(value));
715        }
716        shift += 7;
717    }
718
719    Err(StreamDecodeError::InvalidVarint)
720}
721
722fn read_bytes(bytes: &[u8], offset: &mut usize) -> Result<Option<Vec<u8>>, StreamDecodeError> {
723    let len = match read_varint(bytes, offset)? {
724        Some(value) => value,
725        None => return Ok(None),
726    };
727    read_raw(bytes, offset, len)
728}
729
730fn read_raw(
731    bytes: &[u8],
732    offset: &mut usize,
733    len: u64,
734) -> Result<Option<Vec<u8>>, StreamDecodeError> {
735    let len = usize::try_from(len).map_err(|_| StreamDecodeError::LengthOverflow)?;
736    if bytes.len() < *offset + len {
737        return Ok(None);
738    }
739    let data = bytes[*offset..*offset + len].to_vec();
740    *offset += len;
741    Ok(Some(data))
742}
743
744fn validate_method(method: &[u8]) -> Result<(), StreamDecodeError> {
745    Method::from_bytes(method).map_err(|_| StreamDecodeError::InvalidMethod)?;
746    Ok(())
747}
748
749fn validate_path(path: &[u8]) -> Result<(), StreamDecodeError> {
750    if path.is_empty() || crate::has_crlf(path) {
751        return Err(StreamDecodeError::InvalidPath);
752    }
753    Ok(())
754}
755
756fn validate_header_name(name: &[u8]) -> Result<(), StreamDecodeError> {
757    HeaderName::from_bytes(name).map_err(|_| StreamDecodeError::InvalidHeaderName)?;
758    Ok(())
759}
760
761fn validate_header_value(value: &[u8]) -> Result<(), StreamDecodeError> {
762    HeaderValue::from_bytes(value).map_err(|_| StreamDecodeError::InvalidHeaderValue)?;
763    Ok(())
764}
765
766fn validate_header_field(field: &HeaderField) -> Result<(), DecodeError> {
767    crate::validate_header_name(&field.name)?;
768    crate::validate_header_value(&field.value)?;
769    Ok(())
770}
771
772fn collect_headers(headers: &http::HeaderMap) -> Vec<HeaderField> {
773    headers
774        .iter()
775        .map(|(name, value)| HeaderField {
776            name: name.as_str().as_bytes().to_vec(),
777            value: value.as_bytes().to_vec(),
778        })
779        .collect()
780}
781
782impl StreamKind {
783    fn to_byte(self) -> u8 {
784        match self {
785            StreamKind::Request => 1,
786            StreamKind::Response => 2,
787        }
788    }
789
790    fn from_byte(byte: u8) -> Result<Self, StreamDecodeError> {
791        match byte {
792            1 => Ok(StreamKind::Request),
793            2 => Ok(StreamKind::Response),
794            other => Err(StreamDecodeError::InvalidKind(other)),
795        }
796    }
797}
798
799#[cfg(feature = "body")]
800pub mod body {
801    use super::{StreamFrame, StreamHeaders, StreamBody, StreamEnd, StreamEncodeError};
802    use bytes::Buf;
803    use http::{Request, Response};
804    use http_body::Body;
805    use http_body_util::BodyExt;
806
807    pub async fn encode_request<B, F, E>(
808        req: Request<B>,
809        stream_id: u64,
810        mut emit: F,
811    ) -> Result<(), StreamEncodeError<E>>
812    where
813        B: Body + Unpin,
814        B::Data: Buf,
815        B::Error: std::error::Error + Send + Sync + 'static,
816        F: FnMut(StreamFrame) -> Result<(), E>,
817    {
818        let (parts, mut body) = req.into_parts();
819        let request = Request::from_parts(parts, ());
820        let headers = StreamHeaders::from_request(stream_id, &request)
821            .map_err(StreamEncodeError::Encode)?;
822        emit(StreamFrame::Headers(headers)).map_err(StreamEncodeError::Emit)?;
823
824        while let Some(frame) = body
825            .frame()
826            .await
827            .transpose()
828            .map_err(|err| StreamEncodeError::Body(Box::new(err)))?
829        {
830            if let Ok(mut data) = frame.into_data() {
831                if data.remaining() == 0 {
832                    continue;
833                }
834                let bytes = data.copy_to_bytes(data.remaining());
835                emit(StreamFrame::Body(StreamBody {
836                    stream_id,
837                    data: bytes,
838                }))
839                .map_err(StreamEncodeError::Emit)?;
840            }
841        }
842
843        emit(StreamFrame::End(StreamEnd { stream_id })).map_err(StreamEncodeError::Emit)?;
844        Ok(())
845    }
846
847    pub async fn encode_response<B, F, E>(
848        resp: Response<B>,
849        stream_id: u64,
850        mut emit: F,
851    ) -> Result<(), StreamEncodeError<E>>
852    where
853        B: Body + Unpin,
854        B::Data: Buf,
855        B::Error: std::error::Error + Send + Sync + 'static,
856        F: FnMut(StreamFrame) -> Result<(), E>,
857    {
858        let (parts, mut body) = resp.into_parts();
859        let response = Response::from_parts(parts, ());
860        let headers = StreamHeaders::from_response(stream_id, &response)
861            .map_err(StreamEncodeError::Encode)?;
862        emit(StreamFrame::Headers(headers)).map_err(StreamEncodeError::Emit)?;
863
864        while let Some(frame) = body
865            .frame()
866            .await
867            .transpose()
868            .map_err(|err| StreamEncodeError::Body(Box::new(err)))?
869        {
870            if let Ok(mut data) = frame.into_data() {
871                if data.remaining() == 0 {
872                    continue;
873                }
874                let bytes = data.copy_to_bytes(data.remaining());
875                emit(StreamFrame::Body(StreamBody {
876                    stream_id,
877                    data: bytes,
878                }))
879                .map_err(StreamEncodeError::Emit)?;
880            }
881        }
882
883        emit(StreamFrame::End(StreamEnd { stream_id })).map_err(StreamEncodeError::Emit)?;
884        Ok(())
885    }
886}
887
888#[cfg(feature = "h3")]
889pub mod h3 {
890    use super::{StreamFrame, StreamHeaders, StreamBody, StreamEnd, StreamEncodeError};
891    use bytes::Buf;
892    use h3::quic::RecvStream;
893
894    pub async fn encode_server_request<S, B, F, E>(
895        req: http::Request<()>,
896        stream_id: u64,
897        stream: &mut h3::server::RequestStream<S, B>,
898        mut emit: F,
899    ) -> Result<(), StreamEncodeError<E>>
900    where
901        S: RecvStream,
902        B: Buf,
903        F: FnMut(StreamFrame) -> Result<(), E>,
904    {
905        let headers = StreamHeaders::from_request(stream_id, &req)
906            .map_err(StreamEncodeError::Encode)?;
907        emit(StreamFrame::Headers(headers)).map_err(StreamEncodeError::Emit)?;
908
909        loop {
910            match stream.recv_data().await.map_err(StreamEncodeError::H3Stream)? {
911                Some(mut chunk) => {
912                    let remaining = chunk.remaining();
913                    if remaining == 0 {
914                        continue;
915                    }
916                    let bytes = chunk.copy_to_bytes(remaining);
917                    emit(StreamFrame::Body(StreamBody {
918                        stream_id,
919                        data: bytes,
920                    }))
921                    .map_err(StreamEncodeError::Emit)?;
922                }
923                None => break,
924            }
925        }
926
927        emit(StreamFrame::End(StreamEnd { stream_id })).map_err(StreamEncodeError::Emit)?;
928        Ok(())
929    }
930
931    pub async fn encode_client_response<S, B, F, E>(
932        resp: http::Response<()>,
933        stream_id: u64,
934        stream: &mut h3::client::RequestStream<S, B>,
935        mut emit: F,
936    ) -> Result<(), StreamEncodeError<E>>
937    where
938        S: RecvStream,
939        B: Buf,
940        F: FnMut(StreamFrame) -> Result<(), E>,
941    {
942        let headers = StreamHeaders::from_response(stream_id, &resp)
943            .map_err(StreamEncodeError::Encode)?;
944        emit(StreamFrame::Headers(headers)).map_err(StreamEncodeError::Emit)?;
945
946        loop {
947            match stream.recv_data().await.map_err(StreamEncodeError::H3Stream)? {
948                Some(mut chunk) => {
949                    let remaining = chunk.remaining();
950                    if remaining == 0 {
951                        continue;
952                    }
953                    let bytes = chunk.copy_to_bytes(remaining);
954                    emit(StreamFrame::Body(StreamBody {
955                        stream_id,
956                        data: bytes,
957                    }))
958                    .map_err(StreamEncodeError::Emit)?;
959                }
960                None => break,
961            }
962        }
963
964        emit(StreamFrame::End(StreamEnd { stream_id })).map_err(StreamEncodeError::Emit)?;
965        Ok(())
966    }
967}
968
969#[derive(Debug)]
970pub enum StreamEncodeError<E> {
971    Encode(EncodeError),
972    Body(Box<dyn std::error::Error + Send + Sync>),
973    #[cfg(feature = "h3")]
974    H3Stream(::h3::error::StreamError),
975    Emit(E),
976}
977
978impl<E: std::fmt::Display> std::fmt::Display for StreamEncodeError<E> {
979    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
980        match self {
981            StreamEncodeError::Encode(err) => write!(f, "encode error: {}", err),
982            StreamEncodeError::Body(err) => write!(f, "body error: {}", err),
983            #[cfg(feature = "h3")]
984            StreamEncodeError::H3Stream(err) => write!(f, "h3 stream error: {}", err),
985            StreamEncodeError::Emit(err) => write!(f, "emit error: {}", err),
986        }
987    }
988}
989
990impl<E: std::fmt::Debug + std::fmt::Display> std::error::Error for StreamEncodeError<E> {}
991
992#[cfg(test)]
993mod tests {
994    use super::*;
995
996    #[test]
997    fn frame_roundtrip_request_headers() {
998        let headers = StreamHeaders::Request(StreamRequestHeaders {
999            stream_id: 7,
1000            version: HttpVersion::Http11,
1001            method: b"GET".to_vec(),
1002            scheme: None,
1003            authority: Some(b"example.com".to_vec()),
1004            path: b"/".to_vec(),
1005            headers: vec![HeaderField {
1006                name: b"x-test".to_vec(),
1007                value: b"ok".to_vec(),
1008            }],
1009        });
1010        let frame = StreamFrame::Headers(headers);
1011        let encoded = encode_frame(&frame);
1012        let decoded = decode_frame(&encoded).unwrap();
1013        assert_eq!(frame, decoded);
1014    }
1015
1016    #[test]
1017    fn http1_rebuild_request() {
1018        let headers = StreamHeaders::Request(StreamRequestHeaders {
1019            stream_id: 1,
1020            version: HttpVersion::Http11,
1021            method: b"POST".to_vec(),
1022            scheme: None,
1023            authority: Some(b"example.com".to_vec()),
1024            path: b"/upload".to_vec(),
1025            headers: vec![HeaderField {
1026                name: b"x-test".to_vec(),
1027                value: b"ok".to_vec(),
1028            }],
1029        });
1030
1031        let mut rebuilder = Http1StreamRebuilder::new();
1032        let head = rebuilder
1033            .push_frame(StreamFrame::Headers(headers))
1034            .unwrap();
1035        let head_str = String::from_utf8(head[0].to_vec()).unwrap();
1036        assert!(head_str.starts_with("POST /upload HTTP/1.1\r\n"));
1037        assert!(head_str.contains("transfer-encoding: chunked\r\n"));
1038
1039        let body = rebuilder
1040            .push_frame(StreamFrame::Body(StreamBody {
1041                stream_id: 1,
1042                data: Bytes::from_static(b"hello"),
1043            }))
1044            .unwrap();
1045        assert_eq!(body[0].as_ref(), b"5\r\nhello\r\n");
1046
1047        let end = rebuilder
1048            .push_frame(StreamFrame::End(StreamEnd { stream_id: 1 }))
1049            .unwrap();
1050        assert_eq!(end[0].as_ref(), b"0\r\n\r\n");
1051    }
1052}